utils: add support for msgspec with stdlib json fallback

Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-06-26 19:59:04 -04:00
parent 3ccf02c156
commit f99e5b0bea
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
23 changed files with 137 additions and 100 deletions

View File

@ -8,7 +8,6 @@ from __future__ import annotations
import os
import mimetypes
import logging
import json
import traceback
import ssl
import pathlib
@ -24,6 +23,7 @@ from tornado.http1connection import HTTP1Connection
from tornado.log import access_log
from .common import WebRequest, APIDefinition, APITransport
from .utils import ServerError, source_info
from .utils import json_wrapper as jsonw
from .websockets import (
WebsocketManager,
WebSocket,
@ -545,7 +545,8 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
if 'exc_info' in kwargs:
err['traceback'] = "\n".join(
traceback.format_exception(*kwargs['exc_info']))
self.finish({'error': err})
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'error': err}))
# Due to the way Python treats multiple inheritance its best
# to create a separate authorized handler for serving files
@ -588,7 +589,8 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
if 'exc_info' in kwargs:
err['traceback'] = "\n".join(
traceback.format_exception(*kwargs['exc_info']))
self.finish({'error': err})
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'error': err}))
def _check_need_auth(self) -> bool:
if self.request.method != "GET":
@ -623,7 +625,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
type_funcs: Dict[str, Callable] = {
"int": int, "float": float,
"bool": lambda x: x.lower() == "true",
"json": json.loads}
"json": jsonw.loads}
if hint not in type_funcs:
logging.info(f"No conversion method for type hint {hint}")
return value
@ -672,8 +674,8 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
content_type = self.request.headers.get('Content-Type', "").strip()
if content_type.startswith("application/json"):
try:
args.update(json.loads(self.request.body))
except json.JSONDecodeError:
args.update(jsonw.loads(self.request.body))
except jsonw.JSONDecodeError:
pass
for key, value in self.path_kwargs.items():
if value is not None:
@ -738,11 +740,14 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
e.status_code, reason=str(e)) from e
if self.wrap_result:
result = {'result': result}
elif self.content_type is not None:
self.set_header("Content-Type", self.content_type)
self._log_debug(f"HTTP Response::{req}", result)
if result is None:
self.set_status(204)
self._log_debug(f"HTTP Response::{req}", result)
elif isinstance(result, dict):
self.set_header("Content-Type", "application/json; charset=UTF-8")
result = jsonw.dumps(result)
elif self.content_type is not None:
self.set_header("Content-Type", self.content_type)
self.finish(result)
class FileRequestHandler(AuthorizedFileHandler):
@ -768,7 +773,8 @@ class FileRequestHandler(AuthorizedFileHandler):
filename = await file_manager.delete_file(path)
except self.server.error as e:
raise tornado.web.HTTPError(e.status_code, str(e))
self.finish({'result': filename})
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'result': filename}))
async def get(self, path: str, include_body: bool = True) -> None:
# Set up our path instance variables.
@ -998,7 +1004,8 @@ class FileUploadHandler(AuthorizedRequestHandler):
self.set_header("Location", location)
logging.debug(f"Upload Location header set: {location}")
self.set_status(201)
self.finish(result)
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps(result))
# Default Handler for unregistered endpoints
class AuthorizedErrorHandler(AuthorizedRequestHandler):
@ -1015,15 +1022,16 @@ class AuthorizedErrorHandler(AuthorizedRequestHandler):
if 'exc_info' in kwargs:
err['traceback'] = "\n".join(
traceback.format_exception(*kwargs['exc_info']))
self.finish({'error': err})
self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'error': err}))
class RedirectHandler(AuthorizedRequestHandler):
def get(self, *args, **kwargs) -> None:
url: Optional[str] = self.get_argument('url', None)
if url is None:
try:
body_args: Dict[str, Any] = json.loads(self.request.body)
except json.JSONDecodeError:
body_args: Dict[str, Any] = jsonw.loads(self.request.body)
except jsonw.JSONDecodeError:
body_args = {}
if 'url' not in body_args:
raise tornado.web.HTTPError(

View File

@ -8,8 +8,8 @@ from __future__ import annotations
import ipaddress
import logging
import copy
import json
from .utils import ServerError, Sentinel
from .utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -83,7 +83,7 @@ class BaseRemoteConnection(Subscribable):
self.is_closed: bool = False
self.queue_busy: bool = False
self.pending_responses: Dict[int, Future] = {}
self.message_buf: List[Union[str, Dict[str, Any]]] = []
self.message_buf: List[Union[bytes, str]] = []
self._connected_time: float = 0.
self._identified: bool = False
self._client_data: Dict[str, str] = {
@ -141,7 +141,9 @@ class BaseRemoteConnection(Subscribable):
except Exception:
logging.exception("Websocket Command Error")
def queue_message(self, message: Union[str, Dict[str, Any]]):
def queue_message(self, message: Union[bytes, str, Dict[str, Any]]):
if isinstance(message, dict):
message = jsonw.dumps(message)
self.message_buf.append(message)
if self.queue_busy:
return
@ -190,9 +192,7 @@ class BaseRemoteConnection(Subscribable):
await self.write_to_socket(msg)
self.queue_busy = False
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
async def write_to_socket(self, message: Union[bytes, str]) -> None:
raise NotImplementedError("Children must implement write_to_socket")
def send_status(self,
@ -426,7 +426,7 @@ class JsonRPC:
for field in ["access_token", "api_key"]:
if field in params:
output["params"][field] = "<sanitized>"
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
logging.debug(f"{self.transport} Received::{jsonw.dumps(output).decode()}")
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
if not self.verbose:
@ -438,7 +438,7 @@ class JsonRPC:
output = copy.deepcopy(resp_obj)
output["result"] = "<sanitized>"
self.sanitize_response = False
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
logging.debug(f"{self.transport} Response::{jsonw.dumps(output).decode()}")
def register_method(self,
name: str,
@ -452,14 +452,14 @@ class JsonRPC:
async def dispatch(self,
data: str,
conn: Optional[BaseRemoteConnection] = None
) -> Optional[str]:
) -> Optional[bytes]:
try:
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data)
except Exception:
msg = f"{self.transport} data not json: {data}"
logging.exception(msg)
err = self.build_error(-32700, "Parse error")
return json.dumps(err)
return jsonw.dumps(err)
if isinstance(obj, list):
responses: List[Dict[str, Any]] = []
for item in obj:
@ -469,13 +469,13 @@ class JsonRPC:
self._log_response(resp)
responses.append(resp)
if responses:
return json.dumps(responses)
return jsonw.dumps(responses)
else:
self._log_request(obj)
response = await self.process_object(obj, conn)
if response is not None:
self._log_response(response)
return json.dumps(response)
return jsonw.dumps(response)
return None
async def process_object(self,

View File

@ -17,9 +17,9 @@ import ipaddress
import re
import socket
import logging
import json
from tornado.web import HTTPError
from libnacl.sign import Signer, Verifier
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -570,8 +570,8 @@ class Authorization:
}
header = {'kid': jwk_id}
header.update(JWT_HEADER)
jwt_header = base64url_encode(json.dumps(header).encode())
jwt_payload = base64url_encode(json.dumps(payload).encode())
jwt_header = base64url_encode(jsonw.dumps(header))
jwt_payload = base64url_encode(jsonw.dumps(payload))
jwt_msg = b".".join([jwt_header, jwt_payload])
sig = private_key.signature(jwt_msg)
jwt_sig = base64url_encode(sig)
@ -582,7 +582,7 @@ class Authorization:
) -> Dict[str, Any]:
message, sig = token.rsplit('.', maxsplit=1)
enc_header, enc_payload = message.split('.')
header: Dict[str, Any] = json.loads(base64url_decode(enc_header))
header: Dict[str, Any] = jsonw.loads(base64url_decode(enc_header))
sig_bytes = base64url_decode(sig)
# verify header
@ -597,7 +597,7 @@ class Authorization:
public_key.verify(sig_bytes + message.encode())
# validate claims
payload: Dict[str, Any] = json.loads(base64url_decode(enc_payload))
payload: Dict[str, Any] = jsonw.loads(base64url_decode(enc_payload))
if payload['token_type'] != token_type:
raise self.server.error(
f"JWT Token type mismatch: Expected {token_type}, "

View File

@ -6,7 +6,6 @@
from __future__ import annotations
import pathlib
import json
import struct
import operator
import logging
@ -15,6 +14,7 @@ from functools import reduce
from threading import Lock as ThreadLock
import lmdb
from ..utils import Sentinel, ServerError
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -47,8 +47,8 @@ RECORD_ENCODE_FUNCS = {
float: lambda x: b"d" + struct.pack("d", x),
bool: lambda x: b"?" + struct.pack("?", x),
str: lambda x: b"s" + x.encode(),
list: lambda x: json.dumps(x).encode(),
dict: lambda x: json.dumps(x).encode(),
list: lambda x: jsonw.dumps(x),
dict: lambda x: jsonw.dumps(x),
}
RECORD_DECODE_FUNCS = {
@ -56,8 +56,8 @@ RECORD_DECODE_FUNCS = {
ord("d"): lambda x: struct.unpack("d", x[1:])[0],
ord("?"): lambda x: struct.unpack("?", x[1:])[0],
ord("s"): lambda x: bytes(x[1:]).decode(),
ord("["): lambda x: json.loads(bytes(x)),
ord("{"): lambda x: json.loads(bytes(x)),
ord("["): lambda x: jsonw.loads(bytes(x)),
ord("{"): lambda x: jsonw.loads(bytes(x)),
}
def getitem_with_default(item: Dict, field: Any) -> Any:

View File

@ -7,7 +7,6 @@ from __future__ import annotations
import asyncio
import pathlib
import logging
import json
from ..common import BaseRemoteConnection
from ..utils import get_unix_peer_credentials
@ -182,13 +181,11 @@ class UnixSocketClient(BaseRemoteConnection):
logging.debug("Unix Socket Disconnection From _read_messages()")
await self._on_close(reason="Read Exit")
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
if isinstance(message, dict):
data = json.dumps(message).encode() + b"\x03"
else:
async def write_to_socket(self, message: Union[bytes, str]) -> None:
if isinstance(message, str):
data = message.encode() + b"\x03"
else:
data = message + b"\x03"
try:
self.writer.write(data)
await self.writer.drain()

View File

@ -10,7 +10,6 @@ import sys
import pathlib
import shutil
import logging
import json
import tempfile
import asyncio
import zipfile
@ -20,6 +19,7 @@ from copy import deepcopy
from inotify_simple import INotify
from inotify_simple import flags as iFlags
from ...utils import source_info
from ...utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -2496,7 +2496,7 @@ class MetadataStorage:
if not await scmd.run(timeout=timeout):
raise self.server.error("Extract Metadata returned with error")
try:
decoded_resp: Dict[str, Any] = json.loads(result.strip())
decoded_resp: Dict[str, Any] = jsonw.loads(result.strip())
except Exception:
logging.debug(f"Invalid metadata response:\n{result}")
raise

View File

@ -6,7 +6,6 @@
from __future__ import annotations
import re
import json
import time
import asyncio
import pathlib
@ -14,6 +13,7 @@ import tempfile
import logging
import copy
from ..utils import ServerError
from ..utils import json_wrapper as jsonw
from tornado.escape import url_unescape
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError
from tornado.httputil import HTTPHeaders
@ -72,7 +72,7 @@ class HttpClient:
self,
method: str,
url: str,
body: Optional[Union[str, List[Any], Dict[str, Any]]] = None,
body: Optional[Union[bytes, str, List[Any], Dict[str, Any]]] = None,
headers: Optional[Dict[str, Any]] = None,
connect_timeout: float = 5.,
request_timeout: float = 10.,
@ -87,7 +87,7 @@ class HttpClient:
# prepare the body if required
req_headers: Dict[str, Any] = {}
if isinstance(body, (list, dict)):
body = json.dumps(body)
body = jsonw.dumps(body)
req_headers["Content-Type"] = "application/json"
cached: Optional[HttpResponse] = None
if enable_cache:
@ -341,8 +341,8 @@ class HttpResponse:
self._last_modified: Optional[str] = response_headers.get(
"last-modified", None)
def json(self, **kwargs) -> Union[List[Any], Dict[str, Any]]:
return json.loads(self._result, **kwargs)
def json(self) -> Union[List[Any], Dict[str, Any]]:
return jsonw.loads(self._result)
def is_cachable(self) -> bool:
return self._last_modified is not None or self._etag is not None

View File

@ -8,7 +8,6 @@ from __future__ import annotations
import sys
import os
import re
import json
import pathlib
import logging
import asyncio
@ -23,6 +22,7 @@ import getpass
import configparser
from ..confighelper import FileSourceWrapper
from ..utils import source_info
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -104,7 +104,7 @@ class Machine:
self._public_ip = ""
self.system_info: Dict[str, Any] = {
'python': {
"version": sys.version_info,
"version": tuple(sys.version_info),
"version_string": sys.version.replace("\n", " ")
},
'cpu_info': self._get_cpu_info(),
@ -625,7 +625,7 @@ class Machine:
try:
# get network interfaces
resp = await self.addr_cmd.run_with_response(log_complete=False)
decoded: List[Dict[str, Any]] = json.loads(resp)
decoded: List[Dict[str, Any]] = jsonw.loads(resp)
for interface in decoded:
if interface['operstate'] != "UP":
continue

View File

@ -8,12 +8,12 @@ from __future__ import annotations
import socket
import asyncio
import logging
import json
import pathlib
import ssl
from collections import deque
import paho.mqtt.client as paho_mqtt
from ..common import Subscribable, WebRequest, APITransport, JsonRPC
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -354,7 +354,7 @@ class MQTTClient(APITransport, Subscribable):
if self.user_name is not None:
self.client.username_pw_set(self.user_name, self.password)
self.client.will_set(self.moonraker_status_topic,
payload=json.dumps({'server': 'offline'}),
payload=jsonw.dumps({'server': 'offline'}),
qos=self.qos, retain=True)
self.client.connect_async(self.address, self.port)
self.connect_task = self.event_loop.create_task(
@ -558,8 +558,8 @@ class MQTTClient(APITransport, Subscribable):
pub_fut: asyncio.Future = asyncio.Future()
if isinstance(payload, (dict, list)):
try:
payload = json.dumps(payload)
except json.JSONDecodeError:
payload = jsonw.dumps(payload)
except jsonw.JSONDecodeError:
raise self.server.error(
"Dict or List is not json encodable") from None
elif isinstance(payload, bool):
@ -661,8 +661,8 @@ class MQTTClient(APITransport, Subscribable):
if hdl is not None:
self.unsubscribe(hdl)
try:
payload = json.loads(ret)
except json.JSONDecodeError:
payload = jsonw.loads(ret)
except jsonw.JSONDecodeError:
payload = ret.decode()
return {
'topic': topic,

View File

@ -8,12 +8,12 @@ from __future__ import annotations
import serial
import os
import time
import json
import errno
import logging
import asyncio
from collections import deque
from ..utils import ServerError
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -536,8 +536,8 @@ class PanelDue:
return
def write_response(self, response: Dict[str, Any]) -> None:
byte_resp = json.dumps(response) + "\r\n"
self.ser_conn.send(byte_resp.encode())
byte_resp = jsonw.dumps(response) + b"\r\n"
self.ser_conn.send(byte_resp)
def _get_printer_status(self) -> str:
# PanelDue States applicable to Klipper:

View File

@ -6,12 +6,12 @@
from __future__ import annotations
import logging
import json
import struct
import socket
import asyncio
import time
from urllib.parse import quote, urlencode
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -845,14 +845,14 @@ class TPLinkSmartPlug(PowerDevice):
finally:
writer.close()
await writer.wait_closed()
return json.loads(self._decrypt(data))
return jsonw.loads(self._decrypt(data))
def _encrypt(self, outdata: Dict[str, Any]) -> bytes:
data = json.dumps(outdata)
data = jsonw.dumps(outdata)
key = self.START_KEY
res = struct.pack(">I", len(data))
for c in data:
val = key ^ ord(c)
val = key ^ c
key = val
res += bytes([val])
return res

View File

@ -7,7 +7,7 @@ from __future__ import annotations
import pathlib
import logging
import configparser
import json
from ..utils import json_wrapper as jsonw
from typing import (
TYPE_CHECKING,
Dict,
@ -73,8 +73,8 @@ class Secrets:
def _parse_json(self, data: str) -> Optional[Dict[str, Any]]:
try:
return json.loads(data)
except json.JSONDecodeError:
return jsonw.loads(data)
except jsonw.JSONDecodeError:
return None
def get_type(self) -> str:

View File

@ -7,7 +7,6 @@
from __future__ import annotations
import os
import asyncio
import json
import logging
import time
import pathlib
@ -19,6 +18,7 @@ import tempfile
from queue import SimpleQueue
from ..loghelper import LocalQueueHandler
from ..common import Subscribable, WebRequest
from ..utils import json_wrapper as jsonw
from typing import (
TYPE_CHECKING,
@ -261,8 +261,8 @@ class SimplyPrint(Subscribable):
def _process_message(self, msg: str) -> None:
self._logger.info(f"received: {msg}")
try:
packet: Dict[str, Any] = json.loads(msg)
except json.JSONDecodeError:
packet: Dict[str, Any] = jsonw.loads(msg)
except jsonw.JSONDecodeError:
logging.debug(f"Invalid message, not JSON: {msg}")
return
event: str = packet.get("type", "")
@ -1085,7 +1085,7 @@ class SimplyPrint(Subscribable):
async def _send_wrapper(self, packet: Dict[str, Any]) -> bool:
try:
assert self.ws is not None
await self.ws.write_message(json.dumps(packet))
await self.ws.write_message(jsonw.dumps(packet))
except Exception:
return False
else:

View File

@ -7,7 +7,7 @@ from __future__ import annotations
import logging
import asyncio
import jinja2
import json
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -31,11 +31,11 @@ class TemplateFactory:
)
self.ui_env = jinja2.Environment(enable_async=True)
self.jenv.add_extension("jinja2.ext.do")
self.jenv.filters['fromjson'] = json.loads
self.jenv.filters['fromjson'] = jsonw.loads
self.async_env.add_extension("jinja2.ext.do")
self.async_env.filters['fromjson'] = json.loads
self.async_env.filters['fromjson'] = jsonw.loads
self.ui_env.add_extension("jinja2.ext.do")
self.ui_env.filters['fromjson'] = json.loads
self.ui_env.filters['fromjson'] = jsonw.loads
self.add_environment_global('raise_error', self._raise_error)
self.add_environment_global('secrets', secrets)

View File

@ -11,11 +11,11 @@ import shutil
import hashlib
import logging
import re
import json
import distro
import asyncio
from .common import AppType, Channel
from .base_deploy import BaseDeploy
from ...utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -278,7 +278,7 @@ class AppDeploy(BaseDeploy):
deps_json = self.system_deps_json
try:
ret = await eventloop.run_in_thread(deps_json.read_bytes)
dep_info: Dict[str, List[str]] = json.loads(ret)
dep_info: Dict[str, List[str]] = jsonw.loads(ret)
except asyncio.CancelledError:
raise
except Exception:

View File

@ -10,8 +10,8 @@ import pathlib
import logging
import shutil
import zipfile
import json
from ...utils import source_info
from ...utils import json_wrapper as jsonw
from .common import AppType, Channel
from .base_deploy import BaseDeploy
@ -94,7 +94,7 @@ class WebClientDeploy(BaseDeploy):
if rinfo.is_file():
try:
data = await eventloop.run_in_thread(rinfo.read_text)
uinfo: Dict[str, str] = json.loads(data)
uinfo: Dict[str, str] = jsonw.loads(data)
project_name = uinfo["project_name"]
owner = uinfo["project_owner"]
self.version = uinfo["version"]
@ -134,7 +134,7 @@ class WebClientDeploy(BaseDeploy):
if manifest.is_file():
try:
mtext = await eventloop.run_in_thread(manifest.read_text)
mdata: Dict[str, Any] = json.loads(mtext)
mdata: Dict[str, Any] = jsonw.loads(mtext)
proj_name: str = mdata["name"].lower()
except Exception:
self.log_exc(f"Failed to load json from {manifest}")

View File

@ -7,7 +7,6 @@
from __future__ import annotations
import os
import pathlib
import json
import shutil
import re
import time
@ -15,6 +14,7 @@ import zipfile
from .app_deploy import AppDeploy
from .common import Channel
from ...utils import verify_source
from ...utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -103,7 +103,7 @@ class ZipDeploy(AppDeploy):
try:
event_loop = self.server.get_event_loop()
info_bytes = await event_loop.run_in_thread(info_file.read_text)
info: Dict[str, Any] = json.loads(info_bytes)
info: Dict[str, Any] = jsonw.loads(info_bytes)
except Exception:
self.log_exc(f"Unable to parse info file {file_name}")
info = {}
@ -225,7 +225,7 @@ class ZipDeploy(AppDeploy):
info_url, content_type, size = asset_info['RELEASE_INFO']
client = self.cmd_helper.get_http_client()
rinfo_bytes = await client.get_file(info_url, content_type)
github_rinfo: Dict[str, Any] = json.loads(rinfo_bytes)
github_rinfo: Dict[str, Any] = jsonw.loads(rinfo_bytes)
if github_rinfo.get(self.name, {}) != release_info:
self._add_error(
"Local release info does not match the remote")
@ -243,7 +243,7 @@ class ZipDeploy(AppDeploy):
asset_url, content_type, size = asset_info['RELEASE_INFO']
client = self.cmd_helper.get_http_client()
rinfo_bytes = await client.get_file(asset_url, content_type)
update_release_info: Dict[str, Any] = json.loads(rinfo_bytes)
update_release_info: Dict[str, Any] = jsonw.loads(rinfo_bytes)
update_info = update_release_info.get(self.name, {})
self.lastest_hash = update_info.get('commit_hash', "?")
self.latest_checksum = update_info.get('source_checksum', "?")
@ -260,7 +260,7 @@ class ZipDeploy(AppDeploy):
asset_url, content_type, size = asset_info['COMMIT_LOG']
client = self.cmd_helper.get_http_client()
commit_bytes = await client.get_file(asset_url, content_type)
commit_info: Dict[str, Any] = json.loads(commit_bytes)
commit_info: Dict[str, Any] = jsonw.loads(commit_bytes)
self.commit_log = commit_info.get(self.name, [])
if zip_file_name in asset_info:
self.release_download_info = asset_info[zip_file_name]

View File

@ -11,11 +11,11 @@
from __future__ import annotations
from enum import Enum
import logging
import json
import asyncio
import serial_asyncio
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest
from ..utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -293,7 +293,7 @@ class StripHttp(Strip):
request = HTTPRequest(url=self.url,
method="POST",
headers=headers,
body=json.dumps(state),
body=jsonw.dumps(state),
connect_timeout=self.timeout,
request_timeout=self.timeout)
for i in range(retries):
@ -329,7 +329,7 @@ class StripSerial(Strip):
logging.debug(f"WLED: serial:{self.serialport} json:{state}")
self.ser.write(json.dumps(state).encode())
self.ser.write(jsonw.dumps(state))
def close(self: StripSerial):
if hasattr(self, 'ser'):

View File

@ -9,11 +9,11 @@ from __future__ import annotations
import os
import time
import logging
import json
import getpass
import asyncio
import pathlib
from .utils import ServerError, get_unix_peer_credentials
from .utils import json_wrapper as jsonw
# Annotation imports
from typing import (
@ -180,7 +180,7 @@ class KlippyConnection:
continue
errors_remaining = 10
try:
decoded_cmd = json.loads(data[:-1])
decoded_cmd = jsonw.loads(data[:-1])
self._process_command(decoded_cmd)
except Exception:
logging.exception(
@ -193,7 +193,7 @@ class KlippyConnection:
if self.writer is None or self.closing:
request.set_exception(ServerError("Klippy Host not connected", 503))
return
data = json.dumps(request.to_dict()).encode() + b"\x03"
data = jsonw.dumps(request.to_dict()) + b"\x03"
try:
self.writer.write(data)
await self.writer.drain()

View File

@ -23,7 +23,7 @@ from . import confighelper
from .eventloop import EventLoop
from .app import MoonrakerApp
from .klippy_connection import KlippyConnection
from .utils import ServerError, Sentinel, get_software_info
from .utils import ServerError, Sentinel, get_software_info, json_wrapper
from .loghelper import LogManager
# Annotation imports
@ -585,6 +585,7 @@ def main(from_package: bool = True) -> None:
else:
app_args["log_file"] = str(data_path.joinpath("logs/moonraker.log"))
app_args["python_version"] = sys.version.replace("\n", " ")
app_args["msgspec_enabled"] = json_wrapper.MSGSPEC_ENABLED
log_manager = LogManager(app_args, startup_warnings)
# Start asyncio event loop and server

View File

@ -14,13 +14,13 @@ import sys
import subprocess
import asyncio
import hashlib
import json
import shlex
import re
import struct
import socket
import enum
from . import source_info
from . import json_wrapper
# Annotation imports
from typing import (
@ -190,7 +190,7 @@ def verify_source(
if not rfile.exists():
return None
try:
rinfo = json.loads(rfile.read_text())
rinfo = json_wrapper.loads(rfile.read_text())
except Exception:
return None
orig_chksum = rinfo['source_checksum']

View File

@ -0,0 +1,33 @@
# Wrapper for msgspec with stdlib fallback
#
# Copyright (C) 2023 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import os
import contextlib
from typing import Any, Union, TYPE_CHECKING
if TYPE_CHECKING:
def dumps(obj: Any) -> bytes: ... # type: ignore
def loads(data: Union[str, bytes, bytearray]) -> Any: ...
MSGSPEC_ENABLED = False
_msgspc_var = os.getenv("MOONRAKER_ENABLE_MSGSPEC", "y").lower()
if _msgspc_var in ["y", "yes", "true"]:
with contextlib.suppress(ImportError):
import msgspec
from msgspec import DecodeError as JSONDecodeError
encoder = msgspec.json.Encoder()
decoder = msgspec.json.Decoder()
dumps = encoder.encode
loads = decoder.decode
MSGSPEC_ENABLED = True
if not MSGSPEC_ENABLED:
import json
from json import JSONDecodeError # type: ignore
loads = json.loads # type: ignore
def dumps(obj) -> bytes: # type: ignore
return json.dumps(obj).encode("utf-8")

View File

@ -327,9 +327,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
extensions.remove_agent(self)
self.wsm.remove_client(self)
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
async def write_to_socket(self, message: Union[bytes, str]) -> None:
try:
await self.write_message(message)
except WebSocketClosedError: