websockets: create a client base class
Separate out code that applies to both standard websockets and the future unix socket implementation. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
5a3b1b6e5c
commit
f089794adc
|
@ -501,7 +501,9 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
|
||||||
else:
|
else:
|
||||||
wsm: WebsocketManager = self.server.lookup_component(
|
wsm: WebsocketManager = self.server.lookup_component(
|
||||||
"websockets")
|
"websockets")
|
||||||
conn = wsm.get_websocket(conn_id)
|
conn = wsm.get_client(conn_id)
|
||||||
|
if not isinstance(conn, WebSocket):
|
||||||
|
return None
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def write_error(self, status_code: int, **kwargs) -> None:
|
def write_error(self, status_code: int, **kwargs) -> None:
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#
|
#
|
||||||
# This file may be distributed under the terms of the GNU GPLv3 license.
|
# This file may be distributed under the terms of the GNU GPLv3 license.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from websockets import WebSocket
|
from websockets import BaseSocketClient
|
||||||
|
|
||||||
|
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||||
class ExtensionManager:
|
class ExtensionManager:
|
||||||
def __init__(self, config: ConfigHelper) -> None:
|
def __init__(self, config: ConfigHelper) -> None:
|
||||||
self.server = config.get_server()
|
self.server = config.get_server()
|
||||||
self.agents: Dict[str, WebSocket] = {}
|
self.agents: Dict[str, BaseSocketClient] = {}
|
||||||
self.server.register_endpoint(
|
self.server.register_endpoint(
|
||||||
"/connection/send_event", ["POST"], self._handle_agent_event,
|
"/connection/send_event", ["POST"], self._handle_agent_event,
|
||||||
transports=["websocket"]
|
transports=["websocket"]
|
||||||
|
@ -36,7 +36,7 @@ class ExtensionManager:
|
||||||
"/server/extensions/request", ["POST"], self._handle_call_agent
|
"/server/extensions/request", ["POST"], self._handle_call_agent
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_agent(self, connection: WebSocket) -> None:
|
def register_agent(self, connection: BaseSocketClient) -> None:
|
||||||
data = connection.client_data
|
data = connection.client_data
|
||||||
name = data["name"]
|
name = data["name"]
|
||||||
client_type = data["type"]
|
client_type = data["type"]
|
||||||
|
@ -55,7 +55,7 @@ class ExtensionManager:
|
||||||
}
|
}
|
||||||
connection.send_notification("agent_event", [evt])
|
connection.send_notification("agent_event", [evt])
|
||||||
|
|
||||||
def remove_agent(self, connection: WebSocket) -> None:
|
def remove_agent(self, connection: BaseSocketClient) -> None:
|
||||||
name = connection.client_data["name"]
|
name = connection.client_data["name"]
|
||||||
if name in self.agents:
|
if name in self.agents:
|
||||||
del self.agents[name]
|
del self.agents[name]
|
||||||
|
@ -64,7 +64,7 @@ class ExtensionManager:
|
||||||
|
|
||||||
async def _handle_agent_event(self, web_request: WebRequest) -> str:
|
async def _handle_agent_event(self, web_request: WebRequest) -> str:
|
||||||
conn = web_request.get_connection()
|
conn = web_request.get_connection()
|
||||||
if not isinstance(conn, WebSocket):
|
if not isinstance(conn, BaseSocketClient):
|
||||||
raise self.server.error("No connection detected")
|
raise self.server.error("No connection detected")
|
||||||
if conn.client_data["type"] != "agent":
|
if conn.client_data["type"] != "agent":
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
|
|
|
@ -33,7 +33,7 @@ from typing import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app import InternalTransport
|
from app import InternalTransport
|
||||||
from confighelper import ConfigHelper
|
from confighelper import ConfigHelper
|
||||||
from websockets import WebsocketManager, WebSocket
|
from websockets import WebsocketManager, BaseSocketClient
|
||||||
from tornado.websocket import WebSocketClientConnection
|
from tornado.websocket import WebSocketClientConnection
|
||||||
from components.database import MoonrakerDatabase
|
from components.database import MoonrakerDatabase
|
||||||
from components.klippy_apis import KlippyAPI
|
from components.klippy_apis import KlippyAPI
|
||||||
|
@ -183,11 +183,9 @@ class SimplyPrint(Subscribable):
|
||||||
"proc_stats:cpu_throttled", self._on_cpu_throttled
|
"proc_stats:cpu_throttled", self._on_cpu_throttled
|
||||||
)
|
)
|
||||||
self.server.register_event_handler(
|
self.server.register_event_handler(
|
||||||
"websockets:websocket_identified",
|
"websockets:client_identified", self._on_websocket_identified)
|
||||||
self._on_websocket_identified)
|
|
||||||
self.server.register_event_handler(
|
self.server.register_event_handler(
|
||||||
"websockets:websocket_removed",
|
"websockets:client_removed", self._on_websocket_removed)
|
||||||
self._on_websocket_removed)
|
|
||||||
self.server.register_event_handler(
|
self.server.register_event_handler(
|
||||||
"server:gcode_response", self._on_gcode_response)
|
"server:gcode_response", self._on_gcode_response)
|
||||||
self.server.register_event_handler(
|
self.server.register_event_handler(
|
||||||
|
@ -614,7 +612,7 @@ class SimplyPrint(Subscribable):
|
||||||
is_on = device_info["status"] == "on"
|
is_on = device_info["status"] == "on"
|
||||||
self.send_sp("power_controller", {"on": is_on})
|
self.send_sp("power_controller", {"on": is_on})
|
||||||
|
|
||||||
def _on_websocket_identified(self, ws: WebSocket) -> None:
|
def _on_websocket_identified(self, ws: BaseSocketClient) -> None:
|
||||||
if (
|
if (
|
||||||
self.cache.current_wsid is None and
|
self.cache.current_wsid is None and
|
||||||
ws.client_data.get("type", "") == "web"
|
ws.client_data.get("type", "") == "web"
|
||||||
|
@ -627,7 +625,7 @@ class SimplyPrint(Subscribable):
|
||||||
self.cache.current_wsid = ws.uid
|
self.cache.current_wsid = ws.uid
|
||||||
self.send_sp("machine_data", ui_data)
|
self.send_sp("machine_data", ui_data)
|
||||||
|
|
||||||
def _on_websocket_removed(self, ws: WebSocket) -> None:
|
def _on_websocket_removed(self, ws: BaseSocketClient) -> None:
|
||||||
if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid:
|
if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid:
|
||||||
return
|
return
|
||||||
ui_data = self._get_ui_info()
|
ui_data = self._get_ui_info()
|
||||||
|
@ -952,7 +950,7 @@ class SimplyPrint(Subscribable):
|
||||||
self.cache.current_wsid = None
|
self.cache.current_wsid = None
|
||||||
websockets: WebsocketManager
|
websockets: WebsocketManager
|
||||||
websockets = self.server.lookup_component("websockets")
|
websockets = self.server.lookup_component("websockets")
|
||||||
conns = websockets.get_websockets_by_type("web")
|
conns = websockets.get_clients_by_type("web")
|
||||||
if conns:
|
if conns:
|
||||||
longest = conns[0]
|
longest = conns[0]
|
||||||
ui_data["ui"] = longest.client_data["name"]
|
ui_data["ui"] = longest.client_data["name"]
|
||||||
|
|
|
@ -164,7 +164,7 @@ class JsonRPC:
|
||||||
|
|
||||||
async def dispatch(self,
|
async def dispatch(self,
|
||||||
data: str,
|
data: str,
|
||||||
conn: Optional[WebSocket] = None
|
conn: Optional[BaseSocketClient] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
response: Any = None
|
response: Any = None
|
||||||
try:
|
try:
|
||||||
|
@ -192,7 +192,7 @@ class JsonRPC:
|
||||||
|
|
||||||
async def process_object(self,
|
async def process_object(self,
|
||||||
obj: Dict[str, Any],
|
obj: Dict[str, Any],
|
||||||
conn: Optional[WebSocket]
|
conn: Optional[BaseSocketClient]
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
req_id: Optional[int] = obj.get('id', None)
|
req_id: Optional[int] = obj.get('id', None)
|
||||||
rpc_version: str = obj.get('jsonrpc', "")
|
rpc_version: str = obj.get('jsonrpc', "")
|
||||||
|
@ -217,7 +217,7 @@ class JsonRPC:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self, obj: Dict[str, Any], conn: Optional[WebSocket]
|
self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
|
||||||
) -> None:
|
) -> None:
|
||||||
if conn is None:
|
if conn is None:
|
||||||
logging.debug(f"RPC Response to non-socket request: {obj}")
|
logging.debug(f"RPC Response to non-socket request: {obj}")
|
||||||
|
@ -244,7 +244,7 @@ class JsonRPC:
|
||||||
async def execute_method(self,
|
async def execute_method(self,
|
||||||
callback: RPCCallback,
|
callback: RPCCallback,
|
||||||
req_id: Optional[int],
|
req_id: Optional[int],
|
||||||
conn: Optional[WebSocket],
|
conn: Optional[BaseSocketClient],
|
||||||
params: Dict[str, Any]
|
params: Dict[str, Any]
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
if conn is not None:
|
if conn is not None:
|
||||||
|
@ -302,7 +302,7 @@ class WebsocketManager(APITransport):
|
||||||
def __init__(self, server: Server) -> None:
|
def __init__(self, server: Server) -> None:
|
||||||
self.server = server
|
self.server = server
|
||||||
self.klippy: Klippy = server.lookup_component("klippy_connection")
|
self.klippy: Klippy = server.lookup_component("klippy_connection")
|
||||||
self.websockets: Dict[int, WebSocket] = {}
|
self.clients: Dict[int, BaseSocketClient] = {}
|
||||||
self.rpc = JsonRPC()
|
self.rpc = JsonRPC()
|
||||||
self.closed_event: Optional[asyncio.Event] = None
|
self.closed_event: Optional[asyncio.Event] = None
|
||||||
|
|
||||||
|
@ -318,7 +318,7 @@ class WebsocketManager(APITransport):
|
||||||
notify_name = event_name.split(':')[-1]
|
notify_name = event_name.split(':')[-1]
|
||||||
|
|
||||||
def notify_handler(*args):
|
def notify_handler(*args):
|
||||||
self.notify_websockets(notify_name, args)
|
self.notify_clients(notify_name, args)
|
||||||
self.server.register_event_handler(
|
self.server.register_event_handler(
|
||||||
event_name, notify_handler)
|
event_name, notify_handler)
|
||||||
|
|
||||||
|
@ -345,10 +345,10 @@ class WebsocketManager(APITransport):
|
||||||
|
|
||||||
def _generate_callback(self, endpoint: str) -> RPCCallback:
|
def _generate_callback(self, endpoint: str) -> RPCCallback:
|
||||||
async def func(args: Dict[str, Any]) -> Any:
|
async def func(args: Dict[str, Any]) -> Any:
|
||||||
ws: WebSocket = args.pop("_socket_")
|
sc: BaseSocketClient = args.pop("_socket_")
|
||||||
result = await self.klippy.request(
|
result = await self.klippy.request(
|
||||||
WebRequest(endpoint, args, conn=ws, ip_addr=ws.ip_addr,
|
WebRequest(endpoint, args, conn=sc, ip_addr=sc.ip_addr,
|
||||||
user=ws.current_user))
|
user=sc.user_info))
|
||||||
return result
|
return result
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -358,22 +358,22 @@ class WebsocketManager(APITransport):
|
||||||
callback: Callable[[WebRequest], Coroutine]
|
callback: Callable[[WebRequest], Coroutine]
|
||||||
) -> RPCCallback:
|
) -> RPCCallback:
|
||||||
async def func(args: Dict[str, Any]) -> Any:
|
async def func(args: Dict[str, Any]) -> Any:
|
||||||
ws: WebSocket = args.pop("_socket_")
|
sc: BaseSocketClient = args.pop("_socket_")
|
||||||
result = await callback(
|
result = await callback(
|
||||||
WebRequest(endpoint, args, request_method, ws,
|
WebRequest(endpoint, args, request_method, sc,
|
||||||
ip_addr=ws.ip_addr, user=ws.current_user))
|
ip_addr=sc.ip_addr, user=sc.user_info))
|
||||||
return result
|
return result
|
||||||
return func
|
return func
|
||||||
|
|
||||||
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||||
ws: WebSocket = args["_socket_"]
|
sc: BaseSocketClient = args["_socket_"]
|
||||||
return {'websocket_id': ws.uid}
|
return {'websocket_id': sc.uid}
|
||||||
|
|
||||||
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||||
ws: WebSocket = args["_socket_"]
|
sc: BaseSocketClient = args["_socket_"]
|
||||||
if ws.identified:
|
if sc.identified:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Connection already identified: {ws.client_data}"
|
f"Connection already identified: {sc.client_data}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
name = str(args["client_name"])
|
name = str(args["client_name"])
|
||||||
|
@ -387,7 +387,7 @@ class WebsocketManager(APITransport):
|
||||||
) from None
|
) from None
|
||||||
if client_type not in CLIENT_TYPES:
|
if client_type not in CLIENT_TYPES:
|
||||||
raise self.server.error(f"Invalid Client Type: {client_type}")
|
raise self.server.error(f"Invalid Client Type: {client_type}")
|
||||||
ws.client_data = {
|
sc.client_data = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"version": version,
|
"version": version,
|
||||||
"type": client_type,
|
"type": client_type,
|
||||||
|
@ -397,103 +397,108 @@ class WebsocketManager(APITransport):
|
||||||
extensions: ExtensionManager
|
extensions: ExtensionManager
|
||||||
extensions = self.server.lookup_component("extensions")
|
extensions = self.server.lookup_component("extensions")
|
||||||
try:
|
try:
|
||||||
extensions.register_agent(ws)
|
extensions.register_agent(sc)
|
||||||
except ServerError:
|
except ServerError:
|
||||||
ws.client_data["type"] = ""
|
sc.client_data["type"] = ""
|
||||||
raise
|
raise
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Websocket {ws.uid} Client Identified - "
|
f"Websocket {sc.uid} Client Identified - "
|
||||||
f"Name: {name}, Version: {version}, Type: {client_type}"
|
f"Name: {name}, Version: {version}, Type: {client_type}"
|
||||||
)
|
)
|
||||||
self.server.send_event("websockets:websocket_identified", ws)
|
self.server.send_event("websockets:client_identified", sc)
|
||||||
return {'connection_id': ws.uid}
|
return {'connection_id': sc.uid}
|
||||||
|
|
||||||
def has_websocket(self, ws_id: int) -> bool:
|
def has_socket(self, ws_id: int) -> bool:
|
||||||
return ws_id in self.websockets
|
return ws_id in self.clients
|
||||||
|
|
||||||
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
def get_client(self, ws_id: int) -> Optional[BaseSocketClient]:
|
||||||
return self.websockets.get(ws_id, None)
|
sc = self.clients.get(ws_id, None)
|
||||||
|
if sc is None or not isinstance(sc, WebSocket):
|
||||||
|
return None
|
||||||
|
return sc
|
||||||
|
|
||||||
def get_websockets_by_type(self, client_type: str) -> List[WebSocket]:
|
def get_clients_by_type(
|
||||||
|
self, client_type: str
|
||||||
|
) -> List[BaseSocketClient]:
|
||||||
if not client_type:
|
if not client_type:
|
||||||
return []
|
return []
|
||||||
ret: List[WebSocket] = []
|
ret: List[BaseSocketClient] = []
|
||||||
for ws in self.websockets.values():
|
for sc in self.clients.values():
|
||||||
if ws.client_data.get("type", "") == client_type.lower():
|
if sc.client_data.get("type", "") == client_type.lower():
|
||||||
ret.append(ws)
|
ret.append(sc)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_websockets_by_name(self, name: str) -> List[WebSocket]:
|
def get_clients_by_name(self, name: str) -> List[BaseSocketClient]:
|
||||||
if not name:
|
if not name:
|
||||||
return []
|
return []
|
||||||
ret: List[WebSocket] = []
|
ret: List[BaseSocketClient] = []
|
||||||
for ws in self.websockets.values():
|
for sc in self.clients.values():
|
||||||
if ws.client_data.get("name", "").lower() == name.lower():
|
if sc.client_data.get("name", "").lower() == name.lower():
|
||||||
ret.append(ws)
|
ret.append(sc)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_unidentified_websockets(self) -> List[WebSocket]:
|
def get_unidentified_clients(self) -> List[BaseSocketClient]:
|
||||||
ret: List[WebSocket] = []
|
ret: List[BaseSocketClient] = []
|
||||||
for ws in self.websockets.values():
|
for sc in self.clients.values():
|
||||||
if not ws.client_data:
|
if not sc.client_data:
|
||||||
ret.append(ws)
|
ret.append(sc)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def add_websocket(self, ws: WebSocket) -> None:
|
def add_client(self, sc: BaseSocketClient) -> None:
|
||||||
self.websockets[ws.uid] = ws
|
self.clients[sc.uid] = sc
|
||||||
self.server.send_event("websockets:websocked_added", ws)
|
self.server.send_event("websockets:client_added", sc)
|
||||||
logging.debug(f"New Websocket Added: {ws.uid}")
|
logging.debug(f"New Websocket Added: {sc.uid}")
|
||||||
|
|
||||||
def remove_websocket(self, ws: WebSocket) -> None:
|
def remove_client(self, sc: BaseSocketClient) -> None:
|
||||||
old_ws = self.websockets.pop(ws.uid, None)
|
old_sc = self.clients.pop(sc.uid, None)
|
||||||
if old_ws is not None:
|
if old_sc is not None:
|
||||||
self.klippy.remove_subscription(old_ws)
|
self.klippy.remove_subscription(old_sc)
|
||||||
self.server.send_event("websockets:websocket_removed", ws)
|
self.server.send_event("websockets:client_removed", sc)
|
||||||
logging.debug(f"Websocket Removed: {ws.uid}")
|
logging.debug(f"Websocket Removed: {sc.uid}")
|
||||||
if self.closed_event is not None and not self.websockets:
|
if self.closed_event is not None and not self.clients:
|
||||||
self.closed_event.set()
|
self.closed_event.set()
|
||||||
|
|
||||||
def notify_websockets(self,
|
def notify_clients(
|
||||||
name: str,
|
self,
|
||||||
data: Union[List, Tuple] = [],
|
name: str,
|
||||||
mask: List[int] = []
|
data: Union[List, Tuple] = [],
|
||||||
) -> None:
|
mask: List[int] = []
|
||||||
|
) -> None:
|
||||||
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
||||||
if data:
|
if data:
|
||||||
msg['params'] = data
|
msg['params'] = data
|
||||||
for ws in list(self.websockets.values()):
|
for sc in list(self.clients.values()):
|
||||||
if ws.uid in mask:
|
if sc.uid in mask:
|
||||||
continue
|
continue
|
||||||
ws.queue_message(msg)
|
sc.queue_message(msg)
|
||||||
|
|
||||||
def get_count(self) -> int:
|
def get_count(self) -> int:
|
||||||
return len(self.websockets)
|
return len(self.clients)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if not self.websockets:
|
if not self.clients:
|
||||||
return
|
return
|
||||||
self.closed_event = asyncio.Event()
|
self.closed_event = asyncio.Event()
|
||||||
for ws in list(self.websockets.values()):
|
for sc in list(self.clients.values()):
|
||||||
ws.close(1001, "Server Shutdown")
|
sc.close_socket(1001, "Server Shutdown")
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.closed_event.wait(), 2.)
|
await asyncio.wait_for(self.closed_event.wait(), 2.)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
self.closed_event = None
|
self.closed_event = None
|
||||||
|
|
||||||
class WebSocket(WebSocketHandler, Subscribable):
|
class BaseSocketClient(Subscribable):
|
||||||
def initialize(self) -> None:
|
def on_create(self, server: Server) -> None:
|
||||||
self.server: Server = self.settings['server']
|
self.server = server
|
||||||
self.event_loop = self.server.get_event_loop()
|
self.eventloop = server.get_event_loop()
|
||||||
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
||||||
self.rpc = self.wsm.rpc
|
self.rpc = self.wsm.rpc
|
||||||
self._uid = id(self)
|
self._uid = id(self)
|
||||||
|
self.ip_addr = ""
|
||||||
self.is_closed: bool = False
|
self.is_closed: bool = False
|
||||||
self.ip_addr: str = self.request.remote_ip or ""
|
|
||||||
self.queue_busy: bool = False
|
self.queue_busy: bool = False
|
||||||
self.pending_responses: Dict[int, asyncio.Future] = {}
|
self.pending_responses: Dict[int, asyncio.Future] = {}
|
||||||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||||
self.last_pong_time: float = self.event_loop.get_loop_time()
|
|
||||||
self._connected_time: float = 0.
|
self._connected_time: float = 0.
|
||||||
self._identified: bool = False
|
self._identified: bool = False
|
||||||
self._client_data: Dict[str, str] = {
|
self._client_data: Dict[str, str] = {
|
||||||
|
@ -503,13 +508,17 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
"url": ""
|
"url": ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uid(self) -> int:
|
def uid(self) -> int:
|
||||||
return self._uid
|
return self._uid
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hostname(self) -> str:
|
def hostname(self) -> str:
|
||||||
return self.request.host_name
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_time(self) -> float:
|
def start_time(self) -> float:
|
||||||
|
@ -528,28 +537,6 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
self._client_data = data
|
self._client_data = data
|
||||||
self._identified = True
|
self._identified = True
|
||||||
|
|
||||||
def open(self, *args, **kwargs) -> None:
|
|
||||||
self.set_nodelay(True)
|
|
||||||
self._connected_time = self.event_loop.get_loop_time()
|
|
||||||
agent = self.request.headers.get("User-Agent", "")
|
|
||||||
is_proxy = False
|
|
||||||
if (
|
|
||||||
"X-Forwarded-For" in self.request.headers or
|
|
||||||
"X-Real-Ip" in self.request.headers
|
|
||||||
):
|
|
||||||
is_proxy = True
|
|
||||||
logging.info(f"Websocket Opened: ID: {self.uid}, "
|
|
||||||
f"Proxied: {is_proxy}, "
|
|
||||||
f"User Agent: {agent}, "
|
|
||||||
f"Host Name: {self.hostname}")
|
|
||||||
self.wsm.add_websocket(self)
|
|
||||||
|
|
||||||
def on_message(self, message: Union[bytes, str]) -> None:
|
|
||||||
self.event_loop.register_callback(self._process_message, message)
|
|
||||||
|
|
||||||
def on_pong(self, data: bytes) -> None:
|
|
||||||
self.last_pong_time = self.event_loop.get_loop_time()
|
|
||||||
|
|
||||||
async def _process_message(self, message: str) -> None:
|
async def _process_message(self, message: str) -> None:
|
||||||
try:
|
try:
|
||||||
response = await self.rpc.dispatch(message, self)
|
response = await self.rpc.dispatch(message, self)
|
||||||
|
@ -563,27 +550,23 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
if self.queue_busy:
|
if self.queue_busy:
|
||||||
return
|
return
|
||||||
self.queue_busy = True
|
self.queue_busy = True
|
||||||
self.event_loop.register_callback(self._process_messages)
|
self.eventloop.register_callback(self._write_messages)
|
||||||
|
|
||||||
async def _process_messages(self):
|
async def _write_messages(self):
|
||||||
if self.is_closed:
|
if self.is_closed:
|
||||||
self.message_buf = []
|
self.message_buf = []
|
||||||
self.queue_busy = False
|
self.queue_busy = False
|
||||||
return
|
return
|
||||||
while self.message_buf:
|
while self.message_buf:
|
||||||
msg = self.message_buf.pop(0)
|
msg = self.message_buf.pop(0)
|
||||||
try:
|
await self.write_to_socket(msg)
|
||||||
await self.write_message(msg)
|
|
||||||
except WebSocketClosedError:
|
|
||||||
self.is_closed = True
|
|
||||||
logging.info(
|
|
||||||
f"Websocket closed while writing: {self.uid}")
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
logging.exception(
|
|
||||||
f"Error sending data over websocket: {self.uid}")
|
|
||||||
self.queue_busy = False
|
self.queue_busy = False
|
||||||
|
|
||||||
|
async def write_to_socket(
|
||||||
|
self, message: Union[str, Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError("Children must implement write_to_socket")
|
||||||
|
|
||||||
def send_status(self,
|
def send_status(self,
|
||||||
status: Dict[str, Any],
|
status: Dict[str, Any],
|
||||||
eventtime: float
|
eventtime: float
|
||||||
|
@ -600,7 +583,7 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
method: str,
|
method: str,
|
||||||
params: Optional[Union[List, Dict[str, Any]]] = None
|
params: Optional[Union[List, Dict[str, Any]]] = None
|
||||||
) -> Awaitable:
|
) -> Awaitable:
|
||||||
fut = self.event_loop.create_future()
|
fut = self.eventloop.create_future()
|
||||||
msg = {
|
msg = {
|
||||||
'jsonrpc': "2.0",
|
'jsonrpc': "2.0",
|
||||||
'method': method,
|
'method': method,
|
||||||
|
@ -613,7 +596,7 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
return fut
|
return fut
|
||||||
|
|
||||||
def send_notification(self, name: str, data: List) -> None:
|
def send_notification(self, name: str, data: List) -> None:
|
||||||
self.wsm.notify_websockets(name, data, [self._uid])
|
self.wsm.notify_clients(name, data, [self._uid])
|
||||||
|
|
||||||
def resolve_pending_response(
|
def resolve_pending_response(
|
||||||
self, response_id: int, result: Any
|
self, response_id: int, result: Any
|
||||||
|
@ -627,10 +610,49 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def close_socket(self, code: int, reason: str) -> None:
|
||||||
|
raise NotImplementedError("Children must implement close_socket()")
|
||||||
|
|
||||||
|
class WebSocket(WebSocketHandler, BaseSocketClient):
|
||||||
|
def initialize(self) -> None:
|
||||||
|
self.on_create(self.settings['server'])
|
||||||
|
self.ip_addr: str = self.request.remote_ip or ""
|
||||||
|
self.last_pong_time: float = self.eventloop.get_loop_time()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.current_user
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hostname(self) -> str:
|
||||||
|
return self.request.host_name
|
||||||
|
|
||||||
|
def open(self, *args, **kwargs) -> None:
|
||||||
|
self.set_nodelay(True)
|
||||||
|
self._connected_time = self.eventloop.get_loop_time()
|
||||||
|
agent = self.request.headers.get("User-Agent", "")
|
||||||
|
is_proxy = False
|
||||||
|
if (
|
||||||
|
"X-Forwarded-For" in self.request.headers or
|
||||||
|
"X-Real-Ip" in self.request.headers
|
||||||
|
):
|
||||||
|
is_proxy = True
|
||||||
|
logging.info(f"Websocket Opened: ID: {self.uid}, "
|
||||||
|
f"Proxied: {is_proxy}, "
|
||||||
|
f"User Agent: {agent}, "
|
||||||
|
f"Host Name: {self.hostname}")
|
||||||
|
self.wsm.add_client(self)
|
||||||
|
|
||||||
|
def on_message(self, message: Union[bytes, str]) -> None:
|
||||||
|
self.eventloop.register_callback(self._process_message, message)
|
||||||
|
|
||||||
|
def on_pong(self, data: bytes) -> None:
|
||||||
|
self.last_pong_time = self.eventloop.get_loop_time()
|
||||||
|
|
||||||
def on_close(self) -> None:
|
def on_close(self) -> None:
|
||||||
self.is_closed = True
|
self.is_closed = True
|
||||||
self.message_buf = []
|
self.message_buf = []
|
||||||
now = self.event_loop.get_loop_time()
|
now = self.eventloop.get_loop_time()
|
||||||
pong_elapsed = now - self.last_pong_time
|
pong_elapsed = now - self.last_pong_time
|
||||||
for resp in self.pending_responses.values():
|
for resp in self.pending_responses.values():
|
||||||
resp.set_exception(ServerError("Client Socket Disconnected", 500))
|
resp.set_exception(ServerError("Client Socket Disconnected", 500))
|
||||||
|
@ -643,7 +665,21 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
extensions: ExtensionManager
|
extensions: ExtensionManager
|
||||||
extensions = self.server.lookup_component("extensions")
|
extensions = self.server.lookup_component("extensions")
|
||||||
extensions.remove_agent(self)
|
extensions.remove_agent(self)
|
||||||
self.wsm.remove_websocket(self)
|
self.wsm.remove_client(self)
|
||||||
|
|
||||||
|
async def write_to_socket(
|
||||||
|
self, message: Union[str, Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self.write_message(message)
|
||||||
|
except WebSocketClosedError:
|
||||||
|
self.is_closed = True
|
||||||
|
self.message_buf.clear()
|
||||||
|
logging.info(
|
||||||
|
f"Websocket closed while writing: {self.uid}")
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
f"Error sending data over websocket: {self.uid}")
|
||||||
|
|
||||||
def check_origin(self, origin: str) -> bool:
|
def check_origin(self, origin: str) -> bool:
|
||||||
if not super(WebSocket, self).check_origin(origin):
|
if not super(WebSocket, self).check_origin(origin):
|
||||||
|
@ -658,3 +694,6 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||||
if auth is not None:
|
if auth is not None:
|
||||||
self.current_user = auth.check_authorized(self.request)
|
self.current_user = auth.check_authorized(self.request)
|
||||||
|
|
||||||
|
def close_socket(self, code: int, reason: str) -> None:
|
||||||
|
self.close(code, reason)
|
||||||
|
|
Loading…
Reference in New Issue