Skip to content

Commit

Permalink
ModuleRouter: support paths in BASE
Browse files Browse the repository at this point in the history
If Satosa is installed under a path which is not the root of the
webserver (ie. "https://example.com/satosa"), then endpoint routing must
take the base path into consideration.

Some modules registered some of their endpoints with the base path
included, but other times the base path was omitted, thus it made the
routing fail. Now all endpoint registrations include the base path in
their endpoint map.

Provide a simple implementation for joining path components, since we
don't want to add the separator for empty strings and when any of the
path components already have it.

Additionally, DEBUG logging was configured for the tests so that the
debug logs are accessible during testing.
  • Loading branch information
bajnokk committed Nov 24, 2023
1 parent 4e8d27c commit f558ead
Show file tree
Hide file tree
Showing 18 changed files with 281 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/satosa/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, auth_callback_func, internal_attributes, base_url, name):
self.auth_callback_func = auth_callback_func
self.internal_attributes = internal_attributes
self.converter = AttributeMapper(internal_attributes)
self.base_url = base_url
self.base_url = base_url.rstrip("/") if base_url else ""
self.name = name

def start_auth(self, context, internal_request):
Expand Down
9 changes: 7 additions & 2 deletions src/satosa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid

from saml2.s_utils import UnknownSystemEntity
from urllib.parse import urlparse

from satosa import util
from .context import Context
Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(self, config):
"""
self.config = config

base_path = urlparse(self.config["BASE"]).path.lstrip("/")

logger.info("Loading backend modules...")
backends = load_backends(self.config, self._auth_resp_callback_func,
self.config["INTERNAL_ATTRIBUTES"])
Expand All @@ -63,8 +66,10 @@ def __init__(self, config):
self.config["BASE"]))
self._link_micro_services(self.response_micro_services, self._auth_resp_finish)

self.module_router = ModuleRouter(frontends, backends,
self.request_micro_services + self.response_micro_services)
self.module_router = ModuleRouter(frontends,
backends,
self.request_micro_services + self.response_micro_services,
base_path)

def _link_micro_services(self, micro_services, finisher):
if not micro_services:
Expand Down
4 changes: 0 additions & 4 deletions src/satosa/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ def path(self, p):
raise ValueError("path can't start with '/'")
self._path = p

def target_entity_id_from_path(self):
target_entity_id = self.path.split("/")[1]
return target_entity_id

def decorate(self, key, value):
"""
Add information to the context
Expand Down
10 changes: 9 additions & 1 deletion src/satosa/frontends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Holds a base class for frontend modules used in the SATOSA proxy.
"""
from ..attribute_mapping import AttributeMapper
from ..util import join_paths

from urllib.parse import urlparse


class FrontendModule(object):
Expand All @@ -14,17 +17,22 @@ def __init__(self, auth_req_callback_func, internal_attributes, base_url, name):
:type auth_req_callback_func:
(satosa.context.Context, satosa.internal.InternalData) -> satosa.response.Response
:type internal_attributes: dict[str, dict[str, str | list[str]]]
:type base_url: str
:type name: str
:param auth_req_callback_func: Callback should be called by the module after the
authorization response has been processed.
:param internal_attributes: attribute mapping
:param base_url: base url of the proxy
:param name: name of the plugin
"""
self.auth_req_callback_func = auth_req_callback_func
self.internal_attributes = internal_attributes
self.converter = AttributeMapper(internal_attributes)
self.base_url = base_url
self.base_url = base_url or ""
self.name = name
self.endpoint_baseurl = join_paths(self.base_url, self.name)
self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/")

def handle_authn_response(self, context, internal_resp):
"""
Expand Down
47 changes: 34 additions & 13 deletions src/satosa/frontends/openid_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ..response import BadRequest, Created
from ..response import SeeOther, Response
from ..response import Unauthorized
from ..util import rndstr
from ..util import join_paths, rndstr

import satosa.logging_util as lu
from satosa.internal import InternalData
Expand Down Expand Up @@ -97,7 +97,6 @@ def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url,
else:
cdb = {}

self.endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
self.provider = _create_provider(
provider_config,
self.endpoint_baseurl,
Expand Down Expand Up @@ -173,6 +172,19 @@ def register_endpoints(self, backend_names):
:rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
:raise ValueError: if more than one backend is configured
"""
# See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
#
# We skip the scheme + host + port of the issuer URL, because we can only map the
# path for the provider config endpoint. We are safe to use urlparse().path here,
# because for issuer OIDC allows only https URLs without query and fragment parts.
issuer = self.provider.configuration_information["issuer"]
autoconf_path = ".well-known/openid-configuration"
provider_config = (
"^{}$".format(join_paths(urlparse(issuer).path.lstrip("/"), autoconf_path)),
self.provider_config,
)
jwks_uri = ("^{}/jwks$".format(self.endpoint_basepath), self.jwks)

