Skip to content

Commit

Permalink
Base API URL added for embedding models
Browse files Browse the repository at this point in the history
Jupyter AI currently allows the user to call a model at a URL (location) different from the default one by specifying a selected Base API URL. This can be done for Ollama, OpenAI provider models. However, for these providers, there is no way to change the API URL for embedding models when using the `/learn` command in RAG mode. This PR adds an extra field to make this feasible.

Tested as follows for Ollama:
[1] Start the Ollama system from port 11435 instead 11434 (the default):
`OLLAMA_HOST=127.0.0.1:11435 ollama serve`
[2] Set the Base API URL:

[3] Check that the new API URL works:
  • Loading branch information
srdas committed Dec 4, 2024
1 parent 342bb7b commit db026c6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain_ollama import ChatOllama, OllamaEmbeddings

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import BaseProvider, EnvAuthStrategy, TextField
from ..providers import BaseProvider, TextField


class OllamaProvider(BaseProvider, ChatOllama):
Expand All @@ -23,10 +23,14 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
id = "ollama"
name = "Ollama"
# source: https://ollama.com/library
model_id_key = "model"
models = [
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
]
model_id_key = "model"
registry = True
fields = [
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
model_id_key = "model"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
registry = True
fields = [
TextField(key="openai_api_base", label="Base API URL (optional)", format="text"),
]


class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings):
Expand All @@ -122,5 +126,7 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding
auth_strategy = EnvAuthStrategy(
name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
)

registry = True
fields = [
TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"),
]
47 changes: 28 additions & 19 deletions packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -376,26 +376,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
{/* Embedding model section */}
<h2 className="jp-ai-ChatSettings-header">Embedding model</h2>
{server.emProviders.providers.length > 0 ? (
<Select
value={emGlobalId}
label="Embedding model"
onChange={e => {
const emGid = e.target.value === 'null' ? null : e.target.value;
setEmGlobalId(emGid);
}}
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
>
<MenuItem value="null">None</MenuItem>
{server.emProviders.providers.map(emp =>
emp.models
.filter(em => em !== '*') // TODO: support registry providers
.map(em => (
<MenuItem value={`${emp.id}:${em}`}>
{emp.name} :: {em}
</MenuItem>
))
<Box>
<Select
value={emGlobalId}
label="Embedding model"
onChange={e => {
const emGid = e.target.value === 'null' ? null : e.target.value;
setEmGlobalId(emGid);
}}
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
>
<MenuItem value="null">None</MenuItem>
{server.emProviders.providers.map(emp =>
emp.models
.filter(em => em !== '*') // TODO: support registry providers
.map(em => (
<MenuItem value={`${emp.id}:${em}`}>
{emp.name} :: {em}
</MenuItem>
))
)}
</Select>
{emGlobalId && (
<ModelFields
fields={emProvider?.fields}
values={fields}
onChange={setFields}
/>
)}
</Select>
</Box>
) : (
<p>No embedding models available.</p>
)}
Expand Down

0 comments on commit db026c6

Please sign in to comment.