diff --git a/src/vanna/vllm/vllm.py b/src/vanna/vllm/vllm.py index 53990821..3227d83a 100644 --- a/src/vanna/vllm/vllm.py +++ b/src/vanna/vllm/vllm.py @@ -22,6 +22,12 @@ def __init__(self, config=None): else: self.auth_key = None + if "temperature" in config: + self.temperature = config["temperature"] + else: + # default temperature - can be overrided using config + self.temperature = 0.7 + def system_message(self, message: str) -> any: return {"role": "system", "content": message} @@ -68,6 +74,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: url = f"{self.host}/v1/chat/completions" data = { "model": self.model, + "temperature": self.temperature, "stream": False, "messages": prompt, }