From 70cb4d9d6abd7d3b63bd0155ea358dc3deee3792 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 16:59:53 +0100 Subject: [PATCH 1/9] create shallow copies using slicing --- src/cryptojwt/key_bundle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index bfd8150..2324c43 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -1,4 +1,5 @@ """Implementation of a Key Bundle.""" + import copy import json import logging @@ -580,7 +581,7 @@ def get(self, typ="", only_active=True): _typs = [typ.lower(), typ.upper()] _keys = [k for k in self._keys if k.kty in _typs] else: - _keys = copy.copy(self._keys) + _keys = self._keys[:] if only_active: return [k for k in _keys if not k.inactive_since] @@ -596,7 +597,7 @@ def keys(self, update: bool = True): if update: self._uptodate() with self._lock_reader: - return copy.copy(self._keys) + return self._keys[:] def active_keys(self): """Return the set of active keys.""" From 2ba55897ab8d287ec37a49fcdbb042f21205aff4 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 17:04:05 +0100 Subject: [PATCH 2/9] no need to remove and append an object stored by reference --- src/cryptojwt/key_bundle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 2324c43..5840cb4 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -724,9 +724,7 @@ def mark_as_inactive(self, kid): """ k = self._get_key_with_kid(kid) if k: - self._keys.remove(k) k.inactive_since = time.time() - self._keys.append(k) return True else: return False From 874becd56824dad5113e2629d462e09245e686ff Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 17:04:46 +0100 Subject: [PATCH 3/9] simplify --- src/cryptojwt/key_bundle.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 5840cb4..1c74c78 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -752,10 +752,7 @@ def remove_outdated(self, after, when=0): before it should be removed. :param when: To make it easier to test """ - if when: - now = when - else: - now = time.time() + now = when or time.time() if not isinstance(after, float): after = float(after) From b3d119153ad918c192bda0802361e49a651f11db Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 17:47:23 +0100 Subject: [PATCH 4/9] optimize remove_outdated() --- src/cryptojwt/key_bundle.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 1c74c78..8df7ede 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -757,17 +757,9 @@ def remove_outdated(self, after, when=0): if not isinstance(after, float): after = float(after) - _kl = [] - changed = False - for k in self._keys: - if k.inactive_since and k.inactive_since + after < now: - changed = True - continue - - _kl.append(k) - - self._keys = _kl - return changed + self._keys = [ + k for k in self._keys if not k.inactive_since or k.inactive_since + after > now + ] def __contains__(self, key): return key in self.keys() From 73ee3751cff6f39755ffd55632aa0e5d78719c8c Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 17:56:12 +0100 Subject: [PATCH 5/9] rename do_keys (and keep) to better reflect purpose --- src/cryptojwt/key_bundle.py | 49 ++++++++++++++++++++++++------------- tests/test_03_key_bundle.py | 6 ++--- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 8df7ede..4cfa2b0 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -261,11 +261,11 @@ def __init__( self.source = None if isinstance(keys, dict): if "keys" in keys: - self._do_keys(keys["keys"]) + self._add_jwk_dicts(keys["keys"]) else: - self._do_keys([keys]) + self._add_jwk_dicts([keys]) else: - self._do_keys(keys) + self._add_jwk_dicts(keys) else: self._set_source(source, fileformat) if self.local: @@ -306,18 +306,34 @@ def _local_update_required(self) -> bool: self.last_local = stat.st_mtime return True - @keys_writer def do_keys(self, keys): - return self._do_keys(keys) + """Compatibility function for add_jwk_dicts()""" + self.add_jwk_dicts(keys) - def _do_keys(self, keys): + @keys_writer + def add_jwk_dicts(self, keys): """ - Go from JWK description to binary keys + Add JWK dictionaries - :param keys: + :param keys: List of JWK dictionaries :return: """ - _new_key = [] + self._add_jwk_dicts(keys) + + def _add_jwk_dicts(self, keys): + _new_keys = self._jwk_dicts_to_keys(keys) + if _new_keys: + self._keys.extend(_new_keys) + self.last_updated = time.time() + + def _jwk_dicts_to_keys(self, keys): + """ + Return JWK dictionaries as list of JWK objects + + :param keys: List of JWK dictionaries + :return: List of JWK objects + """ + _new_keys = [] for inst in keys: if inst["kty"].lower() in K2C: @@ -361,14 +377,13 @@ def _do_keys(self, keys): if _key not in self._keys: if not _key.kid: _key.add_kid() - _new_key.append(_key) + _new_keys.append(_key) _error = "" if _error: LOGGER.warning("While loading keys, %s", _error) - if _new_key: - self._keys.extend(_new_key) + return _new_keys self.last_updated = time.time() @@ -386,9 +401,9 @@ def _do_local_jwk(self, filename): with open(filename) as input_file: _info = json.load(input_file) if "keys" in _info: - self._do_keys(_info["keys"]) + self._add_jwk_dicts(_info["keys"]) else: - self._do_keys([_info]) + self._add_jwk_dicts([_info]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time return True @@ -424,7 +439,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): if kid: key_args["kid"] = kid - self._do_keys([key_args]) + self._add_jwk_dicts([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time return True @@ -471,7 +486,7 @@ def do_remote(self): LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source) try: - self._do_keys(self.imp_jwks["keys"]) + self._add_jwk_dicts(self.imp_jwks["keys"]) except KeyError: LOGGER.error("No 'keys' keyword in JWKS") self.ignore_errors_until = time.time() + self.ignore_errors_period @@ -834,7 +849,7 @@ def load(self, spec): """ _keys = spec.get("keys", []) if _keys: - self._do_keys(_keys) + self._add_jwk_dicts(_keys) for attr, default in self.params.items(): val = spec.get(attr) diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index 11e6918..b01c378 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -410,7 +410,7 @@ def test_mark_as_inactive(): for k in kb.keys(): kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} - kb.do_keys([desc]) + kb.add_jwk_dicts([desc]) assert len(kb.keys()) == 2 assert len(kb.active_keys()) == 1 @@ -422,7 +422,7 @@ def test_copy(): for k in kb.keys(): kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} - kb.do_keys([desc]) + kb.add_jwk_dicts([desc]) kbc = kb.copy() assert len(kbc.keys()) == 2 @@ -891,7 +891,7 @@ def test_export_inactive(): for k in kb.keys(): kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} - kb.do_keys([desc]) + kb.add_jwk_dicts([desc]) res = kb.dump() assert set(res.keys()) == { "cache_time", From ae778107217609ce4109ae32b00fb2c722eebcec Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 18:02:29 +0100 Subject: [PATCH 6/9] more renames, isolate updates --- src/cryptojwt/key_bundle.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 4cfa2b0..94a5735 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -321,12 +321,12 @@ def add_jwk_dicts(self, keys): self._add_jwk_dicts(keys) def _add_jwk_dicts(self, keys): - _new_keys = self._jwk_dicts_to_keys(keys) + _new_keys = self.jwk_dicts_as_keys(keys) if _new_keys: self._keys.extend(_new_keys) self.last_updated = time.time() - def _jwk_dicts_to_keys(self, keys): + def jwk_dicts_as_keys(self, keys): """ Return JWK dictionaries as list of JWK objects @@ -385,8 +385,6 @@ def _jwk_dicts_to_keys(self, keys): return _new_keys - self.last_updated = time.time() - def _do_local_jwk(self, filename): """ Load a JWKS from a local file @@ -474,6 +472,7 @@ def do_remote(self): LOGGER.error(err) raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err))) + new_keys = None load_successful = _http_resp.status_code == 200 not_modified = _http_resp.status_code == 304 @@ -486,7 +485,7 @@ def do_remote(self): LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source) try: - self._add_jwk_dicts(self.imp_jwks["keys"]) + new_keys = self.jwk_dicts_as_keys(self.imp_jwks["keys"]) except KeyError: LOGGER.error("No 'keys' keyword in JWKS") self.ignore_errors_until = time.time() + self.ignore_errors_period @@ -507,6 +506,8 @@ def do_remote(self): self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code)) + if new_keys is not None: + self._keys = new_keys self.last_updated = time.time() self.ignore_errors_until = None return load_successful From a77d427ed8db25020ccbab781e13d7e82635ff74 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 18:03:18 +0100 Subject: [PATCH 7/9] do_remote is private --- src/cryptojwt/key_bundle.py | 4 ++-- tests/test_03_key_bundle.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 94a5735..bad54a0 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -442,7 +442,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): self.time_out = self.last_local + self.cache_time return True - def do_remote(self): + def _do_remote(self): """ Load a JWKS from a webpage. @@ -564,7 +564,7 @@ def update(self): elif self.fileformat == "der": updated = self._do_local_der(self.source, self.keytype, self.keyusage) elif self.remote: - updated = self.do_remote() + updated = self._do_remote() except Exception as err: LOGGER.error("Key bundle update failed: %s", err) self._keys = _old_keys # restore diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index b01c378..95f83c0 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -480,7 +480,7 @@ def test_httpc_params_1(): rsps.add(method=responses.GET, url=source, json=JWKS_DICT, status=200) httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) - assert kb.do_remote() + assert kb._do_remote() @pytest.mark.network @@ -926,7 +926,7 @@ def test_remote(): rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) - kb.do_remote() + kb._do_remote() exp = kb.dump() kb2 = KeyBundle().load(exp) @@ -954,13 +954,13 @@ def test_remote_not_modified(): with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers) - assert kb.do_remote() + assert kb._do_remote() assert kb.last_remote == headers.get("Last-Modified") timeout1 = kb.time_out with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, status=304, headers=headers) - assert not kb.do_remote() + assert not kb._do_remote() assert kb.last_remote == headers.get("Last-Modified") timeout2 = kb.time_out @@ -994,19 +994,19 @@ def test_ignore_errors_period(): httpc_params=httpc_params, ignore_errors_period=ignore_errors_period, ) - res = kb.do_remote() + res = kb._do_remote() assert res == True assert kb.ignore_errors_until is None # refetch, but fail by using a bad source kb.source = source_bad try: - res = kb.do_remote() + res = kb._do_remote() except UpdateFailed: pass # retry should fail silently as we're in holddown - res = kb.do_remote() + res = kb._do_remote() assert kb.ignore_errors_until is not None assert res == False @@ -1015,7 +1015,7 @@ def test_ignore_errors_period(): # try again kb.source = source_good - res = kb.do_remote() + res = kb._do_remote() assert res == True @@ -1037,7 +1037,7 @@ def test_exclude_attributes(): rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) - kb.do_remote() + kb._do_remote() exp = kb.dump(exclude_attributes=["cache_time", "ignore_invalid_keys"]) kb2 = KeyBundle(cache_time=600, ignore_invalid_keys=False).load(exp) @@ -1052,7 +1052,7 @@ def test_remote_dump_json(): rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) - kb.do_remote() + kb._do_remote() exp = kb.dump() assert json.dumps(exp) From cac6b2ffbe96edaeab9b91ee8e233a13e532a1d7 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 18:13:58 +0100 Subject: [PATCH 8/9] fix bad merge(?) --- src/cryptojwt/key_bundle.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index bad54a0..f2b0be7 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -714,11 +714,6 @@ def _get_key_with_kid(self, kid): for key in self._keys: if key.kid == kid: return key - - for key in self._keys: - if key.kid == kid: - return key - return None def kids(self): From e997039c06a21f0a8de19a87b8317a292143166f Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 25 Feb 2022 21:11:43 +0100 Subject: [PATCH 9/9] remove read locks since list access is atomic --- src/cryptojwt/key_bundle.py | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index f2b0be7..fee6c62 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -4,6 +4,7 @@ import json import logging import os +import threading import time from datetime import datetime from functools import cmp_to_key @@ -11,7 +12,6 @@ from typing import Optional import requests -from readerwriterlock import rwlock from cryptojwt.jwk.ec import NIST2SEC from cryptojwt.jwk.hmac import new_sym_key @@ -153,14 +153,6 @@ def ec_init(spec): return _kb -def keys_reader(func): - def wrapper(self, *args, **kwargs): - with self._lock_reader: - return func(self, *args, **kwargs) - - return wrapper - - def keys_writer(func): def wrapper(self, *args, **kwargs): with self._lock_writer: @@ -246,9 +238,7 @@ def __init__( self.source = None self.time_out = 0 - self._lock = rwlock.RWLockFairD() - self._lock_reader = self._lock.gen_rlock() - self._lock_writer = self._lock.gen_wlock() + self._lock_writer = threading.Lock() if httpc: self.httpc = httpc @@ -592,12 +582,11 @@ def get(self, typ="", only_active=True): """ self._uptodate() - with self._lock_reader: - if typ: - _typs = [typ.lower(), typ.upper()] - _keys = [k for k in self._keys if k.kty in _typs] - else: - _keys = self._keys[:] + if typ: + _typs = [typ.lower(), typ.upper()] + _keys = [k for k in self._keys[:] if k.kty in _typs] + else: + _keys = self._keys[:] if only_active: return [k for k in _keys if not k.inactive_since] @@ -612,8 +601,7 @@ def keys(self, update: bool = True): """ if update: self._uptodate() - with self._lock_reader: - return self._keys[:] + return self._keys[:] def active_keys(self): """Return the set of active keys.""" @@ -685,7 +673,6 @@ def remove(self, key): except ValueError: pass - @keys_reader def __len__(self): """ The number of keys. @@ -707,8 +694,7 @@ def get_key_with_kid(self, kid): :return: The key or None """ self._uptodate() - with self._lock_reader: - return self._get_key_with_kid(kid) + return self._get_key_with_kid(kid) def _get_key_with_kid(self, kid): for key in self._keys: @@ -775,7 +761,6 @@ def remove_outdated(self, after, when=0): def __contains__(self, key): return key in self.keys() - @keys_reader def copy(self): """ Make deep copy of this KeyBundle