Staging
v0.5.1
https://github.com/python/cpython
Raw File
Tip revision: b0ef5c979b0c4e1df2286ac734160d7560fcce2c authored by Ned Deily on 28 February 2018, 00:49:18 UTC
Update NEWS, docs, and patchlevel for 3.7.0b2
Tip revision: b0ef5c9
test__xxsubinterpreters.py
import contextlib
import os
import pickle
from textwrap import dedent, indent
import threading
import time
import unittest

from test import support
from test.support import script_helper

interpreters = support.import_module('_xxsubinterpreters')


def _captured_script(script):
    r, w = os.pipe()
    indented = script.replace('\n', '\n                ')
    wrapped = dedent(f"""
        import contextlib
        with open({w}, 'w') as chan:
            with contextlib.redirect_stdout(chan):
                {indented}
        """)
    return wrapped, open(r)


def _run_output(interp, request, shared=None):
    script, chan = _captured_script(request)
    with chan:
        interpreters.run_string(interp, script, shared)
        return chan.read()


@contextlib.contextmanager
def _running(interp):
    r, w = os.pipe()
    def run():
        interpreters.run_string(interp, dedent(f"""
            # wait for "signal"
            with open({r}) as chan:
                chan.read()
            """))

    t = threading.Thread(target=run)
    t.start()

    yield

    with open(w, 'w') as chan:
        chan.write('done')
    t.join()


class IsShareableTests(unittest.TestCase):

    def test_default_shareables(self):
        shareables = [
                # singletons
                None,
                # builtin objects
                b'spam',
                ]
        for obj in shareables:
            with self.subTest(obj):
                self.assertTrue(
                    interpreters.is_shareable(obj))

    def test_not_shareable(self):
        class Cheese:
            def __init__(self, name):
                self.name = name
            def __str__(self):
                return self.name

        class SubBytes(bytes):
            """A subclass of a shareable type."""

        not_shareables = [
                # singletons
                True,
                False,
                NotImplemented,
                ...,
                # builtin types and objects
                type,
                object,
                object(),
                Exception(),
                42,
                100.0,
                'spam',
                # user-defined types and objects
                Cheese,
                Cheese('Wensleydale'),
                SubBytes(b'spam'),
                ]
        for obj in not_shareables:
            with self.subTest(obj):
                self.assertFalse(
                    interpreters.is_shareable(obj))


class TestBase(unittest.TestCase):

    def tearDown(self):
        for id in interpreters.list_all():
            if id == 0:  # main
                continue
            try:
                interpreters.destroy(id)
            except RuntimeError:
                pass  # already destroyed

        for cid in interpreters.channel_list_all():
            try:
                interpreters.channel_destroy(cid)
            except interpreters.ChannelNotFoundError:
                pass  # already destroyed


class ListAllTests(TestBase):

    def test_initial(self):
        main = interpreters.get_main()
        ids = interpreters.list_all()
        self.assertEqual(ids, [main])

    def test_after_creating(self):
        main = interpreters.get_main()
        first = interpreters.create()
        second = interpreters.create()
        ids = interpreters.list_all()
        self.assertEqual(ids, [main, first, second])

    def test_after_destroying(self):
        main = interpreters.get_main()
        first = interpreters.create()
        second = interpreters.create()
        interpreters.destroy(first)
        ids = interpreters.list_all()
        self.assertEqual(ids, [main, second])


class GetCurrentTests(TestBase):

    def test_main(self):
        main = interpreters.get_main()
        cur = interpreters.get_current()
        self.assertEqual(cur, main)

    def test_subinterpreter(self):
        main = interpreters.get_main()
        interp = interpreters.create()
        out = _run_output(interp, dedent("""
            import _xxsubinterpreters as _interpreters
            print(int(_interpreters.get_current()))
            """))
        cur = int(out.strip())
        _, expected = interpreters.list_all()
        self.assertEqual(cur, expected)
        self.assertNotEqual(cur, main)


