Skip to content

Commit

Permalink
tool use working
Browse files Browse the repository at this point in the history
  • Loading branch information
monofuel committed Jul 27, 2024
1 parent 476c5cd commit 888d094
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 36 deletions.
2 changes: 1 addition & 1 deletion llama_leap.nimble
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = "1.0.1"
version = "1.1.0"
author = "Andrew Brower"
description = "Ollama API for Nim"
license = "MIT"
Expand Down
59 changes: 29 additions & 30 deletions src/llama_leap.nim
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import curly, jsony, std/[strutils, json, options, strformat, os]
import curly, jsony, std/[strutils, tables, json, options, strformat, os]

## ollama API Interface
## https://github.com/jmorganca/ollama/blob/main/docs/api.md
Expand Down Expand Up @@ -28,14 +28,31 @@ type
num_predict*: Option[int]
top_k*: Option[int]
top_p*: Option[float32]
# ToolFunctionParameter* = object
# `type`*: string
# description*: string
ToolFunctionParameters* = object
`type`*: string # object
# had serialization issues when properties was a table
# it was also kind of confusing to work with
#properties*: Table[string, ToolFunctionParameter]
properties*: JsonNode
required*: seq[string]
ToolFunction* = ref object
name*: string
description*: string
parameters*: ToolFunctionParameters
Tool* = ref object
`type`*: string
function*: ToolFunction
GenerateReq* = ref object
model*: string
prompt*: string
images*: Option[seq[string]] # list of base64 encoded images
format*: Option[string] # optional format=json for a structured response
options*: Option[ModelParameters] # bag of model parameters
system*: Option[string] # override modelfile system prompt
template_str*: Option[string] # override modelfile template
`template`*: Option[string] # override modelfile template
context*: Option[seq[int]] # conversation encoding from a previous response
stream: Option[bool] # stream=false to get a single response
raw*: Option[bool] # use raw=true if you are specifying a fully templated prompt
Expand All @@ -51,16 +68,23 @@ type
prompt_eval_duration*: int
eval_count*: int
eval_duration*: int
ToolCallFunction* = ref object
name*: string
arguments*: JsonNode # map of [string]: any
ToolCall* = ref object
function*: ToolCallFunction
ChatMessage* = ref object
role*: string # "system" "user" or "assistant"
role*: string # "system" "user" "tool" or "assistant"
content*: string
images*: Option[seq[string]] # list of base64 encoded images
tool_calls*: seq[ToolCall]
ChatReq* = ref object
model*: string
tools*: seq[Tool] # requires stream=false currently
messages*: seq[ChatMessage]
format*: Option[string] # optional format=json for a structured response
options*: Option[ModelParameters] # bag of model parameters
template_str*: Option[string] # override modelfile template
`template`*: Option[string] # override modelfile template
stream: Option[bool] # stream=false to get a single response
ChatResp* = ref object
model*: string
Expand Down Expand Up @@ -95,7 +119,7 @@ type
ShowModel* = ref object
modelfile*: string
parameters*: string
template_str*: string
`template`*: string
details*: ModelDetails
EmbeddingReq* = ref object
model*: string
Expand All @@ -104,31 +128,6 @@ type
EmbeddingResp* = ref object
embedding*: seq[float64]

proc renameHook(v: var ChatReq, fieldName: var string) =
## `template` is a special keyword in nim, so we need to rename it during serialization
if fieldName == "template":
fieldName = "template_str"
proc dumpHook(v: var ChatReq, fieldName: var string) =
if fieldName == "template_str":
fieldName = "template"

proc renameHook(v: var GenerateReq, fieldName: var string) =
## `template` is a special keyword in nim, so we need to rename it during serialization
if fieldName == "template":
fieldName = "template_str"
proc dumpHook(v: var GenerateReq, fieldName: var string) =
if fieldName == "template_str":
fieldName = "template"

proc renameHook(v: var ShowModel, fieldName: var string) =
## `template` is a special keyword in nim, so we need to rename it during serialization
if fieldName == "template":
fieldName = "template_str"
proc dumpHook(v: var ShowModel, fieldName: var string) =
if fieldName == "template_str":
fieldName = "template"


