Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Locklesser #110

Merged
merged 9 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 51 additions & 67 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Implementation of a Key Bundle."""

import copy
import json
import logging
import os
import threading
import time
from datetime import datetime
from functools import cmp_to_key
from typing import List
from typing import Optional

import requests
from readerwriterlock import rwlock

from cryptojwt.jwk.ec import NIST2SEC
from cryptojwt.jwk.hmac import new_sym_key
Expand Down Expand Up @@ -152,14 +153,6 @@ def ec_init(spec):
return _kb


def keys_reader(func):
def wrapper(self, *args, **kwargs):
with self._lock_reader:
return func(self, *args, **kwargs)

return wrapper


def keys_writer(func):
def wrapper(self, *args, **kwargs):
with self._lock_writer:
Expand Down Expand Up @@ -245,9 +238,7 @@ def __init__(
self.source = None
self.time_out = 0

self._lock = rwlock.RWLockFairD()
self._lock_reader = self._lock.gen_rlock()
self._lock_writer = self._lock.gen_wlock()
self._lock_writer = threading.Lock()

if httpc:
self.httpc = httpc
Expand All @@ -260,11 +251,11 @@ def __init__(
self.source = None
if isinstance(keys, dict):
if "keys" in keys:
self._do_keys(keys["keys"])
self._add_jwk_dicts(keys["keys"])
else:
self._do_keys([keys])
self._add_jwk_dicts([keys])
else:
self._do_keys(keys)
self._add_jwk_dicts(keys)
else:
self._set_source(source, fileformat)
if self.local:
Expand Down Expand Up @@ -305,18 +296,34 @@ def _local_update_required(self) -> bool:
self.last_local = stat.st_mtime
return True

@keys_writer
def do_keys(self, keys):
return self._do_keys(keys)
"""Compatibility function for add_jwk_dicts()"""
self.add_jwk_dicts(keys)

def _do_keys(self, keys):
@keys_writer
def add_jwk_dicts(self, keys):
"""
Go from JWK description to binary keys
Add JWK dictionaries

:param keys:
:param keys: List of JWK dictionaries
:return:
"""
_new_key = []
self._add_jwk_dicts(keys)

def _add_jwk_dicts(self, keys):
_new_keys = self.jwk_dicts_as_keys(keys)
if _new_keys:
self._keys.extend(_new_keys)
self.last_updated = time.time()

def jwk_dicts_as_keys(self, keys):
"""
Return JWK dictionaries as list of JWK objects

:param keys: List of JWK dictionaries
:return: List of JWK objects
"""
_new_keys = []

for inst in keys:
if inst["kty"].lower() in K2C:
Expand Down Expand Up @@ -360,16 +367,13 @@ def _do_keys(self, keys):
if _key not in self._keys:
if not _key.kid:
_key.add_kid()
_new_key.append(_key)
_new_keys.append(_key)
_error = ""

if _error:
LOGGER.warning("While loading keys, %s", _error)

if _new_key:
self._keys.extend(_new_key)

self.last_updated = time.time()
return _new_keys

def _do_local_jwk(self, filename):
"""
Expand All @@ -385,9 +389,9 @@ def _do_local_jwk(self, filename):
with open(filename) as input_file:
_info = json.load(input_file)
if "keys" in _info:
self._do_keys(_info["keys"])
self._add_jwk_dicts(_info["keys"])
else:
self._do_keys([_info])
self._add_jwk_dicts([_info])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True
Expand Down Expand Up @@ -423,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
if kid:
key_args["kid"] = kid

self._do_keys([key_args])
self._add_jwk_dicts([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_remote(self):
def _do_remote(self):
"""
Load a JWKS from a webpage.

Expand Down Expand Up @@ -458,6 +462,7 @@ def do_remote(self):
LOGGER.error(err)
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))

new_keys = None
load_successful = _http_resp.status_code == 200
not_modified = _http_resp.status_code == 304

Expand All @@ -470,7 +475,7 @@ def do_remote(self):

LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source)
try:
self._do_keys(self.imp_jwks["keys"])
new_keys = self.jwk_dicts_as_keys(self.imp_jwks["keys"])
except KeyError:
LOGGER.error("No 'keys' keyword in JWKS")
self.ignore_errors_until = time.time() + self.ignore_errors_period
Expand All @@ -491,6 +496,8 @@ def do_remote(self):
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))

if new_keys is not None:
self._keys = new_keys
self.last_updated = time.time()
self.ignore_errors_until = None
return load_successful
Expand Down Expand Up @@ -547,7 +554,7 @@ def update(self):
elif self.fileformat == "der":
updated = self._do_local_der(self.source, self.keytype, self.keyusage)
elif self.remote:
updated = self.do_remote()
updated = self._do_remote()
except Exception as err:
LOGGER.error("Key bundle update failed: %s", err)
self._keys = _old_keys # restore
Expand Down Expand Up @@ -575,12 +582,11 @@ def get(self, typ="", only_active=True):
"""
self._uptodate()

with self._lock_reader:
if typ:
_typs = [typ.lower(), typ.upper()]
_keys = [k for k in self._keys if k.kty in _typs]
else:
_keys = copy.copy(self._keys)
if typ:
_typs = [typ.lower(), typ.upper()]
_keys = [k for k in self._keys[:] if k.kty in _typs]
else:
_keys = self._keys[:]

if only_active:
return [k for k in _keys if not k.inactive_since]
Expand All @@ -595,8 +601,7 @@ def keys(self, update: bool = True):
"""
if update:
self._uptodate()
with self._lock_reader:
return copy.copy(self._keys)
return self._keys[:]

def active_keys(self):
"""Return the set of active keys."""
Expand Down Expand Up @@ -668,7 +673,6 @@ def remove(self, key):
except ValueError:
pass

@keys_reader
def __len__(self):
"""
The number of keys.
Expand All @@ -690,18 +694,12 @@ def get_key_with_kid(self, kid):
:return: The key or None
"""
self._uptodate()
with self._lock_reader:
return self._get_key_with_kid(kid)
return self._get_key_with_kid(kid)

def _get_key_with_kid(self, kid):
for key in self._keys:
if key.kid == kid:
return key

for key in self._keys:
if key.kid == kid:
return key

return None

def kids(self):
Expand All @@ -723,9 +721,7 @@ def mark_as_inactive(self, kid):
"""
k = self._get_key_with_kid(kid)
if k:
self._keys.remove(k)
k.inactive_since = time.time()
self._keys.append(k)
return True
else:
return False
Expand Down Expand Up @@ -753,30 +749,18 @@ def remove_outdated(self, after, when=0):
before it should be removed.
:param when: To make it easier to test
"""
if when:
now = when
else:
now = time.time()
now = when or time.time()

if not isinstance(after, float):
after = float(after)

_kl = []
changed = False
for k in self._keys:
if k.inactive_since and k.inactive_since + after < now:
changed = True
continue

_kl.append(k)

self._keys = _kl
return changed
self._keys = [
k for k in self._keys if not k.inactive_since or k.inactive_since + after > now
]

def __contains__(self, key):
return key in self.keys()

@keys_reader
def copy(self):
"""
Make deep copy of this KeyBundle
Expand Down Expand Up @@ -846,7 +830,7 @@ def load(self, spec):
"""
_keys = spec.get("keys", [])
if _keys:
self._do_keys(_keys)
self._add_jwk_dicts(_keys)

for attr, default in self.params.items():
val = spec.get(attr)
Expand Down
Loading