Last active
March 10, 2023 18:09
-
-
Save dankrause/96c944abb6d636c35fa67c6cfd0928d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import builtins | |
import collections | |
import contextlib | |
import copy | |
import dis | |
import json | |
import socket | |
import struct | |
import types | |
_COMPILER_FLAG_MAP = {val: key for key, val in dis.COMPILER_FLAG_NAMES.items()} | |
_HEADER_STRUCT = struct.Struct('!L') | |
_ANNOTATION_MAP = { | |
i: getattr(builtins, i) | |
for i in dir(builtins) | |
if type(getattr(builtins, i)) is type and i[0].islower() | |
} | |
def _get_annotation_type(type_name): | |
return _ANNOTATION_MAP[type_name] if type_name in _ANNOTATION_MAP else type_name | |
class RPCClientException(Exception): | |
pass | |
class RPCClientNamespace: | |
def __init__(self, sock, manifest=None, name=None, func=None): | |
self._name = name or [] | |
self._sock = sock | |
self._manifest = manifest or self._send_raw_message("manifest") | |
self._func = func | |
for name, value in self._manifest.items(): | |
item_type = value.get("_meta", {}).get("type", None) | |
if name == "_meta" or item_type is None: | |
continue | |
if item_type == "func": | |
setattr(self, name, self._deserialize_function(value)) | |
elif item_type == "manifest": | |
namespace = self._build_nested_namespace(sock, value, [*self._name, name]) | |
setattr(self, name, namespace) | |
def _send_to_socket(self, msg): | |
msg_bytes = _HEADER_STRUCT.pack(len(msg)) + bytes(msg, "utf-8") | |
self._sock.sendall(msg_bytes) | |
header = self._sock.recv(_HEADER_STRUCT.size) | |
(msg_len,) = _HEADER_STRUCT.unpack(header) | |
return str(self._sock.recv(msg_len), "utf-8") | |
def _send_raw_message(self, msg): | |
raw_response = self._send_to_socket(json.dumps(msg)) | |
try: | |
response = json.loads(raw_response) | |
except json.JSONDecodeError: | |
response = {"error": {"msg": "Response is not valid json", "invalid_response": raw_response}} | |
if "response" in response: | |
return response["response"] | |
elif "manifest" in response: | |
return RPCClientNamespace(self._sock, response["manifest"], func=response["func"]) | |
else: | |
raise RPCClientException(response.get("error", "unknown error")) | |
def _dispatch(self, args_dict, func_def): | |
args = [] | |
kwargs = {} | |
if "params" in func_def: | |
for name, param in func_def["params"].items(): | |
if name in args_dict: | |
if param["kind"] == "POSITIONAL_ONLY": | |
args.append(args_dict[name]) | |
else: | |
kwargs[name] = args_dict[name] | |
msg = { | |
"func": [*self._name, func_def["name"]], | |
"args": args, | |
"kwargs": kwargs | |
} | |
if self._func: | |
msg["func"] = [*self._func, *msg["func"]] | |
return self._send_raw_message(msg) | |
def _build_nested_namespace(self, _sock, _manifest, _name): | |
if "_call" in _manifest["_meta"]: | |
class CallableRPCClientNamespace(RPCClientNamespace): | |
pass | |
CallableRPCClientNamespace.__call__ = self._deserialize_function(_manifest["_meta"]["_call"], _has_self=True) | |
return CallableRPCClientNamespace(_sock, _manifest, _name) | |
else: | |
return RPCClientNamespace(_sock, _manifest, _name) | |
def _deserialize_function(self, _func_def, _has_self=False): | |
_func_def = copy.deepcopy(_func_def) | |
if _has_self: | |
self_arg = {"self": {"kind": "POSITIONAL_OR_KEYWORD"}} | |
if "params" not in _func_def: | |
_func_def["params"] = self_arg | |
else: | |
_func_def["params"] = {**self_arg, **_func_def["params"]} | |
if "params" in _func_def: | |
varnames = tuple(_func_def["params"].keys()) | |
arg_counts = collections.Counter([val["kind"] for val in _func_def["params"].values()]) | |
defaults = tuple(val["default"] for val in _func_def["params"].values() if "default" in val) | |
nlocals = len(_func_def["params"]) | |
annotations = { | |
key: _get_annotation_type(val["annotation"]) | |
for key, val in _func_def["params"].items() | |
if "annotation" in val | |
} | |
else: | |
varnames = () | |
arg_counts = collections.Counter() | |
defaults = None | |
nlocals = 0 | |
annotations = {} | |
def wrapper(): | |
return self._dispatch(locals(), _func_def) | |
code_obj = wrapper.__code__ | |
flags = code_obj.co_flags | |
flags += _COMPILER_FLAG_MAP["VARARGS"] if 'VAR_POSITIONAL' in arg_counts else 0 | |
flags += _COMPILER_FLAG_MAP["VARKEYWORDS"] if 'VAR_KEYWORD' in arg_counts else 0 | |
new_code_obj = types.CodeType( | |
arg_counts["POSITIONAL_OR_KEYWORD"], | |
arg_counts["POSITIONAL_ONLY"], | |
arg_counts["KEYWORD_ONLY"], | |
code_obj.co_nlocals + nlocals, | |
code_obj.co_stacksize, | |
flags, | |
code_obj.co_code, | |
code_obj.co_consts, | |
code_obj.co_names, | |
varnames, | |
code_obj.co_filename, | |
_func_def["name"], | |
code_obj.co_firstlineno, | |
code_obj.co_lnotab, | |
code_obj.co_freevars | |
) | |
modified = types.FunctionType( | |
new_code_obj, | |
wrapper.__globals__, | |
argdefs=defaults, | |
closure=wrapper.__closure__ | |
) | |
wrapper.__name__ = _func_def["name"] | |
wrapper.__code__ = modified.__code__ | |
wrapper.__defaults__ = defaults | |
wrapper.__annotations__ = annotations | |
if "return" in _func_def: | |
wrapper.__annotations__["return"] = _get_annotation_type(_func_def["return"]) | |
if "doc" in _func_def: | |
wrapper.__doc__ = _func_def["doc"] | |
if _has_self: | |
del(_func_def["params"]["self"]) | |
return wrapper | |
class RPCClient: | |
def __init__(self, address): | |
self._address = address | |
def get_session(self): | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
sock.connect(self._address) | |
return RPCClientNamespace(sock) | |
@staticmethod | |
def close_session(session): | |
session._sock.close() | |
@contextlib.contextmanager | |
def session(self): | |
session = self.get_session() | |
yield session | |
self.close_session(session) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import functools | |
import inspect | |
import json | |
import socketserver | |
import struct | |
import types | |
_HEADER_STRUCT = struct.Struct('!L') | |
class RPCServerException(Exception): | |
pass | |
def serialize_function(func, name=None): | |
response = { | |
"_meta": {"type": "func"}, | |
"name": name | |
} | |
if func.__doc__: | |
response["doc"] = func.__doc__ | |
spec = inspect.signature(func) | |
if spec.return_annotation is not inspect._empty: | |
response["return"] = spec.return_annotation.__name__ | |
if len(spec.parameters): | |
response["params"] = {} | |
for name, param in spec.parameters.items(): | |
response["params"][name] = {} | |
for key in ("annotation", "default", "kind"): | |
val = getattr(param, key) | |
if val is not inspect._empty: | |
if isinstance(val, inspect._ParameterKind): | |
val = val.name | |
elif isinstance(val, type): | |
val = val.__name__ | |
response["params"][name][key] = val | |
return response | |
class RPCServerNamespace: | |
def __init__(self, name, wrapped_obj=None): | |
self._registry = {"_meta":{"type": "manifest", "name": name}} | |
self._obj_registry = collections.defaultdict(dict) | |
self.wrapped_obj = wrapped_obj | |
def _call_wrapped(self, args, kwargs, _client_address): | |
if type(self.wrapped_obj) is type: | |
# class | |
return self._instance_init(self.wrapped_obj, args, kwargs, _client_address) | |
elif callable(self.wrapped_obj): | |
return self.wrapped_obj(*args, **kwargs) | |
else: | |
raise RPCServerException("Namespace not callable") | |
def _instance_init(self, func_obj, args, kwargs, _client_address): | |
obj = func_obj(*args, **kwargs) | |
instance_namespace = RPCServerNamespace(id(obj), wrapped_obj=obj) | |
for method_name in dir(obj): | |
method = getattr(obj, method_name) | |
if method is not type(obj): | |
if hasattr(method, "__func__"): | |
func = method.__func__ | |
else: | |
func = method | |
if hasattr(func, "__name__"): | |
rpcname = getattr(func, "__rpcname__", func.__name__) | |
if not rpcname.startswith("_") and not hasattr(func, "__rpcignore__"): | |
instance_namespace.register(rpcname, func=method) | |
@instance_namespace.register() | |
def _delete(): | |
o = obj | |
i = instance_namespace | |
del(self._obj_registry[_client_address][id(i)], i , o) | |
return True | |
return instance_namespace | |
def _build_class_namespace(self, cls): | |
class_namespace = RPCServerNamespace(cls.__name__, wrapped_obj=cls) | |
for name, method in cls.__dict__.items(): | |
if isinstance(method, (classmethod, staticmethod)): | |
class_namespace.register(func=getattr(cls, name)) | |
return class_namespace | |
def ignore(self, obj): | |
setattr(getattr(obj, "__func__", obj), "__rpcignore__", True) | |
return obj | |
def register(self, name=None, namespace=None, func=None): | |
if name is not None and not name[0].isalpha(): | |
raise ValueError("Namespaces must start with a letter") | |
if func is None: | |
return functools.partial(self.register, name, namespace) | |
if namespace is not None: | |
namespace, remaining_path = self._get_next_namespace(namespace, create=True) | |
if namespace is not self: | |
return namespace.register(name, remaining_path, func) | |
if isinstance(func, types.BuiltinFunctionType): | |
raise TypeError("Built-in functions not supported") | |
elif isinstance(func, classmethod): | |
raise TypeError("Cannot register classmethods directly") | |
elif isinstance(func, staticmethod): | |
func.__func__.__rpcname__ = name or func.__func__.__name__ | |
self._registry[func.__func__.__rpcname__] = func.__func__ | |
elif type(func) is type: | |
# class | |
func.__rpcname__ = name or func.__name__ | |
self._registry[func.__rpcname__] = self._build_class_namespace(func) | |
elif hasattr(func, "__qualname__"): | |
if hasattr(func, "__self__"): | |
# bound method | |
func.__func__.__rpcname__ = name or func.__func__.__name__ | |
self._registry[func.__func__.__rpcname__] = func | |
else: | |
# unbound method or function | |
func.__rpcname__ = name or func.__name__ | |
self._registry[func.__rpcname__] = func | |
elif callable(func) and name is not None: | |
# callable instance | |
func.__rpcname__ = name | |
self._registry[name] = func | |
return func | |
def register_instance(self, instance, name, methods=None): | |
wrapped_obj = instance if callable(instance) else None | |
instance_namespace = self.add_namespace(name, wrapped_obj) | |
if not methods: | |
methods = [ | |
getattr(instance, method) | |
for method in dir(instance) | |
if isinstance(getattr(instance, method), types.MethodType) | |
and method[0].isalpha() | |
] | |
if hasattr(methods, "items"): | |
for name, method in methods.items(): | |
instance_namespace.register(name, func=method) | |
else: | |
for method in methods: | |
instance_namespace.register(func=method) | |
return instance_namespace | |
def _get_next_namespace(self, path, create=False, client_address=None): | |
if path is None or not len(path): | |
return self, None | |
else: | |
if isinstance(path[0], int): | |
if path[0] in self._obj_registry[client_address]: | |
return self._obj_registry[client_address][path[0]], path[1:] or None | |
elif path[0] in self._registry: | |
if isinstance(self._registry[path[0]], RPCServerNamespace): | |
return self._registry[path[0]], path[1:] or None | |
else: | |
if len(path) == 1: | |
return self, path[0] | |
else: | |
if create: | |
namespace = RPCServerNamespace(path[0]) | |
self._registry[path[0]] = namespace | |
return namespace, path[1:] or None | |
else: | |
if len(path) == 1: | |
return self, path[0] | |
raise RPCServerException(f"Invalid namespace: {path[0]}") | |
def call(self, func, args, kwargs, _client_address): | |
if func is not None: | |
namespace, func = self._get_next_namespace(func, client_address=_client_address) | |
if namespace is not self: | |
return namespace.call(func, args, kwargs, _client_address) | |
if func is None: | |
func_obj = self | |
else: | |
func_obj = self._registry[func] | |
if type(func_obj) is type: | |
ns = self._instance_init(func_obj, args, kwargs, _client_address) | |
result = ns | |
elif isinstance(func_obj, RPCServerNamespace): | |
result = func_obj._call_wrapped(args, kwargs, _client_address) | |
else: | |
result = func_obj(*args, **kwargs) | |
if isinstance(result, RPCServerNamespace): | |
self._obj_registry[_client_address][id(result)] = result | |
return result | |
def add_namespace(self, name, wrapped_obj=None): | |
new_ns = RPCServerNamespace(name, wrapped_obj) | |
self._registry[name] = new_ns | |
return new_ns | |
def remove_namespace(self, name): | |
if name not in self._registry: | |
raise RPCServerException("Invalid namespace name") | |
del(self._registry[name]) | |
def clear_objects(self, client_address): | |
if client_address in self._obj_registry: | |
del(self._obj_registry[client_address]) | |
for ns in self._registry.values(): | |
if isinstance(ns, RPCServerNamespace): | |
ns.clear_objects(client_address) | |
@property | |
def manifest(self): | |
manifest = {"_meta": self._registry["_meta"]} | |
if callable(self.wrapped_obj): | |
manifest["_meta"]["_call"] = serialize_function(self.wrapped_obj, manifest["_meta"]["name"]) | |
manifest.update({ | |
key: serialize_function(val, key) | |
for key, val in self._registry.items() | |
if callable(val) | |
}) | |
manifest.update({ | |
key: val.manifest | |
for key, val in self._registry.items() | |
if isinstance(val, RPCServerNamespace) | |
}) | |
return manifest | |
class RPCServer(socketserver.ThreadingTCPServer, RPCServerNamespace): | |
def __init__(self, address): | |
RPCServerNamespace.__init__(self, "_root") | |
class _RCPServerHandler(socketserver.StreamRequestHandler): | |
def handle(handler_self): | |
while True: | |
header = handler_self.request.recv(_HEADER_STRUCT.size) | |
if not len(header): | |
break | |
(msg_len,) = _HEADER_STRUCT.unpack(header) | |
response_json = self._handle_message( | |
handler_self.request.recv(msg_len), | |
handler_self.client_address | |
) | |
response = _HEADER_STRUCT.pack(len(response_json)) + response_json | |
handler_self.request.sendall(response) | |
self.clear_objects(handler_self.client_address) | |
socketserver.ThreadingTCPServer.__init__(self, address, _RCPServerHandler) | |
def _handle_message(self, raw_msg, client_address): | |
try: | |
msg = json.loads(str(raw_msg, 'utf-8')) | |
if msg == "manifest": | |
raw_response = self.manifest | |
else: | |
raw_response = self.call(**msg, _client_address=client_address) | |
if isinstance(raw_response, RPCServerNamespace): | |
return bytes(json.dumps({"manifest": raw_response.manifest, "func": [*msg['func'], id(raw_response)]}), 'utf-8') | |
else: | |
return bytes(json.dumps({"response": raw_response}), 'utf-8') | |
except Exception as e: | |
return bytes(json.dumps({"error": str(e)}), 'utf-8') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from server import RPCServer, RPCServerNamespace | |
from client import RPCClientNamespace, RPCClientException | |
address = ('127.0.0.1', 4512) | |
server = RPCServer(address) | |
dumb_store = {} | |
ns = server.add_namespace("x") | |
broken_ns = server.add_namespace("y") | |
@server.register() | |
def store(name: str, value: str) -> bool: | |
dumb_store[name] = value | |
return True | |
@ns.register("remove") | |
def delete(name: str) -> str: | |
if name in dumb_store: | |
val = dumb_store[name] | |
del(dumb_store[name]) | |
return val | |
else: | |
return "" | |
@server.register() | |
def get(name: str) -> str: | |
if name in dumb_store: | |
return dumb_store[name] | |
else: | |
return "" | |
@ns.register() | |
def error(): | |
raise Exception("error") | |
@server.register("existing_namespace", "x") | |
def bar(): | |
return "bar" | |
@server.register(namespace="z") | |
def baz(): | |
return "bar" | |
@broken_ns.register() | |
def broken(): | |
pass | |
widget_registry = {} | |
@server.register() | |
class Widget: | |
def __init__(self, name): | |
self.name = name | |
def get_name(self): | |
return self.name | |
def set_name(self, name): | |
self.name = name | |
@classmethod | |
def get_class_name(cls): | |
return cls.__name__ | |
@server.register("return_true") | |
@staticmethod | |
def return_true(): | |
return True | |
@server.ignore | |
def ignore_me(self): | |
pass | |
def _also_ignored(self): | |
pass | |
class CallMe: | |
def __init__(self, name): | |
self._name = name | |
def __call__(self, text): | |
return f"{self._name}: {text}" | |
def set_name(self, name): | |
self._name = name | |
def get_name(self): | |
return self._name | |
def _reverse_name(self): | |
return "".join(reversed(self._name)) | |
call_me = CallMe("foo") | |
server.register("call_me", func=call_me) | |
call_me_1 = server.register_instance(call_me, "call_me_1") | |
call_me_2 = server.register_instance(call_me, "call_me_2", [call_me.get_name]) | |
call_me_3 = server.register_instance(call_me, "call_me_3", {"get": call_me.get_name, "set": call_me.set_name, "reverse": call_me._reverse_name}) | |
RPCClientNamespace._send_to_socket = lambda self, msg: server._handle_message(bytes(msg, 'utf-8'), "") | |
session = RPCClientNamespace(None) | |
server._registry.pop("y") | |
widget_foo = session.Widget("fiz") | |
if __name__ == "__main__": | |
assert session.store("foo", "bar") == True | |
assert session.get("foo") == "bar" | |
assert session.store("foo", "baz") == True | |
assert session.get("foo") == "baz" | |
assert session.x.remove("foo") == "baz" | |
assert session.x.existing_namespace() == "bar" | |
assert session.z.baz() == "bar" | |
widget_foo = session.Widget("foo") | |
widget_bar = session.Widget("bar") | |
widget_foo.set_name("Foo") | |
assert widget_foo.get_class_name() == "Widget" | |
assert session.Widget.get_class_name() == "Widget" | |
assert session.return_true() is True | |
assert session.Widget.return_true() is True | |
assert widget_foo.return_true() is True | |
assert widget_foo.get_name() == "Foo" | |
assert widget_bar.get_name() == "bar" | |
assert not hasattr(widget_foo, "ignore_me") | |
assert not hasattr(widget_foo, "_also_ignored") | |
assert session.call_me("bar") == "foo: bar" | |
assert session.call_me_1("bar") == "foo: bar" | |
session.call_me_1.set_name("baz") | |
assert session.call_me("bar") == "baz: bar" | |
assert session.call_me_1("bar") == "baz: bar" | |
assert session.call_me_2("bar") == "baz: bar" | |
assert session.call_me_3("bar") == "baz: bar" | |
assert not hasattr(session.call_me_2, "set_name") | |
assert not hasattr(session.call_me_3, "set_name") | |
assert session.call_me_2.get_name() == "baz" | |
assert session.call_me_3.get() == "baz" | |
assert session.call_me_3.reverse() == "zab" | |
session.call_me_3.set("baz 2") | |
assert session.call_me("bar") == "baz 2: bar" | |
assert session.call_me_1("bar") == "baz 2: bar" | |
assert session.call_me_2("bar") == "baz 2: bar" | |
assert session.call_me_3("bar") == "baz 2: bar" | |
widget_foo._delete() | |
try: | |
widget_foo.get_name() | |
assert False | |
except RPCClientException as e: | |
assert e.args[0].startswith("Invalid namespace: ") | |
try: | |
session.x.error() | |
assert False | |
except RPCClientException as e: | |
assert e.args[0] == "error" | |
try: | |
session.y.broken() | |
assert False | |
except RPCClientException as e: | |
assert e.args[0] == "Invalid namespace: y" | |
print("\nAll tests passed.\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment