authorization: Add wildcards to cors_domians option

Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2020-11-15 17:13:21 -05:00
parent 2d2f8bfbcd
commit ac1d798a36
2 changed files with 17 additions and 15 deletions

View File

@ -8,6 +8,7 @@ import uuid
import os import os
import time import time
import ipaddress import ipaddress
import re
import logging import logging
import tornado import tornado
from tornado.ioloop import IOLoop, PeriodicCallback from tornado.ioloop import IOLoop, PeriodicCallback
@ -28,8 +29,8 @@ class Authorization:
# Get allowed cors domains # Get allowed cors domains
cors_cfg = config.get('cors_domains', "").strip() cors_cfg = config.get('cors_domains', "").strip()
self.cors_domains = [d.strip() for d in cors_cfg.split('\n') self.cors_domains = [d.strip().replace(".", "\\.").replace("*", ".*")
if d.strip()] for d in cors_cfg.split('\n')if d.strip()]
# Get Trusted Clients # Get Trusted Clients
self.trusted_ips = [] self.trusted_ips = []
@ -182,14 +183,18 @@ class Authorization:
return False return False
def check_cors(self, origin, request=None): def check_cors(self, origin, request=None):
if origin in self.cors_domains: if origin is None:
logging.debug(f"CORS Domain Allowed: {origin}")
self._set_cors_headers(origin, request)
elif "*" in self.cors_domains:
self._set_cors_headers("*", request)
else:
return False return False
for regex in self.cors_domains:
match = re.match(regex, origin)
if match is not None and match.group() == origin:
logging.debug(f"CORS Pattern Matched, origin: {origin} "
f" | pattern: {regex}")
self._set_cors_headers(origin, request)
return True return True
else:
logging.debug(f"No CORS match for origin: {origin}")
return False
def _set_cors_headers(self, origin, request): def _set_cors_headers(self, origin, request):
if request is None: if request is None:

View File

@ -315,12 +315,9 @@ class WebSocket(WebSocketHandler):
io_loop.spawn_callback(self.wsm.remove_websocket, self) io_loop.spawn_callback(self.wsm.remove_websocket, self)
def check_origin(self, origin): def check_origin(self, origin):
if self.auth.check_cors(origin): if not super(WebSocket, self).check_origin(origin):
# allow CORS return self.auth.check_cors(origin)
return True return True
else:
return super(WebSocket, self).check_origin(origin)
# Check Authorized User # Check Authorized User
def prepare(self): def prepare(self):
if not self.auth.check_authorized(self.request): if not self.auth.check_authorized(self.request):