From 5878cbde5e91d87d873dd9d5cd4693229fd2ed25 Mon Sep 17 00:00:00 2001 From: Kristof Bajnok Date: Mon, 12 Jun 2023 20:02:14 +0200 Subject: [PATCH] PR #405: apply changes from review Add base_path and endpoint_basepath to backend and micro_services Co-authored-by: Ivan Kanakarakis --- src/satosa/backends/base.py | 6 ++++ src/satosa/frontends/base.py | 1 + src/satosa/frontends/saml2.py | 29 ++++++++++++-------- src/satosa/micro_services/account_linking.py | 4 +-- src/satosa/micro_services/base.py | 4 +++ src/satosa/micro_services/consent.py | 4 +-- src/satosa/routing.py | 4 +-- 7 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/satosa/backends/base.py b/src/satosa/backends/base.py index 381902eee..d18dfc4d6 100644 --- a/src/satosa/backends/base.py +++ b/src/satosa/backends/base.py @@ -3,6 +3,9 @@ """ from ..attribute_mapping import AttributeMapper +from ..util import join_paths + +from urllib.parse import urlparse class BackendModule(object): @@ -30,7 +33,10 @@ def __init__(self, auth_callback_func, internal_attributes, base_url, name): self.internal_attributes = internal_attributes self.converter = AttributeMapper(internal_attributes) self.base_url = base_url.rstrip("/") if base_url else "" + self.base_path = urlparse(self.base_url).path.lstrip("/") self.name = name + self.endpoint_baseurl = join_paths(self.base_url, self.name) + self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") def start_auth(self, context, internal_request): """ diff --git a/src/satosa/frontends/base.py b/src/satosa/frontends/base.py index 08c38b79c..49224c48f 100644 --- a/src/satosa/frontends/base.py +++ b/src/satosa/frontends/base.py @@ -30,6 +30,7 @@ def __init__(self, auth_req_callback_func, internal_attributes, base_url, name): self.internal_attributes = internal_attributes self.converter = AttributeMapper(internal_attributes) self.base_url = base_url or "" + self.base_path = urlparse(self.base_url).path.lstrip("/") self.name = name self.endpoint_baseurl = join_paths(self.base_url, self.name) self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") diff --git a/src/satosa/frontends/saml2.py b/src/satosa/frontends/saml2.py index 41be65799..2b8a1149a 100644 --- a/src/satosa/frontends/saml2.py +++ b/src/satosa/frontends/saml2.py @@ -33,6 +33,7 @@ from ..response import Response from ..response import ServiceError from ..saml_util import make_saml_response +from ..util import join_paths from satosa.exception import SATOSAError import satosa.util as util @@ -511,17 +512,18 @@ def _register_endpoints(self, providers): """ url_map = [] - backend_providers = "|".join(providers) - base_path = urlparse(self.base_url).path.lstrip("/") - if base_path: - base_path = base_path + "/" + backend_providers = "(" + "|".join(providers) + ")" for endp_category in self.endpoints: for binding, endp in self.endpoints[endp_category].items(): endp_path = urlparse(endp).path url_map.append( ( - "^{}({})/{}$".format(base_path, backend_providers, endp_path), - functools.partial(self.handle_authn_request, binding_in=binding) + "^{}$".format( + join_paths(self.base_path, backend_providers, endp_path) + ), + functools.partial( + self.handle_authn_request, binding_in=binding + ), ) ) @@ -769,17 +771,20 @@ def _register_endpoints(self, providers): """ url_map = [] - backend_providers = "|".join(providers) - base_path = urlparse(self.base_url).path.lstrip("/") - if base_path: - base_path = base_path + "/" + backend_providers = "(" + "|".join(providers) + ")" for endp_category in self.endpoints: for binding, endp in self.endpoints[endp_category].items(): endp_path = urlparse(endp).path url_map.append( ( - "^{}({})/\S+/{}$".format(base_path, backend_providers, endp_path), - functools.partial(self.handle_authn_request, binding_in=binding) + "^{}$".format( + join_paths( + self.base_path, backend_providers, "\S+", endp_path + ) + ), + functools.partial( + self.handle_authn_request, binding_in=binding + ), ) ) diff --git a/src/satosa/micro_services/account_linking.py b/src/satosa/micro_services/account_linking.py index 4cfd72c99..bebd77b7f 100644 --- a/src/satosa/micro_services/account_linking.py +++ b/src/satosa/micro_services/account_linking.py @@ -165,9 +165,7 @@ def register_endpoints(self): return [ ( "^{}$".format( - join_paths( - self.base_path, "account_linking", self.endpoint - ) + join_paths(self.base_path, "account_linking", self.endpoint) ), self._handle_al_response, ) diff --git a/src/satosa/micro_services/base.py b/src/satosa/micro_services/base.py index 97271b013..b31baf9b4 100644 --- a/src/satosa/micro_services/base.py +++ b/src/satosa/micro_services/base.py @@ -4,6 +4,8 @@ import logging from urllib.parse import urlparse +from ..util import join_paths + logger = logging.getLogger(__name__) @@ -16,6 +18,8 @@ def __init__(self, name, base_url, **kwargs): self.name = name self.base_url = base_url self.base_path = urlparse(base_url).path.lstrip("/") + self.endpoint_baseurl = join_paths(self.base_url, self.name) + self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") self.next = None def process(self, context, data): diff --git a/src/satosa/micro_services/consent.py b/src/satosa/micro_services/consent.py index c273116d5..fbe8e75dc 100644 --- a/src/satosa/micro_services/consent.py +++ b/src/satosa/micro_services/consent.py @@ -242,9 +242,7 @@ def register_endpoints(self): return [ ( "^{}$".format( - join_paths( - self.base_path, "consent", self.endpoint - ) + join_paths(self.base_path, "consent", self.endpoint) ), self._handle_consent_response, ) diff --git a/src/satosa/routing.py b/src/satosa/routing.py index c9fa8ab8e..d739273ac 100644 --- a/src/satosa/routing.py +++ b/src/satosa/routing.py @@ -38,7 +38,7 @@ class UnknownEndpoint(ValueError): and handles the internal routing between frontends and backends. """ - def __init__(self, frontends, backends, micro_services, base_path=""): + def __init__(self, frontends, backends, micro_services, base_path=None): """ :type frontends: dict[str, satosa.frontends.base.FrontendModule] :type backends: dict[str, satosa.backends.base.BackendModule] @@ -70,7 +70,7 @@ def __init__(self, frontends, backends, micro_services, base_path=""): else: self.micro_services = {} - self.base_path = base_path + self.base_path = base_path if base_path else "" logger.debug("Loaded backends with endpoints: {}".format(backends)) logger.debug("Loaded frontends with endpoints: {}".format(frontends))