diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index f3af1ec99..d43372f67 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -1,5 +1,5 @@ """ Functions connected to signing and verifying. -Based on the use of xmlsec1 binaries and not the python xmlsec module. +Based on the use of xmlsec1 binaries and/or the python xmlsec module. """ import base64 @@ -878,16 +878,20 @@ class CryptoBackendXMLSecurity(CryptoBackend): CryptoBackend implementation using pyXMLSecurity to sign and verify XML documents. - Encrypt and decrypt is currently unsupported by pyXMLSecurity. - - pyXMLSecurity uses lxml (libxml2) to parse XML data, but otherwise - try to get by with native Python code. It does native Python RSA - signatures, or alternatively PyKCS11 to offload cryptographic work - to an external PKCS#11 module. + This implementation is hypothetically more efficient than the CryptoBackendXmlSec1 + implementation, but is less tested and as such is not yet the default option. To + enable it, set .crypto_backend = "XMLSecurity" in your Saml2Client initializer's + first argument. """ - def __init__(self): + __DEBUG = 0 + + def __init__(self, **kwargs): CryptoBackend.__init__(self) + try: + self.non_xml_crypto = RSACrypto(kwargs["rsa_key"]) + except KeyError: + pass @property def version(self): @@ -897,25 +901,155 @@ def version(self): except (ImportError, AttributeError): return "0.0.0" + def encrypt(self, text, recv_key, template, session_key_type, xpath=""): + """ + + :param text: The text to be compiled + :param recv_key: Filename of a file where the key resides + :param template: Filename of a file with the pre-encryption part + :param session_key_type: Type and size of a new session key + 'des-192' generates a new 192 bits DES key for DES3 encryption + :param xpath: What should be encrypted + :return: + """ + logger.debug("Encryption input len: %d", len(text)) + + import lxml.etree + import xmlsec + + manager = xmlsec.KeysManager() + key = xmlsec.Key.from_file( + recv_key, + xmlsec.constants.KeyDataFormatCertPem, + None + ) + manager.add_key(key) + + template = lxml.etree.parse(template).getroot() + enc_ctx = xmlsec.EncryptionContext(manager) + if session_key_type == "des-192": # TODO: Will need to be expanded when additional key type support is added + enc_ctx.key = xmlsec.Key.generate( + xmlsec.constants.KeyDataDes, + 192, + xmlsec.constants.KeyDataTypeSession + ) + data = lxml.etree.fromstring(text) + if xpath: + data = data.xpath(xpath)[0] + enc_data = enc_ctx.encrypt_xml(template, data) + + if xpath: + # Hack to fix deletion of duplicated xmlns:ns1 entry + # Could potentially be fixed in CryptoBackendXmlSec1, as duplicated namespace attributes are not recommended + result = lxml.etree.fromstring(text) + result.replace(result.xpath(xpath)[0], lxml.etree.fromstring("")) + result = str(lxml.etree.tostring(result), encoding="utf-8") + result = str(lxml.etree.tostring(enc_data), encoding="utf-8").join(result.split("")) + else: + result = str(lxml.etree.tostring(enc_data), encoding="utf-8") + + return "" + result # Hack to keep version tags identical, otherwise would have encoding attribute + + def encrypt_assertion(self, statement, enc_key, template, key_type="des-192", node_xpath=None, node_id=None): + """ + Will encrypt an assertion + + :param statement: A XML document that contains the assertion to encrypt + :param enc_key: File name of a file containing the encryption key + :param template: A template for the encryption part to be added. + :param key_type: The type of session key to use. + :return: The encrypted text + """ + import lxml.etree + import xmlsec + + if isinstance(statement, SamlBase): + statement = pre_encrypt_assertion(statement) + + if not node_xpath: + node_xpath = ASSERT_XPATH + + manager = xmlsec.KeysManager() + key = xmlsec.Key.from_file( + enc_key, + xmlsec.constants.KeyDataFormatCertPem, + None + ) + manager.add_key(key) + + template = lxml.etree.parse(template).getroot() + enc_ctx = xmlsec.EncryptionContext(manager) + enc_ctx.key = xmlsec.Key.generate( + xmlsec.constants.KeyDataAes if key_type.startswith("aes") else xmlsec.constants.KeyDataDes, + int(key_type[-3:]) if len(key_type) >= 3 and key_type[-3:].isdigit() else 192, + xmlsec.constants.KeyDataTypeSession + ) + data = lxml.etree.fromstring(statement).xpath(node_xpath)[0] + enc_data = enc_ctx.encrypt_xml(template, data) + + # Hack to fix deletion of duplicated xmlns:ns1 entry + # Could potentially be fixed in CryptoBackendXmlSec1, as duplicated namespace attributes are not recommended + result = lxml.etree.fromstring(statement) + result.replace(result.xpath(node_xpath)[0], lxml.etree.fromstring("")) + result = str(lxml.etree.tostring(result), encoding="utf-8") + result = str(lxml.etree.tostring(enc_data), encoding="utf-8").join(result.split("")) + + return "" + result # Hack to keep version tags identical, otherwise would have encoding attribute + + def decrypt(self, enctext, key_file): + """ + + :param enctext: XML document containing an encrypted part + :param key_file: The key to use for the decryption + :return: The decrypted document + """ + logger.debug("Decrypt input len: %d", len(enctext)) + + import lxml.etree + import xmlsec + + manager = xmlsec.KeysManager() + key = xmlsec.Key.from_file( + key_file, + xmlsec.constants.KeyDataFormatPem, + None + ) + manager.add_key(key) + + enc_ctx = xmlsec.EncryptionContext(manager) + enc_data = xmlsec.tree.find_child(lxml.etree.fromstring(enctext), xmlsec.constants.NodeEncryptedData, xmlsec.constants.EncNs) + decrypted = enc_ctx.decrypt(enc_data) + result = lxml.etree.fromstring(enctext) + result.replace(xmlsec.tree.find_child(result, xmlsec.constants.NodeEncryptedData, xmlsec.constants.EncNs), decrypted) + result = str(lxml.etree.tostring(result), encoding="utf-8") + + return "" + result # Hack to keep version tags identical, otherwise would have encoding attribute + def sign_statement(self, statement, node_name, key_file, node_id): """ Sign an XML statement. - The parameters actually used in this CryptoBackend - implementation are : - - :param statement: XML as string - :param node_name: Name of the node to sign - :param key_file: xmlsec key_spec string(), filename, - 'pkcs11://' URI or PEM data - :returns: Signed XML as string + :param statement: The statement to be signed + :param node_name: string like 'urn:oasis:names:...:Assertion' + :param key_file: The file where the key can be found + :param node_id: (not needed given xmlsec.tree.find_node) + :return: The signed statement """ import lxml.etree import xmlsec - xml = xmlsec.parse_xml(statement) - signed = xmlsec.sign(xml, key_file) - signed_str = lxml.etree.tostring(signed, xml_declaration=False, encoding="UTF-8") + if isinstance(statement, SamlBase): + statement = str(statement) + template = lxml.etree.fromstring(statement) + + source_node = xmlsec.tree.find_node(template, node_name.split(':')[-1], ":".join(node_name.split(":")[:-1])) + signature_node = xmlsec.tree.find_node(source_node, xmlsec.constants.NodeSignature) + ctx = xmlsec.SignatureContext() + ctx.key = xmlsec.Key.from_file(key_file, xmlsec.constants.KeyDataFormatPem) + ctx.register_id(source_node, "ID") + + ctx.sign(signature_node) + signed_str = lxml.etree.tostring(template, xml_declaration=False, encoding="UTF-8") if not isinstance(signed_str, str): signed_str = signed_str.decode("utf-8") return signed_str @@ -924,24 +1058,35 @@ def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_i """ Validate signature on XML document. - The parameters actually used in this CryptoBackend - implementation are : - - :param signedtext: The signed XML data as string - :param cert_file: xmlsec key_spec string(), filename, - 'pkcs11://' URI or PEM data - :param cert_type: string, must be 'pem' for now - :returns: True on successful validation, False otherwise + :param signedtext: The XML document as a string + :param cert_file: The public key that was used to sign the document + :param cert_type: The file type of the certificate + :param node_name: The name of the class that is signed + :param node_id: The identifier of the node (not needed given xmlsec.tree.find_node) + :return: Boolean True if the signature was correct otherwise False. """ - if cert_type != "pem": - raise Unsupported("Only PEM certs supported here") - + import lxml.etree import xmlsec - xml = xmlsec.parse_xml(signedtext) + if not isinstance(signedtext, bytes): + signedtext = signedtext.encode("utf-8") + template = lxml.etree.fromstring(signedtext) + xmlsec.tree.add_ids(template, ["ID"]) + source_node = xmlsec.tree.find_node(template, node_name.split(':')[-1], ":".join(node_name.split(":")[:-1])) + signature_node = xmlsec.tree.find_node(source_node, xmlsec.constants.NodeSignature) + + ctx = xmlsec.SignatureContext() + if cert_type == "pem": + ctx.key = xmlsec.Key.from_file(cert_file, xmlsec.constants.KeyDataFormatCertPem) + elif cert_type == "der": + ctx.key = xmlsec.Key.from_file(cert_file, xmlsec.constants.KeyDataFormatCertDer) + else: + ctx.key = xmlsec.Key.from_file(cert_file, xmlsec.constants.KeyDataFormatUnknown) + ctx.set_enabled_key_data([xmlsec.constants.KeyDataX509]) try: - return xmlsec.verify(xml, cert_file) + ctx.verify(signature_node) + return True except xmlsec.XMLSigException: return False @@ -992,7 +1137,23 @@ def security_context(conf): sec_backend = RSACrypto(rsa_key) elif conf.crypto_backend == "XMLSecurity": # new and somewhat untested pyXMLSecurity crypto backend. + try: + import xmlsec + except ImportError: + logger.error(f"Python xmlsec library not found") + raise + crypto = CryptoBackendXMLSecurity() + + _file_name = conf.getattr("key_file", "") + if _file_name: + try: + rsa_key = import_rsa_key_from_file(_file_name) + except Exception as err: + logger.error(f"Cannot import key from {_file_name}: {err}") + raise + else: + sec_backend = RSACrypto(rsa_key) else: err_msg = "Unknown crypto_backend {backend}" err_msg = err_msg.format(backend=conf.crypto_backend) @@ -1814,32 +1975,20 @@ def pre_signature_part( return signature -# -# -# -# -# -# -# -# my-rsa-key -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# +# +# +# +# +# +# +# +# +# +# +# +# +# +# def pre_encryption_part( diff --git a/tests/test_95_pyxmlsec.py b/tests/test_95_pyxmlsec.py new file mode 100644 index 000000000..e8bb8b477 --- /dev/null +++ b/tests/test_95_pyxmlsec.py @@ -0,0 +1,117 @@ +from pathutils import full_path + +from saml2 import class_name +from saml2 import config +from saml2 import extension_elements_to_elements +from saml2 import saml +from saml2 import samlp +from saml2 import sigver +from saml2.mdstore import MetadataStore +from saml2.s_utils import do_attribute_statement +from saml2.s_utils import factory +from saml2.saml import EncryptedAssertion + +SIGNED = full_path("saml_signed.xml") +UNSIGNED = full_path("saml_unsigned.xml") +SIMPLE_SAML_PHP_RESPONSE = full_path("simplesamlphp_authnresponse.xml") +OKTA_RESPONSE = full_path("okta_response.xml") +OKTA_ASSERTION = full_path("okta_assertion") + +PUB_KEY = full_path("test.pem") +PRIV_KEY = full_path("test.key") + +ENC_PUB_KEY = full_path("pki/test_1.crt") +ENC_PRIV_KEY = full_path("pki/test.key") + +INVALID_KEY = full_path("non-existent.key") + +IDP_EXAMPLE = full_path("idp_example.xml") +METADATA_CERT = full_path("metadata_cert.xml") + +def make_sec(crypto_backend = None): + conf = config.SPConfig() + conf.load_file("server_conf") + md = MetadataStore([saml, samlp], None, conf) + md.load("local", IDP_EXAMPLE) + + conf.metadata = md + conf.only_use_keys_in_metadata = False + if crypto_backend is not None: + conf.crypto_backend = crypto_backend + return sigver.security_context(conf) + +def make_assertion(sec): + return factory( + saml.Assertion, + version="2.0", + id="id-11111", + issue_instant="2009-10-30T13:20:28Z", + signature=sigver.pre_signature_part("id-11111", sec.my_cert, 1), + attribute_statement=do_attribute_statement( + { + ("", "", "surName"): ("Foo", ""), + ("", "", "givenName"): ("Bar", ""), + } + ), + ) + +def sign_assertion(assertion, sec): + sigass = sec.sign_statement( + assertion, + class_name(assertion), + key_file=PRIV_KEY, + node_id=assertion.id, + ) + + _ass0 = saml.assertion_from_string(sigass) + encrypted_assertion = EncryptedAssertion() + encrypted_assertion.add_extension_element(_ass0) + + return encrypted_assertion + +def encrypt_assertion(signed_assertion, sec): + template = str(sigver.pre_encryption_part(encrypted_key_id="EK_TEST", encrypted_data_id="ED_TEST")) + tmp = sigver.make_temp(template.encode("utf-8"), decode=False) + return sec.crypto.encrypt( + str(signed_assertion), + sec.cert_file, + tmp.name, + "des-192", + '/*[local-name()="EncryptedAssertion"]/*[local-name()="Assertion"]', + ) + +def decrypt_assertion(enctext, node_name, sec): + decr_text = sec.decrypt(enctext, key_file=PRIV_KEY) + _seass = saml.encrypted_assertion_from_string(decr_text) + assers = extension_elements_to_elements(_seass.extension_elements, [saml, samlp]) + + for ass in assers: + _txt = sec.verify_signature(str(ass), PUB_KEY, node_name=node_name) + if _txt: + return ass + +def test_libxmlsec(): + sec = make_sec() + assertion = make_assertion(sec) + signed_assertion = sign_assertion(assertion, sec) + enctext = encrypt_assertion(signed_assertion, sec) + decrypted_assertion = decrypt_assertion(enctext, class_name(assertion), sec) + + assert decrypted_assertion + return decrypted_assertion + +def test_pyxmlsec(): + sec = make_sec("XMLSecurity") + assertion = make_assertion(sec) + signed_assertion = sign_assertion(assertion, sec) + enctext = encrypt_assertion(signed_assertion, sec) + decrypted_assertion = decrypt_assertion(enctext, class_name(assertion), sec) + + assert decrypted_assertion + return decrypted_assertion + +def compare_tests(): + assert test_libxmlsec() == test_pyxmlsec() + +if __name__ == "__main__": + compare_tests() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..87deb2d6a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,65 @@ +ELIDED_TAGS = [ + "X509Certificate", + "SignatureValue", + "DigestValue", + "CipherValue" +] + +def tag_name(tag: str, include_namespace: bool = True) -> str: + name = "" + start = 0 + if tag[start] == "<": + start += 1 + if tag[start] == "/": + start += 1 + for char in tag[start:]: + if char in " \n/>": + break + name += char + if not include_namespace: + name = name.split(":")[-1] + return name + +def pretty_print_xml(xml: str): + if isinstance(xml, bytes): + xml = str(xml, "utf-8") + + tag_groups = {} + tags = [] + tag = "" + closed = False + istag = True + for char in xml: + if char == "\n": + continue + if tag == "" and char != "<": + istag = False + if not istag and char == "<": + tags.append(tag) + tag = "" + istag = True + tag += char + if char == ">": + tag_groups[tag_name(tag)] = closed + tags.append(tag) + tag = "" + closed = False + continue + if tag == "