Skip to content

Commit

Permalink
feat: support for openai API (#324)
Browse files Browse the repository at this point in the history
* feat: support for openai API
  • Loading branch information
db0 authored Sep 14, 2024
1 parent 8d8d778 commit 502d91a
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 21 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ This repository allows you to set up a AI Horde Worker to generate or alchemize

## Important Note:

- As of January 2024, the official worker is now [horde-worker-reGen](https://github.com/Haidra-Org/horde-worker-reGen).
- As of January 2024, the official worker for image generation is now [horde-worker-reGen](https://github.com/Haidra-Org/horde-worker-reGen).
- You should use `reGen` if you are a new worker and are looking to do *image generation*.
- If you are looking to do *alchemy* (post-processing, interrogation, captioning, etc), you should continue to use `AI-Horde-Worker`.
- If you are looking to do *text generation*, or *alchemy* (post-processing, interrogation, captioning, etc), you should continue to use `AI-Horde-Worker`.


# Legacy information:
Expand Down Expand Up @@ -70,7 +70,9 @@ The latter option will allow you to see errors in case of a crash, so it's recom

## Update runtime

If you have just installed or updated your worker code run the `update-runtime` script. This will ensure the dependencies needed for your worker to run are up to date
If you have just installed or updated your worker code run the `update-runtime` script. This will ensure the dependencies needed for your worker to run are up to date.

For a scribe (i.e text generation), run `update-runtime --scribe` instead as it has far fewer requirements.

This script can take 10-15 minutes to complete.

Expand Down
6 changes: 6 additions & 0 deletions bridgeData_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ max_context_length: 1024
# This will prevent the model from being used from the shared pool, but will ensure that no other worker
# can pretend to serve it
branded_model: false
# Set to true to use an OpenAI API compatible backend. If you the backend is unknown, please set the backend_engine manually
openai_api: false
# Set this to the name of your backend, unless it's one of the following. The ones below will be autodetermined.
# * aphrodite
backend_engine: unknown


## Alchemist (Image interrogation and post-processing)

Expand Down
6 changes: 6 additions & 0 deletions worker/argparser/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
required=False,
help="Set to true if you do not want this worker generating NSFW images.",
)
arg_parser.add_argument(
"--openai_api",
action="store_true",
required=False,
help="Set to true to expect OpenAI API from the backend.",
)
arg_parser.add_argument(
"--blacklist",
nargs="+",
Expand Down
1 change: 1 addition & 0 deletions worker/bridge_data/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, args):
self.models_reloading = False
self.max_models_to_download = 10
self.suppress_speed_warnings = False
self.backend_engine = os.environ.get("HORDE_BACKEND_ENGINE", "unknown")

def load_config(self):
# YAML config
Expand Down
30 changes: 23 additions & 7 deletions worker/bridge_data/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self):
self.branded_model = os.environ.get("HORDE_BRANDED_MODEL", "false") == "true"
self.softprompts = {}
self.current_softprompt = None
self.openai_api = os.environ.get("HORDE_BACKEND_OPENAI_API", "false") == "true"

self.nsfw = os.environ.get("HORDE_NSFW", "true") == "true"
self.blacklist = list(filter(lambda a: a, os.environ.get("HORDE_BLACKLIST", "").split(",")))
Expand All @@ -38,6 +39,8 @@ def reload_data(self):
self.nsfw = False
if args.blacklist:
self.blacklist = args.blacklist
if args.openai_api:
self.openai_api = args.openai_api
self.validate_kai()
if self.kai_available and not self.initialized and previous_url != self.horde_url:
logger.init(
Expand All @@ -53,8 +56,20 @@ def reload_data(self):
def validate_kai(self):
logger.debug("Retrieving settings from KoboldAI Client...")
try:
req = requests.get(self.kai_url + "/api/latest/model")
self.model = req.json()["result"]
version_req = requests.get(self.kai_url + "/version")
if version_req.ok:
self.backend_engine = f"aphrodite"
else:
logger.warning("Unable to determine OpenAI API compatible backend engine. Will report it as unknown to the Horde which will lead to less kudos rewards.")
if self.openai_api:
req = requests.get(self.kai_url + "/v1/models")
self.model = req.json()["data"][0]['id']
logger.debug([self.model,self.backend_engine])
self.backend_engine += '~oai'
else:
req = requests.get(self.kai_url + "/api/latest/model")
self.model = req.json()["result"]
self.backend_engine += '~kai'
# Normalize huggingface and local downloaded model names
if "/" not in self.model:
self.model = self.model.replace("_", "/", 1)
Expand All @@ -63,11 +78,12 @@ def validate_kai(self):
# self.max_context_length = req.json()["value"]
# req = requests.get(self.kai_url + "/api/latest/config/max_length")
# self.max_length = req.json()["value"]
if self.model not in self.softprompts:
req = requests.get(self.kai_url + "/api/latest/config/soft_prompts_list")
self.softprompts[self.model] = [sp["value"] for sp in req.json()["values"]]
req = requests.get(self.kai_url + "/api/latest/config/soft_prompt")
self.current_softprompt = req.json()["value"]
if not self.openai_api:
if self.model not in self.softprompts:
req = requests.get(self.kai_url + "/api/latest/config/soft_prompts_list")
self.softprompts[self.model] = [sp["value"] for sp in req.json()["values"]]
req = requests.get(self.kai_url + "/api/latest/config/soft_prompt")
self.current_softprompt = req.json()["value"]
except requests.exceptions.JSONDecodeError:
logger.error(f"Server {self.kai_url} is up but does not appear to be a KoboldAI server.")
self.kai_available = False
Expand Down
12 changes: 7 additions & 5 deletions worker/jobs/poppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

class JobPopper:
retry_interval = 1
BRIDGE_AGENT = f"AI Horde Worker:{BRIDGE_VERSION}:https://github.com/db0/AI-Horde-Worker"

def __init__(self, mm, bd):
self.model_manager = mm
self.bridge_data = copy.deepcopy(bd)
self.bridge_agent = f"AI Horde Worker:{BRIDGE_VERSION}:https://github.com/db0/AI-Horde-Worker"
if self.bridge_data.backend_engine:
self.bridge_agent = f"AI Horde Worker~{self.bridge_data.backend_engine}:{BRIDGE_VERSION}:https://github.com/db0/AI-Horde-Worker"
self.pop = None
self.headers = {"apikey": self.bridge_data.api_key}
# This should be set by the extending class
Expand Down Expand Up @@ -161,7 +163,7 @@ def __init__(self, mm, bd):
"allow_lora": self.bridge_data.allow_lora if self.model_manager.lora.are_downloads_complete() else False,
"require_upfront_kudos": self.bridge_data.require_upfront_kudos,
"bridge_version": BRIDGE_VERSION,
"bridge_agent": self.BRIDGE_AGENT,
"bridge_agent": self.bridge_agent,
}
# logger.debug("Cron: End constructing pop payload")

Expand Down Expand Up @@ -216,8 +218,8 @@ def __init__(self, mm, bd):
"max_length": self.bridge_data.max_length,
"max_context_length": self.bridge_data.max_context_length,
"priority_usernames": self.bridge_data.priority_usernames,
"softprompts": self.bridge_data.softprompts[self.bridge_data.model],
"bridge_agent": self.BRIDGE_AGENT,
"softprompts": self.bridge_data.softprompts[self.bridge_data.model] if bd.openai_api is False else [],
"bridge_agent": self.bridge_agent,
"threads": self.bridge_data.max_threads,
}

