Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Feb 22, 2024
1 parent 131942c commit 320fdc5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
14 changes: 11 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,17 @@ function call_cost(msg, model::String)
end
return cost
end
## dispatch for array -> take last message
function call_cost(msg::AbstractVector, model::String)
call_cost(last(msg), model)
## dispatch for array -> take unique messages only (eg, for multiple samples we count only once)
function call_cost(conv::AbstractVector, model::String)
sum_ = 0.0
visited_runs = Set{Int}()
for msg in conv
if isnothing(msg.run_id) || (msg.run_id visited_runs)
sum_ += call_cost(msg, model)
push!(visited_runs, msg.run_id)
end
end
return sum_
end

# helper to produce summary message of how many tokens were used and for how much
Expand Down
25 changes: 22 additions & 3 deletions test/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,17 +596,36 @@ end
@test msg.content == RandomType1235(1)
@test msg.log_prob -0.9

## Test multiple samples
response = Dict(:choices => [mock_choice, mock_choice],
## Test multiple samples -- mock_choice is less probable
mock_choice2 = Dict(:message => Dict(:content => "Hello!",
:tool_calls => [
Dict(:function => Dict(:arguments => JSON3.write(Dict(:x => 1)))),
]),
:logprobs => Dict(:content => [Dict(:logprob => -1.2), Dict(:logprob => -0.4)]),
:finish_reason => "stop")

response = Dict(:choices => [mock_choice, mock_choice2],
:usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1))
schema2 = TestEchoOpenAISchema(; response, status = 200)
conv = aiextract(schema2, "Extract number 1"; return_type,
model = "gpt4",
api_kwargs = (; temperature = 0, n = 2))
@test conv[1].content == RandomType1235(1)
@test conv[1].log_prob -0.9
@test conv[1].log_prob -1.6 # sorted first, despite sent later
@test conv[2].content == RandomType1235(1)
@test conv[2].log_prob -0.9

## Wrong return_type so it returns a Dict
struct RandomType1236
x::Int
y::Int
end
return_type = RandomType1236
conv = aiextract(schema2, "Extract number 1"; return_type,
model = "gpt4",
api_kwargs = (; temperature = 0, n = 2))
conv[1].content isa AbstractDict
conv[2].content isa AbstractDict
end

@testset "aiscan-OpenAI" begin
Expand Down
5 changes: 5 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ end
@test cost == 0.0
@test call_cost(msg, "gpt-3.5-turbo") 1000 * 0.5e-6 + 1.5e-6 * 2000

# Test vector - same message, count once
@test call_cost([msg, msg], "gpt-3.5-turbo") (1000 * 0.5e-6 + 1.5e-6 * 2000)
msg2 = AIMessage(; content = "", tokens = (1000, 2000))
@test call_cost([msg, msg2], "gpt-3.5-turbo") (1000 * 0.5e-6 + 1.5e-6 * 2000) * 2

msg = DataMessage(; content = nothing, tokens = (1000, 1000))
cost = call_cost(msg, "unknown_model")
@test cost == 0.0
Expand Down

0 comments on commit 320fdc5

Please sign in to comment.