websockets: pass connection to WebRequest

This gives handlers direct access to a websocket client connection.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2020-11-09 08:41:38 -05:00
parent de1575f757
commit 8d1239c316
1 changed files with 21 additions and 15 deletions

View File

@ -15,10 +15,11 @@ class Sentinel:
pass
class WebRequest:
def __init__(self, endpoint, args, action=""):
def __init__(self, endpoint, args, action="", conn=None):
self.endpoint = endpoint
self.action = action
self.args = args
self.conn = conn
def get_endpoint(self):
return self.endpoint
@ -29,6 +30,9 @@ class WebRequest:
def get_args(self):
return self.args
def get_connection(self):
return self.conn
def _get_converted_arg(self, key, default=Sentinel, dtype=str):
if key not in self.args:
if default == Sentinel:
@ -79,7 +83,7 @@ class JsonRPC:
def remove_method(self, name):
self.methods.pop(name, None)
async def dispatch(self, data):
async def dispatch(self, data, ws):
response = None
try:
request = json.loads(data)
@ -92,19 +96,19 @@ class JsonRPC:
if isinstance(request, list):
response = []
for req in request:
resp = await self.process_request(req)
resp = await self.process_request(req, ws)
if resp is not None:
response.append(resp)
if not response:
response = None
else:
response = await self.process_request(request)
response = await self.process_request(request, ws)
if response is not None:
response = json.dumps(response)
logging.debug("Websocket Response::" + response)
return response
async def process_request(self, request):
async def process_request(self, request, ws):
req_id = request.get('id', None)
rpc_version = request.get('jsonrpc', "")
method_name = request.get('method', None)
@ -116,18 +120,20 @@ class JsonRPC:
if 'params' in request:
params = request['params']
if isinstance(params, list):
response = await self.execute_method(method, req_id, *params)
response = await self.execute_method(
method, req_id, ws, *params)
elif isinstance(params, dict):
response = await self.execute_method(method, req_id, **params)
response = await self.execute_method(
method, req_id, ws, **params)
else:
return self.build_error(-32600, "Invalid Request", req_id)
else:
response = await self.execute_method(method, req_id)
response = await self.execute_method(method, req_id, ws)
return response
async def execute_method(self, method, req_id, *args, **kwargs):
async def execute_method(self, method, req_id, ws, *args, **kwargs):
try:
result = await method(*args, **kwargs)
result = await method(ws, *args, **kwargs)
except TypeError as e:
return self.build_error(-32603, f"Invalid params:\n{e}", req_id)
except ServerError as e:
@ -209,16 +215,16 @@ class WebsocketManager:
self.rpc.remove_method(ws_method)
def _generate_callback(self, endpoint):
async def func(**kwargs):
async def func(ws, **kwargs):
result = await self.server.make_request(
WebRequest(endpoint, kwargs))
WebRequest(endpoint, kwargs, conn=ws))
return result
return func
def _generate_local_callback(self, endpoint, request_method, callback):
async def func(**kwargs):
async def func(ws, **kwargs):
result = await callback(
WebRequest(endpoint, kwargs, request_method))
WebRequest(endpoint, kwargs, request_method, ws))
return result
return func
@ -274,7 +280,7 @@ class WebSocket(WebSocketHandler):
async def _process_message(self, message):
try:
response = await self.rpc.dispatch(message)
response = await self.rpc.dispatch(message, self)
if response is not None:
self.write_message(response)
except Exception: