diff --git a/horde_sdk/__init__.py b/horde_sdk/__init__.py index a0ac6f6..c7d0f6b 100644 --- a/horde_sdk/__init__.py +++ b/horde_sdk/__init__.py @@ -51,6 +51,12 @@ def _dev_env_var_warnings() -> None: # pragma: no cover if len(dev_key) != 10 and len(dev_key) != 22: raise ValueError("AI_HORDE_DEV_APIKEY must be the anon key or 22 characters long.") + AI_HORDE_MODEL_META_LARGE_MODELS = os.getenv("AI_HORDE_MODEL_META_LARGE_MODELS") + if AI_HORDE_MODEL_META_LARGE_MODELS: + logger.debug( + f"AI_HORDE_MODEL_META_LARGE_MODELS is {AI_HORDE_MODEL_META_LARGE_MODELS}.", + ) + _dev_env_var_warnings() diff --git a/horde_sdk/ai_horde_worker/model_meta.py b/horde_sdk/ai_horde_worker/model_meta.py index adaaeb6..e45f273 100644 --- a/horde_sdk/ai_horde_worker/model_meta.py +++ b/horde_sdk/ai_horde_worker/model_meta.py @@ -1,6 +1,7 @@ +import os import re -from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY +from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord from loguru import logger @@ -128,7 +129,7 @@ def resolve_meta_instructions( return_list.extend(self.resolve_all_nsfw_model_names()) # If no valid meta instruction were found, return None - return set(return_list) + return self.remove_large_models(set(return_list)) @staticmethod def meta_instruction_regex_match(instruction: str, target_string: str) -> re.Match[str] | None: @@ -140,9 +141,22 @@ def meta_instruction_regex_match(instruction: str, target_string: str) -> re.Mat Returns: A Match object if the target string matches the regex pattern, otherwise None. + """ return re.match(instruction, target_string, re.IGNORECASE) + def remove_large_models(self, models: set[str]) -> set[str]: + """Remove large models from the input set of models.""" + AI_HORDE_MODEL_META_LARGE_MODELS = os.getenv("AI_HORDE_MODEL_META_LARGE_MODELS") + if not AI_HORDE_MODEL_META_LARGE_MODELS: + cascade_models = self.resolve_all_models_of_baseline(STABLE_DIFFUSION_BASELINE_CATEGORY.stable_cascade) + flux_models = self.resolve_all_models_of_baseline(STABLE_DIFFUSION_BASELINE_CATEGORY.flux_1) + + logger.debug(f"Removing cascade models: {cascade_models}") + logger.debug(f"Removing flux models: {flux_models}") + models = models - cascade_models - flux_models + return models + def resolve_all_model_names(self) -> set[str]: """Get the names of all models defined in the model reference. @@ -153,11 +167,13 @@ def resolve_all_model_names(self) -> set[str]: sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] - if sd_model_references: - return set(sd_model_references.root.keys()) + all_models = set(sd_model_references.root.keys()) if sd_model_references is not None else set() - logger.error("No stable diffusion models found in model reference.") - return set() + all_models = self.remove_large_models(all_models) + + if not all_models: + logger.error("No stable diffusion models found in model reference.") + return all_models def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]: """Get the names of all SFW or NSFW models defined in the model reference. diff --git a/tests/ai_horde_worker/test_model_meta_api_calls.py b/tests/ai_horde_worker/test_model_meta_api_calls.py index 85d4b8e..7a86961 100644 --- a/tests/ai_horde_worker/test_model_meta_api_calls.py +++ b/tests/ai_horde_worker/test_model_meta_api_calls.py @@ -29,6 +29,16 @@ def test_image_model_load_resolver_all(image_model_load_resolver: ImageModelLoad assert len(all_model_names) > 0 + import os + + os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true" + + all_model_names_with_large = image_model_load_resolver.resolve_all_model_names() + + del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] + + assert len(all_model_names_with_large) > len(all_model_names) + def test_image_model_load_resolver_top_n( image_model_load_resolver: ImageModelLoadResolver, @@ -179,6 +189,19 @@ def test_image_models_unique_results_only( assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta + import os + + os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] = "true" + + resolved_models_names_with_large = image_model_load_resolver.resolve_meta_instructions( + ["top 1000", "bottom 1000"], + AIHordeAPIManualClient(), + ) + + del os.environ["AI_HORDE_MODEL_META_LARGE_MODELS"] + + assert len(resolved_models_names_with_large) >= len(resolved_model_names) + def test_resolve_all_models_of_baseline( image_model_load_resolver: ImageModelLoadResolver,