From 7b95542551c490a438d1bb7d76a88024a929cd26 Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Fri, 6 Sep 2024 13:14:31 -0700 Subject: [PATCH] fix: GenAI - Fixed ValueError when passing response_schema as a dict in generate_content() PiperOrigin-RevId: 671868501 --- tests/unit/vertexai/test_generative_models.py | 36 +++++++++++++++++++ .../generative_models/_generative_models.py | 9 ++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index e6e75fb34a..4d795be593 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -616,6 +616,42 @@ def test_generate_content_streaming(self, generative_models: generative_models): for chunk in stream: assert chunk.text + @mock.patch.object( + target=prediction_service.PredictionServiceClient, + attribute="generate_content", + new=mock_generate_content, + ) + @pytest.mark.parametrize( + "generative_models", + [generative_models, preview_generative_models], + ) + def test_generate_content_with_dict_configs( + self, generative_models: generative_models + ): + model = generative_models.GenerativeModel("gemini-pro") + response = model.generate_content( + "Why is sky blue?", + generation_config={ + "temperature": 0.2, + "top_p": 0.9, + "top_k": 20, + "response_mime_type": "application/json", + "response_schema": { + "type_": "OBJECT", + "properties": { + "description": {"type": "STRING"}, + "steps": {"type": "BOOLEAN"}, + }, + "required": ["description"], + }, + }, + safety_settings={ + generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + ) + assert response.text + @mock.patch.object( target=prediction_service.PredictionServiceClient, attribute="generate_content", diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index e41a7f4145..b5eea235cb 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -465,9 +465,9 @@ def _prepare_request( elif isinstance(generation_config, GenerationConfig): gapic_generation_config = generation_config._raw_generation_config elif isinstance(generation_config, Dict): - gapic_generation_config = gapic_content_types.GenerationConfig( + gapic_generation_config = GenerationConfig( **generation_config - ) + )._raw_generation_config gapic_safety_settings = None if safety_settings: @@ -1603,10 +1603,7 @@ def _from_gapic( @classmethod def from_dict(cls, generation_config_dict: Dict[str, Any]) -> "GenerationConfig": - raw_generation_config = gapic_content_types.GenerationConfig( - generation_config_dict - ) - return cls._from_gapic(raw_generation_config=raw_generation_config) + return cls(**generation_config_dict) def to_dict(self) -> Dict[str, Any]: return _proto_to_dict(self._raw_generation_config)