websockets: pass connection to WebRequest
This gives handlers direct access to a websocket client connection. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
de1575f757
commit
8d1239c316
|
@ -15,10 +15,11 @@ class Sentinel:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class WebRequest:
|
class WebRequest:
|
||||||
def __init__(self, endpoint, args, action=""):
|
def __init__(self, endpoint, args, action="", conn=None):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.action = action
|
self.action = action
|
||||||
self.args = args
|
self.args = args
|
||||||
|
self.conn = conn
|
||||||
|
|
||||||
def get_endpoint(self):
|
def get_endpoint(self):
|
||||||
return self.endpoint
|
return self.endpoint
|
||||||
|
@ -29,6 +30,9 @@ class WebRequest:
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
return self.args
|
return self.args
|
||||||
|
|
||||||
|
def get_connection(self):
|
||||||
|
return self.conn
|
||||||
|
|
||||||
def _get_converted_arg(self, key, default=Sentinel, dtype=str):
|
def _get_converted_arg(self, key, default=Sentinel, dtype=str):
|
||||||
if key not in self.args:
|
if key not in self.args:
|
||||||
if default == Sentinel:
|
if default == Sentinel:
|
||||||
|
@ -79,7 +83,7 @@ class JsonRPC:
|
||||||
def remove_method(self, name):
|
def remove_method(self, name):
|
||||||
self.methods.pop(name, None)
|
self.methods.pop(name, None)
|
||||||
|
|
||||||
async def dispatch(self, data):
|
async def dispatch(self, data, ws):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
request = json.loads(data)
|
request = json.loads(data)
|
||||||
|
@ -92,19 +96,19 @@ class JsonRPC:
|
||||||
if isinstance(request, list):
|
if isinstance(request, list):
|
||||||
response = []
|
response = []
|
||||||
for req in request:
|
for req in request:
|
||||||
resp = await self.process_request(req)
|
resp = await self.process_request(req, ws)
|
||||||
if resp is not None:
|
if resp is not None:
|
||||||
response.append(resp)
|
response.append(resp)
|
||||||
if not response:
|
if not response:
|
||||||
response = None
|
response = None
|
||||||
else:
|
else:
|
||||||
response = await self.process_request(request)
|
response = await self.process_request(request, ws)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
response = json.dumps(response)
|
response = json.dumps(response)
|
||||||
logging.debug("Websocket Response::" + response)
|
logging.debug("Websocket Response::" + response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def process_request(self, request):
|
async def process_request(self, request, ws):
|
||||||
req_id = request.get('id', None)
|
req_id = request.get('id', None)
|
||||||
rpc_version = request.get('jsonrpc', "")
|
rpc_version = request.get('jsonrpc', "")
|
||||||
method_name = request.get('method', None)
|
method_name = request.get('method', None)
|
||||||
|
@ -116,18 +120,20 @@ class JsonRPC:
|
||||||
if 'params' in request:
|
if 'params' in request:
|
||||||
params = request['params']
|
params = request['params']
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
response = await self.execute_method(method, req_id, *params)
|
response = await self.execute_method(
|
||||||
|
method, req_id, ws, *params)
|
||||||
elif isinstance(params, dict):
|
elif isinstance(params, dict):
|
||||||
response = await self.execute_method(method, req_id, **params)
|
response = await self.execute_method(
|
||||||
|
method, req_id, ws, **params)
|
||||||
else:
|
else:
|
||||||
return self.build_error(-32600, "Invalid Request", req_id)
|
return self.build_error(-32600, "Invalid Request", req_id)
|
||||||
else:
|
else:
|
||||||
response = await self.execute_method(method, req_id)
|
response = await self.execute_method(method, req_id, ws)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def execute_method(self, method, req_id, *args, **kwargs):
|
async def execute_method(self, method, req_id, ws, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
result = await method(*args, **kwargs)
|
result = await method(ws, *args, **kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
return self.build_error(-32603, f"Invalid params:\n{e}", req_id)
|
return self.build_error(-32603, f"Invalid params:\n{e}", req_id)
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
|
@ -209,16 +215,16 @@ class WebsocketManager:
|
||||||
self.rpc.remove_method(ws_method)
|
self.rpc.remove_method(ws_method)
|
||||||
|
|
||||||
def _generate_callback(self, endpoint):
|
def _generate_callback(self, endpoint):
|
||||||
async def func(**kwargs):
|
async def func(ws, **kwargs):
|
||||||
result = await self.server.make_request(
|
result = await self.server.make_request(
|
||||||
WebRequest(endpoint, kwargs))
|
WebRequest(endpoint, kwargs, conn=ws))
|
||||||
return result
|
return result
|
||||||
return func
|
return func
|
||||||
|
|
||||||
def _generate_local_callback(self, endpoint, request_method, callback):
|
def _generate_local_callback(self, endpoint, request_method, callback):
|
||||||
async def func(**kwargs):
|
async def func(ws, **kwargs):
|
||||||
result = await callback(
|
result = await callback(
|
||||||
WebRequest(endpoint, kwargs, request_method))
|
WebRequest(endpoint, kwargs, request_method, ws))
|
||||||
return result
|
return result
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -274,7 +280,7 @@ class WebSocket(WebSocketHandler):
|
||||||
|
|
||||||
async def _process_message(self, message):
|
async def _process_message(self, message):
|
||||||
try:
|
try:
|
||||||
response = await self.rpc.dispatch(message)
|
response = await self.rpc.dispatch(message, self)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
self.write_message(response)
|
self.write_message(response)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
Loading…
Reference in New Issue