commit 1425e8e33c43642232a2bc142b9200452cacf1c9 Author: JimZhang Date: Wed Mar 6 14:12:56 2024 +0800 feat: init project diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e05ce42 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +.idea +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e4a8d16 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# wrpc diff --git a/pdm.lock b/pdm.lock new file mode 100644 index 0000000..966b850 --- /dev/null +++ b/pdm.lock @@ -0,0 +1,95 @@ +# This file is @generated by PDM. +# It is not intended for manual editing. + +[metadata] +groups = ["default"] +strategy = ["cross_platform"] +lock_version = "4.4" +content_hash = "sha256:0bb15134896c44e2b834128890e9e548258516fbc4ecc657d1355daccccfdff4" + +[[package]] +name = "colorama" +version = "0.4.6" +requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +summary = "Cross-platform colored terminal text." +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "loguru" +version = "0.7.2" +requires_python = ">=3.5" +summary = "Python logging made (stupidly) simple" +dependencies = [ + "colorama>=0.3.4; sys_platform == \"win32\"", + "win32-setctime>=1.0.0; sys_platform == \"win32\"", +] +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[[package]] +name = "msgpack" +version = "1.0.7" +requires_python = ">=3.8" +summary = "MessagePack serializer" +files = [ + {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04ad6069c86e531682f9e1e71b71c1c3937d6014a7c3e9edd2aa81ad58842862"}, + {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cca1b62fe70d761a282496b96a5e51c44c213e410a964bdffe0928e611368329"}, + {file = "msgpack-1.0.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e50ebce52f41370707f1e21a59514e3375e3edd6e1832f5e5235237db933c98b"}, + {file = "msgpack-1.0.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b4f35de6a304b5533c238bee86b670b75b03d31b7797929caa7a624b5dda6"}, + {file = "msgpack-1.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28efb066cde83c479dfe5a48141a53bc7e5f13f785b92ddde336c716663039ee"}, + {file = "msgpack-1.0.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4cb14ce54d9b857be9591ac364cb08dc2d6a5c4318c1182cb1d02274029d590d"}, + {file = "msgpack-1.0.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b573a43ef7c368ba4ea06050a957c2a7550f729c31f11dd616d2ac4aba99888d"}, + {file = "msgpack-1.0.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ccf9a39706b604d884d2cb1e27fe973bc55f2890c52f38df742bc1d79ab9f5e1"}, + {file = "msgpack-1.0.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cb70766519500281815dfd7a87d3a178acf7ce95390544b8c90587d76b227681"}, + {file = "msgpack-1.0.7-cp310-cp310-win32.whl", hash = "sha256:b610ff0f24e9f11c9ae653c67ff8cc03c075131401b3e5ef4b82570d1728f8a9"}, + {file = "msgpack-1.0.7-cp310-cp310-win_amd64.whl", hash = "sha256:a40821a89dc373d6427e2b44b572efc36a2778d3f543299e2f24eb1a5de65415"}, + {file = "msgpack-1.0.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:576eb384292b139821c41995523654ad82d1916da6a60cff129c715a6223ea84"}, + {file = "msgpack-1.0.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:730076207cb816138cf1af7f7237b208340a2c5e749707457d70705715c93b93"}, + {file = "msgpack-1.0.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:85765fdf4b27eb5086f05ac0491090fc76f4f2b28e09d9350c31aac25a5aaff8"}, + {file = "msgpack-1.0.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3476fae43db72bd11f29a5147ae2f3cb22e2f1a91d575ef130d2bf49afd21c46"}, + {file = "msgpack-1.0.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d4c80667de2e36970ebf74f42d1088cc9ee7ef5f4e8c35eee1b40eafd33ca5b"}, + {file = "msgpack-1.0.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b0bf0effb196ed76b7ad883848143427a73c355ae8e569fa538365064188b8e"}, + {file = "msgpack-1.0.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f9a7c509542db4eceed3dcf21ee5267ab565a83555c9b88a8109dcecc4709002"}, + {file = "msgpack-1.0.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:84b0daf226913133f899ea9b30618722d45feffa67e4fe867b0b5ae83a34060c"}, + {file = "msgpack-1.0.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ec79ff6159dffcc30853b2ad612ed572af86c92b5168aa3fc01a67b0fa40665e"}, + {file = "msgpack-1.0.7-cp311-cp311-win32.whl", hash = "sha256:3e7bf4442b310ff154b7bb9d81eb2c016b7d597e364f97d72b1acc3817a0fdc1"}, + {file = "msgpack-1.0.7-cp311-cp311-win_amd64.whl", hash = "sha256:3f0c8c6dfa6605ab8ff0611995ee30d4f9fcff89966cf562733b4008a3d60d82"}, + {file = "msgpack-1.0.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f0936e08e0003f66bfd97e74ee530427707297b0d0361247e9b4f59ab78ddc8b"}, + {file = "msgpack-1.0.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:98bbd754a422a0b123c66a4c341de0474cad4a5c10c164ceed6ea090f3563db4"}, + {file = "msgpack-1.0.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b291f0ee7961a597cbbcc77709374087fa2a9afe7bdb6a40dbbd9b127e79afee"}, + {file = "msgpack-1.0.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebbbba226f0a108a7366bf4b59bf0f30a12fd5e75100c630267d94d7f0ad20e5"}, + {file = "msgpack-1.0.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e2d69948e4132813b8d1131f29f9101bc2c915f26089a6d632001a5c1349672"}, + {file = "msgpack-1.0.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bdf38ba2d393c7911ae989c3bbba510ebbcdf4ecbdbfec36272abe350c454075"}, + {file = "msgpack-1.0.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:993584fc821c58d5993521bfdcd31a4adf025c7d745bbd4d12ccfecf695af5ba"}, + {file = "msgpack-1.0.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:52700dc63a4676669b341ba33520f4d6e43d3ca58d422e22ba66d1736b0a6e4c"}, + {file = "msgpack-1.0.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e45ae4927759289c30ccba8d9fdce62bb414977ba158286b5ddaf8df2cddb5c5"}, + {file = "msgpack-1.0.7-cp312-cp312-win32.whl", hash = "sha256:27dcd6f46a21c18fa5e5deed92a43d4554e3df8d8ca5a47bf0615d6a5f39dbc9"}, + {file = "msgpack-1.0.7-cp312-cp312-win_amd64.whl", hash = "sha256:7687e22a31e976a0e7fc99c2f4d11ca45eff652a81eb8c8085e9609298916dcf"}, + {file = "msgpack-1.0.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfef2bb6ef068827bbd021017a107194956918ab43ce4d6dc945ffa13efbc25f"}, + {file = "msgpack-1.0.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:484ae3240666ad34cfa31eea7b8c6cd2f1fdaae21d73ce2974211df099a95d81"}, + {file = "msgpack-1.0.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3967e4ad1aa9da62fd53e346ed17d7b2e922cba5ab93bdd46febcac39be636fc"}, + {file = "msgpack-1.0.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dd178c4c80706546702c59529ffc005681bd6dc2ea234c450661b205445a34d"}, + {file = "msgpack-1.0.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ffbc252eb0d229aeb2f9ad051200668fc3a9aaa8994e49f0cb2ffe2b7867e7"}, + {file = "msgpack-1.0.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:822ea70dc4018c7e6223f13affd1c5c30c0f5c12ac1f96cd8e9949acddb48a61"}, + {file = "msgpack-1.0.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:384d779f0d6f1b110eae74cb0659d9aa6ff35aaf547b3955abf2ab4c901c4819"}, + {file = "msgpack-1.0.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f64e376cd20d3f030190e8c32e1c64582eba56ac6dc7d5b0b49a9d44021b52fd"}, + {file = "msgpack-1.0.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5ed82f5a7af3697b1c4786053736f24a0efd0a1b8a130d4c7bfee4b9ded0f08f"}, + {file = "msgpack-1.0.7-cp39-cp39-win32.whl", hash = "sha256:f26a07a6e877c76a88e3cecac8531908d980d3d5067ff69213653649ec0f60ad"}, + {file = "msgpack-1.0.7-cp39-cp39-win_amd64.whl", hash = "sha256:1dc93e8e4653bdb5910aed79f11e165c85732067614f180f70534f056da97db3"}, + {file = "msgpack-1.0.7.tar.gz", hash = "sha256:572efc93db7a4d27e404501975ca6d2d9775705c2d922390d878fcf768d92c87"}, +] + +[[package]] +name = "win32-setctime" +version = "1.1.0" +requires_python = ">=3.5" +summary = "A small Python utility to set file creation time on Windows" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..90a7c85 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "wrpc" +version = "0.1.0" +description = "winhc rpc tools" +authors = [ + {name = "JimZhang", email = "zzl22100048@gmail.com"}, +] +dependencies = [ + "msgpack>=1.0.7", + "loguru>=0.7.2", +] +requires-python = ">=3.9" +readme = "README.md" +license = {text = "MIT"} + +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" diff --git a/src/wrpc/__init__.py b/src/wrpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wrpc/asyncio/__init__.py b/src/wrpc/asyncio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wrpc/asyncio/client.py b/src/wrpc/asyncio/client.py new file mode 100644 index 0000000..8529ada --- /dev/null +++ b/src/wrpc/asyncio/client.py @@ -0,0 +1,247 @@ +import asyncio +from loguru import logger +from ..base import GREETINGS, PROTOCOL_VERSION, OP_NOP, OP_PUBLISH, OP_SUBSCRIBE, OP_UNSUBSCRIBE, OP_MESSAGE, \ + OP_ACK, RESPONSE_OK, PING_FRAME + + +async def on_frame_default(frame): + pass + + +class Client: + + def __init__(self, path, name): + self.path = path + self.writer = None + self.reader_fut = None + self.pinger_fut = None + self.buf_size = 8192 + self.name = name + self.frame_id = 0 + self.ping_interval = 1 + self.on_frame = on_frame_default + self.socket_lock = asyncio.Lock() + self.mgmt_lock = asyncio.Lock() + self.connected = False + self.frames = {} + self.timeout = 1 + + async def connect(self): + async with self.mgmt_lock: + if self.path.endswith('.sock') or self.path.endswith( + '.socket') or self.path.endswith( + '.ipc') or self.path.startswith('/'): + reader, writer = await asyncio.open_unix_connection( + self.path, limit=self.buf_size) + else: + host, port = self.path.rsplit(':', maxsplit=2) + reader, writer = await asyncio.open_connection( + host, int(port), limit=self.buf_size) + buf = await asyncio.wait_for(self.readexactly(reader, 3), + timeout=self.timeout) + if buf[0] != GREETINGS: + raise RuntimeError('Unsupported protocol') + if int.from_bytes(buf[1:3], 'little') != PROTOCOL_VERSION: + raise RuntimeError('Unsupported protocol version') + writer.write(buf) + await asyncio.wait_for(writer.drain(), timeout=self.timeout) + buf = await asyncio.wait_for(self.readexactly(reader, 1), + timeout=self.timeout) + if buf[0] != RESPONSE_OK: + raise RuntimeError(f'Server response: {hex(buf[0])}') + name = self.name.encode() + writer.write(len(name).to_bytes(2, 'little') + name) + await asyncio.wait_for(writer.drain(), timeout=self.timeout) + buf = await asyncio.wait_for(self.readexactly(reader, 1), + timeout=self.timeout) + if buf[0] != RESPONSE_OK: + raise RuntimeError(f'Server response: {hex(buf[0])}') + self.writer = writer + self.connected = True + self.pinger_fut = asyncio.ensure_future(self._t_pinger()) + self.reader_fut = asyncio.ensure_future(self._t_reader(reader)) + + async def handle_daemon_exception(self, e): + async with self.mgmt_lock: + if self.connected: + await self._disconnect() + import traceback + logger.error(traceback.format_exc()) + + async def _t_pinger(self): + try: + while True: + await asyncio.sleep(self.ping_interval) + async with self.socket_lock: + self.writer.write(PING_FRAME) + await asyncio.wait_for(self.writer.drain(), + timeout=self.timeout) + except Exception as e: + asyncio.ensure_future(self.handle_daemon_exception(e)) + + async def _t_reader(self, reader): + try: + while True: + buf = await self.readexactly(reader, 6) + if buf[0] == OP_NOP: + continue + elif buf[0] == OP_ACK: + op_id = int.from_bytes(buf[1:5], 'little') + try: + o = self.frames.pop(op_id) + o.result = buf[5] + o.completed.set() + except KeyError: + logger.warning(f'orphaned BUS/RT frame ack {op_id}') + else: + + async def read_frame(tp, data_len): + frame = Frame() + frame.type = tp + sender = await reader.readuntil(b'\x00') + data_len -= len(sender) + frame.sender = sender[:-1].decode() + frame.primary_sender = frame.sender.split('%%', 1)[0] + if buf[0] == OP_PUBLISH: + topic = await reader.readuntil(b'\x00') + data_len -= len(topic) + frame.topic = topic[:-1].decode() + else: + frame.topic = None + data = b'' + while len(data) < data_len: + buf_size = data_len - len(data) + prev_len = len(data) + try: + data += await reader.readexactly( + buf_size if buf_size < self.buf_size else + self.buf_size) + except asyncio.IncompleteReadError: + pass + frame.payload = data + return frame + + try: + data_len = int.from_bytes(buf[1:5], 'little') + frame = await read_frame(buf[0], data_len) + except Exception as e: + logger.error(f'Invalid frame from the server: {e}') + raise + asyncio.ensure_future(self.on_frame(frame)) + except Exception as e: + asyncio.ensure_future(self.handle_daemon_exception(e)) + + async def readexactly(self, reader, data_len): + data = b'' + while len(data) < data_len: + buf_size = data_len - len(data) + try: + chunk = await reader.readexactly( + buf_size if buf_size < self.buf_size else self.buf_size) + data += chunk + except asyncio.IncompleteReadError: + await asyncio.sleep(0.01) + return data + + async def disconnect(self): + async with self.mgmt_lock: + await self._disconnect() + + async def _disconnect(self): + self.connected = False + self.writer.close() + if self.reader_fut is not None: + self.reader_fut.cancel() + if self.pinger_fut is not None: + self.pinger_fut.cancel() + + async def send(self, target=None, frame=None): + try: + async with self.socket_lock: + self.frame_id += 1 + if self.frame_id > 0xffff_ffff: + self.frame_id = 1 + frame_id = self.frame_id + o = ClientFrame(frame.qos) + if frame.qos & 0b1 != 0: + self.frames[frame_id] = o + flags = frame.type | frame.qos << 6 + if frame.type == OP_SUBSCRIBE or frame.type == OP_UNSUBSCRIBE: + topics = frame.topic if isinstance(frame.topic, + list) else [frame.topic] + payload = b'\x00'.join(t.encode() for t in topics) + self.writer.write( + frame_id.to_bytes(4, 'little') + + flags.to_bytes(1, 'little') + + len(payload).to_bytes(4, 'little') + payload) + else: + frame_len = len(target) + len(frame.payload) + 1 + if frame.header is not None: + frame_len += len(frame.header) + if frame_len > 0xffff_ffff: + raise ValueError('frame too large') + self.writer.write( + frame_id.to_bytes(4, 'little') + + flags.to_bytes(1, 'little') + + frame_len.to_bytes(4, 'little') + target.encode() + + b'\x00' + + (frame.header if frame.header is not None else b'')) + self.writer.write(frame.payload.encode( + ) if isinstance(frame.payload, str) else frame.payload) + await self.writer.drain() + return o + except: + try: + del self.frames[frame_id] + except KeyError: + pass + raise + + def subscribe(self, topics): + frame = Frame(tp=OP_SUBSCRIBE) + frame.topic = topics + return self.send(None, frame) + + def unsubscribe(self, topics): + frame = Frame(tp=OP_UNSUBSCRIBE) + frame.topic = topics + return self.send(None, frame) + + def is_connected(self): + return self.connected + + +class ClientFrame: + + def __init__(self, qos): + self.qos = qos + self.result = 0 + if qos & 0b1 != 0: + self.completed = asyncio.Event() + else: + self.completed = None + + def is_completed(self): + if self.qos & 0b1 != 0: + return self.completed.is_set() + else: + return True + + async def wait_completed(self, timeout=None): + if self.qos & 0b1 == 0: + return RESPONSE_OK + elif timeout: + await asyncio.wait_for(self.completed.wait(), timeout=timeout) + else: + await self.completed.wait() + return self.result + + +class Frame: + + def __init__(self, payload=None, tp=OP_MESSAGE, qos=0): + self.payload = payload + # used for zero-copy + self.header = None + self.type = tp + self.qos = qos diff --git a/src/wrpc/asyncio/rpc.py b/src/wrpc/asyncio/rpc.py new file mode 100644 index 0000000..06d3ad8 --- /dev/null +++ b/src/wrpc/asyncio/rpc.py @@ -0,0 +1,237 @@ +import asyncio +import functools + +from loguru import logger + +from .client import OP_MESSAGE +from ..base import ERR_IO, ERR_TIMEOUT, RESPONSE_OK, RPC_NOTIFICATION_HEADER, RPC_REQUEST_HEADER, RPC_REPLY_HEADER, \ + RPC_ERROR_REPLY_HEADER, RPC_NOTIFICATION, RPC_REQUEST, RPC_REPLY, RPC_ERROR, RPC_ERROR_CODE_METHOD_NOT_FOUND, \ + RPC_ERROR_CODE_INTERNAL, serialize, deserialize,RpcException + +from .client import on_frame_default + + +async def on_call_default(event): + raise RpcException('RPC Engine not initialized', + RPC_ERROR_CODE_METHOD_NOT_FOUND) + + +async def on_notification_default(event): + pass + + +def format_rpc_e_msg(e): + if isinstance(e, RpcException): + return e.rpc_error_payload + else: + return str(e) + + + +class Rpc: + + def __init__(self, client): + self.client = client + self.client.on_frame = self._handle_frame + self.call_id = 0 + self.call_lock = asyncio.Lock() + self.calls = {} + self.on_frame = on_frame_default + self.on_call = on_call_default + self.on_notification = on_notification_default + + def is_connected(self): + return self.client.connected + + def notify(self, target, notification): + return self.client.send(target, notification) + + def call0(self, target, request): + request.header = RPC_REQUEST_HEADER + b'\x00\x00\x00\x00' + \ + request.method + b'\x00' + return self.client.send(target, request) + + async def call(self, target, request): + async with self.call_lock: + call_id = self.call_id + 1 + if call_id == 0xffff_ffff: + self.call_id = 0 + else: + self.call_id = call_id + call_event = RpcCallEvent() + self.calls[call_id] = call_event + request.header = RPC_REQUEST_HEADER + call_id.to_bytes( + 4, 'little') + request.method + b'\x00' + try: + try: + code = await (await self.client.send( + target, + request)).wait_completed(timeout=self.client.timeout) + if code != RESPONSE_OK: + try: + del self.calls[call_id] + except KeyError: + pass + err_code = -32000 - code + call_event.error = RpcException('RPC error', code=err_code) + call_event.completed.set() + except asyncio.TimeoutError: + try: + del self.calls[call_id] + except KeyError: + pass + err_code = -32000 - ERR_TIMEOUT + call_event.error = RpcException('RPC timeout', code=err_code) + call_event.completed.set() + except Exception as e: + try: + del self.calls[call_id] + except KeyError: + pass + call_event.error = RpcException(str(e), code=-32000 - ERR_IO) + call_event.completed.set() + return call_event + + def __getattr__(self, method): + return functools.partial(self.params_call, method=method) + + async def params_call(self, target: str, method: str, **kwargs): + params = kwargs + if '_raw' in kwargs: + params = kwargs.pop('_raw') + params = params or None + c0 = False + if method.endswith('0'): + c0 = True + method = method[:-1] + request = Request(method, serialize(params)) + if c0: + await self.call0(target, request) + return None + result = await self.call(target, request) + return deserialize((await result.wait_completed()).get_payload()) + + async def _handle_frame(self, frame): + try: + if frame.type == OP_MESSAGE: + if frame.payload[0] == RPC_NOTIFICATION: + event = Event(RPC_NOTIFICATION, frame, 1) + await self.on_notification(event) + elif frame.payload[0] == RPC_REQUEST: + sender = frame.sender + call_id_b = frame.payload[1:5] + call_id = int.from_bytes(call_id_b, 'little') + method = frame.payload[5:5 + + frame.payload[5:].index(b'\x00')] + event = Event(RPC_REQUEST, frame, 6 + len(method)) + event.call_id = call_id + event.method = method + if call_id == 0: + await self.on_call(event) + else: + reply = Reply() + try: + reply.payload = await self.on_call(event) + if reply.payload is None: + reply.payload = b'' + reply.header = RPC_REPLY_HEADER + call_id_b + except Exception as e: + try: + code = e.rpc_error_code + except AttributeError: + code = RPC_ERROR_CODE_INTERNAL + reply.header = ( + RPC_ERROR_REPLY_HEADER + call_id_b + + code.to_bytes(2, 'little', signed=True)) + reply.payload = format_rpc_e_msg(e) + await self.client.send(sender, reply) + elif frame.payload[0] == RPC_REPLY or frame.payload[ + 0] == RPC_ERROR: + call_id = int.from_bytes(frame.payload[1:5], 'little') + try: + call_event = self.calls.pop(call_id) + call_event.frame = frame + if frame.payload[0] == RPC_ERROR: + err_code = int.from_bytes(frame.payload[5:7], + 'little', + signed=True) + call_event.error = RpcException(frame.payload[7:], + code=err_code) + call_event.completed.set() + except KeyError: + logger.warning(f'orphaned RPC response: {call_id}') + else: + logger.error(f'Invalid RPC frame code: {frame.payload[0]}') + else: + await self.on_frame(frame) + except Exception: + import traceback + logger.error(traceback.format_exc()) + + +class RpcCallEvent: + + def __init__(self): + self.frame = None + self.error = None + self.completed = asyncio.Event() + + def is_completed(self): + return self.completed.is_set() + + async def wait_completed(self, timeout=None): + if timeout: + if not await asyncio.wait_for(self.completed.wait(), + timeout=timeout): + raise TimeoutError + else: + await self.completed.wait() + if self.error is None: + return self + else: + raise self.error + + def get_payload(self): + return self.frame.payload[5:] + + def is_empty(self): + return len(self.frame.payload) <= 5 + + +class Event: + + def __init__(self, tp, frame, payload_pos): + self.tp = tp + self.frame = frame + self._payload_pos = payload_pos + + def get_payload(self): + return self.frame.payload[self._payload_pos:] + + +class Notification: + + def __init__(self, payload=b''): + self.payload = payload + self.type = OP_MESSAGE + self.qos = 1 + self.header = RPC_NOTIFICATION_HEADER + + +class Request: + + def __init__(self, method, params=b''): + self.payload = b'' if params is None else params + self.type = OP_MESSAGE + self.qos = 1 + self.method = method.encode() if isinstance(method, str) else method + self.header = None + + +class Reply: + + def __init__(self, result=b''): + self.payload = result + self.type = OP_MESSAGE + self.qos = 1 + self.header = None diff --git a/src/wrpc/base.py b/src/wrpc/base.py new file mode 100644 index 0000000..1f7a2f8 --- /dev/null +++ b/src/wrpc/base.py @@ -0,0 +1,64 @@ +import msgpack + +GREETINGS = 0xEB + +PROTOCOL_VERSION = 1 + +OP_NOP = 0 +OP_PUBLISH = 1 +OP_SUBSCRIBE = 2 +OP_UNSUBSCRIBE = 3 +OP_MESSAGE = 0x12 +OP_BROADCAST = 0x13 +OP_ACK = 0xFE + +RESPONSE_OK = 0x01 +ERR_CLIENT_NOT_REGISTERED = 0x71 +ERR_DATA = 0x72 +ERR_IO = 0x73 +ERR_OTHER = 0x74 +ERR_NOT_SUPPORTED = 0x75 +ERR_BUSY = 0x76 +ERR_NOT_DELIVERED = 0x77 +ERR_TIMEOUT = 0x78 + +PING_FRAME = b'\x00' * 9 + +RPC_NOTIFICATION_HEADER = b'\x00' +RPC_REQUEST_HEADER = b'\x01' +RPC_REPLY_HEADER = b'\x11' +RPC_ERROR_REPLY_HEADER = b'\x12' + +RPC_NOTIFICATION = 0x00 +RPC_REQUEST = 0x01 +RPC_REPLY = 0x11 +RPC_ERROR = 0x12 + +RPC_ERROR_CODE_PARSE = -32700 +RPC_ERROR_CODE_INVALID_REQUEST = -32600 +RPC_ERROR_CODE_METHOD_NOT_FOUND = -32601 +RPC_ERROR_CODE_INVALID_METHOD_PARAMS = -32602 +RPC_ERROR_CODE_INTERNAL = -32603 + + +def serialize(data): + if data is None: + return None + return msgpack.dumps(data) + + +def deserialize(data): + if data is None: + return None + return msgpack.unpackb(data) + + +class RpcException(Exception): + + def __init__(self, msg='', code=RPC_ERROR_CODE_INTERNAL): + self.rpc_error_code = code + self.rpc_error_payload = msg + super().__init__(msg if isinstance(msg, str) else msg.decode()) + + def __str__(self): + return super().__str__() + f' (code: {self.rpc_error_code})' diff --git a/src/wrpc/client.py b/src/wrpc/client.py new file mode 100644 index 0000000..c626e8e --- /dev/null +++ b/src/wrpc/client.py @@ -0,0 +1,234 @@ +import socket +import threading +import time +import traceback +from loguru import logger +from .base import GREETINGS, PROTOCOL_VERSION, OP_NOP, OP_PUBLISH, OP_SUBSCRIBE, OP_UNSUBSCRIBE, OP_MESSAGE, \ + OP_ACK, RESPONSE_OK, PING_FRAME + + +def on_frame_default(frame): + pass + + +class Client: + def __init__(self, path, name): + self.path = path + self.socket = None + self.buf_size = 8192 + self.name = name + self.frame_id = 0 + self.ping_interval = 1 + self.on_frame = on_frame_default + self.socket_lock = threading.Lock() + self.mgmt_lock = threading.Lock() + self.connected = False + self.frames = {} + self.timeout = 1 + + def connect(self): + with self.mgmt_lock: + if self.path.endswith('.sock') or self.path.endswith( + '.socket') or self.path.endswith( + '.ipc') or self.path.startswith('/'): + self.socket = socket.socket(socket.AF_UNIX) + path = self.path + else: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, + 1) + path = self.path.rsplit(':', maxsplit=2) + path[1] = int(path[1]) + path = tuple(path) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, + self.buf_size) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, + self.buf_size) + self.socket.settimeout(self.timeout) + self.socket.connect(path) + buf = self.read_exact(3) + if buf[0] != GREETINGS: + raise RuntimeError('Unsupported protocol') + if int.from_bytes(buf[1:3], 'little') != PROTOCOL_VERSION: + raise RuntimeError('Unsupported protocol version') + self.socket.sendall(buf) + buf = self.socket.recv(1) + if buf[0] != RESPONSE_OK: + raise RuntimeError(f'Server response: {hex(buf[0])}') + name = self.name.encode() + self.socket.sendall(len(name).to_bytes(2, 'little') + name) + buf = self.socket.recv(1) + if buf[0] != RESPONSE_OK: + raise RuntimeError(f'Server response: {hex(buf[0])}') + self.connected = True + threading.Thread(target=self._t_reader, daemon=True).start() + threading.Thread(target=self._t_pinger, daemon=True).start() + + def _handle_daemon_exception(self): + with self.mgmt_lock: + if self.connected: + try: + self.socket.close() + except: + pass + self.connected = False + logger.error(traceback.format_exc()) + + def _t_pinger(self): + try: + while True: + time.sleep(self.ping_interval) + with self.socket_lock: + self.socket.sendall(PING_FRAME) + except: + self._handle_daemon_exception() + + def _t_reader(self): + try: + while True: + buf = self.read_exact(6) + if buf[0] == OP_NOP: + continue + elif buf[0] == OP_ACK: + op_id = int.from_bytes(buf[1:5], 'little') + try: + o = self.frames.pop(op_id) + o.result = buf[5] + o.completed.set() + except KeyError: + logger.warning(f'orphaned BUS/RT frame ack {op_id}') + else: + data_len = int.from_bytes(buf[1:5], 'little') + # do not use read_exact for max zero-copy + data = b'' + while len(data) < data_len: + buf_size = data_len - len(data) + data += self.socket.recv(buf_size if buf_size < self. + buf_size else self.buf_size) + frame = Frame() + try: + frame.type = buf[0] + if buf[0] == OP_PUBLISH: + sender, topic, frame.payload = data.split( + b'\x00', maxsplit=2) + frame.topic = topic.decode() + else: + sender, frame.payload = data.split(b'\x00', + maxsplit=1) + frame.topic = None + frame.sender = sender.decode() + frame.primary_sender = frame.sender.split('%%', 1)[0] + except Exception as e: + logger.error(f'Invalid frame from the server: {e}') + raise + try: + self.on_frame(frame) + except: + + logger.error(traceback.format_exc()) + except: + self._handle_daemon_exception() + + def disconnect(self): + with self.mgmt_lock: + self.socket.close() + self.connected = False + + def read_exact(self, data_len): + data = b'' + while len(data) < data_len: + buf_size = data_len - len(data) + try: + data += self.socket.recv( + buf_size if buf_size < self.buf_size else self.buf_size) + except socket.timeout: + if not self.connected: + break + return data + + def send(self, target=None, frame=None): + try: + with self.socket_lock: + self.frame_id += 1 + if self.frame_id > 0xffff_ffff: + self.frame_id = 1 + frame_id = self.frame_id + o = ClientFrame(frame.qos) + if frame.qos & 0b1 != 0: + self.frames[frame_id] = o + flags = frame.type | frame.qos << 6 + if frame.type == OP_SUBSCRIBE or frame.type == OP_UNSUBSCRIBE: + topics = frame.topic if isinstance(frame.topic, + list) else [frame.topic] + payload = b'\x00'.join(t.encode() for t in topics) + self.socket.sendall( + frame_id.to_bytes(4, 'little') + + flags.to_bytes(1, 'little') + + len(payload).to_bytes(4, 'little') + payload) + else: + frame_len = len(target) + len(frame.payload) + 1 + if frame.header is not None: + frame_len += len(frame.header) + if frame_len > 0xffff_ffff: + raise ValueError('frame too large') + self.socket.sendall( + frame_id.to_bytes(4, 'little') + + flags.to_bytes(1, 'little') + + frame_len.to_bytes(4, 'little') + target.encode() + + b'\x00' + + (frame.header if frame.header is not None else b'')) + self.socket.sendall(frame.payload.encode( + ) if isinstance(frame.payload, str) else frame.payload) + return o + except: + try: + del self.frames[frame_id] + except KeyError: + pass + raise + + def subscribe(self, topics): + frame = Frame(tp=OP_SUBSCRIBE) + frame.topic = topics + return self.send(None, frame) + + def unsubscribe(self, topics): + frame = Frame(tp=OP_UNSUBSCRIBE) + frame.topic = topics + return self.send(None, frame) + + def is_connected(self): + return self.connected + + +class ClientFrame: + + def __init__(self, qos): + self.qos = qos + self.result = 0 + if qos & 0b1 != 0: + self.completed = threading.Event() + + def is_completed(self): + if self.qos & 0b1 != 0: + return self.completed.is_set() + else: + return True + + def wait_completed(self, *args, **kwargs): + if self.qos & 0b1 == 0: + return RESPONSE_OK + elif not self.completed.wait(*args, **kwargs): + raise TimeoutError + else: + return self.result + + +class Frame: + + def __init__(self, payload=None, tp=OP_MESSAGE, qos=0): + self.payload = payload + # used for zero-copy + self.header = None + self.type = tp + self.qos = qos diff --git a/src/wrpc/helpers.py b/src/wrpc/helpers.py new file mode 100644 index 0000000..4c20f39 --- /dev/null +++ b/src/wrpc/helpers.py @@ -0,0 +1,71 @@ +import asyncio +import functools +import os +import random +import threading + +from .asyncio.client import Client as AClient +from .asyncio.rpc import Rpc as ARpc +from .client import Client +from .rpc import Rpc + + +class RpcEndpoint: + def __init__(self, rpc, endpoints, get_endpoints, target): + self.__rpc = rpc + self.__endpoints = endpoints + self.__get_endpoints = get_endpoints + self.__target = target + + def refresh_endpoint(self): + return self.__get_endpoints(self.__target) + + def __getattr__(self, name): + return functools.partial(getattr(self.__rpc, name), target=random.choice(self.__endpoints)) + + def set_endpoints(self, endpoints): + self.__endpoints = endpoints + + +class AsyncHelper: + def __init__(self, uri): + self.bus = AClient(uri, f'python.dynamic.caller.{os.urandom(3).hex()}') + self.rpc = ARpc(self.bus) + self.locker = asyncio.Lock() + + async def get_endpoints(self, target): + async with self.locker: + endpoints = {await self.rpc.get_worker(f'js.{target}.api') for _ in range(10)} + return endpoints + + async def enter(self, target): + if not self.bus.connected: + async with self.locker: + if not self.bus.connected: + await self.bus.connect() + return RpcEndpoint(self.rpc, await self.get_endpoints(target), self.get_endpoints, target) + + async def close(self): + await self.bus.disconnect() + + +class Helper: + def __init__(self, uri): + self.bus = Client(uri, f'python.dynamic.caller.{os.urandom(3).hex()}') + self.rpc = Rpc(self.bus) + self.locker = threading.Lock() + + def get_endpoints(self, target): + with self.locker: + endpoints = {self.rpc.get_worker(f'js.{target}.api') for _ in range(10)} + return endpoints + + def enter(self, target): + if not self.bus.connected: + with self.locker: + if not self.bus.connected: + self.bus.connect() + return RpcEndpoint(self.rpc, self.get_endpoints(target), self.get_endpoints, target) + + def close(self): + self.bus.disconnect() diff --git a/src/wrpc/rpc.py b/src/wrpc/rpc.py new file mode 100644 index 0000000..5dfc62c --- /dev/null +++ b/src/wrpc/rpc.py @@ -0,0 +1,236 @@ +import functools +import threading +from loguru import logger + +from .client import OP_MESSAGE +from .base import ERR_IO, ERR_TIMEOUT, RESPONSE_OK, RPC_NOTIFICATION_HEADER, RPC_REQUEST_HEADER, RPC_REPLY_HEADER, \ + RPC_ERROR_REPLY_HEADER, RPC_NOTIFICATION, RPC_REQUEST, RPC_REPLY, RPC_ERROR, RPC_ERROR_CODE_METHOD_NOT_FOUND, \ + RPC_ERROR_CODE_INTERNAL, serialize, deserialize,RpcException +from .client import on_frame_default + + +def on_call_default(event): + raise RpcException('RPC Engine not initialized', + RPC_ERROR_CODE_METHOD_NOT_FOUND) + + +def on_notification_default(event): + pass + + +def format_rpc_e_msg(e): + if isinstance(e, RpcException): + return e.rpc_error_payload + else: + return str(e) + + + +class Rpc: + + def __init__(self, client): + self.client = client + self.client.on_frame = self._handle_frame + self.call_id = 0 + self.call_lock = threading.Lock() + self.calls = {} + self.on_frame = on_frame_default + self.on_call = on_call_default + self.on_notification = on_notification_default + + def is_connected(self): + return self.client.connected + + def notify(self, target, notification): + return self.client.send(target, notification) + + def call0(self, target, request): + request.header = RPC_REQUEST_HEADER + b'\x00\x00\x00\x00' + \ + request.method + b'\x00' + return self.client.send(target, request) + + def call(self, target, request): + with self.call_lock: + call_id = self.call_id + 1 + if call_id == 0xffff_ffff: + self.call_id = 0 + else: + self.call_id = call_id + call_event = RpcCallEvent() + self.calls[call_id] = call_event + request.header = RPC_REQUEST_HEADER + call_id.to_bytes( + 4, 'little') + request.method + b'\x00' + try: + try: + code = self.client.send( + target, request).wait_completed(timeout=self.client.timeout) + if code != RESPONSE_OK: + try: + del self.calls[call_id] + except KeyError: + pass + err_code = -32000 - code + call_event.error = RpcException('RPC error', code=err_code) + call_event.completed.set() + except TimeoutError: + try: + del self.calls[call_id] + except KeyError: + pass + err_code = -32000 - ERR_TIMEOUT + call_event.error = RpcException('RPC timeout', code=err_code) + call_event.completed.set() + except Exception as e: + try: + del self.calls[call_id] + except KeyError: + pass + call_event.error = RpcException(str(e), code=-32000 - ERR_IO) + call_event.completed.set() + return call_event + + def __getattr__(self, method): + return functools.partial(self.params_call, method=method) + + def params_call(self, target: str, method: str, **kwargs): + params = kwargs + if '_raw' in kwargs: + params = kwargs.pop('_raw') + params = params or None + c0 = False + if method.endswith('0'): + c0 = True + method = method[:-1] + request = Request(method, serialize(params)) + if c0: + self.call0(target, request) + return None + result = self.call(target, request) + return deserialize((result.wait_completed()).get_payload()) + + def _handle_frame(self, frame): + self.spawn(self._t_handler, frame) + + def spawn(self, f, *args, **kwargs): + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + def _t_handler(self, frame): + try: + if frame.type == OP_MESSAGE: + if frame.payload[0] == RPC_NOTIFICATION: + event = Event(RPC_NOTIFICATION, frame, 1) + self.on_notification(event) + elif frame.payload[0] == RPC_REQUEST: + sender = frame.sender + call_id_b = frame.payload[1:5] + call_id = int.from_bytes(call_id_b, 'little') + method = frame.payload[5:5 + + frame.payload[5:].index(b'\x00')] + event = Event(RPC_REQUEST, frame, 6 + len(method)) + event.call_id = call_id + event.method = method + if call_id == 0: + self.on_call(event) + else: + reply = Reply() + try: + reply.payload = self.on_call(event) + if reply.payload is None: + reply.payload = b'' + reply.header = RPC_REPLY_HEADER + call_id_b + except Exception as e: + try: + code = e.rpc_error_code + except AttributeError: + code = RPC_ERROR_CODE_INTERNAL + reply.header = ( + RPC_ERROR_REPLY_HEADER + call_id_b + + code.to_bytes(2, 'little', signed=True)) + reply.payload = format_rpc_e_msg(e) + self.client.send(sender, reply) + elif frame.payload[0] == RPC_REPLY or frame.payload[ + 0] == RPC_ERROR: + call_id = int.from_bytes(frame.payload[1:5], 'little') + try: + call_event = self.calls.pop(call_id) + call_event.frame = frame + if frame.payload[0] == RPC_ERROR: + err_code = int.from_bytes(frame.payload[5:7], + 'little', + signed=True) + call_event.error = RpcException(frame.payload[7:], + code=err_code) + call_event.completed.set() + except KeyError: + logger.warning(f'orphaned RPC response: {call_id}') + else: + logger.error(f'Invalid RPC frame code: {frame.payload[0]}') + else: + self.on_frame(frame) + except Exception: + import traceback + logger.error(traceback.format_exc()) + + +class RpcCallEvent: + + def __init__(self): + self.frame = None + self.error = None + self.completed = threading.Event() + + def is_completed(self): + return self.completed.is_set() + + def wait_completed(self, *args, **kwargs): + if not self.completed.wait(*args, **kwargs): + raise TimeoutError + if self.error is None: + return self + else: + raise self.error + + def get_payload(self): + return self.frame.payload[5:] + + def is_empty(self): + return len(self.frame.payload) <= 5 + + +class Event: + + def __init__(self, tp, frame, payload_pos): + self.tp = tp + self.frame = frame + self._payload_pos = payload_pos + + def get_payload(self): + return self.frame.payload[self._payload_pos:] + + +class Notification: + + def __init__(self, payload=b''): + self.payload = payload + self.type = OP_MESSAGE + self.qos = 1 + self.header = RPC_NOTIFICATION_HEADER + + +class Request: + + def __init__(self, method, params=b''): + self.payload = b'' if params is None else params + self.type = OP_MESSAGE + self.qos = 1 + self.method = method.encode() if isinstance(method, str) else method + self.header = None + + +class Reply: + + def __init__(self, result=b''): + self.payload = result + self.type = OP_MESSAGE + self.qos = 1 + self.header = None diff --git a/src/wrpc/zxgk.py b/src/wrpc/zxgk.py new file mode 100644 index 0000000..b96bd00 --- /dev/null +++ b/src/wrpc/zxgk.py @@ -0,0 +1,61 @@ +from typing import Union, Awaitable, Dict + +from .helpers import AsyncHelper, Helper, RpcEndpoint +from .rpc import RpcException + + +class _Caller: + def __init__(self, caller: RpcEndpoint): + self.caller = caller + + def get_url(self, js_code): + try: + return self.caller.get_url(js_code=js_code) + except RpcException as e: + if e.rpc_error_code == -32113: + endpoints = self.caller.refresh_endpoints() + self.caller.set_endpoints(endpoints) + return self.get_url(js_code) + raise e + + async def async_get_url(self, js_code): + try: + return await self.caller.get_url(js_code=js_code) + except RpcException as e: + if e.rpc_error_code == -32113: + endpoints = await self.caller.refresh_endpoints() + self.caller.set_endpoints(endpoints) + return self.get_url(js_code) + raise e + + +class _ACaller: + def __init__(self, caller: RpcEndpoint): + self.caller = caller + + async def get_url(self, js_code): + try: + return await self.caller.get_url(js_code=js_code) + except RpcException as e: + if e.rpc_error_code == -32113: + endpoints = await self.caller.refresh_endpoints() + self.caller.set_endpoints(endpoints) + return self.get_url(js_code) + raise e + + +class ZXGKCaller: + callers: Dict[str, Union[_Caller, _ACaller]] = {} + + def __init__(self, uri, async_: bool): + self.async_ = async_ + self.helper = AsyncHelper(uri) if async_ else Helper(uri) + + def get_endpoint_caller(self, target) -> Union[RpcEndpoint, Awaitable[RpcEndpoint]]: + return self.helper.enter(target) + + def set_endpoint_caller(self, name: str, caller: RpcEndpoint): + self.callers[name] = _ACaller(caller) if self.async_ else _Caller(caller) + + def __getattr__(self, item): + return self.callers[item] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29