class GetMainTests(TestBase):

    def test_from_main(self):
        [expected] = interpreters.list_all()
        main = interpreters.get_main()
        self.assertEqual(main, expected)

    def test_from_subinterpreter(self):
        [expected] = interpreters.list_all()
        interp = interpreters.create()
        out = _run_output(interp, dedent("""
            import _xxsubinterpreters as _interpreters
            print(int(_interpreters.get_main()))
            """))
        main = int(out.strip())
        self.assertEqual(main, expected)


class IsRunningTests(TestBase):

    def test_main(self):
        main = interpreters.get_main()
        self.assertTrue(interpreters.is_running(main))

    def test_subinterpreter(self):
        interp = interpreters.create()
        self.assertFalse(interpreters.is_running(interp))

        with _running(interp):
            self.assertTrue(interpreters.is_running(interp))
        self.assertFalse(interpreters.is_running(interp))

    def test_from_subinterpreter(self):
        interp = interpreters.create()
        out = _run_output(interp, dedent(f"""
            import _xxsubinterpreters as _interpreters
            if _interpreters.is_running({int(interp)}):
                print(True)
            else:
                print(False)
            """))
        self.assertEqual(out.strip(), 'True')

    def test_already_destroyed(self):
        interp = interpreters.create()
        interpreters.destroy(interp)
        with self.assertRaises(RuntimeError):
            interpreters.is_running(interp)

    def test_does_not_exist(self):
        with self.assertRaises(RuntimeError):
            interpreters.is_running(1_000_000)

    def test_bad_id(self):
        with self.assertRaises(RuntimeError):
            interpreters.is_running(-1)


class InterpreterIDTests(TestBase):

    def test_with_int(self):
        id = interpreters.InterpreterID(10, force=True)

        self.assertEqual(int(id), 10)

    def test_coerce_id(self):
        id = interpreters.InterpreterID('10', force=True)
        self.assertEqual(int(id), 10)

        id = interpreters.InterpreterID(10.0, force=True)
        self.assertEqual(int(id), 10)

        class Int(str):
            def __init__(self, value):
                self._value = value
            def __int__(self):
                return self._value

        id = interpreters.InterpreterID(Int(10), force=True)
        self.assertEqual(int(id), 10)

    def test_bad_id(self):
        for id in [-1, 'spam']:
            with self.subTest(id):
                with self.assertRaises(ValueError):
                    interpreters.InterpreterID(id)
        with self.assertRaises(OverflowError):
            interpreters.InterpreterID(2**64)
        with self.assertRaises(TypeError):
            interpreters.InterpreterID(object())

    def test_does_not_exist(self):
        id = interpreters.channel_create()
        with self.assertRaises(RuntimeError):
            interpreters.InterpreterID(int(id) + 1)  # unforced

    def test_repr(self):
        id = interpreters.InterpreterID(10, force=True)
        self.assertEqual(repr(id), 'InterpreterID(10)')

    def test_equality(self):
        id1 = interpreters.create()
        id2 = interpreters.InterpreterID(int(id1))
        id3 = interpreters.create()

        self.assertTrue(id1 == id1)
        self.assertTrue(id1 == id2)
        self.assertTrue(id1 == int(id1))
        self.assertFalse(id1 == id3)

        self.assertFalse(id1 != id1)
        self.assertFalse(id1 != id2)
        self.assertTrue(id1 != id3)


