Skip to content

Commit

Permalink
[azure] Fix fetching user delegation key when custom domain is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
jschneier committed Jul 6, 2024
1 parent e268efa commit a124f32
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 119 deletions.
49 changes: 16 additions & 33 deletions storages/backends/azure_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from datetime import datetime
from datetime import timedelta
from tempfile import SpooledTemporaryFile
from urllib.parse import urlparse
from urllib.parse import urlunparse

from azure.core.exceptions import ResourceNotFoundError
from azure.core.utils import parse_connection_string
Expand Down Expand Up @@ -118,9 +120,7 @@ class AzureStorage(BaseStorage):
def __init__(self, **settings):
super().__init__(**settings)
self._service_client = None
self._custom_service_client = None
self._client = None
self._custom_client = None
self._user_delegation_key = None
self._user_delegation_key_expiry = datetime.utcnow()
if self.connection_string and (not self.account_name or not self.account_key):
Expand Down Expand Up @@ -155,18 +155,11 @@ def get_default_settings(self):
"api_version": setting("AZURE_API_VERSION", None),
}

def _get_service_client(self, use_custom_domain):
def _get_service_client(self):
if self.connection_string is not None:
return BlobServiceClient.from_connection_string(self.connection_string)

account_domain = (
self.custom_domain
if self.custom_domain and use_custom_domain
else "{}.blob.{}".format(
self.account_name,
self.endpoint_suffix,
)
)
account_domain = "{}.blob.{}".format(self.account_name, self.endpoint_suffix)
account_url = "{}://{}".format(self.azure_protocol, account_domain)

credential = None
Expand All @@ -187,17 +180,9 @@ def _get_service_client(self, use_custom_domain):
@property
def service_client(self):
if self._service_client is None:
self._service_client = self._get_service_client(use_custom_domain=False)
self._service_client = self._get_service_client()
return self._service_client

@property
def custom_service_client(self):
if self._custom_service_client is None:
self._custom_service_client = self._get_service_client(
use_custom_domain=True
)
return self._custom_service_client

@property
def client(self):
if self._client is None:
Expand All @@ -206,14 +191,6 @@ def client(self):
)
return self._client

@property
def custom_client(self):
if self._custom_client is None:
self._custom_client = self.custom_service_client.get_container_client(
self.azure_container
)
return self._custom_client

def get_user_delegation_key(self, expiry):
# We'll only be able to get a user delegation key if we've authenticated with a
# token credential.
Expand All @@ -228,10 +205,8 @@ def get_user_delegation_key(self, expiry):
):
now = datetime.utcnow()
key_expiry_time = now + timedelta(days=7)
self._user_delegation_key = (
self.custom_service_client.get_user_delegation_key(
key_start_time=now, key_expiry_time=key_expiry_time
)
self._user_delegation_key = self.service_client.get_user_delegation_key(
key_start_time=now, key_expiry_time=key_expiry_time
)
self._user_delegation_key_expiry = key_expiry_time

Expand Down Expand Up @@ -333,7 +308,15 @@ def url(self, name, expire=None, parameters=None, mode="r"):
)
credential = sas_token

container_blob_url = self.custom_client.get_blob_client(name).url
container_blob_url = self.client.get_blob_client(name).url

if self.custom_domain:
# Replace the account name with the custom domain
parsed_url = urlparse(container_blob_url)
container_blob_url = urlunparse(
parsed_url._replace(netloc=self.custom_domain)
)

return BlobClient.from_blob_url(container_blob_url, credential=credential).url

