Skip to content

Commit

Permalink
Merge pull request #110 from jschlyter/locklesser
Browse files Browse the repository at this point in the history
Locklesser
  • Loading branch information
jschlyter authored Feb 28, 2022
2 parents 366d889 + e997039 commit 3318c89
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 80 deletions.
118 changes: 51 additions & 67 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Implementation of a Key Bundle."""

import copy
import json
import logging
import os
import threading
import time
from datetime import datetime
from functools import cmp_to_key
from typing import List
from typing import Optional

import requests
from readerwriterlock import rwlock

from cryptojwt.jwk.ec import NIST2SEC
from cryptojwt.jwk.hmac import new_sym_key
Expand Down Expand Up @@ -152,14 +153,6 @@ def ec_init(spec):
return _kb


def keys_reader(func):
def wrapper(self, *args, **kwargs):
with self._lock_reader:
return func(self, *args, **kwargs)

return wrapper


def keys_writer(func):
def wrapper(self, *args, **kwargs):
with self._lock_writer:
Expand Down Expand Up @@ -245,9 +238,7 @@ def __init__(
self.source = None
self.time_out = 0

self._lock = rwlock.RWLockFairD()
self._lock_reader = self._lock.gen_rlock()
self._lock_writer = self._lock.gen_wlock()
self._lock_writer = threading.Lock()

if httpc:
self.httpc = httpc
Expand All @@ -260,11 +251,11 @@ def __init__(
self.source = None
if isinstance(keys, dict):
if "keys" in keys:
self._do_keys(keys["keys"])
self._add_jwk_dicts(keys["keys"])
else:
self._do_keys([keys])
self._add_jwk_dicts([keys])
else:
self._do_keys(keys)
self._add_jwk_dicts(keys)
else:
self._set_source(source, fileformat)
if self.local:
Expand Down Expand Up @@ -305,18 +296,34 @@ def _local_update_required(self) -> bool:
self.last_local = stat.st_mtime
return True

@keys_writer
def do_keys(self, keys):
return self._do_keys(keys)
"""Compatibility function for add_jwk_dicts()"""
self.add_jwk_dicts(keys)

def _do_keys(self, keys):
@keys_writer
def add_jwk_dicts(self, keys):
"""
Go from JWK description to binary keys
Add JWK dictionaries
:param keys:
:param keys: List of JWK dictionaries
:return:
"""
_new_key = []
self._add_jwk_dicts(keys)

def _add_jwk_dicts(self, keys):
_new_keys = self.jwk_dicts_as_keys(keys)
if _new_keys:
self._keys.extend(_new_keys)
self.last_updated = time.time()

def jwk_dicts_as_keys(self, keys):
"""
Return JWK dictionaries as list of JWK objects
:param keys: List of JWK dictionaries
:return: List of JWK objects
"""
_new_keys = []

for inst in keys:
if inst["kty"].lower() in K2C:
Expand Down Expand Up @@ -360,16 +367,13 @@ def _do_keys(self, keys):
if _key not in self._keys:
if not _key.kid:
_key.add_kid()
_new_key.append(_key)
_new_keys.append(_key)
_error = ""

if _error:
LOGGER.warning("While loading keys, %s", _error)

if _new_key:
self._keys.extend(_new_key)

self.last_updated = time.time()
return _new_keys

def _do_local_jwk(self, filename):
"""
Expand All @@ -385,9 +389,9 @@ def _do_local_jwk(self, filename):
with open(filename) as input_file:
_info = json.load(input_file)
if "keys" in _info:
self._do_keys(_info["keys"])
self._add_jwk_dicts(_info["keys"])
else:
self._do_keys([_info])
self._add_jwk_dicts([_info])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True
Expand Down Expand Up @@ -423,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
if kid:
key_args["kid"] = kid