proc dumpHook(s: var string, v: object) =
## jsony `hack` to skip optional fields that are nil
s.add '{'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_llama_leap.nim
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ SYSTEM Please talk like a pirate. You are Longbeard the llama.
test "show model":
let resp = ollama.showModel(TestModelfileName)
echo "> " & toJson(resp)
# validate that renameHook() is working properly
assert resp.template_str != ""
# assert that special keywords are working properly
assert resp.`template` != ""

suite "embeddings":
test "generate embeddings":
Expand Down
88 changes: 85 additions & 3 deletions tests/test_ollama_tools.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@
## Ensure that ollama is running!

import
std/[unittest],
llama_leap
std/[unittest, json, tables, strutils],
llama_leap, jsony

# Must use a tools compatible model!
const
TestModel = "llama3.1:8b"

proc getFlightTimes(departure: string, arrival: string): string =
var flights = initTable[string, JsonNode]()

flights["NYC-LAX"] = %* {"departure": "08:00 AM", "arrival": "11:30 AM", "duration": "5h 30m"}
flights["LAX-NYC"] = %* {"departure": "02:00 PM", "arrival": "10:30 PM", "duration": "5h 30m"}
flights["LHR-JFK"] = %* {"departure": "10:00 AM", "arrival": "01:00 PM", "duration": "8h 00m"}
flights["JFK-LHR"] = %* {"departure": "09:00 PM", "arrival": "09:00 AM", "duration": "7h 00m"}
flights["CDG-DXB"] = %* {"departure": "11:00 AM", "arrival": "08:00 PM", "duration": "6h 00m"}
flights["DXB-CDG"] = %* {"departure": "03:00 AM", "arrival": "07:30 AM", "duration": "7h 30m"}

let key = (departure & "-" & arrival).toUpperAscii()
if flights.contains(key):
return $flights[key]
else:
raise newException(ValueError, "No flight found for " & key)

suite "ollama tools":
var ollama: OllamaAPI

Expand All @@ -23,5 +39,71 @@ suite "ollama tools":
suite "pull":
test "pull model":
ollama.pullModel(TestModel)

suite "flight times":
test "getFlightTimes":
echo getFlightTimes("NYC", "LAX")

test "tool call queries":
var messages = @[
ChatMessage(
role: "user",
content: "What is the flight time from New York (NYC) to Los Angeles (LAX)?"
)
]

let firstRequest = ChatReq(
model: TestModel,
messages: messages,
tools: @[
Tool(
`type`: "function",
function: ToolFunction(
name: "get_flight_times",
description: "Get the flight times between two cities",
parameters: ToolFunctionParameters(
`type`: "object",
required: @["departure", "arrival"],
properties: %* {
"departure": {
"type": "string",
"description": "The departure city (airport code)"
},
"arrival": {
"type": "string",
"description": "The arrival city (airport code)"
}
}
)
)
)
]
)

let toolResp = ollama.chat(firstRequest)
# add the model response to conversation history
messages.add(toolResp.message)

assert toolResp.message.tool_calls.len != 0

# process the function call
assert toolResp.message.tool_calls.len == 1
let toolCall = toolResp.message.tool_calls[0]
let toolFunc = toolCall.function
assert toolFunc.name == "get_flight_times"
let toolFuncArgs = toolCall.function.arguments
assert toolFuncArgs["departure"].getStr == "NYC"
assert toolFuncArgs["arrival"].getStr == "LAX"

let toolResult = getFlightTimes(toolFuncArgs["departure"].getStr, toolFuncArgs["arrival"].getStr)
messages.add(ChatMessage(
role: "tool",
content: toolResult
))

# TODO do the thing
# message history with tool result
let finalResponse = ollama.chat(ChatReq(
model: TestModel,
messages: messages
))
echo "RESULT: " & finalResponse.message.content

0 comments on commit 888d094

Please sign in to comment.