Skip to content

Commit

Permalink
applied come code suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Zicchio committed Sep 25, 2024
1 parent 5fa54ed commit 5fd5100
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 49 deletions.
54 changes: 22 additions & 32 deletions pyeudiw/openid4vp/vp_sd_jwt_kb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable, Optional, Union
from typing import Callable, Union

from cryptojwt.jws.exception import JWSException
from jwcrypto.common import base64url_decode, json_decode
import jwcrypto.jwk
from sd_jwt.common import SDJWTCommon
from sd_jwt.verifier import SDJWTVerifier

from pyeudiw.jwk import JWK
Expand All @@ -17,8 +17,6 @@
from pyeudiw.tools.utils import iat_now


_SD_JWT_DELIMITER = '~'

_CLOCK_SKEW = 0


Expand All @@ -29,15 +27,8 @@ class VerifierChallenge:


class VpVcSdJwtKbVerifier(VpVerifier):
_DEFAULT_DISCLOSABLE_CLAIMS = (
"given_name",
"family_name",
"birth_date",
"unique_id",
"tax_id_code"
)

def __init__(self, sdjwtkb: str, verifier_id: str, verifier_nonce: str, jwk_by_kid: dict[str, dict], accepted_claims: Optional[list[str]] = None):

def __init__(self, sdjwtkb: str, verifier_id: str, verifier_nonce: str, jwk_by_kid: dict[str, dict]):
"""
VpVcSdJwtKbVerifier is a utility class for parsing and verifying sd-jwt.
Expand All @@ -60,7 +51,6 @@ def __init__(self, sdjwtkb: str, verifier_id: str, verifier_nonce: str, jwk_by_k
self.verifier_id = verifier_id
self.verifier_nonce = verifier_nonce
self.jwk_by_kid = jwk_by_kid
self.accepted_claims: list[str] = accepted_claims if accepted_claims is not None else deepcopy(VpVcSdJwtKbVerifier._DEFAULT_DISCLOSABLE_CLAIMS)
# precomputed values
self._issuer_jwt: UnverfiedJwt = UnverfiedJwt("", "", "", "")
self._encoded_disclosures: list[str] = []
Expand All @@ -69,7 +59,7 @@ def __init__(self, sdjwtkb: str, verifier_id: str, verifier_nonce: str, jwk_by_k
self._post_init_evaluate_precomputed_values()

def _post_init_evaluate_precomputed_values(self):
iss_jwt, *disclosures, kb_jwt = self.sdjwtkb.split(_SD_JWT_DELIMITER)
iss_jwt, *disclosures, kb_jwt = self.sdjwtkb.split(SDJWTCommon.COMBINED_SERIALIZATION_FORMAT_SEPARATOR)
self._encoded_disclosures = disclosures
self._disclosures = [json_decode(base64url_decode(disc)) for disc in disclosures]
self._issuer_jwt = unsafe_parse_jws(iss_jwt)
Expand Down Expand Up @@ -124,27 +114,24 @@ def parse_digital_credential(self) -> dict:
serialization_format="compact"
)
payload_claims: dict = sdjwt_verifier.get_verified_payload()
# NOTE: if acceptance list is empty, accept everything
# this assumes that an empy acceptance list means nothing, which is an invariant that might not hold in the future
if len(self.accepted_claims) == 0:
return payload_claims
filtered_claims_result = {}
for claim_name in self.accepted_claims:
if claim_name in payload_claims.keys():
filtered_claims_result.update({claim_name: payload_claims[claim_name]})
return filtered_claims_result
return payload_claims

def __str__(self) -> str:
return "VpVcSdJwtKb(" \
f"sdjwt={self.sdjwtkb}" \
f"sdjwt={self.sdjwtkb}, " \
f"verifier_id={self.verifier_id}, " \
f"verifier_nonce={self.verifier_nonce}, " \
f"jwk_by_kid={self.jwk_by_kid}" \
")"


def _verify_jws_with_key(issuer_jwt: str, issuer_key: JWK):
verifier = JWSHelper(issuer_key)
verifier.verify(issuer_jwt)
try:
verifier = JWSHelper(issuer_key)
except Exception as e:
raise InvalidVPSignature(f"failed signature verification of issuer-jwt: invalid issuer key due to cause: {e}")
try:
pass
verifier.verify(issuer_jwt)
except JWSException as e:
raise InvalidVPSignature(f"failed signature verification of issuer-jwt: {e}")
return
Expand All @@ -153,17 +140,17 @@ def _verify_jws_with_key(issuer_jwt: str, issuer_key: JWK):
def _verify_kb_jwt(kbjwt: UnverfiedJwt, cnf_jwk: JWK, challenge: VerifierChallenge) -> None:
_verify_kb_jwt_payload_challenge(kbjwt.payload, challenge)
_verify_kb_jwt_payload_iat(kbjwt.payload)
# TODO: sd-jwt-python already does this check, however it would be space for us to have it more explicit in our code
# _verify_kb_jwt_payload_sd_hash(sdjwt)
_verify_kb_jwt_signature(kbjwt.jwt, cnf_jwk)


# def _verify_kb_jwt_payload_sd_hash(sdjwt: VpVcSdJwtKbVerifier):
# def _verify_kb_jwt_payload_sd_hash(sdjwt):
# hash_alg: str | None = sdjwt._issuer_jwt.payload.get("_sd_alg", None)
# if hash_alg is None:
# raise ValueError("missing parameter [_sd_alg] in issuer signet JWT payload")
# *parts, _ = sdjwt.sdjwtkb.split(_SD_JWT_DELIMITER)
# iss_jwt_disclosed = ''.join(parts)
# # TODO
# TODO: go on
# pass


Expand All @@ -190,7 +177,10 @@ def _verify_kb_jwt_payload_iat(kb_jwt_payload: dict) -> None:


def _verify_kb_jwt_signature(kbjwt: str, verification_jwk: JWK) -> None:
verifier = JWSHelper(verification_jwk)
try:
verifier = JWSHelper(verification_jwk)
except Exception as e:
raise InvalidVPKeyBinding(f"failed signature verification of kb-jwt: invalid cnf key to cause: {e}")
try:
verifier.verify(kbjwt)
except JWSException as e:
Expand Down
2 changes: 1 addition & 1 deletion pyeudiw/sd_jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
holder_key,
sign_alg,
add_decoy_claims,
serialization_format,
serialization_format
)

def _create_signed_jws(self):
Expand Down
12 changes: 2 additions & 10 deletions pyeudiw/sd_jwt/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Literal, Optional, TypeVar
from typing import Dict, Literal, Optional, TypeVar

from pydantic import BaseModel, HttpUrl, field_validator

Expand Down Expand Up @@ -30,7 +30,7 @@ def is_sd_jwt_kb_format(sd_jwt_kb: str) -> bool:
class VcSdJwtHeaderSchema(BaseModel):
typ: str
alg: str
kid: str
kid: Optional[str] = None
trust_chain: Optional[list[str]] = None
x5c: Optional[str] = None
vctm: Optional[list[str]] = None
Expand Down Expand Up @@ -90,14 +90,6 @@ def validate_verification(cls, v: dict) -> dict:
return v


class PidVcSdJwtPayloadSchema(VcSdJwtPayloadSchema):
given_name: Optional[str] = None
family_name: Optional[str] = None
birth_date: Optional[Any] = None # TODO: date is dd-mm-yyyy but I'm not sure if libraries parses them as str or a native format
unique_id: Optional[str] = None
tax_id_code: Optional[str] = None


class KeyBindingJwtHeader(BaseModel):
typ: str
alg: str
Expand Down
8 changes: 2 additions & 6 deletions pyeudiw/tests/openid4vp/test_vp_sd_jwt_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ def test_VpVcSdJwtKbVerifier():
"y": "Xv5zWwuoaTgdS6hV43yI6gBwTnjukmFQQnJ_kCxzqk8"
}
}
claims = [
"address", # discolsed
"family_name" # NOT disclosed
]
verifier = VpVcSdJwtKbVerifier(token, aud, nonce, jwk_d, claims)
verifier = VpVcSdJwtKbVerifier(token, aud, nonce, jwk_d)
try:
verifier.validate_schema()
except VPSchemaException:
Expand All @@ -28,4 +24,4 @@ def test_VpVcSdJwtKbVerifier():
verifier.verify()
expected_credentials = {"address": {"street_address": "123 Main St", "locality": "Anytown", "region": "Anystate", "country": "US"}}
credentials = verifier.parse_digital_credential()
assert credentials == expected_credentials, f"failed to parse credentials: expected {expected_credentials}, obtained {credentials}"
assert expected_credentials.items() <= credentials.items(), f"failed to parse credentials: expected {expected_credentials}, obtained {credentials}"

0 comments on commit 5fd5100

Please sign in to comment.