From 6d18edda2b72e90caecd2ba4f207c25ba2b2ab8b Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 12 Dec 2023 18:41:39 -0500 Subject: [PATCH] Add assertions for unary failure RPC error tests --- Sources/GoogleAI/Errors.swift | 8 +-- .../GoogleAITests/GenerativeModelTests.swift | 65 ++++++++++++------- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/Sources/GoogleAI/Errors.swift b/Sources/GoogleAI/Errors.swift index 0809589..4fca53c 100644 --- a/Sources/GoogleAI/Errors.swift +++ b/Sources/GoogleAI/Errors.swift @@ -15,11 +15,11 @@ import Foundation struct RPCError: Error { - let httpResponseCode: Int32 + let httpResponseCode: Int let message: String let status: RPCStatus - init(httpResponseCode: Int32, message: String, status: RPCStatus) { + init(httpResponseCode: Int, message: String, status: RPCStatus) { self.httpResponseCode = httpResponseCode self.message = message self.status = status @@ -56,7 +56,7 @@ extension RPCError: Decodable { } struct ErrorStatus { - let code: Int32? + let code: Int? let message: String? let status: RPCStatus? } @@ -70,7 +70,7 @@ extension ErrorStatus: Decodable { init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) - code = try container.decodeIfPresent(Int32.self, forKey: .code) + code = try container.decodeIfPresent(Int.self, forKey: .code) message = try container.decodeIfPresent(String.self, forKey: .message) do { status = try container.decodeIfPresent(RPCStatus.self, forKey: .status) diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 765d964..2af18a7 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -163,24 +163,29 @@ final class GenerativeModelTests: XCTestCase { } func testGenerateContent_failure_invalidAPIKey() async throws { + let expectedStatusCode = 400 MockURLProtocol .requestHandler = try httpRequestHandler( forResource: "unary-failure-api-key", withExtension: "json", - statusCode: 400 + statusCode: expectedStatusCode ) - var responseError: Error? - var content: GenerateContentResponse? do { - content = try await model.generateContent(testPrompt) + _ = try await model.generateContent(testPrompt) + XCTFail("Should throw GenerateContentError.internalError; no error thrown.") + } catch let GenerateContentError.internalError(underlying: underlyingError) { + guard let rpcError = underlyingError as? RPCError else { + XCTFail("Not an RPCError: \(underlyingError)") + return + } + + XCTAssertEqual(rpcError.status, .invalidArgument) + XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) + XCTAssertTrue(rpcError.message.hasPrefix("API key not valid")) } catch { - responseError = error + XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)") } - - XCTAssertNotNil(responseError) - XCTAssertNil(content) - // TODO: Add assertions about `responseError`. } func testGenerateContent_failure_emptyContent() async throws { @@ -243,6 +248,7 @@ final class GenerativeModelTests: XCTestCase { } func testGenerateContent_failure_imageRejected() async throws { + let expectedStatusCode = 400 MockURLProtocol .requestHandler = try httpRequestHandler( forResource: "unary-failure-image-rejected", @@ -250,17 +256,21 @@ final class GenerativeModelTests: XCTestCase { statusCode: 400 ) - var responseError: Error? - var content: GenerateContentResponse? do { - content = try await model.generateContent(testPrompt) + _ = try await model.generateContent(testPrompt) + XCTFail("Should throw GenerateContentError.internalError; no error thrown.") + } catch let GenerateContentError.internalError(underlying: underlyingError) { + guard let rpcError = underlyingError as? RPCError else { + XCTFail("Not an RPCError: \(underlyingError)") + return + } + + XCTAssertEqual(rpcError.status, .invalidArgument) + XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) + XCTAssertEqual(rpcError.message, "Request contains an invalid argument.") } catch { - responseError = error + XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)") } - - XCTAssertNotNil(responseError) - XCTAssertNil(content) - // TODO: Add assertions about `responseError`. } func testGenerateContent_failure_promptBlockedSafety() async throws { @@ -281,6 +291,7 @@ final class GenerativeModelTests: XCTestCase { } func testGenerateContent_failure_unknownModel() async throws { + let expectedStatusCode = 404 MockURLProtocol .requestHandler = try httpRequestHandler( forResource: "unary-failure-unknown-model", @@ -288,17 +299,21 @@ final class GenerativeModelTests: XCTestCase { statusCode: 404 ) - var responseError: Error? - var content: GenerateContentResponse? do { - content = try await model.generateContent(testPrompt) + _ = try await model.generateContent(testPrompt) + XCTFail("Should throw GenerateContentError.internalError; no error thrown.") + } catch let GenerateContentError.internalError(underlying: underlyingError) { + guard let rpcError = underlyingError as? RPCError else { + XCTFail("Not an RPCError: \(underlyingError)") + return + } + + XCTAssertEqual(rpcError.status, .notFound) + XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) + XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found")) } catch { - responseError = error + XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)") } - - XCTAssertNotNil(responseError) - XCTAssertNil(content) - // TODO: Add assertions about `responseError`. } func testGenerateContent_failure_nonHTTPResponse() async throws {