diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index 7623577bac..6a7c6d6be2 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -5,7 +5,7 @@ import { readFile, unlink, writeFile } from "node:fs/promises" import { ensureDir, existsSync, remove } from "fs-extra" import { resolve, dirname } from "node:path" import { glob } from "glob" -import { debug, error, info, warn } from "./log" +import { debug, error, info, isQuiet, warn } from "./log" import { execa } from "execa" import { join } from "node:path" import { createNodePath } from "./nodepath" @@ -70,7 +70,8 @@ class ModelManager implements ModelService { if (provider === MODEL_PROVIDER_OLLAMA) { if (this.pulled.includes(modelid)) return { ok: true } - logVerbose(`ollama: pulling ${modelid}...`) + if (!isQuiet) + logVerbose(`ollama pull ${model}`) const conn = await this.getModelToken(modelid) const res = await fetch(`${conn.base}/api/pull`, { method: "POST", diff --git a/packages/core/src/models.test.ts b/packages/core/src/models.test.ts index e384d25cf6..9edb8429f3 100644 --- a/packages/core/src/models.test.ts +++ b/packages/core/src/models.test.ts @@ -11,29 +11,37 @@ import { // generate unit tests for parseModelIdentifier describe("parseModelIdentifier", () => { test("aici:gpt-3.5:en", () => { - const { provider, model, tag, modelId } = + const { provider, model, tag, family } = parseModelIdentifier("aici:gpt-3.5:en") assert(provider === MODEL_PROVIDER_AICI) - assert(model === "gpt-3.5") + assert(family === "gpt-3.5") assert(tag === "en") - assert(modelId === "gpt-3.5:en") + assert(model === "gpt-3.5:en") }) test("ollama:phi3", () => { - const { provider, model, tag, modelId } = + const { provider, model, tag, family } = parseModelIdentifier("ollama:phi3") assert(provider === MODEL_PROVIDER_OLLAMA) assert(model === "phi3") - assert(modelId === "phi3") + assert(family === "phi3") + }) + test("ollama:gemma2:2b", () => { + const { provider, model, tag, family } = + parseModelIdentifier("ollama:gemma2:2b") + assert(provider === MODEL_PROVIDER_OLLAMA) + assert(model === "gemma2:2b") + assert(family === "gemma2") }) test("llamafile", () => { - const { provider, model } = parseModelIdentifier("llamafile") + const { provider, model, family } = parseModelIdentifier("llamafile") assert(provider === MODEL_PROVIDER_LLAMAFILE) - assert(model === "*") + assert(family === "*") + assert(model === "llamafile") }) test("gpt4", () => { - const { provider, model, modelId } = parseModelIdentifier("gpt4") + const { provider, model, family } = parseModelIdentifier("gpt4") assert(provider === MODEL_PROVIDER_OPENAI) assert(model === "gpt4") - assert(modelId === "gpt4") + assert(family === "gpt4") }) }) diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index 5c7ce7e823..39dcfb96af 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -9,22 +9,27 @@ import { assert } from "./util" * provider:model * provider:model:tag where modelId model:tag */ -export function parseModelIdentifier(id: string) { +export function parseModelIdentifier(id: string): { + provider: string + family: string + model: string + tag?: string +} { assert(!!id) id = id.replace("-35-", "-3.5-") const parts = id.split(":") if (parts.length >= 3) return { provider: parts[0], - model: parts[1], + family: parts[1], tag: parts.slice(2).join(":"), - modelId: parts.slice(1).join(":"), + model: parts.slice(1).join(":"), } else if (parts.length === 2) - return { provider: parts[0], model: parts[1], modelId: parts[1] } + return { provider: parts[0], family: parts[1], model: parts[1] } else if (id === MODEL_PROVIDER_LLAMAFILE) - return { provider: MODEL_PROVIDER_LLAMAFILE, model: "*", modelId: id } - else return { provider: MODEL_PROVIDER_OPENAI, model: id, modelId: id } + return { provider: MODEL_PROVIDER_LLAMAFILE, family: "*", model: id } + else return { provider: MODEL_PROVIDER_OPENAI, family: id, model: id } } export interface ModelConnectionInfo diff --git a/packages/core/src/ollama.ts b/packages/core/src/ollama.ts index ff5c50921e..c224d4114f 100644 --- a/packages/core/src/ollama.ts +++ b/packages/core/src/ollama.ts @@ -16,23 +16,23 @@ export const OllamaCompletion: ChatCompletionHandler = async ( return await OpenAIChatCompletion(req, cfg, options, trace) } catch (e) { if (isRequestError(e)) { - const { modelId } = parseModelIdentifier(req.model) + const { model } = parseModelIdentifier(req.model) if ( e.status === 404 && e.body?.type === "api_error" && - e.body?.message?.includes(`model '${modelId}' not found`) + e.body?.message?.includes(`model '${model}' not found`) ) { - trace.log(`model ${modelId} not found, trying to pull it`) + trace.log(`model ${model} not found, trying to pull it`) // model not installed locally // trim v1 const fetch = await createFetch({ trace }) const res = await fetch(cfg.base.replace("/v1", "/api/pull"), { method: "POST", - body: JSON.stringify({ name: modelId, stream: false }), + body: JSON.stringify({ name: model, stream: false }), }) if (!res.ok) { throw new Error( - `Failed to pull model ${modelId}: ${res.status} ${res.statusText}` + `Failed to pull model ${model}: ${res.status} ${res.statusText}` ) } trace.log(`model pulled`) diff --git a/packages/sample/genaisrc/summarize-ollama-gemma2.genai.js b/packages/sample/genaisrc/summarize-ollama-gemma2.genai.js new file mode 100644 index 0000000000..7283f0ec81 --- /dev/null +++ b/packages/sample/genaisrc/summarize-ollama-gemma2.genai.js @@ -0,0 +1,15 @@ +script({ + model: "ollama:gemma2:2b", + title: "summarize with ollama gemma 2 2b", + system: [], + files: "src/rag/markdown.md", + tests: { + files: "src/rag/markdown.md", + keywords: "markdown", + }, +}) + +const file = def("FILE", env.files) + +$`Summarize ${file} in a sentence. Make it short. +`