Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple OpenAI API #47

Merged
merged 15 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ See [an example pom in our Spring Boot application](e2e-test-app/pom.xml)

### Simple chat completion

```java
final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(GPT_35_TURBO).chatCompletion("Hello World! Why is this phrase so famous?");
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved

final String resultMessage = result.getContent();
```

### Message history

```java
final var systemMessage =
new OpenAiChatSystemMessage().setContent("You are a helpful assistant");
Expand All @@ -192,7 +201,7 @@ final var request =
final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);

final String resultMessage = result.getChoices().get(0).getMessage().getContent();
final String resultMessage = result.getContent();
```

See [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ class OpenAiController {
@GetMapping("/chatCompletion")
@Nonnull
public static OpenAiChatCompletionOutput chatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest")));

return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion("Who is the prettiest");
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.StreamedDelta;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
import java.io.IOException;
import java.util.List;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
Expand Down Expand Up @@ -95,6 +97,22 @@ public static OpenAiClient withCustomDestination(@Nonnull final Destination dest
return new OpenAiClient(destination);
}

/**
* Generate a completion for the given prompt.
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
*
* @param prompt a text message.
* @return the completion output
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final var parameters =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(prompt)));
return chatCompletion(parameters);
}

/**
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
* Generate a completion for the given prompt.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull;
import lombok.EqualsAndHashCode;
import lombok.Getter;
Expand All @@ -28,6 +29,21 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput
@Getter(onMethod_ = @Nonnull)
private String systemFingerprint;

/**
* Get the message content from the output.
*
* <p>Note: If there are multiple choices only the first one is returned
*
* @return the message content or empty string.
*/
@Nonnull
public String getContent() {
if (getChoices().isEmpty()) {
return "";
}
return Objects.requireNonNullElse(getChoices().get(0).getMessage().getContent(), "");
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Add a streamed delta to the total output.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterPromptResults;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
Expand All @@ -30,9 +29,11 @@
import java.io.InputStream;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.function.Function;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.SneakyThrows;
import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.io.entity.InputStreamEntity;
Expand Down Expand Up @@ -84,7 +85,7 @@ void apiVersion() {
verify(exactly(2), postRequestedFor(anyUrl()).withoutQueryParam("api-version"));
}

private static Runnable[] chatCompletionCalls() {
private static Runnable[] errorHandlingCalls() {
MatKuhr marked this conversation as resolved.
Show resolved Hide resolved
return new Runnable[] {
() -> client.chatCompletion(new OpenAiChatCompletionParameters()),
() ->
Expand All @@ -96,7 +97,7 @@ private static Runnable[] chatCompletionCalls() {
}

@ParameterizedTest
@MethodSource("chatCompletionCalls")
@MethodSource("errorHandlingCalls")
void chatCompletionErrorHandling(@Nonnull final Runnable request) {
final var errorJson =
"""
Expand Down Expand Up @@ -169,23 +170,28 @@ void chatCompletionErrorHandling(@Nonnull final Runnable request) {
softly.assertAll();
}

@Test
void chatCompletion() throws IOException {
private static Callable<?>[] chatCompletionCalls() {
return new Callable[] {
() -> {
final var userMessage =
new OpenAiChatUserMessage().addText("Hello World! Why is this phrase so famous?");
final var request = new OpenAiChatCompletionParameters().setMessages(List.of(userMessage));
return client.chatCompletion(request);
},
() -> client.chatCompletion("Hello World! Why is this phrase so famous?")
};
}

@SneakyThrows
@ParameterizedTest
@MethodSource("chatCompletionCalls")
void chatCompletion(@Nonnull final Callable<OpenAiChatCompletionOutput> request) {
try (var inputStream = TEST_FILE_LOADER.apply("chatCompletionResponse.json")) {

final String response = new String(inputStream.readAllBytes());
stubFor(post("/chat/completions").willReturn(okJson(response)));

final var systemMessage =
new OpenAiChatMessage.OpenAiChatSystemMessage()
.setContent(
"You are a helpful, friendly and sometimes slightly snarky AI assistant!");
final var userMessage =
new OpenAiChatUserMessage().addText("Hello World! Why is this phrase so famous?");
final var request =
new OpenAiChatCompletionParameters().setMessages(List.of(systemMessage, userMessage));

final var result = client.chatCompletion(request);
final OpenAiChatCompletionOutput result = request.call();

assertThat(result).isNotNull();
assertThat(result.getCreated()).isEqualTo(1719300073);
Expand Down Expand Up @@ -252,7 +258,7 @@ void chatCompletion() throws IOException {
.withRequestBody(
equalToJson(
"""
{"messages":[{"role":"system","content":"You are a helpful, friendly and sometimes slightly snarky AI assistant!"},{"role":"user","content":[{"type":"text","text":"Hello World! Why is this phrase so famous?"}]}]}""")));
{"messages":[{"role":"user","content":[{"type":"text","text":"Hello World! Why is this phrase so famous?"}]}]}""")));
}
}

Expand Down