Skip to content

Commit

Permalink
fix: use new default ssl context in all aiohttp requests (#271)
Browse files Browse the repository at this point in the history
Adds the missing ssl kwarg usage in other parts of the code, missed in #268.
  • Loading branch information
tazlin authored Oct 4, 2024
1 parent 63a3917 commit 7912776
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions examples/ai_horde_client/image/async_manual_client_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import aiohttp
from loguru import logger

from horde_sdk import ANON_API_KEY
from horde_sdk import ANON_API_KEY, _default_sslcontext
from horde_sdk.ai_horde_api import AIHordeAPIAsyncManualClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest
from horde_sdk.generic_api.apimodels import RequestErrorResponse
Expand Down Expand Up @@ -90,7 +90,7 @@ async def main(apikey: str = ANON_API_KEY) -> None:

image_bytes = None
# image_gen.img is a url, download it using aiohttp.
async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp:
async with aiohttp.ClientSession() as session, session.get(image_gen.img, ssl=_default_sslcontext) as resp:
image_bytes = await resp.read()

if image_bytes is None:
Expand Down
5 changes: 4 additions & 1 deletion horde_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# isort: off
# We import dotenv first so that we can use it to load environment variables before importing anything else.
import ssl
import certifi
import dotenv

# If the current working directory contains a `.env` file, import the environment variables from it.
Expand Down Expand Up @@ -59,7 +61,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover


_dev_env_var_warnings()

_default_sslcontext = ssl.create_default_context(cafile=certifi.where())

from horde_sdk.consts import (
PAYLOAD_HTTP_METHODS,
Expand Down Expand Up @@ -109,4 +111,5 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
"PROGRESS_LOGGER_LABEL",
"COMPLETE_LOGGER_LABEL",
"HordeException",
"_default_sslcontext",
]
7 changes: 3 additions & 4 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import requests
from loguru import logger

from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL
from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL, _default_sslcontext
from horde_sdk.ai_horde_api.apimodels import (
AIHordeHeartbeatRequest,
AIHordeHeartbeatResponse,
Expand Down Expand Up @@ -79,7 +79,6 @@
GenericAsyncHordeAPISession,
GenericHordeAPIManualClient,
GenericHordeAPISession,
_default_sslcontext,
)


Expand Down Expand Up @@ -1290,7 +1289,7 @@ async def download_image_from_generation(self, generation: ImageGeneration) -> t

image_bytes: bytes | None = None
if urllib.parse.urlparse(generation.img).scheme in ["http", "https"]:
async with self._aiohttp_session.get(generation.img) as response:
async with self._aiohttp_session.get(generation.img, ssl=_default_sslcontext) as response:
if response.status != 200: # pragma: no cover
logger.error(f"Error downloading image: {response.status}")
response.raise_for_status()
Expand Down Expand Up @@ -1326,7 +1325,7 @@ async def download_image_from_url(self, url: str) -> PIL.Image.Image:
if self._aiohttp_session is None:
raise RuntimeError("No aiohttp session provided but an async request was made.")

async with self._aiohttp_session.get(url) as response:
async with self._aiohttp_session.get(url, ssl=_default_sslcontext) as response:
if response.status != 200: # pragma: no cover
logger.error(f"Error downloading image: {response.status}")
response.raise_for_status()
Expand Down
5 changes: 3 additions & 2 deletions horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import override

from horde_sdk import _default_sslcontext
from horde_sdk.consts import HTTPMethod, HTTPStatusCode
from horde_sdk.generic_api.consts import ANON_API_KEY
from horde_sdk.generic_api.endpoints import GENERIC_API_ENDPOINT_SUBPATH, url_with_path
Expand Down Expand Up @@ -256,7 +257,7 @@ class ResponseRequiringDownloadMixin(HordeAPIDataObject):

async def download_file_as_base64(self, client_session: aiohttp.ClientSession, url: str) -> str:
"""Download a file and return the value as a base64 string."""
async with client_session.get(url) as response:
async with client_session.get(url, ssl=_default_sslcontext) as response:
response.raise_for_status()
return base64.b64encode(await response.read()).decode("utf-8")

Expand All @@ -273,7 +274,7 @@ async def download_file_to_field_as_base64(
url (str): The URL to download the file from.
field_name (str): The name of the field to save the file to.
"""
async with client_session.get(url) as response:
async with client_session.get(url, ssl=_default_sslcontext) as response:
response.raise_for_status()
setattr(self, field_name, base64.b64encode(await response.read()).decode("utf-8"))

Expand Down
4 changes: 1 addition & 3 deletions horde_sdk/generic_api/generic_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@

import asyncio
import os
import ssl
from abc import ABC
from ssl import SSLContext
from typing import Any, TypeVar

import aiohttp
import certifi
import requests
from loguru import logger
from pydantic import BaseModel, ValidationError
from strenum import StrEnum
from typing_extensions import override

from horde_sdk import _default_sslcontext
from horde_sdk.ai_horde_api.exceptions import AIHordePayloadValidationError
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import (
Expand All @@ -35,7 +34,6 @@
GenericQueryFields,
)

_default_sslcontext = ssl.create_default_context(cafile=certifi.where())
"""The default SSL context to use for aiohttp requests."""


Expand Down

0 comments on commit 7912776

Please sign in to comment.