Skip to content

Commit

Permalink
feat: add embeddings from Ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
astappiev committed Nov 29, 2024
1 parent 0a32046 commit d262949
Show file tree
Hide file tree
Showing 18 changed files with 467 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.common.util.RestMediaType;

import de.l3s.interweb.connector.ollama.entity.ChatBody;
import de.l3s.interweb.connector.ollama.entity.ChatResponse;
import de.l3s.interweb.connector.ollama.entity.ChatStreamBody;
import de.l3s.interweb.connector.ollama.entity.TagsResponse;
import de.l3s.interweb.connector.ollama.entity.*;
import de.l3s.interweb.core.ConnectorException;

@Path("")
Expand All @@ -35,6 +32,10 @@ public interface OllamaClient {
@Produces(RestMediaType.APPLICATION_NDJSON)
Multi<ChatResponse> chatStream(ChatStreamBody body);

@POST
@Path("/api/embed")
Uni<EmbedResponse> embed(EmbedBody body);

@GET
@Path("/api/tags")
Uni<TagsResponse> tags();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.jboss.logging.Logger;

import de.l3s.interweb.connector.ollama.entity.ChatBody;
import de.l3s.interweb.connector.ollama.entity.ChatResponse;
import de.l3s.interweb.connector.ollama.entity.ChatStreamBody;
import de.l3s.interweb.connector.ollama.entity.*;
import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.chat.ChatConnector;
import de.l3s.interweb.core.chat.CompletionsQuery;
import de.l3s.interweb.core.chat.CompletionsResults;
import de.l3s.interweb.core.embeddings.EmbeddingConnector;
import de.l3s.interweb.core.embeddings.EmbeddingsQuery;
import de.l3s.interweb.core.embeddings.EmbeddingsResults;
import de.l3s.interweb.core.models.Model;
import de.l3s.interweb.core.models.UsagePrice;

@Dependent
public class OllamaConnector implements ChatConnector {
public class OllamaConnector implements ChatConnector, EmbeddingConnector {
private static final Logger log = Logger.getLogger(OllamaConnector.class);

@RestClient
Expand Down Expand Up @@ -65,6 +66,12 @@ public Multi<CompletionsResults> completionsStream(CompletionsQuery query) throw
return ollama.chatStream(body).map(ChatResponse::toCompletionResults);
}

@Override
public Uni<EmbeddingsResults> embeddings(EmbeddingsQuery query) throws ConnectorException {
final EmbedBody body = new EmbedBody(query);
return ollama.embed(body).map(EmbedResponse::toCompletionResults);
}

@Override
public boolean validate() {
Optional<String> apikey = ConfigProvider.getConfig().getOptionalValue("connector.ollama.url", String.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package de.l3s.interweb.connector.ollama.entity;

import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonInclude;

import de.l3s.interweb.core.embeddings.EmbeddingsQuery;

@RegisterForReflection
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbedBody {
private String model;
private List<String> input;
private Boolean truncate;

public EmbedBody(EmbeddingsQuery query) {
this.model = query.getModel();
this.input = query.getInput();
this.truncate = query.getTruncate();
}

public String getModel() {
return model;
}

public List<String> getInput() {
return input;
}

public Boolean getTruncate() {
return truncate;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package de.l3s.interweb.connector.ollama.entity;

import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonProperty;

import de.l3s.interweb.core.chat.Usage;
import de.l3s.interweb.core.embeddings.Embedding;
import de.l3s.interweb.core.embeddings.EmbeddingsResults;

@RegisterForReflection
public class EmbedResponse {
private String model;
private List<List<Double>> embeddings;
@JsonProperty("total_duration")
private Long totalDuration;
@JsonProperty("load_duration")
private Long loadDuration;
@JsonProperty("prompt_eval_count")
private Integer promptEvalCount;

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public List<List<Double>> getEmbeddings() {
return embeddings;
}

public void setEmbeddings(List<List<Double>> embeddings) {
this.embeddings = embeddings;
}

public Long getTotalDuration() {
return totalDuration;
}

public void setTotalDuration(Long totalDuration) {
this.totalDuration = totalDuration;
}

public Long getLoadDuration() {
return loadDuration;
}

public void setLoadDuration(Long loadDuration) {
this.loadDuration = loadDuration;
}

public Integer getPromptEvalCount() {
return promptEvalCount;
}

public void setPromptEvalCount(Integer promptEvalCount) {
this.promptEvalCount = promptEvalCount;
}

public EmbeddingsResults toCompletionResults() {
EmbeddingsResults results = new EmbeddingsResults();
results.setModel(model);
results.setElapsedTime(totalDuration);

for (int i = 0; i < embeddings.size(); i++) {
Embedding embedding = new Embedding();
embedding.setEmbedding(embeddings.get(i));
embedding.setIndex(i);
results.addData(embedding);
}

Usage usage = new Usage();
usage.setPromptTokens(promptEvalCount);
results.setUsage(usage);
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

@RegisterForReflection
public class UsageCost {
private double prompt;
private double completion;
@JsonInclude(JsonInclude.Include.ALWAYS)
private double total;
@JsonProperty("chat_total")
private double chatTotal;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package de.l3s.interweb.core.embeddings;

import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

@RegisterForReflection
public class Embedding {
private final String object = "embedding";
private Integer index;
private List<Double> embedding;

public String getObject() {
return object;
}

public Integer getIndex() {
return index;
}

public void setIndex(Integer index) {
this.index = index;
}

public List<Double> getEmbedding() {
return embedding;
}

public void setEmbedding(List<Double> embedding) {
this.embedding = embedding;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package de.l3s.interweb.core.embeddings;

import io.smallrye.mutiny.Uni;

import de.l3s.interweb.core.Connector;
import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.models.ModelsConnector;

public interface EmbeddingConnector extends ModelsConnector, Connector {

Uni<EmbeddingsResults> embeddings(EmbeddingsQuery query) throws ConnectorException;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package de.l3s.interweb.core.embeddings;

import java.util.List;

import jakarta.validation.constraints.NotEmpty;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;

import de.l3s.interweb.core.util.StringOrArrayDeserializer;

@RegisterForReflection
@JsonInclude(JsonInclude.Include.NON_NULL)
public class EmbeddingsQuery {

/**
* ID of the model to use. Use `GET /models` to retrieve available models.
*/
@NotEmpty
@JsonProperty("model")
private String model;

/**
* Input text to embed, encoded as a string or array of tokens.
* To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
*/
@NotEmpty
@JsonProperty("input")
@JsonDeserialize(using = StringOrArrayDeserializer.class)
@JsonFormat(shape = JsonFormat.Shape.ARRAY)
private List<String> input;

/**
* The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
*/
@JsonProperty("dimensions")
private Integer dimensions;

/**
* Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. Defaults to true
*/
@JsonProperty(value = "truncate")
private Boolean truncate;

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public List<String> getInput() {
return input;
}

public void setInput(List<String> input) {
this.input = input;
}

public Integer getDimensions() {
return dimensions;
}

public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}

public Boolean getTruncate() {
return truncate;
}

public void setTruncate(Boolean truncate) {
this.truncate = truncate;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package de.l3s.interweb.core.embeddings;

import java.util.ArrayList;
import java.util.List;

import io.quarkus.runtime.annotations.RegisterForReflection;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;

import de.l3s.interweb.core.chat.Usage;
import de.l3s.interweb.core.chat.UsageCost;

@RegisterForReflection
@JsonPropertyOrder({"object", "model", "data", "usage", "estimated_cost", "elapsed_time"})
public class EmbeddingsResults {
private final String object = "list";
private String model;
private Usage usage;
private List<Embedding> data;
@JsonProperty(value = "estimated_cost")
private UsageCost cost;
@JsonProperty("elapsed_time")
private Long elapsedTime;

public String getObject() {
return this.object;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public List<Embedding> getData() {
return data;
}

public void setData(List<Embedding> data) {
this.data = data;
}

public void addData(Embedding embedding) {
if (this.data == null) {
this.data = new ArrayList<>();
}

this.data.add(embedding);
}

public Usage getUsage() {
return usage;
}

public void setUsage(Usage usage) {
this.usage = usage;
}

public UsageCost getCost() {
return cost;
}

public void setCost(UsageCost cost) {
this.cost = cost;
}

public Long getElapsedTime() {
return elapsedTime;
}

public void setElapsedTime(Long elapsedTime) {
this.elapsedTime = elapsedTime;
}
}
Loading

0 comments on commit d262949

Please sign in to comment.