diff --git a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java index f2e99c37..4f9402de 100644 --- a/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java +++ b/connectors/OpenaiConnector/src/main/java/de/l3s/interweb/connector/openai/OpenaiClient.java @@ -20,6 +20,10 @@ @ClientHeaderParam(name = "api-key", value = "${connector.openai.apikey}") public interface OpenaiClient { + /** + * OpenAI Completion API + * https://learn.microsoft.com/en-us/azure/ai-services/openai/reference + */ @POST @Path("/{model}/chat/completions") Uni chatCompletions(@PathParam("model") String model, @QueryParam("api-version") String apiVersion, CompletionsBody body); diff --git a/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java b/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java index e4fdf45c..ab4d0dea 100644 --- a/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java +++ b/interweb-client/src/main/java/de/l3s/interweb/client/Interweb.java @@ -46,11 +46,11 @@ public Interweb(String serverUrl, String apikey) { } public SearchResults search(SearchQuery query) throws InterwebException { - return sendRequest("/search", query, SearchResults.class); + return sendPostRequest("/search", query, SearchResults.class); } public SuggestResults suggest(SuggestQuery query) throws InterwebException { - return sendRequest("/suggest", query, SuggestResults.class); + return sendPostRequest("/suggest", query, SuggestResults.class); } public SuggestResults suggest(String query, String language) throws InterwebException { @@ -58,30 +58,34 @@ public SuggestResults suggest(String query, String language) throws InterwebExce params.setQuery(query); params.setLanguage(language); - return sendRequest("/suggest", params, SuggestResults.class); + return sendPostRequest("/suggest", params, SuggestResults.class); } public DescribeResults describe(DescribeQuery query) throws InterwebException { - return sendRequest("/describe", query, DescribeResults.class); + return sendPostRequest("/describe", query, DescribeResults.class); } public DescribeResults describe(String link) throws InterwebException { final DescribeQuery params = new DescribeQuery(); params.setLink(link); - return sendRequest("/describe", params, DescribeResults.class); + return sendPostRequest("/describe", params, DescribeResults.class); } - public List conversations(String user) throws InterwebException { - return sendRequest("/chat", Map.of("user", user), new TypeReference<>() {}); + public CompletionResults completions(CompletionQuery query) throws InterwebException { + return sendPostRequest("/chat/completions", query, CompletionResults.class); } - public CompletionResults completion(CompletionQuery query) throws InterwebException { - return sendRequest("/chat/completions", query, CompletionResults.class); + public List chatAll(String user) throws InterwebException { + return sendGetRequest("/chat", Map.of("user", user), new TypeReference<>() {}); } - public void completion(Conversation conversation) throws InterwebException { - CompletionResults results = sendRequest("/chat/completions", conversation, CompletionResults.class); + public Conversation chatById(String uuid) throws InterwebException { + return sendGetRequest("/chat/" + uuid, null, new TypeReference<>() {}); + } + + public void chatComplete(Conversation conversation) throws InterwebException { + CompletionResults results = sendPostRequest("/chat/completions", conversation, CompletionResults.class); if (results.getLastMessage() != null) { conversation.addMessage(results.getLastMessage()); } @@ -116,7 +120,7 @@ private URI createRequestUri(final String apiPath, final Map par return URI.create(sb.toString()); } - public T sendRequest(final String apiPath, final Map params, TypeReference valueType) throws InterwebException { + public T sendGetRequest(final String apiPath, final Map params, TypeReference valueType) throws InterwebException { try { final URI uri = createRequestUri(apiPath, params); HttpRequest.Builder builder = HttpRequest.newBuilder().uri(uri).GET(); @@ -128,7 +132,7 @@ public T sendRequest(final String apiPath, final Map params, } } - public T sendRequest(final String apiPath, final Object query, Class valueType) throws InterwebException { + public T sendPostRequest(final String apiPath, final Object query, Class valueType) throws InterwebException { try { String body = mapper.writeValueAsString(query); diff --git a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java index fff35930..8697ecf0 100644 --- a/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java +++ b/interweb-client/src/test/java/de/l3s/interweb/client/InterwebCompletionTest.java @@ -25,15 +25,7 @@ public InterwebCompletionTest(@ConfigProperty(name = "interweb.server") String s } @Test - void conversationsTest() throws InterwebException { - List response = interweb.conversations("user1"); - - assertFalse(response.isEmpty()); - assertFalse(response.get(0).getTitle().isEmpty()); - } - - @Test - void chatCompletionsTest() throws InterwebException { + void completionsTest() throws InterwebException { CompletionQuery query = new CompletionQuery(); query.setUser("user1"); query.setGenerateTitle(true); @@ -41,7 +33,7 @@ void chatCompletionsTest() throws InterwebException { query.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system); query.addMessage("What is your name?.", Message.Role.user); - CompletionResults response = interweb.completion(query); + CompletionResults response = interweb.completions(query); assertFalse(response.getResults().isEmpty()); for (Choice result : response.getResults()) { @@ -51,30 +43,50 @@ void chatCompletionsTest() throws InterwebException { } @Test - void conversationTest() throws InterwebException { - Conversation query = new Conversation(); - query.setUser("user1"); - query.setGenerateTitle(true); - query.setModel("gpt-35-turbo"); - query.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system); - query.addMessage("What is your name?.", Message.Role.user); + void chatAllTest() throws InterwebException { + List response = interweb.chatAll("user1"); + + assertFalse(response.isEmpty()); + assertFalse(response.get(0).getTitle().isEmpty()); + } + + @Test + void chatByIdTest() throws InterwebException { + Conversation conversation = interweb.chatById("ef235b94-09a0-4b0e-b1cb-c06b6a3adf6c"); + assertNotNull(conversation.getTitle()); + assertNotNull(conversation.getModel()); + + for (Message message : conversation.getMessages()) { + assertNotNull(message.getContent()); + System.out.println(message.getContent()); + } + } + + @Test + void chatCompleteTest() throws InterwebException { + Conversation conversation = new Conversation(); + conversation.setUser("user1"); + conversation.setGenerateTitle(true); + conversation.setModel("gpt-35-turbo"); + conversation.addMessage("You are Interweb Assistant, a helpful chat bot.", Message.Role.system); + conversation.addMessage("What is your name?.", Message.Role.user); - assertNull(query.getTitle()); - assertNull(query.getEstimatedCost()); - assertEquals(2, query.getMessages().size()); + assertNull(conversation.getTitle()); + assertNull(conversation.getEstimatedCost()); + assertEquals(2, conversation.getMessages().size()); - interweb.completion(query); + interweb.chatComplete(conversation); - assertNotNull(query.getTitle()); - assertNotNull(query.getEstimatedCost()); - assertEquals(3, query.getMessages().size()); + assertNotNull(conversation.getTitle()); + assertNotNull(conversation.getEstimatedCost()); + assertEquals(3, conversation.getMessages().size()); - query.addMessage("That's time now?", Message.Role.user); - interweb.completion(query); + conversation.addMessage("That's time now?", Message.Role.user); + interweb.chatComplete(conversation); - assertEquals(5, query.getMessages().size()); + assertEquals(5, conversation.getMessages().size()); - for (Message result : query.getMessages()) { + for (Message result : conversation.getMessages()) { System.out.println(result.getContent()); } } diff --git a/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java b/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java index acb5a4fb..20d4cb6c 100644 --- a/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java +++ b/interweb-core/src/main/java/de/l3s/interweb/core/completion/CompletionQuery.java @@ -53,7 +53,7 @@ public class CompletionQuery { @Min(0) @Max(2) @JsonProperty("temperature") - private Double temperature = 1.0; + private Double temperature; /** * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results @@ -64,7 +64,7 @@ public class CompletionQuery { @Min(0) @Max(1) @JsonProperty("top_p") - private Double topP = 1.0; + private Double topP; /** * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, @@ -73,7 +73,7 @@ public class CompletionQuery { @Min(-2) @Max(2) @JsonProperty("frequency_penalty") - private Double frequencyPenalty = 0.0; + private Double frequencyPenalty; /** * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, @@ -82,7 +82,7 @@ public class CompletionQuery { @Min(-2) @Max(2) @JsonProperty("presence_penalty") - private Double presencePenalty = 0.0; + private Double presencePenalty; /** * The maximum number of tokens to generate in the chat completion. Defaults to 800. @@ -90,7 +90,7 @@ public class CompletionQuery { * The total length of input tokens and generated tokens is limited by the model's context length. */ @JsonProperty("max_tokens") - private Integer maxTokens = 800; + private Integer maxTokens; /** * Whether the conversation should be summarized into a title. Defaults to false. diff --git a/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java b/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java index 5b9b7efa..56a0d4d6 100644 --- a/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java +++ b/interweb-core/src/main/java/de/l3s/interweb/core/completion/Conversation.java @@ -3,7 +3,9 @@ import java.time.Instant; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +@JsonPropertyOrder({"id", "title", "model", "messages", "used_tokens", "estimated_cost", "created"}) public class Conversation extends CompletionQuery { @JsonProperty("title") private String title; diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/ObjectMapperConfig.java b/interweb-server/src/main/java/de/l3s/interweb/server/ObjectMapperConfig.java index 1b77ad45..f2e8e32b 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/ObjectMapperConfig.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/ObjectMapperConfig.java @@ -13,7 +13,7 @@ public class ObjectMapperConfig implements ObjectMapperCustomizer { public void customize(ObjectMapper config) { config.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); - config.setSerializationInclusion(JsonInclude.Include.NON_EMPTY); + config.setSerializationInclusion(JsonInclude.Include.NON_DEFAULT); config.registerModule(new JavaTimeModule()); } } diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java index 7d1e7828..e7b01c85 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatMessage.java @@ -45,12 +45,17 @@ public ChatMessage() { } public ChatMessage(final Message message) { + this.id = message.getId(); this.role = message.getRole(); this.content = message.getContent(); + this.created = message.getCreated(); } public Message toMessage() { - return new Message(role, content); + Message message = new Message(role, content); + message.setCreated(created); + message.setId(id); + return message; } public static Uni> listByChat(UUID id) { diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java index 7d1aded3..d7ca78a5 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java @@ -16,10 +16,7 @@ import io.smallrye.mutiny.Uni; import org.hibernate.reactive.mutiny.Mutiny; -import de.l3s.interweb.core.completion.CompletionQuery; -import de.l3s.interweb.core.completion.CompletionResults; -import de.l3s.interweb.core.completion.Message; -import de.l3s.interweb.core.completion.UsagePrice; +import de.l3s.interweb.core.completion.*; import de.l3s.interweb.server.principal.Consumer; @Path("/chat") @@ -36,7 +33,7 @@ public class ChatResource { public Uni> chats(@QueryParam("user") String user) { Consumer consumer = securityIdentity.getCredential(Consumer.class); if (user == null) { - return Chat.list("consumer.id = ?1 AND user = NULL ORDER BY created DESC LIMIT 20", consumer.id); + return Chat.list("consumer.id = ?1 AND user IS NULL ORDER BY created DESC LIMIT 20", consumer.id); } else { return Chat.list("consumer.id = ?1 AND user = ?2 ORDER BY created DESC LIMIT 20", consumer.id, user); } @@ -45,8 +42,17 @@ public Uni> chats(@QueryParam("user") String user) { @GET @Authenticated @Path("{uuid}") - public Uni> chat(@PathParam("uuid") UUID id) { - return ChatMessage.listByChat(id); + public Uni chat(@PathParam("uuid") UUID id) { + return Chat.findById(id).call(chat -> Mutiny.fetch(chat.getMessages())).map(chat -> { + Conversation conversation = new Conversation(); + conversation.setId(chat.id); + conversation.setTitle(chat.title); + conversation.setUsedTokens(chat.usedTokens); + conversation.setEstimatedCost(chat.estimatedCost); + conversation.setCreated(chat.created); + conversation.setMessages(chat.getMessages().stream().map(ChatMessage::toMessage).collect(Collectors.toList())); + return conversation; + }); } @GET