class CreateTests(TestBase):

    def test_in_main(self):
        id = interpreters.create()

        self.assertIn(id, interpreters.list_all())

    @unittest.skip('enable this test when working on pystate.c')
    def test_unique_id(self):
        seen = set()
        for _ in range(100):
            id = interpreters.create()
            interpreters.destroy(id)
            seen.add(id)

        self.assertEqual(len(seen), 100)

    def test_in_thread(self):
        lock = threading.Lock()
        id = None
        def f():
            nonlocal id
            id = interpreters.create()
            lock.acquire()
            lock.release()

        t = threading.Thread(target=f)
        with lock:
            t.start()
        t.join()
        self.assertIn(id, interpreters.list_all())

    def test_in_subinterpreter(self):
        main, = interpreters.list_all()
        id1 = interpreters.create()
        out = _run_output(id1, dedent("""
            import _xxsubinterpreters as _interpreters
            id = _interpreters.create()
            print(int(id))
            """))
        id2 = int(out.strip())

        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})

    def test_in_threaded_subinterpreter(self):
        main, = interpreters.list_all()
        id1 = interpreters.create()
        id2 = None
        def f():
            nonlocal id2
            out = _run_output(id1, dedent("""
                import _xxsubinterpreters as _interpreters
                id = _interpreters.create()
                print(int(id))
                """))
            id2 = int(out.strip())

        t = threading.Thread(target=f)
        t.start()
        t.join()

        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})

    def test_after_destroy_all(self):
        before = set(interpreters.list_all())
        # Create 3 subinterpreters.
        ids = []
        for _ in range(3):
            id = interpreters.create()
            ids.append(id)
        # Now destroy them.
        for id in ids:
            interpreters.destroy(id)
        # Finally, create another.
        id = interpreters.create()
        self.assertEqual(set(interpreters.list_all()), before | {id})

    def test_after_destroy_some(self):
        before = set(interpreters.list_all())
        # Create 3 subinterpreters.
        id1 = interpreters.create()
        id2 = interpreters.create()
        id3 = interpreters.create()
        # Now destroy 2 of them.
        interpreters.destroy(id1)
        interpreters.destroy(id3)
        # Finally, create another.
        id = interpreters.create()
        self.assertEqual(set(interpreters.list_all()), before | {id, id2})


