Skip to content

Commit

Permalink
Merge pull request #66 from fumito-ito/feature/validate-messages
Browse files Browse the repository at this point in the history
validate message and batch
  • Loading branch information
fumito-ito authored Nov 28, 2024
2 parents 2cbf664 + 2b2a0ae commit dcfedf3
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 3 deletions.
8 changes: 8 additions & 0 deletions Sources/AnthropicSwiftSDK/ClientError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public enum ClientError: Error {
case failedToMakeEncodableToolUseInput([String: Any])
/// SDK failed to encode `SystemPrompt` object
case failedToEncodeSystemPrompt
/// These messages are not supported by the model.
case unsupportedMessageContentContained(model: Model, messages: [Message])
/// Some unsupported features are used.
case unsupportedFeatureUsed(description: String)

/// Description of sdk internal errors.
public var localizedDescription: String {
Expand Down Expand Up @@ -63,6 +67,10 @@ public enum ClientError: Error {
return "Failed to make ToolUse.input object Encodable"
case .failedToEncodeSystemPrompt:
return "Failed to encode `SystemPrompt` object"
case let .unsupportedMessageContentContained(model, messages):
return "The model \(model.stringfy) does not support these messages: \(messages)"
case let .unsupportedFeatureUsed(description):
return "Some unsupported features are used. For more detail, see \(description)."
}
}
}
44 changes: 44 additions & 0 deletions Sources/AnthropicSwiftSDK/Entity/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,50 @@ public enum Model {
}
}

extension Model {
/// Whether this model supports Message Batches API or not.
///
/// `Claude 3.0 Sonnet` does not support it.
var isSupportBatches: Bool {
switch self {
case
.claude_3_Opus,
.claude_3_Haiku,
.claude_3_5_Sonnet,
.claude_3_5_Haiku,
.custom:
return true
case .claude_3_Sonnet:
return false
}
}

/// Whether this model supports Vision feature or not.
///
/// `Claude 3.5 Haiku` does not support it.
var isSupportVision: Bool {
switch self {
case
.claude_3_Opus,
.claude_3_Haiku,
.claude_3_Sonnet,
.claude_3_5_Sonnet,
.custom:
return true
case .claude_3_5_Haiku:
return false
}
}

func isValid(for message: Message) -> Bool {
if isSupportVision {
return true
}

return message.content.allSatisfy { $0.contentType != .image }
}
}

extension Model {
var stringfy: String {
switch self {
Expand Down
23 changes: 22 additions & 1 deletion Sources/AnthropicSwiftSDK/MessageBatches.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public struct MessageBatches {
/// - Returns: A `BatchResponse` containing the details of the created batches.
/// - Throws: An error if the request fails.
public func createBatches(batches: [MessageBatch]) async throws -> BatchResponse {
try await createBatches(
try validate(batches: batches)

return try await createBatches(
batches: batches,
anthropicHeaderProvider: DefaultAnthropicHeaderProvider(),
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
Expand Down Expand Up @@ -336,3 +338,22 @@ public struct MessageBatches {
return try anthropicJSONDecoder.decode(BatchResponse.self, from: data)
}
}

extension MessageBatches {
func validate(batches: [MessageBatch]) throws {
try batches.forEach { batch in
let model = batch.parameter.model
guard model.isSupportBatches else {
throw ClientError.unsupportedFeatureUsed(description: "The model: \(model.stringfy) does not support Message Batches API")
}

let messages = batch.parameter.messages
guard (messages.allSatisfy { model.isValid(for: $0) }) else {
throw ClientError.unsupportedMessageContentContained(
model: model,
messages: messages.filter { model.isValid(for: $0) == false }
)
}
}
}
}
19 changes: 17 additions & 2 deletions Sources/AnthropicSwiftSDK/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ public struct Messages {
tools: [Tool]? = nil,
toolChoice: ToolChoice = .auto
) async throws -> MessagesResponse {
try await createMessage(
try validate(model, for: messages)

return try await createMessage(
messages,
model: model,
system: system,
Expand Down Expand Up @@ -160,7 +162,9 @@ public struct Messages {
tools: [Tool]? = nil,
toolChoice: ToolChoice = .auto
) async throws -> AsyncThrowingStream<StreamingResponse, Error> {
try await streamMessage(
try validate(model, for: messages)

return try await streamMessage(
messages,
model: model,
system: system,
Expand Down Expand Up @@ -246,3 +250,14 @@ public struct Messages {
return try await AnthropicStreamingParser.parse(stream: data.lines).accumulated()
}
}

extension Messages {
func validate(_ model: Model, for messages: [Message]) throws {
guard (messages.allSatisfy { model.isValid(for: $0) }) else {
throw ClientError.unsupportedMessageContentContained(
model: model,
messages: messages.filter { model.isValid(for: $0) == false }
)
}
}
}
62 changes: 62 additions & 0 deletions Tests/AnthropicSwiftSDKTests/Entity/ModelTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//
// ModelTests.swift
// AnthropicSwiftSDK
//
// Created by 伊藤史 on 2024/11/27.
//

import XCTest
@testable import AnthropicSwiftSDK

final class ModelTests: XCTestCase {

func testIsSupportBatches() {
XCTAssertTrue(Model.claude_3_Opus.isSupportBatches, "claude_3_Opus should support batches.")
XCTAssertFalse(Model.claude_3_Sonnet.isSupportBatches, "claude_3_Sonnet should not support batches.")
XCTAssertTrue(Model.claude_3_Haiku.isSupportBatches, "claude_3_Haiku should support batches.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isSupportBatches, "claude_3_5_Sonnet should support batches.")
XCTAssertTrue(Model.claude_3_5_Haiku.isSupportBatches, "claude_3_5_Haiku should support batches.")
XCTAssertTrue(Model.custom("custom-model").isSupportBatches, "Custom models should support batches.")
}

func testIsSupportVision() {
XCTAssertTrue(Model.claude_3_Opus.isSupportVision, "claude_3_Opus should support vision.")
XCTAssertTrue(Model.claude_3_Sonnet.isSupportVision, "claude_3_Sonnet should support vision.")
XCTAssertTrue(Model.claude_3_Haiku.isSupportVision, "claude_3_Haiku should support vision.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isSupportVision, "claude_3_5_Sonnet should support vision.")
XCTAssertFalse(Model.claude_3_5_Haiku.isSupportVision, "claude_3_5_Haiku should not support vision.")
XCTAssertTrue(Model.custom("custom-model").isSupportVision, "Custom models should support vision.")
}

func testIsValid() {

let textMessage = Message(role: .user, content: [.text("")])
let imageMessage = Message(role: .user, content: [.image(.init(type: .base64, mediaType: .gif, data: Data()))])
let documentMessage = Message(role: .user, content: [.document(.init(type: .base64, mediaType: .pdf, data: Data()))])

// Models that support vision
XCTAssertTrue(Model.claude_3_Opus.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_Opus.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_Opus.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_Haiku.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_Haiku.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_Haiku.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_5_Haiku.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertFalse(Model.claude_3_5_Haiku.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_5_Haiku.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.custom("custom-model").isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.custom("custom-model").isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.custom("custom-model").isValid(for: documentMessage), "claude_3_Opus should validate document messages.")
}
}
109 changes: 109 additions & 0 deletions Tests/AnthropicSwiftSDKTests/MessageBatchesTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,113 @@ final class MessageBatchesTests: XCTestCase {
XCTAssertEqual(error, .invalidRequestError)
}
}

func testValidate_Success() {
let batch = MessageBatch(
customId: "",
parameter: .init(
messages: [
.init(
role: .user,
content: [
.text("")
]
),
.init(
role: .user,
content: [
.text("")
]
)
],
model: .claude_3_Opus,
maxTokens: 1
)
)
let batches = [batch]
let messageBatches = MessageBatches(apiKey: "", session: .shared)

XCTAssertNoThrow(try messageBatches.validate(batches: batches))
}

func testValidate_ModelDoesNotSupportBatches() {
let batch = MessageBatch(
customId: "",
parameter: .init(
messages: [
.init(
role: .user,
content: [
.text("")
]
),
.init(
role: .user,
content: [
.text("")
]
)
],
model: .claude_3_Sonnet,
maxTokens: 1
)
)
let batches = [batch]
let messageBatches = MessageBatches(apiKey: "", session: .shared)

XCTAssertThrowsError(try messageBatches.validate(batches: batches)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedFeatureUsed(let description):
XCTAssertEqual(description, "The model: \(Model.claude_3_Sonnet.stringfy) does not support Message Batches API")
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}

func testValidate_UnsupportedMessageContentContained() {
// Arrange
let batch = MessageBatch(
customId: "",
parameter: .init(
messages: [
.init(
role: .user,
content: [
.image(.init(type: .base64, mediaType: .png, data: Data()))
]
),
.init(
role: .user,
content: [
.text("")
]
)
],
model: .claude_3_5_Haiku,
maxTokens: 1
)
)
let batches = [batch]
let messageBatches = MessageBatches(apiKey: "", session: .shared)

XCTAssertThrowsError(try messageBatches.validate(batches: batches)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedMessageContentContained(let model, let messages):
XCTAssertEqual(model.stringfy, Model.claude_3_5_Haiku.stringfy)
XCTAssertEqual(messages.count, 1)
XCTAssertEqual(messages.first?.content.first?.contentType, .image)
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}
}
58 changes: 58 additions & 0 deletions Tests/AnthropicSwiftSDKTests/MessagesTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,62 @@ final class MessagesTests: XCTestCase {
XCTAssertEqual(response.message.usage.outputTokens, 1)
}
}

func testValidate_Success() {
let messages: [Message] = [
.init(role: .user, content: [.text("Valid message")]),
.init(role: .user, content: [.text("Another valid message")])
]
let messagesHandler = Messages(apiKey: "", session: .shared)

XCTAssertNoThrow(try messagesHandler.validate(.claude_3_Opus, for: messages))
}

func testValidate_UnsupportedMessageContentContained() {
let messages: [Message] = [
.init(role: .user, content: [.text("Valid text message")]),
.init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]) // Unsupported image
]
let messagesHandler = Messages(apiKey: "", session: .shared)

// Act & Assert
XCTAssertThrowsError(try messagesHandler.validate(.claude_3_5_Haiku, for: messages)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedMessageContentContained(let invalidModel, let invalidMessages):
XCTAssertEqual(invalidModel.stringfy, Model.claude_3_5_Haiku.stringfy)
XCTAssertEqual(invalidMessages.count, 1)
XCTAssertEqual(invalidMessages.first?.content.first?.contentType, .image)
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}

func testValidate_AllUnsupportedMessages() {
let messages: [Message] = [
.init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]), // Unsupported image
.init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]) // Unsupported image
]
let messagesHandler = Messages(apiKey: "", session: .shared)

// Act & Assert
XCTAssertThrowsError(try messagesHandler.validate(.claude_3_5_Haiku, for: messages)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedMessageContentContained(let invalidModel, let invalidMessages):
XCTAssertEqual(invalidModel.stringfy, Model.claude_3_5_Haiku.stringfy)
XCTAssertEqual(invalidMessages.count, 2)
XCTAssertTrue(invalidMessages.allSatisfy { $0.content.first?.contentType == .image })
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}
}

0 comments on commit dcfedf3

Please sign in to comment.