Skip to content

Commit

Permalink
add tests for request object
Browse files Browse the repository at this point in the history
  • Loading branch information
fumito-ito committed Oct 22, 2024
1 parent 914bee6 commit 7de3a5f
Show file tree
Hide file tree
Showing 14 changed files with 353 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Sources/AnthropicSwiftSDK/Anthropic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Foundation
public final class Anthropic {
/// Messages API Interface
public let messages: Messages

/// MessageBatches API Interface
public let messageBatches: MessageBatches

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,3 @@ public struct BatchParameter {
self.toolChoice = toolChoice
}
}

2 changes: 1 addition & 1 deletion Sources/AnthropicSwiftSDK/Entity/Batch/MessageBatch.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
public struct MessageBatch {
public let customId: String
public let parameter: BatchParameter

public init(customId: String, parameter: BatchParameter) {
self.customId = customId
self.parameter = parameter
Expand Down
21 changes: 10 additions & 11 deletions Sources/AnthropicSwiftSDK/MessageBatches.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public struct MessageBatches {
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
)
}

public func createBatches(
batches: [MessageBatch],
anthropicHeaderProvider: AnthropicHeaderProvider,
Expand Down Expand Up @@ -57,7 +57,7 @@ public struct MessageBatches {
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
)
}

public func retrieve(
batchId: String,
anthropicHeaderProvider: AnthropicHeaderProvider,
Expand Down Expand Up @@ -90,7 +90,7 @@ public struct MessageBatches {
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
)
}

public func results(
of batchId: String,
anthropicHeaderProvider: AnthropicHeaderProvider,
Expand All @@ -115,15 +115,15 @@ public struct MessageBatches {

return try anthropicJSONDecoder.decode([BatchResultResponse].self, from: data)
}

public func results(streamOf batchId: String) async throws -> AsyncThrowingStream<BatchResultResponse, Error> {
try await results(
streamOf: batchId,
anthropicHeaderProvider: DefaultAnthropicHeaderProvider(),
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
)
}

public func results(
streamOf batchId: String,
anthropicHeaderProvider: AnthropicHeaderProvider,
Expand All @@ -145,14 +145,14 @@ public struct MessageBatches {
guard httpResponse.statusCode == 200 else {
throw AnthropicAPIError(fromHttpStatusCode: httpResponse.statusCode)
}

return AsyncThrowingStream.init { continuation in
let task = Task {
for try await line in data.lines {
guard let data = line.data(using: .utf8) else {
return
}

continuation.yield(try anthropicJSONDecoder.decode(BatchResultResponse.self, from: data))
}
continuation.finish()
Expand All @@ -172,7 +172,7 @@ public struct MessageBatches {
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
)
}

public func list(
beforeId: String?,
afterId: String?,
Expand All @@ -194,7 +194,7 @@ public struct MessageBatches {
if let afterId {
queries[ListMessageBatchesRequest.Parameter.afterId.rawValue] = afterId
}

return queries
}()

Expand All @@ -219,7 +219,7 @@ public struct MessageBatches {
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
)
}

public func cancel(
batchId: String,
anthropicHeaderProvider: AnthropicHeaderProvider,
Expand All @@ -243,6 +243,5 @@ public struct MessageBatches {
}

return try anthropicJSONDecoder.decode(BatchResponse.self, from: data)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct ListMessageBatchesRequest: Request {
let path: String = RequestType.batches.basePath
let queries: [String: CustomStringConvertible]?
let body: Never? = nil

enum Parameter: String {
case beforeId = "before_id"
case afterId = "after_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct MessageBatchesRequest: Request {
struct MessageBatchesRequestBody: Encodable {
/// List of requests for prompt completion. Each is an individual request to create a Message.
let requests: [Batch]

init(from batches: [MessageBatch]) {
self.requests = batches.map { Batch(from: $0) }
}
Expand All @@ -36,7 +36,7 @@ struct Batch: Encodable {
///
/// See the [Messages API reference](https://docs.anthropic.com/en/api/messages) for full documentation on available parameters.
let params: MessagesRequestBody

init(from batch: MessageBatch) {
self.customId = batch.customId
self.params = MessagesRequestBody(from: batch.parameter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct MessagesRequestBody: Encodable {
self.tools = tools
self.toolChoice = tools == nil ? nil : toolChoice // ToolChoice should be set if tools are specified.
}

init(from parameter: BatchParameter) {
self.model = parameter.model
self.messages = parameter.messages
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//
// CancelMessageBatchRequestTests.swift
// AnthropicSwiftSDK
//
// Created by 伊藤史 on 2024/10/23.
//

import XCTest
@testable import AnthropicSwiftSDK

final class CancelMessageBatchRequestTests: XCTestCase {

func testCancelMessageBatchRequest() {
let testBatchId = "test_batch_123"

let request = CancelMessageBatchRequest(batchId: testBatchId)

XCTAssertEqual(request.method, .post)
XCTAssertEqual(request.path, "\(RequestType.batches.basePath)/\(testBatchId)/cancel")
XCTAssertNil(request.queries)
XCTAssertNil(request.body)
XCTAssertEqual(request.batchId, testBatchId)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//
// ListMessageBatchesRequestTests.swift
// AnthropicSwiftSDK
//
// Created by 伊藤史 on 2024/10/23.
//

import XCTest
@testable import AnthropicSwiftSDK

final class ListMessageBatchesRequestTests: XCTestCase {

func testListMessageBatchesRequestProperties() {
let queries: [String: CustomStringConvertible] = [
"before_id": "batch123",
"after_id": "batch456",
"limit": 10
]
let request = ListMessageBatchesRequest(queries: queries)

XCTAssertEqual(request.method, .get)
XCTAssertEqual(request.path, RequestType.batches.basePath)
XCTAssertEqual(request.queries?["before_id"] as? String, "batch123")
XCTAssertEqual(request.queries?["after_id"] as? String, "batch456")
XCTAssertEqual(request.queries?["limit"] as? Int, 10)
XCTAssertNil(request.body)
}

func testParameterRawValues() {
XCTAssertEqual(ListMessageBatchesRequest.Parameter.beforeId.rawValue, "before_id")
XCTAssertEqual(ListMessageBatchesRequest.Parameter.afterId.rawValue, "after_id")
XCTAssertEqual(ListMessageBatchesRequest.Parameter.limit.rawValue, "limit")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//
// MessageBatchesRequestTests.swift
// AnthropicSwiftSDK
//
// Created by 伊藤史 on 2024/10/22.
//

import XCTest
@testable import AnthropicSwiftSDK

final class MessageBatchesRequestTests: XCTestCase {

func testMessageBatchesRequest() throws {
let batchParameter1 = BatchParameter(
messages: [Message(role: .user, content: [.text("こんにちは")])],
model: .claude_3_Opus,
maxTokens: 100
)
let batchParameter2 = BatchParameter(
messages: [Message(role: .user, content: [.text("お元気ですか?")])],
model: .claude_3_Sonnet,
maxTokens: 200
)

let messageBatch1 = MessageBatch(customId: "test1", parameter: batchParameter1)
let messageBatch2 = MessageBatch(customId: "test2", parameter: batchParameter2)

let request = MessageBatchesRequest(body: .init(from: [messageBatch1, messageBatch2]))

XCTAssertEqual(request.method, .post)
XCTAssertEqual(request.path, RequestType.batches.basePath)
XCTAssertNil(request.queries)
XCTAssertNotNil(request.body)

XCTAssertEqual(request.body?.requests.count, 2)

XCTAssertEqual(request.body?.requests[0].customId, "test1")
XCTAssertEqual(request.body?.requests[0].params.model.stringfy, Model.claude_3_Opus.stringfy)
XCTAssertEqual(request.body?.requests[0].params.maxTokens, 100)
XCTAssertEqual(request.body?.requests[0].params.messages.count, 1)
XCTAssertEqual(request.body?.requests[0].params.messages[0].role.rawValue, "user")
let content1 = try XCTUnwrap(request.body?.requests[0].params.messages[0].content)
guard case let .text(text1) = content1[0] else {
XCTFail("content1[0] is not .text")
return
}
XCTAssertEqual(text1, "こんにちは")

XCTAssertEqual(request.body?.requests[1].customId, "test2")
XCTAssertEqual(request.body?.requests[1].params.model.stringfy, Model.claude_3_Sonnet.stringfy)
XCTAssertEqual(request.body?.requests[1].params.maxTokens, 200)
XCTAssertEqual(request.body?.requests[1].params.messages.count, 1)
XCTAssertEqual(request.body?.requests[1].params.messages[0].role.rawValue, "user")
let content2 = try XCTUnwrap(request.body?.requests[1].params.messages[0].content)
guard case let .text(text2) = content2[0] else {
XCTFail("content2[0] is not .text")
return
}
XCTAssertEqual(text2, "お元気ですか?")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//
// MessagesRequestTests.swift
// AnthropicSwiftSDK
//
// Created by 伊藤史 on 2024/10/22.
//

import XCTest
@testable import AnthropicSwiftSDK

final class MessagesRequestTests: XCTestCase {
func testMessagesRequest() throws {
let messages = [Message(role: .user, content: [.text("こんにちは")])]
let systemPrompt: [SystemPrompt] = [.text("あなたは親切なアシスタントです。", .ephemeral)]

let requestBody = MessagesRequestBody(
model: .claude_3_Opus,
messages: messages,
system: systemPrompt,
maxTokens: 100,
metaData: MetaData(userId: "test-user"),
stopSequences: ["END"],
stream: false,
temperature: 0.7,
topP: 0.9,
topK: 10,
tools: nil,
toolChoice: .auto
)

let request = MessagesRequest(body: requestBody)

XCTAssertEqual(request.method, .post)
XCTAssertEqual(request.path, RequestType.messages.basePath)
XCTAssertNil(request.queries)

XCTAssertEqual(request.body?.model.stringfy, Model.claude_3_Opus.stringfy)
XCTAssertEqual(request.body?.messages.count, 1)
XCTAssertEqual(request.body?.messages[0].role.rawValue, "user")
let content1 = try XCTUnwrap(request.body?.messages[0].content)
guard case let .text(text1) = content1[0] else {
XCTFail("content1[0] is not .text")
return
}
XCTAssertEqual(text1, "こんにちは")
let system1 = try XCTUnwrap(request.body?.system)
guard case let .text(text1, _) = system1[0] else {
XCTFail("system1[0] is not .text")
return
}
XCTAssertEqual(text1, "あなたは親切なアシスタントです。")
XCTAssertEqual(request.body?.maxTokens, 100)
XCTAssertEqual(request.body?.metaData?.userId, "test-user")
XCTAssertEqual(request.body?.stopSequences, ["END"])
XCTAssertFalse(request.body?.stream ?? true)
XCTAssertEqual(request.body?.temperature, 0.7)
XCTAssertEqual(request.body?.topP, 0.9)
XCTAssertEqual(request.body?.topK, 10)
XCTAssertNil(request.body?.tools)
XCTAssertNil(request.body?.toolChoice)
}

func testMessagesRequestWithBatchParameter() throws {
let messages = [Message(role: .user, content: [.text("こんにちは")])]
let systemPrompt: [SystemPrompt] = [.text("あなたは親切なアシスタントです。", .ephemeral)]

let batchParameter = BatchParameter(
messages: messages,
model: .claude_3_Opus,
system: systemPrompt,
maxTokens: 100,
metaData: MetaData(userId: "test-user"),
stopSequence: ["END"],
temperature: 0.7,
topP: 0.9,
topK: 10,
toolContainer: nil,
toolChoice: .auto
)

let requestBody = MessagesRequestBody(from: batchParameter)
let request = MessagesRequest(body: requestBody)

XCTAssertEqual(request.body?.model.stringfy, Model.claude_3_Opus.stringfy)
XCTAssertEqual(request.body?.messages.count, 1)
XCTAssertEqual(request.body?.messages[0].role.rawValue, "user")
let content1 = try XCTUnwrap(request.body?.messages[0].content)
guard case let .text(text1) = content1[0] else {
XCTFail("content1[0] is not .text")
return
}
XCTAssertEqual(text1, "こんにちは")
let system1 = try XCTUnwrap(request.body?.system)
guard case let .text(text1, _) = system1[0] else {
XCTFail("system1[0] is not .text")
return
}
XCTAssertEqual(text1, "あなたは親切なアシスタントです。")
XCTAssertEqual(request.body?.maxTokens, 100)
XCTAssertEqual(request.body?.metaData?.userId, "test-user")
XCTAssertEqual(request.body?.stopSequences, ["END"])
XCTAssertFalse(request.body?.stream ?? true)
XCTAssertEqual(request.body?.temperature, 0.7)
XCTAssertEqual(request.body?.topP, 0.9)
XCTAssertEqual(request.body?.topK, 10)
XCTAssertNil(request.body?.tools)
XCTAssertNil(request.body?.toolChoice)
}
}
Loading

0 comments on commit 7de3a5f

Please sign in to comment.