Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use new default ssl context in all aiohttp requests #271

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/ai_horde_client/image/async_manual_client_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import aiohttp
from loguru import logger

from horde_sdk import ANON_API_KEY
from horde_sdk import ANON_API_KEY, _default_sslcontext
from horde_sdk.ai_horde_api import AIHordeAPIAsyncManualClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest
from horde_sdk.generic_api.apimodels import RequestErrorResponse
Expand Down Expand Up @@ -90,7 +90,7 @@ async def main(apikey: str = ANON_API_KEY) -> None:

image_bytes = None
# image_gen.img is a url, download it using aiohttp.
async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp:
async with aiohttp.ClientSession() as session, session.get(image_gen.img, ssl=_default_sslcontext) as resp:
image_bytes = await resp.read()

if image_bytes is None:
Expand Down
5 changes: 4 additions & 1 deletion horde_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# isort: off
# We import dotenv first so that we can use it to load environment variables before importing anything else.
import ssl
import certifi
import dotenv

# If the current working directory contains a `.env` file, import the environment variables from it.
Expand Down Expand Up @@ -59,7 +61,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover


_dev_env_var_warnings()

_default_sslcontext = ssl.create_default_context(cafile=certifi.where())

from horde_sdk.consts import (
PAYLOAD_HTTP_METHODS,
Expand Down Expand Up @@ -109,4 +111,5 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
"PROGRESS_LOGGER_LABEL",
"COMPLETE_LOGGER_LABEL",
"HordeException",
"_default_sslcontext",
]
7 changes: 3 additions & 4 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import requests
from loguru import logger

from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL
from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL, _default_sslcontext
from horde_sdk.ai_horde_api.apimodels import (
AIHordeHeartbeatRequest,
AIHordeHeartbeatResponse,
Expand Down Expand Up @@ -79,7 +79,6 @@
GenericAsyncHordeAPISession,
GenericHordeAPIManualClient,
GenericHordeAPISession,
_default_sslcontext,
)


Expand Down Expand Up @@ -1290,7 +1289,7 @@ async def download_image_from_generation(self, generation: ImageGeneration) -> t

image_bytes: bytes | None = None
if urllib.parse.urlparse(generation.img).scheme in ["http", "https"]:
async with self._aiohttp_session.get(generation.img) as response:
async with self._aiohttp_session.get(generation.img, ssl=_default_sslcontext) as response:
if response.status != 200: # pragma: no cover
logger.error(f"Error downloading image: {response.status}")
response.raise_for_status()
Expand Down Expand Up @@ -1326,7 +1325,7 @@ async def download_image_from_url(self, url: str) -> PIL.Image.Image:
if self._aiohttp_session is None:
raise RuntimeError("No aiohttp session provided but an async request was made.")

async with self._aiohttp_session.get(url) as response:
async with self._aiohttp_session.get(url, ssl=_default_sslcontext) as response:
if response.status != 200: # pragma: no cover
logger.error(f"Error downloading image: {response.status}")
response.raise_for_status()
Expand Down
5 changes: 3 additions & 2 deletions horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import override

from horde_sdk import _default_sslcontext
from horde_sdk.consts import HTTPMethod, HTTPStatusCode
from horde_sdk.generic_api.consts import ANON_API_KEY
from horde_sdk.generic_api.endpoints import GENERIC_API_ENDPOINT_SUBPATH, url_with_path
Expand Down Expand Up @@ -256,7 +257,7 @@ class ResponseRequiringDownloadMixin(HordeAPIDataObject):

async def download_file_as_base64(self, client_session: aiohttp.ClientSession, url: str) -> str:
"""Download a file and return the value as a base64 string."""
async with client_session.get(url) as response:
async with client_session.get(url, ssl=_default_sslcontext) as response:
response.raise_for_status()
return base64.b64encode(await response.read()).decode("utf-8")

Expand All @@ -273,7 +274,7 @@ async def download_file_to_field_as_base64(
url (str): The URL to download the file from.
field_name (str): The name of the field to save the file to.
"""
async with client_session.get(url) as response:
async with client_session.get(url, ssl=_default_sslcontext) as response:
response.raise_for_status()
setattr(self, field_name, base64.b64encode(await response.read()).decode("utf-8"))

Expand Down
4 changes: 1 addition & 3 deletions horde_sdk/generic_api/generic_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@

import asyncio
import os
import ssl
from abc import ABC
from ssl import SSLContext
from typing import Any, TypeVar

import aiohttp
import certifi
import requests
from loguru import logger
from pydantic import BaseModel, ValidationError
from strenum import StrEnum
from typing_extensions import override

from horde_sdk import _default_sslcontext
from horde_sdk.ai_horde_api.exceptions import AIHordePayloadValidationError
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import (
Expand All @@ -35,7 +34,6 @@
GenericQueryFields,
)

_default_sslcontext = ssl.create_default_context(cafile=certifi.where())
"""The default SSL context to use for aiohttp requests."""


Expand Down
Loading