websockets: pass connection to WebRequest

This gives handlers direct access to a websocket client connection.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2020-11-09 08:41:38 -05:00
parent de1575f757
commit 8d1239c316
1 changed files with 21 additions and 15 deletions

View File

@ -15,10 +15,11 @@ class Sentinel:
pass pass
class WebRequest: class WebRequest:
def __init__(self, endpoint, args, action=""): def __init__(self, endpoint, args, action="", conn=None):
self.endpoint = endpoint self.endpoint = endpoint
self.action = action self.action = action
self.args = args self.args = args
self.conn = conn
def get_endpoint(self): def get_endpoint(self):
return self.endpoint return self.endpoint
@ -29,6 +30,9 @@ class WebRequest:
def get_args(self): def get_args(self):
return self.args return self.args
def get_connection(self):
return self.conn
def _get_converted_arg(self, key, default=Sentinel, dtype=str): def _get_converted_arg(self, key, default=Sentinel, dtype=str):
if key not in self.args: if key not in self.args:
if default == Sentinel: if default == Sentinel:
@ -79,7 +83,7 @@ class JsonRPC:
def remove_method(self, name): def remove_method(self, name):
self.methods.pop(name, None) self.methods.pop(name, None)
async def dispatch(self, data): async def dispatch(self, data, ws):
response = None response = None
try: try:
request = json.loads(data) request = json.loads(data)
@ -92,19 +96,19 @@ class JsonRPC:
if isinstance(request, list): if isinstance(request, list):
response = [] response = []
for req in request: for req in request:
resp = await self.process_request(req) resp = await self.process_request(req, ws)
if resp is not None: if resp is not None:
response.append(resp) response.append(resp)
if not response: if not response:
response = None response = None
else: else:
response = await self.process_request(request) response = await self.process_request(request, ws)
if response is not None: if response is not None:
response = json.dumps(response) response = json.dumps(response)
logging.debug("Websocket Response::" + response) logging.debug("Websocket Response::" + response)
return response return response
async def process_request(self, request): async def process_request(self, request, ws):
req_id = request.get('id', None) req_id = request.get('id', None)
rpc_version = request.get('jsonrpc', "") rpc_version = request.get('jsonrpc', "")
method_name = request.get('method', None) method_name = request.get('method', None)
@ -116,18 +120,20 @@ class JsonRPC:
if 'params' in request: if 'params' in request:
params = request['params'] params = request['params']
if isinstance(params, list): if isinstance(params, list):
response = await self.execute_method(method, req_id, *params) response = await self.execute_method(
method, req_id, ws, *params)
elif isinstance(params, dict): elif isinstance(params, dict):
response = await self.execute_method(method, req_id, **params) response = await self.execute_method(
method, req_id, ws, **params)
else: else:
return self.build_error(-32600, "Invalid Request", req_id) return self.build_error(-32600, "Invalid Request", req_id)
else: else:
response = await self.execute_method(method, req_id) response = await self.execute_method(method, req_id, ws)
return response return response
async def execute_method(self, method, req_id, *args, **kwargs): async def execute_method(self, method, req_id, ws, *args, **kwargs):
try: try:
result = await method(*args, **kwargs) result = await method(ws, *args, **kwargs)
except TypeError as e: except TypeError as e:
return self.build_error(-32603, f"Invalid params:\n{e}", req_id) return self.build_error(-32603, f"Invalid params:\n{e}", req_id)
except ServerError as e: except ServerError as e:
@ -209,16 +215,16 @@ class WebsocketManager:
self.rpc.remove_method(ws_method) self.rpc.remove_method(ws_method)
def _generate_callback(self, endpoint): def _generate_callback(self, endpoint):
async def func(**kwargs): async def func(ws, **kwargs):
result = await self.server.make_request( result = await self.server.make_request(
WebRequest(endpoint, kwargs)) WebRequest(endpoint, kwargs, conn=ws))
return result return result
return func return func
def _generate_local_callback(self, endpoint, request_method, callback): def _generate_local_callback(self, endpoint, request_method, callback):
async def func(**kwargs): async def func(ws, **kwargs):
result = await callback( result = await callback(
WebRequest(endpoint, kwargs, request_method)) WebRequest(endpoint, kwargs, request_method, ws))
return result return result
return func return func
@ -274,7 +280,7 @@ class WebSocket(WebSocketHandler):
async def _process_message(self, message): async def _process_message(self, message):
try: try:
response = await self.rpc.dispatch(message) response = await self.rpc.dispatch(message, self)
if response is not None: if response is not None:
self.write_message(response) self.write_message(response)
except Exception: except Exception: