diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py
index f3af1ec99..d43372f67 100644
--- a/src/saml2/sigver.py
+++ b/src/saml2/sigver.py
@@ -1,5 +1,5 @@
""" Functions connected to signing and verifying.
-Based on the use of xmlsec1 binaries and not the python xmlsec module.
+Based on the use of xmlsec1 binaries and/or the python xmlsec module.
"""
import base64
@@ -878,16 +878,20 @@ class CryptoBackendXMLSecurity(CryptoBackend):
CryptoBackend implementation using pyXMLSecurity to sign and verify
XML documents.
- Encrypt and decrypt is currently unsupported by pyXMLSecurity.
-
- pyXMLSecurity uses lxml (libxml2) to parse XML data, but otherwise
- try to get by with native Python code. It does native Python RSA
- signatures, or alternatively PyKCS11 to offload cryptographic work
- to an external PKCS#11 module.
+ This implementation is hypothetically more efficient than the CryptoBackendXmlSec1
+ implementation, but is less tested and as such is not yet the default option. To
+ enable it, set .crypto_backend = "XMLSecurity" in your Saml2Client initializer's
+ first argument.
"""
- def __init__(self):
+ __DEBUG = 0
+
+ def __init__(self, **kwargs):
CryptoBackend.__init__(self)
+ try:
+ self.non_xml_crypto = RSACrypto(kwargs["rsa_key"])
+ except KeyError:
+ pass
@property
def version(self):
@@ -897,25 +901,155 @@ def version(self):
except (ImportError, AttributeError):
return "0.0.0"
+ def encrypt(self, text, recv_key, template, session_key_type, xpath=""):
+ """
+
+ :param text: The text to be compiled
+ :param recv_key: Filename of a file where the key resides
+ :param template: Filename of a file with the pre-encryption part
+ :param session_key_type: Type and size of a new session key
+ 'des-192' generates a new 192 bits DES key for DES3 encryption
+ :param xpath: What should be encrypted
+ :return:
+ """
+ logger.debug("Encryption input len: %d", len(text))
+
+ import lxml.etree
+ import xmlsec
+
+ manager = xmlsec.KeysManager()
+ key = xmlsec.Key.from_file(
+ recv_key,
+ xmlsec.constants.KeyDataFormatCertPem,
+ None
+ )
+ manager.add_key(key)
+
+ template = lxml.etree.parse(template).getroot()
+ enc_ctx = xmlsec.EncryptionContext(manager)
+ if session_key_type == "des-192": # TODO: Will need to be expanded when additional key type support is added
+ enc_ctx.key = xmlsec.Key.generate(
+ xmlsec.constants.KeyDataDes,
+ 192,
+ xmlsec.constants.KeyDataTypeSession
+ )
+ data = lxml.etree.fromstring(text)
+ if xpath:
+ data = data.xpath(xpath)[0]
+ enc_data = enc_ctx.encrypt_xml(template, data)
+
+ if xpath:
+ # Hack to fix deletion of duplicated xmlns:ns1 entry
+ # Could potentially be fixed in CryptoBackendXmlSec1, as duplicated namespace attributes are not recommended
+ result = lxml.etree.fromstring(text)
+ result.replace(result.xpath(xpath)[0], lxml.etree.fromstring(""))
+ result = str(lxml.etree.tostring(result), encoding="utf-8")
+ result = str(lxml.etree.tostring(enc_data), encoding="utf-8").join(result.split(""))
+ else:
+ result = str(lxml.etree.tostring(enc_data), encoding="utf-8")
+
+ return "" + result # Hack to keep version tags identical, otherwise would have encoding attribute
+
+ def encrypt_assertion(self, statement, enc_key, template, key_type="des-192", node_xpath=None, node_id=None):
+ """
+ Will encrypt an assertion
+
+ :param statement: A XML document that contains the assertion to encrypt
+ :param enc_key: File name of a file containing the encryption key
+ :param template: A template for the encryption part to be added.
+ :param key_type: The type of session key to use.
+ :return: The encrypted text
+ """
+ import lxml.etree
+ import xmlsec
+
+ if isinstance(statement, SamlBase):
+ statement = pre_encrypt_assertion(statement)
+
+ if not node_xpath:
+ node_xpath = ASSERT_XPATH
+
+ manager = xmlsec.KeysManager()
+ key = xmlsec.Key.from_file(
+ enc_key,
+ xmlsec.constants.KeyDataFormatCertPem,
+ None
+ )
+ manager.add_key(key)
+
+ template = lxml.etree.parse(template).getroot()
+ enc_ctx = xmlsec.EncryptionContext(manager)
+ enc_ctx.key = xmlsec.Key.generate(
+ xmlsec.constants.KeyDataAes if key_type.startswith("aes") else xmlsec.constants.KeyDataDes,
+ int(key_type[-3:]) if len(key_type) >= 3 and key_type[-3:].isdigit() else 192,
+ xmlsec.constants.KeyDataTypeSession
+ )
+ data = lxml.etree.fromstring(statement).xpath(node_xpath)[0]
+ enc_data = enc_ctx.encrypt_xml(template, data)
+
+ # Hack to fix deletion of duplicated xmlns:ns1 entry
+ # Could potentially be fixed in CryptoBackendXmlSec1, as duplicated namespace attributes are not recommended
+ result = lxml.etree.fromstring(statement)
+ result.replace(result.xpath(node_xpath)[0], lxml.etree.fromstring(""))
+ result = str(lxml.etree.tostring(result), encoding="utf-8")
+ result = str(lxml.etree.tostring(enc_data), encoding="utf-8").join(result.split(""))
+
+ return "" + result # Hack to keep version tags identical, otherwise would have encoding attribute
+
+ def decrypt(self, enctext, key_file):
+ """
+
+ :param enctext: XML document containing an encrypted part
+ :param key_file: The key to use for the decryption
+ :return: The decrypted document
+ """
+ logger.debug("Decrypt input len: %d", len(enctext))
+
+ import lxml.etree
+ import xmlsec
+
+ manager = xmlsec.KeysManager()
+ key = xmlsec.Key.from_file(
+ key_file,
+ xmlsec.constants.KeyDataFormatPem,
+ None
+ )
+ manager.add_key(key)
+
+ enc_ctx = xmlsec.EncryptionContext(manager)
+ enc_data = xmlsec.tree.find_child(lxml.etree.fromstring(enctext), xmlsec.constants.NodeEncryptedData, xmlsec.constants.EncNs)
+ decrypted = enc_ctx.decrypt(enc_data)
+ result = lxml.etree.fromstring(enctext)
+ result.replace(xmlsec.tree.find_child(result, xmlsec.constants.NodeEncryptedData, xmlsec.constants.EncNs), decrypted)
+ result = str(lxml.etree.tostring(result), encoding="utf-8")
+
+ return "" + result # Hack to keep version tags identical, otherwise would have encoding attribute
+
def sign_statement(self, statement, node_name, key_file, node_id):
"""
Sign an XML statement.
- The parameters actually used in this CryptoBackend
- implementation are :
-
- :param statement: XML as string
- :param node_name: Name of the node to sign
- :param key_file: xmlsec key_spec string(), filename,
- 'pkcs11://' URI or PEM data
- :returns: Signed XML as string
+ :param statement: The statement to be signed
+ :param node_name: string like 'urn:oasis:names:...:Assertion'
+ :param key_file: The file where the key can be found
+ :param node_id: (not needed given xmlsec.tree.find_node)
+ :return: The signed statement
"""
import lxml.etree
import xmlsec
- xml = xmlsec.parse_xml(statement)
- signed = xmlsec.sign(xml, key_file)
- signed_str = lxml.etree.tostring(signed, xml_declaration=False, encoding="UTF-8")
+ if isinstance(statement, SamlBase):
+ statement = str(statement)
+ template = lxml.etree.fromstring(statement)
+
+ source_node = xmlsec.tree.find_node(template, node_name.split(':')[-1], ":".join(node_name.split(":")[:-1]))
+ signature_node = xmlsec.tree.find_node(source_node, xmlsec.constants.NodeSignature)
+ ctx = xmlsec.SignatureContext()
+ ctx.key = xmlsec.Key.from_file(key_file, xmlsec.constants.KeyDataFormatPem)
+ ctx.register_id(source_node, "ID")
+
+ ctx.sign(signature_node)
+ signed_str = lxml.etree.tostring(template, xml_declaration=False, encoding="UTF-8")
if not isinstance(signed_str, str):
signed_str = signed_str.decode("utf-8")
return signed_str
@@ -924,24 +1058,35 @@ def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_i
"""
Validate signature on XML document.
- The parameters actually used in this CryptoBackend
- implementation are :
-
- :param signedtext: The signed XML data as string
- :param cert_file: xmlsec key_spec string(), filename,
- 'pkcs11://' URI or PEM data
- :param cert_type: string, must be 'pem' for now
- :returns: True on successful validation, False otherwise
+ :param signedtext: The XML document as a string
+ :param cert_file: The public key that was used to sign the document
+ :param cert_type: The file type of the certificate
+ :param node_name: The name of the class that is signed
+ :param node_id: The identifier of the node (not needed given xmlsec.tree.find_node)
+ :return: Boolean True if the signature was correct otherwise False.
"""
- if cert_type != "pem":
- raise Unsupported("Only PEM certs supported here")
-
+ import lxml.etree
import xmlsec
- xml = xmlsec.parse_xml(signedtext)
+ if not isinstance(signedtext, bytes):
+ signedtext = signedtext.encode("utf-8")
+ template = lxml.etree.fromstring(signedtext)
+ xmlsec.tree.add_ids(template, ["ID"])
+ source_node = xmlsec.tree.find_node(template, node_name.split(':')[-1], ":".join(node_name.split(":")[:-1]))
+ signature_node = xmlsec.tree.find_node(source_node, xmlsec.constants.NodeSignature)
+
+ ctx = xmlsec.SignatureContext()
+ if cert_type == "pem":
+ ctx.key = xmlsec.Key.from_file(cert_file, xmlsec.constants.KeyDataFormatCertPem)
+ elif cert_type == "der":
+ ctx.key = xmlsec.Key.from_file(cert_file, xmlsec.constants.KeyDataFormatCertDer)
+ else:
+ ctx.key = xmlsec.Key.from_file(cert_file, xmlsec.constants.KeyDataFormatUnknown)
+ ctx.set_enabled_key_data([xmlsec.constants.KeyDataX509])
try:
- return xmlsec.verify(xml, cert_file)
+ ctx.verify(signature_node)
+ return True
except xmlsec.XMLSigException:
return False
@@ -992,7 +1137,23 @@ def security_context(conf):
sec_backend = RSACrypto(rsa_key)
elif conf.crypto_backend == "XMLSecurity":
# new and somewhat untested pyXMLSecurity crypto backend.
+ try:
+ import xmlsec
+ except ImportError:
+ logger.error(f"Python xmlsec library not found")
+ raise
+
crypto = CryptoBackendXMLSecurity()
+
+ _file_name = conf.getattr("key_file", "")
+ if _file_name:
+ try:
+ rsa_key = import_rsa_key_from_file(_file_name)
+ except Exception as err:
+ logger.error(f"Cannot import key from {_file_name}: {err}")
+ raise
+ else:
+ sec_backend = RSACrypto(rsa_key)
else:
err_msg = "Unknown crypto_backend {backend}"
err_msg = err_msg.format(backend=conf.crypto_backend)
@@ -1814,32 +1975,20 @@ def pre_signature_part(
return signature
-#
-#
-#
-#
-#
-#
-#
-# my-rsa-key
-#
-#
-#
-#
-#
-#
-#
-#
-#
-#
-#
-#
-#
-#
-#
+#
+#
+#
+#
+#
+#
+#
+#
+#
+#
+#
+#
+#
+#
def pre_encryption_part(
diff --git a/tests/test_95_pyxmlsec.py b/tests/test_95_pyxmlsec.py
new file mode 100644
index 000000000..e8bb8b477
--- /dev/null
+++ b/tests/test_95_pyxmlsec.py
@@ -0,0 +1,117 @@
+from pathutils import full_path
+
+from saml2 import class_name
+from saml2 import config
+from saml2 import extension_elements_to_elements
+from saml2 import saml
+from saml2 import samlp
+from saml2 import sigver
+from saml2.mdstore import MetadataStore
+from saml2.s_utils import do_attribute_statement
+from saml2.s_utils import factory
+from saml2.saml import EncryptedAssertion
+
+SIGNED = full_path("saml_signed.xml")
+UNSIGNED = full_path("saml_unsigned.xml")
+SIMPLE_SAML_PHP_RESPONSE = full_path("simplesamlphp_authnresponse.xml")
+OKTA_RESPONSE = full_path("okta_response.xml")
+OKTA_ASSERTION = full_path("okta_assertion")
+
+PUB_KEY = full_path("test.pem")
+PRIV_KEY = full_path("test.key")
+
+ENC_PUB_KEY = full_path("pki/test_1.crt")
+ENC_PRIV_KEY = full_path("pki/test.key")
+
+INVALID_KEY = full_path("non-existent.key")
+
+IDP_EXAMPLE = full_path("idp_example.xml")
+METADATA_CERT = full_path("metadata_cert.xml")
+
+def make_sec(crypto_backend = None):
+ conf = config.SPConfig()
+ conf.load_file("server_conf")
+ md = MetadataStore([saml, samlp], None, conf)
+ md.load("local", IDP_EXAMPLE)
+
+ conf.metadata = md
+ conf.only_use_keys_in_metadata = False
+ if crypto_backend is not None:
+ conf.crypto_backend = crypto_backend
+ return sigver.security_context(conf)
+
+def make_assertion(sec):
+ return factory(
+ saml.Assertion,
+ version="2.0",
+ id="id-11111",
+ issue_instant="2009-10-30T13:20:28Z",
+ signature=sigver.pre_signature_part("id-11111", sec.my_cert, 1),
+ attribute_statement=do_attribute_statement(
+ {
+ ("", "", "surName"): ("Foo", ""),
+ ("", "", "givenName"): ("Bar", ""),
+ }
+ ),
+ )
+
+def sign_assertion(assertion, sec):
+ sigass = sec.sign_statement(
+ assertion,
+ class_name(assertion),
+ key_file=PRIV_KEY,
+ node_id=assertion.id,
+ )
+
+ _ass0 = saml.assertion_from_string(sigass)
+ encrypted_assertion = EncryptedAssertion()
+ encrypted_assertion.add_extension_element(_ass0)
+
+ return encrypted_assertion
+
+def encrypt_assertion(signed_assertion, sec):
+ template = str(sigver.pre_encryption_part(encrypted_key_id="EK_TEST", encrypted_data_id="ED_TEST"))
+ tmp = sigver.make_temp(template.encode("utf-8"), decode=False)
+ return sec.crypto.encrypt(
+ str(signed_assertion),
+ sec.cert_file,
+ tmp.name,
+ "des-192",
+ '/*[local-name()="EncryptedAssertion"]/*[local-name()="Assertion"]',
+ )
+
+def decrypt_assertion(enctext, node_name, sec):
+ decr_text = sec.decrypt(enctext, key_file=PRIV_KEY)
+ _seass = saml.encrypted_assertion_from_string(decr_text)
+ assers = extension_elements_to_elements(_seass.extension_elements, [saml, samlp])
+
+ for ass in assers:
+ _txt = sec.verify_signature(str(ass), PUB_KEY, node_name=node_name)
+ if _txt:
+ return ass
+
+def test_libxmlsec():
+ sec = make_sec()
+ assertion = make_assertion(sec)
+ signed_assertion = sign_assertion(assertion, sec)
+ enctext = encrypt_assertion(signed_assertion, sec)
+ decrypted_assertion = decrypt_assertion(enctext, class_name(assertion), sec)
+
+ assert decrypted_assertion
+ return decrypted_assertion
+
+def test_pyxmlsec():
+ sec = make_sec("XMLSecurity")
+ assertion = make_assertion(sec)
+ signed_assertion = sign_assertion(assertion, sec)
+ enctext = encrypt_assertion(signed_assertion, sec)
+ decrypted_assertion = decrypt_assertion(enctext, class_name(assertion), sec)
+
+ assert decrypted_assertion
+ return decrypted_assertion
+
+def compare_tests():
+ assert test_libxmlsec() == test_pyxmlsec()
+
+if __name__ == "__main__":
+ compare_tests()
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 000000000..87deb2d6a
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,65 @@
+ELIDED_TAGS = [
+ "X509Certificate",
+ "SignatureValue",
+ "DigestValue",
+ "CipherValue"
+]
+
+def tag_name(tag: str, include_namespace: bool = True) -> str:
+ name = ""
+ start = 0
+ if tag[start] == "<":
+ start += 1
+ if tag[start] == "/":
+ start += 1
+ for char in tag[start:]:
+ if char in " \n/>":
+ break
+ name += char
+ if not include_namespace:
+ name = name.split(":")[-1]
+ return name
+
+def pretty_print_xml(xml: str):
+ if isinstance(xml, bytes):
+ xml = str(xml, "utf-8")
+
+ tag_groups = {}
+ tags = []
+ tag = ""
+ closed = False
+ istag = True
+ for char in xml:
+ if char == "\n":
+ continue
+ if tag == "" and char != "<":
+ istag = False
+ if not istag and char == "<":
+ tags.append(tag)
+ tag = ""
+ istag = True
+ tag += char
+ if char == ">":
+ tag_groups[tag_name(tag)] = closed
+ tags.append(tag)
+ tag = ""
+ closed = False
+ continue
+ if tag == "":
+ closed = True
+ continue
+
+ space_count = 0
+ prev = ""
+ for tag in tags:
+ istag = tag.startswith("<")
+ closing = istag and tag[1] == "/"
+ if closing and tag_groups[tag_name(tag)]:
+ space_count -= 4
+ if istag and prev.startswith("<"): # istag and prev_istag
+ print("\n" + (" " * space_count), end="")
+ print(tag if (istag or tag_name(prev).split(":")[-1] not in ELIDED_TAGS) else "...", end="")
+ if istag and tag_groups[tag_name(tag)] and not closing:
+ space_count += 4
+ prev = tag
+ print()