From 987910731149966258a4fa118821055bd9d7e098 Mon Sep 17 00:00:00 2001 From: Fumito Ito Date: Wed, 20 Nov 2024 15:56:09 +0900 Subject: [PATCH 1/2] validate message and batch --- Sources/AnthropicSwiftSDK/ClientError.swift | 8 ++++ Sources/AnthropicSwiftSDK/Entity/Model.swift | 44 +++++++++++++++++++ .../AnthropicSwiftSDK/MessageBatches.swift | 23 +++++++++- Sources/AnthropicSwiftSDK/Messages.swift | 19 +++++++- 4 files changed, 91 insertions(+), 3 deletions(-) diff --git a/Sources/AnthropicSwiftSDK/ClientError.swift b/Sources/AnthropicSwiftSDK/ClientError.swift index 55e6cb3..f7720ea 100644 --- a/Sources/AnthropicSwiftSDK/ClientError.swift +++ b/Sources/AnthropicSwiftSDK/ClientError.swift @@ -35,6 +35,10 @@ public enum ClientError: Error { case failedToMakeEncodableToolUseInput([String: Any]) /// SDK failed to encode `SystemPrompt` object case failedToEncodeSystemPrompt + /// These messages are not supported by the model. + case unsupportedMessageContentContained(model: Model, messages: [Message]) + /// Some unsupported features are used. + case unsupportedFeatureUsed(description: String) /// Description of sdk internal errors. public var localizedDescription: String { @@ -63,6 +67,10 @@ public enum ClientError: Error { return "Failed to make ToolUse.input object Encodable" case .failedToEncodeSystemPrompt: return "Failed to encode `SystemPrompt` object" + case let .unsupportedMessageContentContained(model, messages): + return "The model \(model.stringfy) does not support these messages: \(messages)" + case let .unsupportedFeatureUsed(description): + return "Some unsupported features are used. For more detail, see \(description)." } } } diff --git a/Sources/AnthropicSwiftSDK/Entity/Model.swift b/Sources/AnthropicSwiftSDK/Entity/Model.swift index 246b328..8e3f193 100644 --- a/Sources/AnthropicSwiftSDK/Entity/Model.swift +++ b/Sources/AnthropicSwiftSDK/Entity/Model.swift @@ -47,6 +47,50 @@ public enum Model { } } +extension Model { + /// Whether this model supports Message Batches API or not. + /// + /// `Claude 3.0 Sonnet` does not support it. + var isSupportBatches: Bool { + switch self { + case + .claude_3_Opus, + .claude_3_Haiku, + .claude_3_5_Sonnet, + .claude_3_5_Haiku, + .custom: + return true + case .claude_3_Sonnet: + return false + } + } + + /// Whether this model supports Vision feature or not. + /// + /// `Claude 3.5 Haiku` does not support it. + var isSupportVision: Bool { + switch self { + case + .claude_3_Opus, + .claude_3_Haiku, + .claude_3_Sonnet, + .claude_3_5_Sonnet, + .custom: + return true + case .claude_3_5_Haiku: + return false + } + } + + func isValid(for message: Message) -> Bool { + if isSupportVision { + return true + } + + return message.content.allSatisfy { $0.contentType != .image } + } +} + extension Model { var stringfy: String { switch self { diff --git a/Sources/AnthropicSwiftSDK/MessageBatches.swift b/Sources/AnthropicSwiftSDK/MessageBatches.swift index 7c6d67b..2f745f0 100644 --- a/Sources/AnthropicSwiftSDK/MessageBatches.swift +++ b/Sources/AnthropicSwiftSDK/MessageBatches.swift @@ -31,7 +31,9 @@ public struct MessageBatches { /// - Returns: A `BatchResponse` containing the details of the created batches. /// - Throws: An error if the request fails. public func createBatches(batches: [MessageBatch]) async throws -> BatchResponse { - try await createBatches( + try validate(batches: batches) + + return try await createBatches( batches: batches, anthropicHeaderProvider: DefaultAnthropicHeaderProvider(), authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey) @@ -336,3 +338,22 @@ public struct MessageBatches { return try anthropicJSONDecoder.decode(BatchResponse.self, from: data) } } + +extension MessageBatches { + func validate(batches: [MessageBatch]) throws { + try batches.forEach { batch in + let model = batch.parameter.model + guard model.isSupportBatches else { + throw ClientError.unsupportedFeatureUsed(description: "The model: \(model.stringfy) does not support Message Batches API") + } + + let messages = batch.parameter.messages + guard (messages.allSatisfy { model.isValid(for: $0) }) else { + throw ClientError.unsupportedMessageContentContained( + model: model, + messages: messages.filter { model.isValid(for: $0) == false } + ) + } + } + } +} diff --git a/Sources/AnthropicSwiftSDK/Messages.swift b/Sources/AnthropicSwiftSDK/Messages.swift index 47ee92f..536e941 100644 --- a/Sources/AnthropicSwiftSDK/Messages.swift +++ b/Sources/AnthropicSwiftSDK/Messages.swift @@ -45,7 +45,9 @@ public struct Messages { tools: [Tool]? = nil, toolChoice: ToolChoice = .auto ) async throws -> MessagesResponse { - try await createMessage( + try validate(model, for: messages) + + return try await createMessage( messages, model: model, system: system, @@ -160,7 +162,9 @@ public struct Messages { tools: [Tool]? = nil, toolChoice: ToolChoice = .auto ) async throws -> AsyncThrowingStream { - try await streamMessage( + try validate(model, for: messages) + + return try await streamMessage( messages, model: model, system: system, @@ -246,3 +250,14 @@ public struct Messages { return try await AnthropicStreamingParser.parse(stream: data.lines).accumulated() } } + +extension Messages { + func validate(_ model: Model, for messages: [Message]) throws { + guard (messages.allSatisfy { model.isValid(for: $0) }) else { + throw ClientError.unsupportedMessageContentContained( + model: model, + messages: messages.filter { model.isValid(for: $0) == false } + ) + } + } +} From 2b2a0aeb6a8cbbeae739a26303bac6bf2a0133f6 Mon Sep 17 00:00:00 2001 From: Fumito Ito Date: Wed, 27 Nov 2024 17:42:51 +0900 Subject: [PATCH 2/2] add tests --- .../Entity/ModelTests.swift | 62 ++++++++++ .../MessageBatchesTests.swift | 109 ++++++++++++++++++ .../MessagesTests.swift | 58 ++++++++++ 3 files changed, 229 insertions(+) create mode 100644 Tests/AnthropicSwiftSDKTests/Entity/ModelTests.swift diff --git a/Tests/AnthropicSwiftSDKTests/Entity/ModelTests.swift b/Tests/AnthropicSwiftSDKTests/Entity/ModelTests.swift new file mode 100644 index 0000000..7cce76f --- /dev/null +++ b/Tests/AnthropicSwiftSDKTests/Entity/ModelTests.swift @@ -0,0 +1,62 @@ +// +// ModelTests.swift +// AnthropicSwiftSDK +// +// Created by 伊藤史 on 2024/11/27. +// + +import XCTest +@testable import AnthropicSwiftSDK + +final class ModelTests: XCTestCase { + + func testIsSupportBatches() { + XCTAssertTrue(Model.claude_3_Opus.isSupportBatches, "claude_3_Opus should support batches.") + XCTAssertFalse(Model.claude_3_Sonnet.isSupportBatches, "claude_3_Sonnet should not support batches.") + XCTAssertTrue(Model.claude_3_Haiku.isSupportBatches, "claude_3_Haiku should support batches.") + XCTAssertTrue(Model.claude_3_5_Sonnet.isSupportBatches, "claude_3_5_Sonnet should support batches.") + XCTAssertTrue(Model.claude_3_5_Haiku.isSupportBatches, "claude_3_5_Haiku should support batches.") + XCTAssertTrue(Model.custom("custom-model").isSupportBatches, "Custom models should support batches.") + } + + func testIsSupportVision() { + XCTAssertTrue(Model.claude_3_Opus.isSupportVision, "claude_3_Opus should support vision.") + XCTAssertTrue(Model.claude_3_Sonnet.isSupportVision, "claude_3_Sonnet should support vision.") + XCTAssertTrue(Model.claude_3_Haiku.isSupportVision, "claude_3_Haiku should support vision.") + XCTAssertTrue(Model.claude_3_5_Sonnet.isSupportVision, "claude_3_5_Sonnet should support vision.") + XCTAssertFalse(Model.claude_3_5_Haiku.isSupportVision, "claude_3_5_Haiku should not support vision.") + XCTAssertTrue(Model.custom("custom-model").isSupportVision, "Custom models should support vision.") + } + + func testIsValid() { + + let textMessage = Message(role: .user, content: [.text("")]) + let imageMessage = Message(role: .user, content: [.image(.init(type: .base64, mediaType: .gif, data: Data()))]) + let documentMessage = Message(role: .user, content: [.document(.init(type: .base64, mediaType: .pdf, data: Data()))]) + + // Models that support vision + XCTAssertTrue(Model.claude_3_Opus.isValid(for: textMessage), "claude_3_Opus should validate text messages.") + XCTAssertTrue(Model.claude_3_Opus.isValid(for: imageMessage), "claude_3_Opus should validate image messages.") + XCTAssertTrue(Model.claude_3_Opus.isValid(for: documentMessage), "claude_3_Opus should validate document messages.") + + XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: textMessage), "claude_3_Opus should validate text messages.") + XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: imageMessage), "claude_3_Opus should validate image messages.") + XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: documentMessage), "claude_3_Opus should validate document messages.") + + XCTAssertTrue(Model.claude_3_Haiku.isValid(for: textMessage), "claude_3_Opus should validate text messages.") + XCTAssertTrue(Model.claude_3_Haiku.isValid(for: imageMessage), "claude_3_Opus should validate image messages.") + XCTAssertTrue(Model.claude_3_Haiku.isValid(for: documentMessage), "claude_3_Opus should validate document messages.") + + XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: textMessage), "claude_3_Opus should validate text messages.") + XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: imageMessage), "claude_3_Opus should validate image messages.") + XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: documentMessage), "claude_3_Opus should validate document messages.") + + XCTAssertTrue(Model.claude_3_5_Haiku.isValid(for: textMessage), "claude_3_Opus should validate text messages.") + XCTAssertFalse(Model.claude_3_5_Haiku.isValid(for: imageMessage), "claude_3_Opus should validate image messages.") + XCTAssertTrue(Model.claude_3_5_Haiku.isValid(for: documentMessage), "claude_3_Opus should validate document messages.") + + XCTAssertTrue(Model.custom("custom-model").isValid(for: textMessage), "claude_3_Opus should validate text messages.") + XCTAssertTrue(Model.custom("custom-model").isValid(for: imageMessage), "claude_3_Opus should validate image messages.") + XCTAssertTrue(Model.custom("custom-model").isValid(for: documentMessage), "claude_3_Opus should validate document messages.") + } +} diff --git a/Tests/AnthropicSwiftSDKTests/MessageBatchesTests.swift b/Tests/AnthropicSwiftSDKTests/MessageBatchesTests.swift index 0b798b6..25967a1 100644 --- a/Tests/AnthropicSwiftSDKTests/MessageBatchesTests.swift +++ b/Tests/AnthropicSwiftSDKTests/MessageBatchesTests.swift @@ -240,4 +240,113 @@ final class MessageBatchesTests: XCTestCase { XCTAssertEqual(error, .invalidRequestError) } } + + func testValidate_Success() { + let batch = MessageBatch( + customId: "", + parameter: .init( + messages: [ + .init( + role: .user, + content: [ + .text("") + ] + ), + .init( + role: .user, + content: [ + .text("") + ] + ) + ], + model: .claude_3_Opus, + maxTokens: 1 + ) + ) + let batches = [batch] + let messageBatches = MessageBatches(apiKey: "", session: .shared) + + XCTAssertNoThrow(try messageBatches.validate(batches: batches)) + } + + func testValidate_ModelDoesNotSupportBatches() { + let batch = MessageBatch( + customId: "", + parameter: .init( + messages: [ + .init( + role: .user, + content: [ + .text("") + ] + ), + .init( + role: .user, + content: [ + .text("") + ] + ) + ], + model: .claude_3_Sonnet, + maxTokens: 1 + ) + ) + let batches = [batch] + let messageBatches = MessageBatches(apiKey: "", session: .shared) + + XCTAssertThrowsError(try messageBatches.validate(batches: batches)) { error in + guard let clientError = error as? ClientError else { + XCTFail("Expected ClientError but got \(error)") + return + } + switch clientError { + case .unsupportedFeatureUsed(let description): + XCTAssertEqual(description, "The model: \(Model.claude_3_Sonnet.stringfy) does not support Message Batches API") + default: + XCTFail("Unexpected ClientError: \(clientError)") + } + } + } + + func testValidate_UnsupportedMessageContentContained() { + // Arrange + let batch = MessageBatch( + customId: "", + parameter: .init( + messages: [ + .init( + role: .user, + content: [ + .image(.init(type: .base64, mediaType: .png, data: Data())) + ] + ), + .init( + role: .user, + content: [ + .text("") + ] + ) + ], + model: .claude_3_5_Haiku, + maxTokens: 1 + ) + ) + let batches = [batch] + let messageBatches = MessageBatches(apiKey: "", session: .shared) + + XCTAssertThrowsError(try messageBatches.validate(batches: batches)) { error in + guard let clientError = error as? ClientError else { + XCTFail("Expected ClientError but got \(error)") + return + } + switch clientError { + case .unsupportedMessageContentContained(let model, let messages): + XCTAssertEqual(model.stringfy, Model.claude_3_5_Haiku.stringfy) + XCTAssertEqual(messages.count, 1) + XCTAssertEqual(messages.first?.content.first?.contentType, .image) + default: + XCTFail("Unexpected ClientError: \(clientError)") + } + } + } } diff --git a/Tests/AnthropicSwiftSDKTests/MessagesTests.swift b/Tests/AnthropicSwiftSDKTests/MessagesTests.swift index da64c2a..d4d4c63 100644 --- a/Tests/AnthropicSwiftSDKTests/MessagesTests.swift +++ b/Tests/AnthropicSwiftSDKTests/MessagesTests.swift @@ -121,4 +121,62 @@ final class MessagesTests: XCTestCase { XCTAssertEqual(response.message.usage.outputTokens, 1) } } + + func testValidate_Success() { + let messages: [Message] = [ + .init(role: .user, content: [.text("Valid message")]), + .init(role: .user, content: [.text("Another valid message")]) + ] + let messagesHandler = Messages(apiKey: "", session: .shared) + + XCTAssertNoThrow(try messagesHandler.validate(.claude_3_Opus, for: messages)) + } + + func testValidate_UnsupportedMessageContentContained() { + let messages: [Message] = [ + .init(role: .user, content: [.text("Valid text message")]), + .init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]) // Unsupported image + ] + let messagesHandler = Messages(apiKey: "", session: .shared) + + // Act & Assert + XCTAssertThrowsError(try messagesHandler.validate(.claude_3_5_Haiku, for: messages)) { error in + guard let clientError = error as? ClientError else { + XCTFail("Expected ClientError but got \(error)") + return + } + switch clientError { + case .unsupportedMessageContentContained(let invalidModel, let invalidMessages): + XCTAssertEqual(invalidModel.stringfy, Model.claude_3_5_Haiku.stringfy) + XCTAssertEqual(invalidMessages.count, 1) + XCTAssertEqual(invalidMessages.first?.content.first?.contentType, .image) + default: + XCTFail("Unexpected ClientError: \(clientError)") + } + } + } + + func testValidate_AllUnsupportedMessages() { + let messages: [Message] = [ + .init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]), // Unsupported image + .init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]) // Unsupported image + ] + let messagesHandler = Messages(apiKey: "", session: .shared) + + // Act & Assert + XCTAssertThrowsError(try messagesHandler.validate(.claude_3_5_Haiku, for: messages)) { error in + guard let clientError = error as? ClientError else { + XCTFail("Expected ClientError but got \(error)") + return + } + switch clientError { + case .unsupportedMessageContentContained(let invalidModel, let invalidMessages): + XCTAssertEqual(invalidModel.stringfy, Model.claude_3_5_Haiku.stringfy) + XCTAssertEqual(invalidMessages.count, 2) + XCTAssertTrue(invalidMessages.allSatisfy { $0.content.first?.contentType == .image }) + default: + XCTFail("Unexpected ClientError: \(clientError)") + } + } + } }