diff --git a/private_gpt/server/chat/chat_service.py b/private_gpt/server/chat/chat_service.py index 11352a7cf4..fbc413cbf0 100644 --- a/private_gpt/server/chat/chat_service.py +++ b/private_gpt/server/chat/chat_service.py @@ -32,10 +32,12 @@ class CompletionGen(BaseModel): response: TokenGen sources: list[Chunk] | None = None + class SqlQueryResponse(BaseModel): response: str sources: None = None + @dataclass class ChatEngineInput: system_message: ChatMessage | None = None diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 0fbd5e7e36..1efef0b355 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -14,7 +14,11 @@ from private_gpt.constants import PROJECT_ROOT_PATH from private_gpt.di import global_injector -from private_gpt.server.chat.chat_service import ChatService, CompletionGen, SqlQueryResponse +from private_gpt.server.chat.chat_service import ( + ChatService, + CompletionGen, + SqlQueryResponse, +) from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.settings.settings import settings @@ -79,7 +83,7 @@ def __init__( def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: def yield_deltas( - completion_gen: CompletionGen|SqlQueryResponse, sources: bool = True + completion_gen: CompletionGen | SqlQueryResponse, sources: bool = True ) -> Iterable[str]: full_response: str = "" stream = completion_gen.response