Skip to content

Commit

Permalink
Add gemma2:2b test and ollama pull format as cli (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan authored Aug 1, 2024
1 parent e5c0e75 commit 2070fed
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 22 deletions.
5 changes: 3 additions & 2 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { readFile, unlink, writeFile } from "node:fs/promises"
import { ensureDir, existsSync, remove } from "fs-extra"
import { resolve, dirname } from "node:path"
import { glob } from "glob"
import { debug, error, info, warn } from "./log"
import { debug, error, info, isQuiet, warn } from "./log"
import { execa } from "execa"
import { join } from "node:path"
import { createNodePath } from "./nodepath"
Expand Down Expand Up @@ -70,7 +70,8 @@ class ModelManager implements ModelService {
if (provider === MODEL_PROVIDER_OLLAMA) {
if (this.pulled.includes(modelid)) return { ok: true }

logVerbose(`ollama: pulling ${modelid}...`)
if (!isQuiet)
logVerbose(`ollama pull ${model}`)
const conn = await this.getModelToken(modelid)
const res = await fetch(`${conn.base}/api/pull`, {
method: "POST",
Expand Down
26 changes: 17 additions & 9 deletions packages/core/src/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,37 @@ import {
// generate unit tests for parseModelIdentifier
describe("parseModelIdentifier", () => {
test("aici:gpt-3.5:en", () => {
const { provider, model, tag, modelId } =
const { provider, model, tag, family } =
parseModelIdentifier("aici:gpt-3.5:en")
assert(provider === MODEL_PROVIDER_AICI)
assert(model === "gpt-3.5")
assert(family === "gpt-3.5")
assert(tag === "en")
assert(modelId === "gpt-3.5:en")
assert(model === "gpt-3.5:en")
})
test("ollama:phi3", () => {
const { provider, model, tag, modelId } =
const { provider, model, tag, family } =
parseModelIdentifier("ollama:phi3")
assert(provider === MODEL_PROVIDER_OLLAMA)
assert(model === "phi3")
assert(modelId === "phi3")
assert(family === "phi3")
})
test("ollama:gemma2:2b", () => {
const { provider, model, tag, family } =
parseModelIdentifier("ollama:gemma2:2b")
assert(provider === MODEL_PROVIDER_OLLAMA)
assert(model === "gemma2:2b")
assert(family === "gemma2")
})
test("llamafile", () => {
const { provider, model } = parseModelIdentifier("llamafile")
const { provider, model, family } = parseModelIdentifier("llamafile")
assert(provider === MODEL_PROVIDER_LLAMAFILE)
assert(model === "*")
assert(family === "*")
assert(model === "llamafile")
})
test("gpt4", () => {
const { provider, model, modelId } = parseModelIdentifier("gpt4")
const { provider, model, family } = parseModelIdentifier("gpt4")
assert(provider === MODEL_PROVIDER_OPENAI)
assert(model === "gpt4")
assert(modelId === "gpt4")
assert(family === "gpt4")
})
})
17 changes: 11 additions & 6 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,27 @@ import { assert } from "./util"
* provider:model
* provider:model:tag where modelId model:tag
*/
export function parseModelIdentifier(id: string) {
export function parseModelIdentifier(id: string): {
provider: string
family: string
model: string
tag?: string
} {
assert(!!id)
id = id.replace("-35-", "-3.5-")
const parts = id.split(":")
if (parts.length >= 3)
return {
provider: parts[0],
model: parts[1],
family: parts[1],
tag: parts.slice(2).join(":"),
modelId: parts.slice(1).join(":"),
model: parts.slice(1).join(":"),
}
else if (parts.length === 2)
return { provider: parts[0], model: parts[1], modelId: parts[1] }
return { provider: parts[0], family: parts[1], model: parts[1] }
else if (id === MODEL_PROVIDER_LLAMAFILE)
return { provider: MODEL_PROVIDER_LLAMAFILE, model: "*", modelId: id }
else return { provider: MODEL_PROVIDER_OPENAI, model: id, modelId: id }
return { provider: MODEL_PROVIDER_LLAMAFILE, family: "*", model: id }
else return { provider: MODEL_PROVIDER_OPENAI, family: id, model: id }
}

export interface ModelConnectionInfo
Expand Down
10 changes: 5 additions & 5 deletions packages/core/src/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ export const OllamaCompletion: ChatCompletionHandler = async (
return await OpenAIChatCompletion(req, cfg, options, trace)
} catch (e) {
if (isRequestError(e)) {
const { modelId } = parseModelIdentifier(req.model)
const { model } = parseModelIdentifier(req.model)
if (
e.status === 404 &&
e.body?.type === "api_error" &&
e.body?.message?.includes(`model '${modelId}' not found`)
e.body?.message?.includes(`model '${model}' not found`)
) {
trace.log(`model ${modelId} not found, trying to pull it`)
trace.log(`model ${model} not found, trying to pull it`)
// model not installed locally
// trim v1
const fetch = await createFetch({ trace })
const res = await fetch(cfg.base.replace("/v1", "/api/pull"), {
method: "POST",
body: JSON.stringify({ name: modelId, stream: false }),
body: JSON.stringify({ name: model, stream: false }),
})
if (!res.ok) {
throw new Error(
`Failed to pull model ${modelId}: ${res.status} ${res.statusText}`
`Failed to pull model ${model}: ${res.status} ${res.statusText}`
)
}
trace.log(`model pulled`)
Expand Down
15 changes: 15 additions & 0 deletions packages/sample/genaisrc/summarize-ollama-gemma2.genai.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
script({
model: "ollama:gemma2:2b",
title: "summarize with ollama gemma 2 2b",
system: [],
files: "src/rag/markdown.md",
tests: {
files: "src/rag/markdown.md",
keywords: "markdown",
},
})

const file = def("FILE", env.files)

$`Summarize ${file} in a sentence. Make it short.
`

0 comments on commit 2070fed

Please sign in to comment.