Skip to content

Commit

Permalink
feat(completion): some bugfixes and improvements to openai & anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
astappiev committed Jun 13, 2024
1 parent 0b5a29b commit 4dc15c3
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 83 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package de.l3s.interweb.connector.anthropic;

import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

import jakarta.enterprise.context.Dependent;

import io.smallrye.mutiny.Uni;
import org.eclipse.microprofile.config.ConfigProvider;
import org.eclipse.microprofile.rest.client.inject.RestClient;

import de.l3s.interweb.connector.anthropic.entity.AnthropicContent;
import de.l3s.interweb.connector.anthropic.entity.AnthropicUsage;
import de.l3s.interweb.connector.anthropic.entity.CompletionBody;
import de.l3s.interweb.core.ConnectorException;
Expand All @@ -20,13 +23,20 @@
import de.l3s.interweb.core.completion.Usage;
import de.l3s.interweb.core.completion.Choice;

import org.jboss.logging.Logger;

@Dependent
public class AnthropicConnector implements CompletionConnector {
private static final Logger log = Logger.getLogger(AnthropicConnector.class);

private static final Map<String, UsagePrice> models = Map.of(
"claude-3-opus-20240229", new UsagePrice(0.015, 0.075),
"claude-3-sonnet-20240229", new UsagePrice(0.003, 0.015),
"claude-3-haiku-20240307", new UsagePrice(0.00025, 0.00125)
"claude-3-haiku-20240307", new UsagePrice(0.00025, 0.00125),
// legacy
"claude-2.1", new UsagePrice(0.008, 0.024),
"claude-2.0", new UsagePrice(0.008, 0.024),
"claude-instant-1.2", new UsagePrice(0.0008, 0.0024)
);

@Override
Expand Down Expand Up @@ -60,21 +70,29 @@ public Uni<CompletionResults> complete(CompletionQuery query) throws ConnectorEx
anthropicUsage.getInputTokens(),
anthropicUsage.getOutputTokens()
);


AnthropicContent content = response.getContent().get(0);
Message message = new Message(Message.Role.user, content.getText());
Choice choice = new Choice(0, response.getStopReason(), message);


CompletionResults results = new CompletionResults(
query.getModel(),
usage,
choice,
Instant.now()
);


AtomicInteger index = new AtomicInteger();
List<Choice> choices = response.getContent().stream().map(content -> {
Message message = new Message(Message.Role.assistant, content.getText());
return new Choice(index.getAndIncrement(), response.getStopReason(), message);
}).toList();

CompletionResults results = new CompletionResults();
results.setModel(query.getModel());
results.setUsage(usage);
results.setChoices(choices);
results.setCreated(Instant.now());
return results;
});
}