def _get_content_settings_parameters(self, name, content=None):
Expand Down
103 changes: 17 additions & 86 deletions tests/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class AzureStorageTest(TestCase):
def setUp(self, *args):
self.storage = azure_storage.AzureStorage()
self.storage._client = mock.MagicMock()
self.storage._custom_client = mock.MagicMock()
self.storage.overwrite_files = True
self.account_name = "test"
self.account_key = "key"
Expand Down Expand Up @@ -75,15 +74,12 @@ def test_get_available_name(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.exists.side_effect = [True, False]
custom_client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.storage._custom_client.get_blob_client.return_value = custom_client_mock
name = self.storage.get_available_name("foo.txt")
self.assertTrue(name.startswith("foo_"))
self.assertTrue(name.endswith(".txt"))
self.assertTrue(len(name) > len("foo.txt"))
self.assertEqual(client_mock.exists.call_count, 2)
self.assertEqual(custom_client_mock.exists.call_count, 0)

def test_get_available_name_first(self):
self.storage.overwrite_files = False
Expand Down Expand Up @@ -136,29 +132,27 @@ def test_get_available_invalid(self):
def test_url(self):
blob_mock = mock.MagicMock()
blob_mock.url = "https://ret_foo.blob.core.windows.net/test/some%20blob"
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage._client.get_blob_client.return_value = blob_mock
self.assertEqual(self.storage.url("some blob"), blob_mock.url)
self.storage.custom_client.get_blob_client.assert_called_once_with("some blob")
self.storage._client.get_blob_client.assert_not_called()
self.storage._client.get_blob_client.assert_called_once_with("some blob")

def test_url_unsafe_chars(self):
blob_mock = mock.MagicMock()
blob_mock.url = "https://ret_foo.blob.core.windows.net/test/some%20blob"
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage._client.get_blob_client.return_value = blob_mock
self.assertEqual(
self.storage.url("foo;?:@=&\"<>#%{}|^~[]`bar/~!*()'"), blob_mock.url
)
self.storage.custom_client.get_blob_client.assert_called_once_with(
self.storage._client.get_blob_client.assert_called_once_with(
"foo;?:@=&\"<>#%{}|^~[]`bar/~!*()'"
)
self.storage._client.get_blob_client.assert_not_called()

@mock.patch("storages.backends.azure_storage.generate_blob_sas")
def test_url_expire(self, generate_blob_sas_mocked):
generate_blob_sas_mocked.return_value = "foo_token"
blob_mock = mock.MagicMock()
blob_mock.url = "https://ret_foo.blob.core.windows.net/test/some%20blob"
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage._client.get_blob_client.return_value = blob_mock
self.storage.account_name = self.account_name

fixed_time = make_aware(
Expand Down Expand Up @@ -202,25 +196,31 @@ def test_url_expire(self, generate_blob_sas_mocked):
called_args, called_kwargs = generate_blob_sas_mocked.call_args
self.assertEqual(str(called_kwargs["permission"]), "w")

def test_url_custom_domain(self):
self.storage.custom_domain = "foo_domain"
blob_mock = mock.MagicMock()
blob_mock.url = "https://ret_foo.blob.core.windows.net/test/foo_name"
self.storage._client.get_blob_client.return_value = blob_mock
url = self.storage.url("foo_name")
self.assertEqual(url, "https://foo_domain/test/foo_name")

@mock.patch("storages.backends.azure_storage.generate_blob_sas")
def test_url_expire_user_delegation_key(self, generate_blob_sas_mocked):
generate_blob_sas_mocked.return_value = "foo_token"
blob_mock = mock.MagicMock()
blob_mock.url = "https://ret_foo.blob.core.windows.net/test/some%20blob"
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage._client.get_blob_client.return_value = blob_mock
self.storage.account_name = self.account_name
custom_service_client = mock.MagicMock()
self.storage._custom_service_client = custom_service_client
service_client = mock.MagicMock()
self.storage._service_client = service_client
self.storage.token_credential = "token_credential"

fixed_time = make_aware(
datetime.datetime(2016, 11, 6, 4), datetime.timezone.utc
)
with mock.patch("storages.backends.azure_storage.datetime") as d_mocked:
d_mocked.utcnow.return_value = fixed_time
custom_service_client.get_user_delegation_key.return_value = (
"user delegation key"
)
service_client.get_user_delegation_key.return_value = "user delegation key"
self.assertEqual(
self.storage.url("some blob", 100),
"https://ret_foo.blob.core.windows.net/test/some%20blob",
Expand Down Expand Up @@ -251,55 +251,34 @@ def test_container_client_default_params(self):
def test_container_client_params_account_key(self):
storage = azure_storage.AzureStorage()
storage.account_name = "foo_name"
storage.azure_ssl = True
storage.custom_domain = "foo_domain"
storage.account_key = "foo_key"
with mock.patch(
"storages.backends.azure_storage.BlobServiceClient", autospec=True
) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
"https://foo_name.blob.core.windows.net",
credential={"account_name": "foo_name", "account_key": "foo_key"},
)

bsc_mocked.return_value.get_container_client.return_value = (
custom_client_mock
)
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
"https://foo_domain",
credential={"account_name": "foo_name", "account_key": "foo_key"},
)

def test_container_client_params_sas_token(self):
storage = azure_storage.AzureStorage()
storage.account_name = "foo_name"
storage.azure_ssl = False
storage.custom_domain = "foo_domain"
storage.sas_token = "foo_token"
with mock.patch(
"storages.backends.azure_storage.BlobServiceClient", autospec=True
) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
"http://foo_name.blob.core.windows.net", credential="foo_token"
)

bsc_mocked.return_value.get_container_client.return_value = (
custom_client_mock
)
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with("http://foo_domain", credential="foo_token")

def test_container_client_params_token_credential(self):
storage = azure_storage.AzureStorage()
storage.account_name = self.account_name
Expand Down Expand Up @@ -356,7 +335,6 @@ def test_storage_save(self):
c_mocked.assert_called_once_with(
content_type="text/plain", content_encoding=None, cache_control=None
)
self.storage._custom_client.upload_blob.assert_not_called()

def test_storage_open_write(self):
"""
Expand All @@ -377,22 +355,17 @@ def test_storage_open_write(self):
timeout=20,
overwrite=True,
)
self.storage._custom_client.upload_blob.assert_not_called()

def test_storage_exists(self):
blob_name = "blob"
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.storage._custom_client.get_blob_client.return_value = client_mock
self.assertTrue(self.storage.exists(blob_name))
self.assertEqual(client_mock.exists.call_count, 1)
self.assertEqual(custom_client_mock.exists.call_count, 0)

def test_delete_blob(self):
self.storage.delete("name")
self.storage._client.delete_blob.assert_called_once_with("name", timeout=20)
self.storage._custom_client.delete_blob.assert_not_called()

def test_storage_listdir_base(self):
file_names = ["some/path/1.txt", "2.txt", "other/path/3.txt", "4.txt"]
Expand All @@ -408,8 +381,6 @@ def test_storage_listdir_base(self):
self.storage._client.list_blobs.assert_called_with(
name_starts_with="", timeout=20
)
self.storage._custom_client.list_blobs.assert_not_called()

self.assertEqual(len(dirs), 0)

self.assertEqual(len(files), 4)
Expand Down Expand Up @@ -453,43 +424,3 @@ def test_override_init_argument(self):
self.assertEqual(storage.azure_container, "foo1")
storage = azure_storage.AzureStorage(azure_container="foo2")
self.assertEqual(storage.azure_container, "foo2")

@mock.patch(
"storages.backends.azure_storage.AzureStorage._get_service_client",
)
def test_get_service_client_use_custom_domain(self, gsc_mocked):
storage = azure_storage.AzureStorage()
storage.account_name = self.account_name

_ = storage.service_client
gsc_mocked.assert_called_once_with(use_custom_domain=False)

_ = storage.custom_service_client
gsc_mocked.assert_called_with(use_custom_domain=True)

def test_blobserviceclient_no_custom_domain(self):
storage = azure_storage.AzureStorage()
storage.account_name = "foo_name"
storage.custom_domain = None
storage.account_key = "foo_key"
with mock.patch(
"storages.backends.azure_storage.BlobServiceClient", autospec=True
) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
"https://foo_name.blob.core.windows.net",
credential={"account_name": "foo_name", "account_key": "foo_key"},
)

bsc_mocked.return_value.get_container_client.return_value = (
custom_client_mock
)
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
"https://foo_name.blob.core.windows.net",
credential={"account_name": "foo_name", "account_key": "foo_key"},
)

0 comments on commit a124f32

Please sign in to comment.