self._do_keys([key_args])
self._add_jwk_dicts([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_remote(self):
def _do_remote(self):
"""
Load a JWKS from a webpage.
Expand Down Expand Up @@ -458,6 +462,7 @@ def do_remote(self):
LOGGER.error(err)
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))

new_keys = None
load_successful = _http_resp.status_code == 200
not_modified = _http_resp.status_code == 304

Expand All @@ -470,7 +475,7 @@ def do_remote(self):

LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source)
try:
self._do_keys(self.imp_jwks["keys"])
new_keys = self.jwk_dicts_as_keys(self.imp_jwks["keys"])
except KeyError:
LOGGER.error("No 'keys' keyword in JWKS")
self.ignore_errors_until = time.time() + self.ignore_errors_period
Expand All @@ -491,6 +496,8 @@ def do_remote(self):
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))

if new_keys is not None:
self._keys = new_keys
self.last_updated = time.time()
self.ignore_errors_until = None
return load_successful
Expand Down Expand Up @@ -547,7 +554,7 @@ def update(self):
elif self.fileformat == "der":
updated = self._do_local_der(self.source, self.keytype, self.keyusage)
elif self.remote:
updated = self.do_remote()
updated = self._do_remote()
except Exception as err:
LOGGER.error("Key bundle update failed: %s", err)
self._keys = _old_keys # restore
Expand Down Expand Up @@ -575,12 +582,11 @@ def get(self, typ="", only_active=True):
"""
self._uptodate()

with self._lock_reader:
if typ:
_typs = [typ.lower(), typ.upper()]
_keys = [k for k in self._keys if k.kty in _typs]
else:
_keys = copy.copy(self._keys)
if typ:
_typs = [typ.lower(), typ.upper()]
_keys = [k for k in self._keys[:] if k.kty in _typs]
else:
_keys = self._keys[:]

if only_active:
return [k for k in _keys if not k.inactive_since]
Expand All @@ -595,8 +601,7 @@ def keys(self, update: bool = True):
"""
if update:
self._uptodate()
with self._lock_reader:
return copy.copy(self._keys)
return self._keys[:]

def active_keys(self):
"""Return the set of active keys."""
Expand Down Expand Up @@ -668,7 +673,6 @@ def remove(self, key):
except ValueError:
pass

@keys_reader
def __len__(self):
"""
The number of keys.
Expand All @@ -690,18 +694,12 @@ def get_key_with_kid(self, kid):
:return: The key or None
"""
self._uptodate()
with self._lock_reader:
return self._get_key_with_kid(kid)
return self._get_key_with_kid(kid)

def _get_key_with_kid(self, kid):
for key in self._keys:
if key.kid == kid:
return key

for key in self._keys:
if key.kid == kid:
return key

return None

def kids(self):
Expand All @@ -723,9 +721,7 @@ def mark_as_inactive(self, kid):
"""
k = self._get_key_with_kid(kid)
if k:
self._keys.remove(k)
k.inactive_since = time.time()
self._keys.append(k)
return True
else:
return False
Expand Down Expand Up @@ -753,30 +749,18 @@ def remove_outdated(self, after, when=0):
before it should be removed.
:param when: To make it easier to test
"""
if when:
now = when
else:
now = time.time()
now = when or time.time()

if not isinstance(after, float):
after = float(after)

_kl = []
changed = False
for k in self._keys:
if k.inactive_since and k.inactive_since + after < now:
changed = True
continue

_kl.append(k)

self._keys = _kl
return changed
self._keys = [
k for k in self._keys if not k.inactive_since or k.inactive_since + after > now
]

def __contains__(self, key):
return key in self.keys()

@keys_reader
def copy(self):
"""
Make deep copy of this KeyBundle
Expand Down Expand Up @@ -846,7 +830,7 @@ def load(self, spec):
"""
_keys = spec.get("keys", [])
if _keys:
self._do_keys(_keys)
self._add_jwk_dicts(_keys)

for attr, default in self.params.items():
val = spec.get(attr)
Expand Down
Loading

0 comments on commit 3318c89

Please sign in to comment.