@Override
public boolean validate() {
Optional<String> apikey = ConfigProvider.getConfig().getOptionalValue("connector.anthropic.apikey", String.class);
if (apikey.isEmpty() || apikey.get().isEmpty()) {
log.warn("API key is empty, please provide a valid API key in the configuration.");
return false;
}
return true;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Required properties, recommended to set via environment variables (for tests, create .env in the root of the module)
connector.anthropic.url=
connector.anthropic.url=https://api.anthropic.com
connector.anthropic.apikey=
quarkus.rest-client.anthropic.url=${connector.anthropic.url}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package de.l3s.interweb.connector.anthropic;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.*;

import jakarta.inject.Inject;

Expand All @@ -9,16 +9,16 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import de.l3s.interweb.connector.anthropic.entity.CompletionBody;
import de.l3s.interweb.core.ConnectorException;
import de.l3s.interweb.core.completion.Choice;
import de.l3s.interweb.core.completion.CompletionQuery;
import de.l3s.interweb.core.completion.CompletionResults;
import de.l3s.interweb.core.completion.Message;

import de.l3s.interweb.connector.anthropic.entity.CompletionBody;

@Disabled
@QuarkusTest
class AnthropicConnectorTest {
Expand All @@ -27,6 +27,11 @@ class AnthropicConnectorTest {
@Inject
AnthropicConnector connector;

@Test
void validate() throws ConnectorException {
assertTrue(connector.validate());
}

@Test
void complete() throws ConnectorException {
CompletionQuery query = new CompletionQuery();
Expand All @@ -35,19 +40,17 @@ void complete() throws ConnectorException {
query.setMaxTokens(100);
query.setModel("claude-3-haiku-20240307");


CompletionResults results = connector.complete(query).await().indefinitely();


assertEquals(1, results.getChoices().size());
System.out.println("Results for '" + query.getMessages().get(query.getMessages().size() - 1).getContent() + "':");
log.infov("user: {0}", query.getMessages().getLast().getContent());
for (Choice result : results.getChoices()) {
System.out.println(result.getMessage().getContent());
log.infov("assistant: {0}", result.getMessage().getContent());
}
}

@Test
void jsonBody() {
void jsonBody() throws JsonProcessingException {
CompletionQuery query = new CompletionQuery();
query.setModel("claude-3-haiku-20240307");
query.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system);
Expand All @@ -60,14 +63,7 @@ void jsonBody() {
// Print body as json
ObjectMapper mapper = new ObjectMapper();

try {
String jsonString = mapper.writeValueAsString(body);
System.out.println(jsonString);
} catch (Exception e) {
e.printStackTrace();
}

String jsonString = mapper.writeValueAsString(body);
assertEquals("{\"messages\":[{\"role\":\"user\",\"content\":\"What is your name?.\"},{\"role\":\"assistant\",\"content\":\"My name is Interweb Assistant.\"},{\"role\":\"user\",\"content\":\"Hi Interweb Assistant, I am a user.\"}],\"model\":\"claude-3-haiku-20240307\",\"system\":\"You are Interweb Assistant, a helpful chat bot.\",\"max_tokens\":800}", jsonString);
}


}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package de.l3s.interweb.connector.openai;

import io.quarkus.rest.client.reactive.ClientQueryParam;

import jakarta.ws.rs.*;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
Expand All @@ -18,6 +20,7 @@
@Produces(MediaType.APPLICATION_JSON)
@RegisterRestClient(configKey = "openai")
@ClientHeaderParam(name = "api-key", value = "${connector.openai.apikey}")
@ClientQueryParam(name = "api-version", value = "2024-02-01")
public interface OpenaiClient {

/**
Expand All @@ -26,7 +29,7 @@ public interface OpenaiClient {
*/
@POST
@Path("/{model}/chat/completions")
Uni<CompletionResponse> chatCompletions(@PathParam("model") String model, @QueryParam("api-version") String apiVersion, CompletionBody body);
Uni<CompletionResponse> chatCompletions(@PathParam("model") String model, CompletionBody body);

@ClientExceptionMapper
static RuntimeException toException(Response response) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package de.l3s.interweb.connector.openai;

import java.util.Map;
import java.util.Optional;

import jakarta.enterprise.context.Dependent;

import io.smallrye.mutiny.Uni;
import org.eclipse.microprofile.config.ConfigProvider;
import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.jboss.logging.Logger;

import de.l3s.interweb.connector.openai.entity.CompletionBody;
import de.l3s.interweb.core.ConnectorException;
Expand All @@ -16,15 +19,19 @@

@Dependent
public class OpenaiConnector implements CompletionConnector {
private static final Logger log = Logger.getLogger(OpenaiConnector.class);

private static final String version = "2023-05-15";
/**
* UK South, US-Dollar prices (as EUR price is automatically converted from USD it's floating a bit)
* https://azure.microsoft.com/de-de/pricing/details/cognitive-services/openai-service/
*/
private static final Map<String, UsagePrice> models = Map.of(
"gpt-35-turbo", new UsagePrice(0.0014, 0.0019),
"gpt-35-turbo-16k", new UsagePrice(0.003, 0.004),
"gpt-35-turbo-1106", new UsagePrice(0.001, 0.002),
"gpt-4-turbo", new UsagePrice(0.010, 0.028),
"gpt-4", new UsagePrice(0.028, 0.055),
"gpt-4-32k", new UsagePrice(0.055, 0.109)
"gpt-35-turbo", new UsagePrice(0.002, 0.002),
"gpt-35-turbo-16k", new UsagePrice(0.003, 0.004),
"gpt-35-turbo-1106", new UsagePrice(0.001, 0.002),
"gpt-4-turbo", new UsagePrice(0.01, 0.03),
"gpt-4", new UsagePrice(0.03, 0.06),
"gpt-4-32k", new UsagePrice(0.06, 0.12)
);

@Override
Expand Down Expand Up @@ -52,15 +59,23 @@ public UsagePrice getPrice(String model) {

@Override
public Uni<CompletionResults> complete(CompletionQuery query) throws ConnectorException {
return openai.chatCompletions(query.getModel(), version, new CompletionBody(query)).map(response -> {
CompletionResults results = new CompletionResults(
query.getModel(),
response.getUsage(),
response.getChoices(),
response.getCreated()
);

return openai.chatCompletions(query.getModel(), new CompletionBody(query)).map(response -> {
CompletionResults results = new CompletionResults();
results.setModel(query.getModel());
results.setCreated(response.getCreated());
results.setChoices(response.getChoices());
results.setUsage(response.getUsage());
return results;
});
}

@Override
public boolean validate() {
Optional<String> apikey = ConfigProvider.getConfig().getOptionalValue("connector.openai.apikey", String.class);
if (apikey.isEmpty() || apikey.get().length() < 32) {
log.warn("API key is empty, please provide a valid API key in the configuration.");
return false;
}
return true;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Required properties, recommended to set via environment variables (for tests, create .env in the root of the module)
connector.openai.url=
connector.openai.url=https://YOUR_RESOURCE_NAME.openai.azure.com
connector.openai.apikey=
quarkus.rest-client.openai.url=${connector.openai.url}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package de.l3s.interweb.connector.openai;

import static org.junit.jupiter.api.Assertions.assertEquals;

import jakarta.inject.Inject;

import io.quarkus.test.junit.QuarkusTest;
Expand All @@ -15,6 +13,8 @@
import de.l3s.interweb.core.completion.CompletionResults;
import de.l3s.interweb.core.completion.Message;

import static org.junit.jupiter.api.Assertions.*;

@Disabled
@QuarkusTest
class OpenaiConnectorTest {
Expand All @@ -23,6 +23,11 @@ class OpenaiConnectorTest {
@Inject
OpenaiConnector connector;

@Test
void validate() throws ConnectorException {
assertTrue(connector.validate());
}

@Test
void complete() throws ConnectorException {
CompletionQuery query = new CompletionQuery();
Expand All @@ -32,9 +37,9 @@ void complete() throws ConnectorException {
CompletionResults results = connector.complete(query).await().indefinitely();

assertEquals(1, results.getChoices().size());
System.out.println("Results for '" + query.getMessages().get(query.getMessages().size() - 1).getContent() + "':");
log.infov("user: {0}", query.getMessages().getLast().getContent());
for (Choice result : results.getChoices()) {
System.out.println(result.getMessage().getContent());
log.infov("assistant: {0}", result.getMessage().getContent());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ default void fillResult(ConnectorResults results, long elapsedTime) {
results.setServiceUrl(getBaseUrl());
results.setElapsedTime(elapsedTime);
}

default boolean validate() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ public class Choice extends ConnectorResults {
private String finishReason;
private Message message;

public Choice() {
}

public Choice(int index, String finishReason, Message message) {
this.index = index;
this.finishReason = finishReason;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,6 @@ public class CompletionResults extends Results<Choice> {
private UsageCost cost;
private Instant created;

public CompletionResults(
String model,
Usage usage,
Choice choice,
Instant created
) {
this.model = model;
this.usage = usage;
this.created = created;

add(choice);
}

public CompletionResults(
String model,
Usage usage,
List<Choice> choices,
Instant created
) {
this.model = model;
this.usage = usage;
this.created = created;

add(choices);
}

public UUID getChatId() {
return chatId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ public class Usage {
@JsonProperty("total_tokens")
private int totalTokens;

public Usage() {
}

public Usage(int promptTokens, int completionTokens) {
this.promptTokens = promptTokens;
this.completionTokens = completionTokens;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ public class ChatService {
public ChatService(@All List<CompletionConnector> connectors) {
services = new HashMap<>();
connectors.forEach(connector -> {
for (String model : connector.getModels()) {
services.put(model, connector);
if (connector.validate()) {
for (String model : connector.getModels()) {
services.put(model, connector);
}
} else {
log.error("Connector skipped due to failed validation: " + connector.getClass().getName());
}
});

Expand Down

0 comments on commit 4dc15c3

Please sign in to comment.