class DestroyTests(TestBase):

    def test_one(self):
        id1 = interpreters.create()
        id2 = interpreters.create()
        id3 = interpreters.create()
        self.assertIn(id2, interpreters.list_all())
        interpreters.destroy(id2)
        self.assertNotIn(id2, interpreters.list_all())
        self.assertIn(id1, interpreters.list_all())
        self.assertIn(id3, interpreters.list_all())

    def test_all(self):
        before = set(interpreters.list_all())
        ids = set()
        for _ in range(3):
            id = interpreters.create()
            ids.add(id)
        self.assertEqual(set(interpreters.list_all()), before | ids)
        for id in ids:
            interpreters.destroy(id)
        self.assertEqual(set(interpreters.list_all()), before)

    def test_main(self):
        main, = interpreters.list_all()
        with self.assertRaises(RuntimeError):
            interpreters.destroy(main)

        def f():
            with self.assertRaises(RuntimeError):
                interpreters.destroy(main)

        t = threading.Thread(target=f)
        t.start()
        t.join()

    def test_already_destroyed(self):
        id = interpreters.create()
        interpreters.destroy(id)
        with self.assertRaises(RuntimeError):
            interpreters.destroy(id)

    def test_does_not_exist(self):
        with self.assertRaises(RuntimeError):
            interpreters.destroy(1_000_000)

    def test_bad_id(self):
        with self.assertRaises(RuntimeError):
            interpreters.destroy(-1)

    def test_from_current(self):
        main, = interpreters.list_all()
        id = interpreters.create()
        script = dedent(f"""
            import _xxsubinterpreters as _interpreters
            try:
                _interpreters.destroy({int(id)})
            except RuntimeError:
                pass
            """)

        interpreters.run_string(id, script)
        self.assertEqual(set(interpreters.list_all()), {main, id})

    def test_from_sibling(self):
        main, = interpreters.list_all()
        id1 = interpreters.create()
        id2 = interpreters.create()
        script = dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.destroy({int(id2)})
            """)
        interpreters.run_string(id1, script)

        self.assertEqual(set(interpreters.list_all()), {main, id1})

    def test_from_other_thread(self):
        id = interpreters.create()
        def f():
            interpreters.destroy(id)

        t = threading.Thread(target=f)
        t.start()
        t.join()

    def test_still_running(self):
        main, = interpreters.list_all()
        interp = interpreters.create()
        with _running(interp):
            with self.assertRaises(RuntimeError):
                interpreters.destroy(interp)
            self.assertTrue(interpreters.is_running(interp))


class RunStringTests(TestBase):

    SCRIPT = dedent("""
        with open('{}', 'w') as out:
            out.write('{}')
        """)
    FILENAME = 'spam'

    def setUp(self):
        super().setUp()
        self.id = interpreters.create()
        self._fs = None

    def tearDown(self):
        if self._fs is not None:
            self._fs.close()
        super().tearDown()

    @property
    def fs(self):
        if self._fs is None:
            self._fs = FSFixture(self)
        return self._fs

    def test_success(self):
        script, file = _captured_script('print("it worked!", end="")')
        with file:
            interpreters.run_string(self.id, script)
            out = file.read()

        self.assertEqual(out, 'it worked!')

    def test_in_thread(self):
        script, file = _captured_script('print("it worked!", end="")')
        with file:
            def f():
                interpreters.run_string(self.id, script)

            t = threading.Thread(target=f)
            t.start()
            t.join()
            out = file.read()

        self.assertEqual(out, 'it worked!')

    def test_create_thread(self):
        script, file = _captured_script("""
            import threading
            def f():
                print('it worked!', end='')

            t = threading.Thread(target=f)
            t.start()
            t.join()
            """)
        with file:
            interpreters.run_string(self.id, script)
            out = file.read()

        self.assertEqual(out, 'it worked!')

    @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
    def test_fork(self):
        import tempfile
        with tempfile.NamedTemporaryFile('w+') as file:
            file.write('')
            file.flush()

            expected = 'spam spam spam spam spam'
            script = dedent(f"""
                # (inspired by Lib/test/test_fork.py)
                import os
                pid = os.fork()
                if pid == 0:  # child
                    with open('{file.name}', 'w') as out:
                        out.write('{expected}')
                    # Kill the unittest runner in the child process.
                    os._exit(1)
                else:
                    SHORT_SLEEP = 0.1
                    import time
                    for _ in range(10):
                        spid, status = os.waitpid(pid, os.WNOHANG)
                        if spid == pid:
                            break
                        time.sleep(SHORT_SLEEP)
                    assert(spid == pid)
                """)
            interpreters.run_string(self.id, script)

            file.seek(0)
            content = file.read()
            self.assertEqual(content, expected)

    def test_already_running(self):
        with _running(self.id):
            with self.assertRaises(RuntimeError):
                interpreters.run_string(self.id, 'print("spam")')

    def test_does_not_exist(self):
        id = 0
        while id in interpreters.list_all():
            id += 1
        with self.assertRaises(RuntimeError):
            interpreters.run_string(id, 'print("spam")')

    def test_error_id(self):
        with self.assertRaises(RuntimeError):
            interpreters.run_string(-1, 'print("spam")')

    def test_bad_id(self):
        with self.assertRaises(TypeError):
            interpreters.run_string('spam', 'print("spam")')

    def test_bad_script(self):
        with self.assertRaises(TypeError):
            interpreters.run_string(self.id, 10)

    def test_bytes_for_script(self):
        with self.assertRaises(TypeError):
            interpreters.run_string(self.id, b'print("spam")')

    @contextlib.contextmanager
    def assert_run_failed(self, exctype, msg=None):
        with self.assertRaises(interpreters.RunFailedError) as caught:
            yield
        if msg is None:
            self.assertEqual(str(caught.exception).split(':')[0],
                             str(exctype))
        else:
            self.assertEqual(str(caught.exception),
                             "{}: {}".format(exctype, msg))

    def test_invalid_syntax(self):
        with self.assert_run_failed(SyntaxError):
            # missing close paren
            interpreters.run_string(self.id, 'print("spam"')

    def test_failure(self):
        with self.assert_run_failed(Exception, 'spam'):
            interpreters.run_string(self.id, 'raise Exception("spam")')

    def test_SystemExit(self):
        with self.assert_run_failed(SystemExit, '42'):
            interpreters.run_string(self.id, 'raise SystemExit(42)')

    def test_sys_exit(self):
        with self.assert_run_failed(SystemExit):
            interpreters.run_string(self.id, dedent("""
                import sys
                sys.exit()
                """))

        with self.assert_run_failed(SystemExit, '42'):
            interpreters.run_string(self.id, dedent("""
                import sys
                sys.exit(42)
                """))

    def test_with_shared(self):
        r, w = os.pipe()

        shared = {
                'spam': b'ham',
                'eggs': b'-1',
                'cheddar': None,
                }
        script = dedent(f"""
            eggs = int(eggs)
            spam = 42
            result = spam + eggs

            ns = dict(vars())
            del ns['__builtins__']
            import pickle
            with open({w}, 'wb') as chan:
                pickle.dump(ns, chan)
            """)
        interpreters.run_string(self.id, script, shared)
        with open(r, 'rb') as chan:
            ns = pickle.load(chan)

        self.assertEqual(ns['spam'], 42)
        self.assertEqual(ns['eggs'], -1)
        self.assertEqual(ns['result'], 41)
        self.assertIsNone(ns['cheddar'])

    def test_shared_overwrites(self):
        interpreters.run_string(self.id, dedent("""
            spam = 'eggs'
            ns1 = dict(vars())
            del ns1['__builtins__']
            """))

        shared = {'spam': b'ham'}
        script = dedent(f"""
            ns2 = dict(vars())
            del ns2['__builtins__']
        """)
        interpreters.run_string(self.id, script, shared)

        r, w = os.pipe()
        script = dedent(f"""
            ns = dict(vars())
            del ns['__builtins__']
            import pickle
            with open({w}, 'wb') as chan:
                pickle.dump(ns, chan)
            """)
        interpreters.run_string(self.id, script)
        with open(r, 'rb') as chan:
            ns = pickle.load(chan)

        self.assertEqual(ns['ns1']['spam'], 'eggs')
        self.assertEqual(ns['ns2']['spam'], b'ham')
        self.assertEqual(ns['spam'], b'ham')

    def test_shared_overwrites_default_vars(self):
        r, w = os.pipe()

        shared = {'__name__': b'not __main__'}
        script = dedent(f"""
            spam = 42

            ns = dict(vars())
            del ns['__builtins__']
            import pickle
            with open({w}, 'wb') as chan:
                pickle.dump(ns, chan)
            """)
        interpreters.run_string(self.id, script, shared)
        with open(r, 'rb') as chan:
            ns = pickle.load(chan)

        self.assertEqual(ns['__name__'], b'not __main__')

    def test_main_reused(self):
        r, w = os.pipe()
        interpreters.run_string(self.id, dedent(f"""
            spam = True

            ns = dict(vars())
            del ns['__builtins__']
            import pickle
            with open({w}, 'wb') as chan:
                pickle.dump(ns, chan)
            del ns, pickle, chan
            """))
        with open(r, 'rb') as chan:
            ns1 = pickle.load(chan)

        r, w = os.pipe()
        interpreters.run_string(self.id, dedent(f"""
            eggs = False

            ns = dict(vars())
            del ns['__builtins__']
            import pickle
            with open({w}, 'wb') as chan:
                pickle.dump(ns, chan)
            """))
        with open(r, 'rb') as chan:
            ns2 = pickle.load(chan)

        self.assertIn('spam', ns1)
        self.assertNotIn('eggs', ns1)
        self.assertIn('eggs', ns2)
        self.assertIn('spam', ns2)

    def test_execution_namespace_is_main(self):
        r, w = os.pipe()

        script = dedent(f"""
            spam = 42

            ns = dict(vars())
            ns['__builtins__'] = str(ns['__builtins__'])
            import pickle
            with open({w}, 'wb') as chan:
                pickle.dump(ns, chan)
            """)
        interpreters.run_string(self.id, script)
        with open(r, 'rb') as chan:
            ns = pickle.load(chan)

        ns.pop('__builtins__')
        ns.pop('__loader__')
        self.assertEqual(ns, {
            '__name__': '__main__',
            '__annotations__': {},
            '__doc__': None,
            '__package__': None,
            '__spec__': None,
            'spam': 42,
            })

    # XXX Fix this test!
    @unittest.skip('blocking forever')
    def test_still_running_at_exit(self):
        script = dedent(f"""
        from textwrap import dedent
        import threading
        import _xxsubinterpreters as _interpreters
        id = _interpreters.create()
        def f():
            _interpreters.run_string(id, dedent('''
                import time
                # Give plenty of time for the main interpreter to finish.
                time.sleep(1_000_000)
                '''))

        t = threading.Thread(target=f)
        t.start()
        """)
        with support.temp_dir() as dirname:
            filename = script_helper.make_script(dirname, 'interp', script)
            with script_helper.spawn_python(filename) as proc:
                retcode = proc.wait()

        self.assertEqual(retcode, 0)


class ChannelIDTests(TestBase):

    def test_default_kwargs(self):
        cid = interpreters._channel_id(10, force=True)

        self.assertEqual(int(cid), 10)
        self.assertEqual(cid.end, 'both')

    def test_with_kwargs(self):
        cid = interpreters._channel_id(10, send=True, force=True)
        self.assertEqual(cid.end, 'send')

        cid = interpreters._channel_id(10, send=True, recv=False, force=True)
        self.assertEqual(cid.end, 'send')

        cid = interpreters._channel_id(10, recv=True, force=True)
        self.assertEqual(cid.end, 'recv')

        cid = interpreters._channel_id(10, recv=True, send=False, force=True)
        self.assertEqual(cid.end, 'recv')

        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
        self.assertEqual(cid.end, 'both')

    def test_coerce_id(self):
        cid = interpreters._channel_id('10', force=True)
        self.assertEqual(int(cid), 10)

        cid = interpreters._channel_id(10.0, force=True)
        self.assertEqual(int(cid), 10)

        class Int(str):
            def __init__(self, value):
                self._value = value
            def __int__(self):
                return self._value

        cid = interpreters._channel_id(Int(10), force=True)
        self.assertEqual(int(cid), 10)

    def test_bad_id(self):
        for cid in [-1, 'spam']:
            with self.subTest(cid):
                with self.assertRaises(ValueError):
                    interpreters._channel_id(cid)
        with self.assertRaises(OverflowError):
            interpreters._channel_id(2**64)
        with self.assertRaises(TypeError):
            interpreters._channel_id(object())

    def test_bad_kwargs(self):
        with self.assertRaises(ValueError):
            interpreters._channel_id(10, send=False, recv=False)

    def test_does_not_exist(self):
        cid = interpreters.channel_create()
        with self.assertRaises(interpreters.ChannelNotFoundError):
            interpreters._channel_id(int(cid) + 1)  # unforced

    def test_repr(self):
        cid = interpreters._channel_id(10, force=True)
        self.assertEqual(repr(cid), 'ChannelID(10)')

        cid = interpreters._channel_id(10, send=True, force=True)
        self.assertEqual(repr(cid), 'ChannelID(10, send=True)')

        cid = interpreters._channel_id(10, recv=True, force=True)
        self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')

        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
        self.assertEqual(repr(cid), 'ChannelID(10)')

    def test_equality(self):
        cid1 = interpreters.channel_create()
        cid2 = interpreters._channel_id(int(cid1))
        cid3 = interpreters.channel_create()

        self.assertTrue(cid1 == cid1)
        self.assertTrue(cid1 == cid2)
        self.assertTrue(cid1 == int(cid1))
        self.assertFalse(cid1 == cid3)

        self.assertFalse(cid1 != cid1)
        self.assertFalse(cid1 != cid2)
        self.assertTrue(cid1 != cid3)


class ChannelTests(TestBase):

    def test_sequential_ids(self):
        before = interpreters.channel_list_all()
        id1 = interpreters.channel_create()
        id2 = interpreters.channel_create()
        id3 = interpreters.channel_create()
        after = interpreters.channel_list_all()

        self.assertEqual(id2, int(id1) + 1)
        self.assertEqual(id3, int(id2) + 1)
        self.assertEqual(set(after) - set(before), {id1, id2, id3})

    def test_ids_global(self):
        id1 = interpreters.create()
        out = _run_output(id1, dedent("""
            import _xxsubinterpreters as _interpreters
            cid = _interpreters.channel_create()
            print(int(cid))
            """))
        cid1 = int(out.strip())

        id2 = interpreters.create()
        out = _run_output(id2, dedent("""
            import _xxsubinterpreters as _interpreters
            cid = _interpreters.channel_create()
            print(int(cid))
            """))
        cid2 = int(out.strip())

        self.assertEqual(cid2, int(cid1) + 1)

    ####################

    def test_drop_single_user(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_drop_interpreter(cid, send=True, recv=True)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'eggs')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_drop_multiple_users(self):
        cid = interpreters.channel_create()
        id1 = interpreters.create()
        id2 = interpreters.create()
        interpreters.run_string(id1, dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.channel_send({int(cid)}, b'spam')
            """))
        out = _run_output(id2, dedent(f"""
            import _xxsubinterpreters as _interpreters
            obj = _interpreters.channel_recv({int(cid)})
            _interpreters.channel_drop_interpreter({int(cid)})
            print(repr(obj))
            """))
        interpreters.run_string(id1, dedent(f"""
            _interpreters.channel_drop_interpreter({int(cid)})
            """))

        self.assertEqual(out.strip(), "b'spam'")

    def test_drop_no_kwargs(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_drop_interpreter(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'eggs')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_drop_multiple_times(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_drop_interpreter(cid, send=True, recv=True)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_drop_interpreter(cid, send=True, recv=True)

    def test_drop_with_unused_items(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_send(cid, b'ham')
        interpreters.channel_drop_interpreter(cid, send=True, recv=True)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_drop_never_used(self):
        cid = interpreters.channel_create()
        interpreters.channel_drop_interpreter(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'spam')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_drop_by_unassociated_interp(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interp = interpreters.create()
        interpreters.run_string(interp, dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.channel_drop_interpreter({int(cid)})
            """))
        obj = interpreters.channel_recv(cid)
        interpreters.channel_drop_interpreter(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'eggs')
        self.assertEqual(obj, b'spam')

    def test_drop_close_if_unassociated(self):
        cid = interpreters.channel_create()
        interp = interpreters.create()
        interpreters.run_string(interp, dedent(f"""
            import _xxsubinterpreters as _interpreters
            obj = _interpreters.channel_send({int(cid)}, b'spam')
            _interpreters.channel_drop_interpreter({int(cid)})
            """))

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_drop_partially(self):
        # XXX Is partial close too wierd/confusing?
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, None)
        interpreters.channel_recv(cid)
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_drop_interpreter(cid, send=True)
        obj = interpreters.channel_recv(cid)

        self.assertEqual(obj, b'spam')

    def test_drop_used_multiple_times_by_single_user(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_drop_interpreter(cid, send=True, recv=True)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'eggs')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    ####################

    def test_close_single_user(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_close(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'eggs')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_close_multiple_users(self):
        cid = interpreters.channel_create()
        id1 = interpreters.create()
        id2 = interpreters.create()
        interpreters.run_string(id1, dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.channel_send({int(cid)}, b'spam')
            """))
        interpreters.run_string(id2, dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.channel_recv({int(cid)})
            """))
        interpreters.channel_close(cid)
        with self.assertRaises(interpreters.RunFailedError) as cm:
            interpreters.run_string(id1, dedent(f"""
                _interpreters.channel_send({int(cid)}, b'spam')
                """))
        self.assertIn('ChannelClosedError', str(cm.exception))
        with self.assertRaises(interpreters.RunFailedError) as cm:
            interpreters.run_string(id2, dedent(f"""
                _interpreters.channel_send({int(cid)}, b'spam')
                """))
        self.assertIn('ChannelClosedError', str(cm.exception))

    def test_close_multiple_times(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_close(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_close(cid)

    def test_close_with_unused_items(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_send(cid, b'ham')
        interpreters.channel_close(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_close_never_used(self):
        cid = interpreters.channel_create()
        interpreters.channel_close(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'spam')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    def test_close_by_unassociated_interp(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interp = interpreters.create()
        interpreters.run_string(interp, dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.channel_close({int(cid)})
            """))
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_close(cid)

    def test_close_used_multiple_times_by_single_user(self):
        cid = interpreters.channel_create()
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_send(cid, b'spam')
        interpreters.channel_recv(cid)
        interpreters.channel_close(cid)

        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_send(cid, b'eggs')
        with self.assertRaises(interpreters.ChannelClosedError):
            interpreters.channel_recv(cid)

    ####################

    def test_send_recv_main(self):
        cid = interpreters.channel_create()
        orig = b'spam'
        interpreters.channel_send(cid, orig)
        obj = interpreters.channel_recv(cid)

        self.assertEqual(obj, orig)
        self.assertIsNot(obj, orig)

    def test_send_recv_same_interpreter(self):
        id1 = interpreters.create()
        out = _run_output(id1, dedent("""
            import _xxsubinterpreters as _interpreters
            cid = _interpreters.channel_create()
            orig = b'spam'
            _interpreters.channel_send(cid, orig)
            obj = _interpreters.channel_recv(cid)
            assert obj is not orig
            assert obj == orig
            """))

    def test_send_recv_different_interpreters(self):
        cid = interpreters.channel_create()
        id1 = interpreters.create()
        out = _run_output(id1, dedent(f"""
            import _xxsubinterpreters as _interpreters
            _interpreters.channel_send({int(cid)}, b'spam')
            """))
        obj = interpreters.channel_recv(cid)

        self.assertEqual(obj, b'spam')

    def test_send_recv_different_threads(self):
        cid = interpreters.channel_create()

        def f():
            while True:
                try:
                    obj = interpreters.channel_recv(cid)
                    break
                except interpreters.ChannelEmptyError:
                    time.sleep(0.1)
            interpreters.channel_send(cid, obj)
        t = threading.Thread(target=f)
        t.start()

        interpreters.channel_send(cid, b'spam')
        t.join()
        obj = interpreters.channel_recv(cid)

        self.assertEqual(obj, b'spam')

    def test_send_recv_different_interpreters_and_threads(self):
        cid = interpreters.channel_create()
        id1 = interpreters.create()
        out = None

        def f():
            nonlocal out
            out = _run_output(id1, dedent(f"""
                import time
                import _xxsubinterpreters as _interpreters
                while True:
                    try:
                        obj = _interpreters.channel_recv({int(cid)})
                        break
                    except _interpreters.ChannelEmptyError:
                        time.sleep(0.1)
                assert(obj == b'spam')
                _interpreters.channel_send({int(cid)}, b'eggs')
                """))
        t = threading.Thread(target=f)
        t.start()

        interpreters.channel_send(cid, b'spam')
        t.join()
        obj = interpreters.channel_recv(cid)

        self.assertEqual(obj, b'eggs')

    def test_send_not_found(self):
        with self.assertRaises(interpreters.ChannelNotFoundError):
            interpreters.channel_send(10, b'spam')

    def test_recv_not_found(self):
        with self.assertRaises(interpreters.ChannelNotFoundError):
            interpreters.channel_recv(10)

    def test_recv_empty(self):
        cid = interpreters.channel_create()
        with self.assertRaises(interpreters.ChannelEmptyError):
            interpreters.channel_recv(cid)

    def test_run_string_arg(self):
        cid = interpreters.channel_create()
        interp = interpreters.create()

        out = _run_output(interp, dedent("""
            import _xxsubinterpreters as _interpreters
            print(cid.end)
            _interpreters.channel_send(cid, b'spam')
            """),
            dict(cid=cid.send))
        obj = interpreters.channel_recv(cid)

        self.assertEqual(obj, b'spam')
        self.assertEqual(out.strip(), 'send')


if __name__ == '__main__':
    unittest.main()
back to top