Expand Down Expand Up @@ -252,7 +254,7 @@ def __init__(self, mm, bd):
"priority_usernames": self.bridge_data.priority_usernames,
"threads": self.bridge_data.max_threads,
"bridge_version": BRIDGE_VERSION,
"bridge_agent": self.BRIDGE_AGENT,
"bridge_agent": self.bridge_agent,
"max_tiles": self.bridge_data.max_power,
}
logger.debug(self.pop_payload)
Expand Down
62 changes: 56 additions & 6 deletions worker/jobs/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, mm, bd, pop):
self.requested_softprompt = self.current_payload.get("softprompt")
self.censored = None
self.max_seconds = None
self.openai_api = self.bridge_data.openai_api

@logger.catch(reraise=True)
def start_job(self):
Expand Down Expand Up @@ -63,11 +64,57 @@ def start_job(self):
gen_success = False
while not gen_success and loop_retry < 5:
try:
gen_req = requests.post(
self.bridge_data.kai_url + "/api/latest/generate",
json=self.current_payload,
timeout=self.max_seconds,
)
if not self.openai_api:
gen_req = requests.post(
self.bridge_data.kai_url + "/api/latest/generate",
json=self.current_payload,
timeout=self.max_seconds,
)
else:
oai_payload = {
"model": self.current_model,
"prompt": self.current_payload["prompt"],
}
oai_kai_translations = {
"rep_pen": "repetition_penalty",
"temperature": "temperature",
"dynatemp_exponent": "dynatemp_exponent",
"dynatemp_max": "dynatemp_max",
"dynatemp_min": "dynatemp_min",
"tfs": "tfs",
"top_k": "top_k",
"top_p": "top_p",
"top_a": "top_a",
"min_p": "min_p",
"typical": "typical_p",
"eta_cutoff": "eta_cutoff",
"eps_cutoff": "epsilon_cutoff",
"mirostat": "mirostat_mode",
"mirostat": "mirostat_mode",
"mirostat_tau": "mirostat_tau",
"mirostat_eta": "mirostat_eta",
"stop_sequence": "stop",
"include_stop_str_in_output": "include_stop_str_in_output",
"badwordsids": "custom_token_bans",
"use_default_badwordsids": "use_default_badwordsids",
"max_length": "max_tokens",
"max_context_length": "truncate_prompt_tokens",
"sampler_seed": "seed",
}
for kai,oai in oai_kai_translations.items():
if kai in self.current_payload:
oai_payload[oai] = self.current_payload[kai]
if "top_k" in oai_payload:
oai_payload["top_k"] = oai_payload["top_k"] if oai_payload["top_k"] != 0.0 else -1
oai_payload["truncate_prompt_tokens"] = max(
1, self.current_payload.get("max_context_length", 2048) - self.current_payload.get("max_length", 256))
logger.debug("Attempting OpenAI API...")
gen_req = requests.post(
self.bridge_data.kai_url + "/v1/completions",
json=oai_payload,
timeout=self.max_seconds,
)

except requests.exceptions.ConnectionError:
logger.error(f"Worker {self.bridge_data.kai_url} unavailable. Retrying in 3 seconds...")
loop_retry += 1
Expand Down Expand Up @@ -115,7 +162,10 @@ def start_job(self):
time.sleep(3)
continue
try:
self.text = req_json["results"][0]["text"]
if self.bridge_data.openai_api:
self.text = req_json["choices"][0]["text"]
else:
self.text = req_json["results"][0]["text"]
except KeyError:
logger.error(
(
Expand Down

0 comments on commit 502d91a

Please sign in to comment.