Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: remove large models from model meta commands #257

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions horde_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
if len(dev_key) != 10 and len(dev_key) != 22:
raise ValueError("AI_HORDE_DEV_APIKEY must be the anon key or 22 characters long.")

AI_HORDE_MODEL_META_LARGE_MODELS = os.getenv("AI_HORDE_MODEL_META_LARGE_MODELS")
if AI_HORDE_MODEL_META_LARGE_MODELS:
logger.debug(
f"AI_HORDE_MODEL_META_LARGE_MODELS is {AI_HORDE_MODEL_META_LARGE_MODELS}.",
)


_dev_env_var_warnings()

Expand Down
28 changes: 22 additions & 6 deletions horde_sdk/ai_horde_worker/model_meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re

from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY
from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY
from horde_model_reference.model_reference_manager import ModelReferenceManager
from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord
from loguru import logger
Expand Down Expand Up @@ -128,7 +129,7 @@ def resolve_meta_instructions(
return_list.extend(self.resolve_all_nsfw_model_names())

# If no valid meta instruction were found, return None
return set(return_list)
return self.remove_large_models(set(return_list))

@staticmethod
def meta_instruction_regex_match(instruction: str, target_string: str) -> re.Match[str] | None:
Expand All @@ -140,9 +141,22 @@ def meta_instruction_regex_match(instruction: str, target_string: str) -> re.Mat
Returns:
A Match object if the target string matches the regex pattern, otherwise None.
"""
return re.match(instruction, target_string, re.IGNORECASE)

def remove_large_models(self, models: set[str]) -> set[str]:
"""Remove large models from the input set of models."""
AI_HORDE_MODEL_META_LARGE_MODELS = os.getenv("AI_HORDE_MODEL_META_LARGE_MODELS")
if not AI_HORDE_MODEL_META_LARGE_MODELS:
cascade_models = self.resolve_all_models_of_baseline(STABLE_DIFFUSION_BASELINE_CATEGORY.stable_cascade)
flux_models = self.resolve_all_models_of_baseline(STABLE_DIFFUSION_BASELINE_CATEGORY.flux_1)

logger.debug(f"Removing cascade models: {cascade_models}")
logger.debug(f"Removing flux models: {flux_models}")
models = models - cascade_models - flux_models
return models

def resolve_all_model_names(self) -> set[str]:
"""Get the names of all models defined in the model reference.
Expand All @@ -153,11 +167,13 @@ def resolve_all_model_names(self) -> set[str]:

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

if sd_model_references:
return set(sd_model_references.root.keys())
all_models = set(sd_model_references.root.keys()) if sd_model_references is not None else set()

logger.error("No stable diffusion models found in model reference.")
return set()
all_models = self.remove_large_models(all_models)

if not all_models:
logger.error("No stable diffusion models found in model reference.")
return all_models

def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]:
"""Get the names of all SFW or NSFW models defined in the model reference.
Expand Down
23 changes: 23 additions & 0 deletions tests/ai_horde_worker/test_model_meta_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_image_model_load_resolver_all(image_model_load_resolver: ImageModelLoad

assert len(all_model_names) > 0

import os

os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true"

all_model_names_with_large = image_model_load_resolver.resolve_all_model_names()

del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"]

assert len(all_model_names_with_large) > len(all_model_names)


def test_image_model_load_resolver_top_n(
image_model_load_resolver: ImageModelLoadResolver,
Expand Down Expand Up @@ -179,6 +189,19 @@ def test_image_models_unique_results_only(

assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta

import os

os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true"

resolved_models_names_with_large = image_model_load_resolver.resolve_meta_instructions(
["top 1000", "bottom 1000"],
AIHordeAPIManualClient(),
)

del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"]

assert len(resolved_models_names_with_large) >= len(resolved_model_names)


def test_resolve_all_models_of_baseline(
image_model_load_resolver: ImageModelLoadResolver,
Expand Down
Loading