websockets: Store IP Address in WebRequest object

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-04-12 17:07:16 -04:00
parent 24e6fded91
commit 43a8d25619
2 changed files with 13 additions and 5 deletions

View File

@ -430,11 +430,12 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
async def _do_local_request(self, args, conn): async def _do_local_request(self, args, conn):
return await self.callback( return await self.callback(
WebRequest(self.request.path, args, self.request.method, WebRequest(self.request.path, args, self.request.method,
conn=conn)) conn=conn, ip_addr=self.request.remote_ip))
async def _do_remote_request(self, args, conn): async def _do_remote_request(self, args, conn):
return await self.server.make_request( return await self.server.make_request(
WebRequest(self.callback, args, conn=conn)) WebRequest(self.callback, args, conn=conn,
ip_addr=self.request.remote_ip))
async def _process_http_request(self): async def _process_http_request(self):
if self.request.method not in self.methods: if self.request.method not in self.methods:

View File

@ -15,11 +15,13 @@ class Sentinel:
pass pass
class WebRequest: class WebRequest:
def __init__(self, endpoint, args, action="", conn=None): def __init__(self, endpoint, args, action="",
conn=None, ip_addr=""):
self.endpoint = endpoint self.endpoint = endpoint
self.action = action self.action = action
self.args = args self.args = args
self.conn = conn self.conn = conn
self.ip_addr = ip_addr
def get_endpoint(self): def get_endpoint(self):
return self.endpoint return self.endpoint
@ -33,6 +35,9 @@ class WebRequest:
def get_connection(self): def get_connection(self):
return self.conn return self.conn
def get_ip_address(self):
return self.ip_addr
def _get_converted_arg(self, key, default=Sentinel, dtype=str): def _get_converted_arg(self, key, default=Sentinel, dtype=str):
if key not in self.args: if key not in self.args:
if default == Sentinel: if default == Sentinel:
@ -202,14 +207,15 @@ class WebsocketManager:
def _generate_callback(self, endpoint): def _generate_callback(self, endpoint):
async def func(ws, **kwargs): async def func(ws, **kwargs):
result = await self.server.make_request( result = await self.server.make_request(
WebRequest(endpoint, kwargs, conn=ws)) WebRequest(endpoint, kwargs, conn=ws, ip_addr=ws.ip_addr))
return result return result
return func return func
def _generate_local_callback(self, endpoint, request_method, callback): def _generate_local_callback(self, endpoint, request_method, callback):
async def func(ws, **kwargs): async def func(ws, **kwargs):
result = await callback( result = await callback(
WebRequest(endpoint, kwargs, request_method, ws)) WebRequest(endpoint, kwargs, request_method,
ws, ip_addr=ws.ip_addr))
return result return result
return func return func
@ -264,6 +270,7 @@ class WebSocket(WebSocketHandler):
self.rpc = self.wsm.rpc self.rpc = self.wsm.rpc
self.uid = id(self) self.uid = id(self)
self.is_closed = False self.is_closed = False
self.ip_addr = self.request.remote_ip
async def open(self): async def open(self):
await self.wsm.add_websocket(self) await self.wsm.add_websocket(self)