Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gemma2:2b test and ollama pull format as cli #599

Merged
merged 4 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ModelManager implements ModelService {
if (provider === MODEL_PROVIDER_OLLAMA) {
if (this.pulled.includes(modelid)) return { ok: true }

logVerbose(`ollama: pulling ${modelid}...`)
logVerbose(`ollama pull ${model}`)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable model is not defined in the current scope. It seems like you meant to use modelid instead. πŸ€”

generated by pr-review-commit variable_name

const conn = await this.getModelToken(modelid)
const res = await fetch(`${conn.base}/api/pull`, {
method: "POST",
Expand Down
26 changes: 17 additions & 9 deletions packages/core/src/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,37 @@ import {
// generate unit tests for parseModelIdentifier
describe("parseModelIdentifier", () => {
test("aici:gpt-3.5:en", () => {
const { provider, model, tag, modelId } =
const { provider, model, tag, family } =
parseModelIdentifier("aici:gpt-3.5:en")
assert(provider === MODEL_PROVIDER_AICI)
assert(model === "gpt-3.5")
assert(family === "gpt-3.5")
assert(tag === "en")
assert(modelId === "gpt-3.5:en")
assert(model === "gpt-3.5:en")
})
test("ollama:phi3", () => {
const { provider, model, tag, modelId } =
const { provider, model, tag, family } =
parseModelIdentifier("ollama:phi3")
assert(provider === MODEL_PROVIDER_OLLAMA)
assert(model === "phi3")
assert(modelId === "phi3")
assert(family === "phi3")
})
test("ollama:gemma2:2b", () => {
const { provider, model, tag, family } =
parseModelIdentifier("ollama:gemma2:2b")
assert(provider === MODEL_PROVIDER_OLLAMA)
assert(model === "gemma2:2b")
assert(family === "gemma2")
})
test("llamafile", () => {
const { provider, model } = parseModelIdentifier("llamafile")
const { provider, model, family } = parseModelIdentifier("llamafile")
assert(provider === MODEL_PROVIDER_LLAMAFILE)
assert(model === "*")
assert(family === "*")
assert(model === "llamafile")
})
test("gpt4", () => {
const { provider, model, modelId } = parseModelIdentifier("gpt4")
const { provider, model, family } = parseModelIdentifier("gpt4")
assert(provider === MODEL_PROVIDER_OPENAI)
assert(model === "gpt4")
assert(modelId === "gpt4")
assert(family === "gpt4")
})
})
17 changes: 11 additions & 6 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,27 @@ import { assert } from "./util"
* provider:model
* provider:model:tag where modelId model:tag
*/
export function parseModelIdentifier(id: string) {
export function parseModelIdentifier(id: string): {
provider: string
family: string
model: string
tag?: string
} {
assert(!!id)
id = id.replace("-35-", "-3.5-")
const parts = id.split(":")
if (parts.length >= 3)
return {
provider: parts[0],
model: parts[1],
family: parts[1],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable model has been changed to family which might not be defined in this scope. Please ensure that family is defined and holds the correct value.

generated by pr-review-commit variable_name

tag: parts.slice(2).join(":"),
modelId: parts.slice(1).join(":"),
model: parts.slice(1).join(":"),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tag field is not being returned when the length of parts is greater than or equal to 3. This could lead to unexpected behavior if the tag field is expected in the returned object. πŸ˜•

generated by pr-review-commit missing_tag

}
else if (parts.length === 2)
return { provider: parts[0], model: parts[1], modelId: parts[1] }
return { provider: parts[0], family: parts[1], model: parts[1] }
else if (id === MODEL_PROVIDER_LLAMAFILE)
return { provider: MODEL_PROVIDER_LLAMAFILE, model: "*", modelId: id }
else return { provider: MODEL_PROVIDER_OPENAI, model: id, modelId: id }
return { provider: MODEL_PROVIDER_LLAMAFILE, family: "*", model: id }
else return { provider: MODEL_PROVIDER_OPENAI, family: id, model: id }
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name change from model to family and modelId to model could potentially cause confusion and bugs in the future. It's important to ensure that variable names accurately represent the data they hold. 😊

generated by pr-review-commit variable_name_change

}

export interface ModelConnectionInfo
Expand Down
10 changes: 5 additions & 5 deletions packages/core/src/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ export const OllamaCompletion: ChatCompletionHandler = async (
return await OpenAIChatCompletion(req, cfg, options, trace)
} catch (e) {
if (isRequestError(e)) {
const { modelId } = parseModelIdentifier(req.model)
const { model } = parseModelIdentifier(req.model)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable modelId has been changed to model which might not be defined in this scope. Please ensure that model is defined and holds the correct value.

generated by pr-review-commit variable_name

if (
e.status === 404 &&
e.body?.type === "api_error" &&
e.body?.message?.includes(`model '${modelId}' not found`)
e.body?.message?.includes(`model '${model}' not found`)
) {
trace.log(`model ${modelId} not found, trying to pull it`)
trace.log(`model ${model} not found, trying to pull it`)
// model not installed locally
// trim v1
const fetch = await createFetch({ trace })
const res = await fetch(cfg.base.replace("/v1", "/api/pull"), {
method: "POST",
body: JSON.stringify({ name: modelId, stream: false }),
body: JSON.stringify({ name: model, stream: false }),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API endpoint has been changed from "/v1" to "/api/pull". Please ensure that this new endpoint is correct and the server is configured to handle requests at this endpoint. πŸ€”

generated by pr-review-commit api_endpoint_change

})
if (!res.ok) {
throw new Error(
`Failed to pull model ${modelId}: ${res.status} ${res.statusText}`
`Failed to pull model ${model}: ${res.status} ${res.statusText}`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name change from modelId to model could potentially cause confusion and bugs in the future. It's important to ensure that variable names accurately represent the data they hold. 😊

generated by pr-review-commit variable_name_change

)
}
trace.log(`model pulled`)
Expand Down
15 changes: 15 additions & 0 deletions packages/sample/genaisrc/summarize-ollama-gemma2.genai.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
script({
model: "ollama:gemma2:2b",
title: "summarize with ollama gemma 2 2b",
system: [],
files: "src/rag/markdown.md",
tests: {
files: "src/rag/markdown.md",
keywords: "markdown",
},
})

const file = def("FILE", env.files)

$`Summarize ${file} in a sentence. Make it short.
`
Loading