Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple OpenAI API #47

Merged
merged 15 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,29 @@ See [an example pom in our Spring Boot application](e2e-test-app/pom.xml)

### Simple chat completion

```java
final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(GPT_35_TURBO)
.withSystemPrompt("You are a helpful AI")
.chatCompletion("Hello World! Why is this phrase so famous?");

final String resultMessage = result.getContent();
```

### Message history

```java
final var systemMessage =
new OpenAiChatSystemMessage().setContent("You are a helpful assistant");
final var userMessage =
new OpenAiChatUserMessage().addText("Hello World! Why is this phrase so famous?");
final var request =
new OpenAiChatCompletionParameters().setMessages(List.of(systemMessage, userMessage));
new OpenAiChatCompletionParameters().addMessages(systemMessage, userMessage);

final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);

final String resultMessage = result.getChoices().get(0).getMessage().getContent();
final String resultMessage = result.getContent();
```

See [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java)
Expand All @@ -213,14 +224,10 @@ This is a blocking example for streaming and printing directly to the console:
```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);

// try-with-resources on stream ensures the connection will be closed
try( Stream<String> stream = client.streamChatCompletion(request)) {
try( Stream<String> stream = client.streamChatCompletion(msg)) {
stream.forEach(deltaString -> {
System.out.print(deltaString);
System.out.flush();
Expand All @@ -239,7 +246,7 @@ String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));
.addMessages(new OpenAiChatUserMessage().addText(msg));

OpenAiChatCompletionOutput totalOutput = new OpenAiChatCompletionOutput();
OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ class OpenAiController {
@GetMapping("/chatCompletion")
@Nonnull
public static OpenAiChatCompletionOutput chatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest")));

return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion("Who is the prettiest");
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -59,8 +55,7 @@ public static OpenAiChatCompletionOutput chatCompletion() {
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas() {
final var msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));
new OpenAiChatCompletionParameters().addMessages(new OpenAiChatUserMessage().addText(msg));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);

Expand Down Expand Up @@ -103,15 +98,11 @@ private static String objectToJson(@Nonnull final Object obj) {
@GetMapping("/streamChatCompletion")
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText(
"Can you give me the first 100 numbers of the Fibonacci sequence?")));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request);
final var stream =
OpenAiClient.forModel(GPT_35_TURBO)
.withSystemPrompt("Be a good, honest AI and answer the following question:")
.streamChatCompletion(
"Can you give me the first 100 numbers of the Fibonacci sequence?");

final var emitter = new ResponseBodyEmitter();

Expand Down Expand Up @@ -150,13 +141,12 @@ private static void send(
public static OpenAiChatCompletionOutput chatCompletionImage() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText("Describe the following image.")
.addImage(
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png",
ImageDetailLevel.HIGH)));
.addMessages(
new OpenAiChatUserMessage()
.addText("Describe the following image.")
.addImage(
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png",
ImageDetailLevel.HIGH));

return OpenAiClient.forModel(GPT_4O).chatCompletion(request);
}
Expand All @@ -180,7 +170,7 @@ public static OpenAiChatCompletionOutput chatCompletionTools() {
final var tool = new OpenAiChatCompletionTool().setType(FUNCTION).setFunction(function);
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(question)))
.addMessages(new OpenAiChatUserMessage().addText(question))
.setTools(List.of(tool))
.setToolChoiceFunction("fibonacci");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -36,7 +35,7 @@ void chatCompletionImage() {
void streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest?")));
.addMessages(new OpenAiChatUserMessage().addText("Who is the prettiest?"));

final var totalOutput = new OpenAiChatCompletionOutput();
final var filledDeltaCount = new AtomicInteger(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.StreamedDelta;
Expand All @@ -34,6 +36,7 @@
public final class OpenAiClient {
private static final String DEFAULT_API_VERSION = "2024-02-01";
static final ObjectMapper JACKSON;
private String systemPrompt = null;

static {
JACKSON =
Expand Down Expand Up @@ -95,6 +98,36 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest
return new OpenAiClient(destination);
}

/**
* Add a system prompt before user prompts.
*
* @param systemPrompt the system prompt
* @return the client
*/
@Nonnull
public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
this.systemPrompt = systemPrompt;
return this;
}

/**
* Generate a completion for the given user prompt.
*
* @param prompt a text message.
* @return the completion output
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
if (systemPrompt != null) {
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
}
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
return chatCompletion(parameters);
}

/**
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
* Generate a completion for the given prompt.
*
Expand All @@ -105,19 +138,25 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
warnIfUnsupportedUsage();
return execute("/chat/completions", parameters, OpenAiChatCompletionOutput.class);
}

/**
* Generate a completion for the given prompt.
*
* @param parameters the prompt, including messages and other parameters.
* @param prompt a text message.
* @return A stream of message deltas
* @throws OpenAiClientException if the request fails or if the finish reason is content_filter
*/
@Nonnull
public Stream<String> streamChatCompletion(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
public Stream<String> streamChatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
if (systemPrompt != null) {
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
}
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
return streamChatCompletionDeltas(parameters)
.peek(OpenAiClient::throwOnContentFilter)
.map(OpenAiChatCompletionDelta::getDeltaContent);
Expand All @@ -140,10 +179,18 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
@Nonnull
public Stream<OpenAiChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
warnIfUnsupportedUsage();
parameters.enableStreaming();
return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class);
}

private void warnIfUnsupportedUsage() {
if (systemPrompt != null) {
log.warn(
"Previously set messages will be ignored, set it as an argument of this method instead.");
}
}

/**
* Get a vector representation of a given input that can be easily consumed by machine learning
* models and algorithms.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.sap.ai.sdk.foundationmodels.openai.model;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClientException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull;
import lombok.EqualsAndHashCode;
import lombok.Getter;
Expand All @@ -28,6 +30,25 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput
@Getter(onMethod_ = @Nonnull)
private String systemFingerprint;

/**
* Get the message content from the output.
*
* <p>Note: If there are multiple choices only the first one is returned
*
* @return the message content or empty string.
* @throws OpenAiClientException if the content filter filtered the output.
*/
@Nonnull
public String getContent() throws OpenAiClientException {
if (getChoices().isEmpty()) {
return "";
}
if ("content_filter".equals(getChoices().get(0).getFinishReason())) {
throw new OpenAiClientException("Content filter filtered the output.");
}
return Objects.requireNonNullElse(getChoices().get(0).getMessage().getContent(), "");
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Add a streamed delta to the total output.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonValue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
Expand All @@ -25,7 +26,6 @@
public class OpenAiChatCompletionParameters extends OpenAiCompletionParameters {
/** A list of messages comprising the conversation so far. */
@JsonProperty("messages")
@Setter(onParam_ = @Nonnull)
private List<OpenAiChatMessage> messages;

/**
Expand Down Expand Up @@ -197,4 +197,19 @@ private record Function(@JsonProperty("name") @Nonnull String name) {}
public OpenAiChatCompletionParameters setStop(@Nullable final String... values) {
return (OpenAiChatCompletionParameters) super.setStop(values);
}

/**
* Add messages to the conversation.
*
* @param messages The messages to add.
* @return this instance for chaining.
*/
@Nonnull
public OpenAiChatCompletionParameters addMessages(@Nonnull final OpenAiChatMessage... messages) {
if (this.messages == null) {
this.messages = new ArrayList<>();
}
this.messages.addAll(Arrays.asList(messages));
return this;
}
}
Loading