backend_name = None
if len(backend_names) != 1:
# only supports one backend since there currently is no way to publish multiple authorization endpoints
Expand All @@ -189,40 +201,49 @@ def register_endpoints(self, backend_names):
else:
backend_name = backend_names[0]

provider_config = ("^.well-known/openid-configuration$", self.provider_config)
jwks_uri = ("^{}/jwks$".format(self.name), self.jwks)

if backend_name:
# if there is only one backend, include its name in the path so the default routing can work
auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name, self.name, AuthorizationEndpoint.url)
auth_endpoint = join_paths(
self.base_url,
backend_name,
self.name,
AuthorizationEndpoint.url,
)
self.provider.configuration_information["authorization_endpoint"] = auth_endpoint
auth_path = urlparse(auth_endpoint).path.lstrip("/")
else:
auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url)
auth_path = join_paths(self.endpoint_basepath, AuthorizationRequest.url)

authentication = ("^{}$".format(auth_path), self.handle_authn_request)
url_map = [provider_config, jwks_uri, authentication]

if any("code" in v for v in self.provider.configuration_information["response_types_supported"]):
self.provider.configuration_information["token_endpoint"] = "{}/{}".format(
self.endpoint_baseurl, TokenEndpoint.url
self.provider.configuration_information["token_endpoint"] = join_paths(
self.endpoint_baseurl,
TokenEndpoint.url,
)
token_endpoint = (
"^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint
"^{}".format(join_paths(self.endpoint_basepath, TokenEndpoint.url)),
self.token_endpoint,
)
url_map.append(token_endpoint)

self.provider.configuration_information["userinfo_endpoint"] = (
"{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url)
join_paths(self.endpoint_baseurl, UserinfoEndpoint.url)
)
userinfo_endpoint = (
"^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint
"^{}".format(
join_paths(self.endpoint_basepath, UserinfoEndpoint.url)
),
self.userinfo_endpoint,
)
url_map.append(userinfo_endpoint)

if "registration_endpoint" in self.provider.configuration_information:
client_registration = (
"^{}/{}".format(self.name, RegistrationEndpoint.url),
"^{}".format(
join_paths(self.endpoint_basepath, RegistrationEndpoint.url)
),
self.client_registration,
)
url_map.append(client_registration)
Expand Down
3 changes: 2 additions & 1 deletion src/satosa/frontends/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import satosa.logging_util as lu
from satosa.frontends.base import FrontendModule
from satosa.response import Response
from satosa.util import join_paths


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,7 +44,7 @@ def register_endpoints(self, backend_names):
:rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
:raise ValueError: if more than one backend is configured
"""
url_map = [("^{}".format(self.name), self.ping_endpoint)]
url_map = [("^{}".format(join_paths(self.endpoint_basepath, self.name)), self.ping_endpoint)]

return url_map

Expand Down
42 changes: 28 additions & 14 deletions src/satosa/frontends/saml2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def register_endpoints(self, backend_names):

if self.enable_metadata_reload():
url_map.append(
("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata))
("^%s/%s$" % (self.endpoint_basepath, "reload-metadata"), self._reload_metadata))

self.idp_config = self._build_idp_config_endpoints(
self.config[self.KEY_IDP_CONFIG], backend_names)
Expand Down Expand Up @@ -511,15 +511,19 @@ def _register_endpoints(self, providers):
"""
url_map = []

backend_providers = "|".join(providers)
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path:
base_path = base_path + "/"
for endp_category in self.endpoints:
for binding, endp in self.endpoints[endp_category].items():
valid_providers = ""
for provider in providers:
valid_providers = "{}|^{}".format(valid_providers, provider)
valid_providers = valid_providers.lstrip("|")
parsed_endp = urlparse(endp)
url_map.append(("(%s)/%s$" % (valid_providers, parsed_endp.path),
functools.partial(self.handle_authn_request, binding_in=binding)))
endp_path = urlparse(endp).path
url_map.append(
(
"^{}({})/{}$".format(base_path, backend_providers, endp_path),
functools.partial(self.handle_authn_request, binding_in=binding)
)
)

if self.expose_entityid_endpoint():
logger.debug("Exposing frontend entity endpoint = {}".format(self.idp.config.entityid))
Expand Down Expand Up @@ -675,11 +679,18 @@ def _load_idp_dynamic_endpoints(self, context):
:param context:
:return: An idp server
"""
target_entity_id = context.target_entity_id_from_path()
target_entity_id = self._target_entity_id_from_path(context.path)
idp_conf_file = self._load_endpoints_to_config(context.target_backend, target_entity_id)
idp_config = IdPConfig().load(idp_conf_file)
return Server(config=idp_config)

def _target_entity_id_from_path(self, request_path):
path = request_path.lstrip("/")
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path and path.startswith(base_path):
path = path[len(base_path):].lstrip("/")
return path.split("/")[1]

def _load_idp_dynamic_entity_id(self, state):
"""
Loads an idp server with the entity id saved in state
Expand All @@ -705,7 +716,7 @@ def handle_authn_request(self, context, binding_in):
:type binding_in: str
:rtype: satosa.response.Response
"""
target_entity_id = context.target_entity_id_from_path()
target_entity_id = self._target_entity_id_from_path(context.path)
target_entity_id = urlsafe_b64decode(target_entity_id).decode()
context.decorate(Context.KEY_TARGET_ENTITYID, target_entity_id)

Expand All @@ -723,7 +734,7 @@ def _create_state_data(self, context, resp_args, relay_state):
:rtype: dict[str, dict[str, str] | str]
"""
state = super()._create_state_data(context, resp_args, relay_state)
state["target_entity_id"] = context.target_entity_id_from_path()
state["target_entity_id"] = self._target_entity_id_from_path(context.path)
return state

def handle_backend_error(self, exception):
Expand Down Expand Up @@ -758,13 +769,16 @@ def _register_endpoints(self, providers):
"""
url_map = []

backend_providers = "|".join(providers)
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path:
base_path = base_path + "/"
for endp_category in self.endpoints:
for binding, endp in self.endpoints[endp_category].items():
valid_providers = "|^".join(providers)
parsed_endp = urlparse(endp)
endp_path = urlparse(endp).path
url_map.append(
(
r"(^{})/\S+/{}".format(valid_providers, parsed_endp.path),
"^{}({})/\S+/{}$".format(base_path, backend_providers, endp_path),
functools.partial(self.handle_authn_request, binding_in=binding)
)
)
Expand Down
12 changes: 11 additions & 1 deletion src/satosa/micro_services/account_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..exception import SATOSAAuthenticationError
from ..micro_services.base import ResponseMicroService
from ..response import Redirect
from ..util import join_paths

import satosa.logging_util as lu
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -161,4 +162,13 @@ def register_endpoints(self):
:return: A list of endpoints bound to a function
"""
return [("^account_linking%s$" % self.endpoint, self._handle_al_response)]
return [
(
"^{}$".format(
join_paths(
self.base_path, "account_linking", self.endpoint
)
),
self._handle_al_response,
)
]
2 changes: 2 additions & 0 deletions src/satosa/micro_services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Micro service for SATOSA
"""
import logging
from urllib.parse import urlparse

logger = logging.getLogger(__name__)

Expand All @@ -14,6 +15,7 @@ class MicroService(object):
def __init__(self, name, base_url, **kwargs):
self.name = name
self.base_url = base_url
self.base_path = urlparse(base_url).path.lstrip("/")
self.next = None

def process(self, context, data):
Expand Down
12 changes: 11 additions & 1 deletion src/satosa/micro_services/consent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from satosa.internal import InternalData
from satosa.micro_services.base import ResponseMicroService
from satosa.response import Redirect
from satosa.util import join_paths


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -238,4 +239,13 @@ def register_endpoints(self):
:return: A list of endpoints bound to a function
"""
return [("^consent%s$" % self.endpoint, self._handle_consent_response)]
return [
(
"^{}$".format(
join_paths(
self.base_path, "consent", self.endpoint
)
),
self._handle_consent_response,
)
]
Loading

0 comments on commit f558ead

Please sign in to comment.