diff --git a/README.md b/README.md index 09dba116..9121fa2b 100644 --- a/README.md +++ b/README.md @@ -197,13 +197,78 @@ final String resultMessage = result.getChoices().get(0).getMessage().getContent( See [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java) -### Chat completion with a model not in defined `OpenAiModel` +### Chat completion with a model not defined in `OpenAiModel` ```java final OpenAiChatCompletionOutput result = OpenAiClient.forModel(new OpenAiModel("model")).chatCompletion(request); ``` +### Stream chat completion + +It's possible to pass a stream of chat completion delta elements, e.g. from the application backend to the frontend in real-time. + +#### Stream the chat completion asynchronously +This is a blocking example for streaming and printing directly to the console: +```java +String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?"; + +OpenAiChatCompletionParameters request = + new OpenAiChatCompletionParameters() + .setMessages(List.of(new OpenAiChatUserMessage().addText(msg))); + +OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO); + +// try-with-resources on stream ensures the connection will be closed +try( Stream stream = client.streamChatCompletion(request)) { + stream.forEach(deltaString -> { + System.out.print(deltaString); + System.out.flush(); + }); +} +``` + +
+It's also possible to aggregate the total output. + +The following example is non-blocking. +Any asynchronous library can be used, e.g. classic Thread API. + +```java +String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?"; + +OpenAiChatCompletionParameters request = + new OpenAiChatCompletionParameters() + .setMessages(List.of(new OpenAiChatUserMessage().addText(msg))); + +OpenAiChatCompletionOutput totalOutput = new OpenAiChatCompletionOutput(); +OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO); + +// Do the request before the thread starts to handle exceptions during request initialization +Stream stream = client.streamChatCompletionDeltas(request); + +Thread thread = new Thread(() -> { + // try-with-resources ensures the stream is closed + try (stream) { + stream.peek(totalOutput::addDelta).forEach(delta -> System.out.println(delta)); + } +}); +thread.start(); // non-blocking + +thread.join(); // blocking + +// access aggregated information from total output, e.g. +Integer tokens = totalOutput.getUsage().getCompletionTokens(); +System.out.println("Tokens: " + tokens); +``` + +
+ +#### Spring Boot example + +Please find [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java). +It shows the usage of Spring Boot's `ResponseBodyEmitter` to stream the chat completion delta messages to the frontend in real-time. + ## Orchestration chat completion ### Prerequisites diff --git a/e2e-test-app/pom.xml b/e2e-test-app/pom.xml index ca47c77a..1a0a56b8 100644 --- a/e2e-test-app/pom.xml +++ b/e2e-test-app/pom.xml @@ -87,6 +87,11 @@ org.springframework spring-web + + org.springframework + spring-webmvc + ${springframework.version} + com.google.code.findbugs jsr305 @@ -95,6 +100,14 @@ org.slf4j slf4j-api + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-core + ch.qos.logback diff --git a/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java b/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java index 08a02b5a..2e563a50 100644 --- a/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java +++ b/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java @@ -5,6 +5,8 @@ import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_ADA_002; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionFunction; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; @@ -14,13 +16,21 @@ import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage.ImageDetailLevel; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; +import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors; +import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.Map; import javax.annotation.Nonnull; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; /** Endpoints for OpenAI operations */ +@Slf4j @RestController class OpenAiController { /** @@ -38,6 +48,98 @@ public static OpenAiChatCompletionOutput chatCompletion() { return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request); } + /** + * Asynchronous stream of an OpenAI chat request + * + * @return the emitter that streams the assistant message response + */ + @SuppressWarnings("unused") // The end-to-end test doesn't use this method + @GetMapping("/streamChatCompletionDeltas") + @Nonnull + public static ResponseEntity streamChatCompletionDeltas() { + final var msg = "Can you give me the first 100 numbers of the Fibonacci sequence?"; + final var request = + new OpenAiChatCompletionParameters() + .setMessages(List.of(new OpenAiChatUserMessage().addText(msg))); + + final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request); + + final var emitter = new ResponseBodyEmitter(); + + final Runnable consumeStream = + () -> { + final var totalOutput = new OpenAiChatCompletionOutput(); + // try-with-resources ensures the stream is closed + try (stream) { + stream + .peek(totalOutput::addDelta) + .forEach(delta -> send(emitter, delta.getDeltaContent())); + } finally { + send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput)); + emitter.complete(); + } + }; + + ThreadContextExecutors.getExecutor().execute(consumeStream); + + // TEXT_EVENT_STREAM allows the browser to display the content as it is streamed + return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter); + } + + private static String objectToJson(@Nonnull final Object obj) { + try { + return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj); + } catch (final JsonProcessingException ignored) { + return "Could not parse object to JSON"; + } + } + + /** + * Asynchronous stream of an OpenAI chat request + * + * @return the emitter that streams the assistant message response + */ + @SuppressWarnings("unused") // The end-to-end test doesn't use this method + @GetMapping("/streamChatCompletion") + @Nonnull + public static ResponseEntity streamChatCompletion() { + final var request = + new OpenAiChatCompletionParameters() + .setMessages( + List.of( + new OpenAiChatUserMessage() + .addText( + "Can you give me the first 100 numbers of the Fibonacci sequence?"))); + + final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request); + + final var emitter = new ResponseBodyEmitter(); + + final Runnable consumeStream = + () -> { + try (stream) { + stream.forEach(deltaMessage -> send(emitter, deltaMessage)); + } finally { + emitter.complete(); + } + }; + + ThreadContextExecutors.getExecutor().execute(consumeStream); + + // TEXT_EVENT_STREAM allows the browser to display the content as it is streamed + return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter); + } + + private static void send( + @Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) { + try { + emitter.send(chunk); + } catch (final IOException e) { + log.error(Arrays.toString(e.getStackTrace())); + emitter.completeWithError(e); + } + } + /** * Chat request to OpenAI with an image * diff --git a/e2e-test-app/src/main/resources/static/index.html b/e2e-test-app/src/main/resources/static/index.html index f601884e..b7a2d904 100644 --- a/e2e-test-app/src/main/resources/static/index.html +++ b/e2e-test-app/src/main/resources/static/index.html @@ -71,6 +71,8 @@

Endpoints

  • OpenAI

    • /chatCompletion
    • +
    • /streamChatCompletion
    • +
    • /streamChatCompletionDeltas
    • /chatCompletionTool
    • /chatCompletionImage
    • /embedding
    • diff --git a/e2e-test-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java b/e2e-test-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java index ea826aa7..c37aff53 100644 --- a/e2e-test-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java +++ b/e2e-test-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiTest.java @@ -1,9 +1,18 @@ package com.sap.ai.sdk.app.controllers; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO; import static org.assertj.core.api.Assertions.assertThat; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +@Slf4j class OpenAiTest { @Test void chatCompletion() { @@ -23,12 +32,44 @@ void chatCompletionImage() { assertThat(message.getContent()).isNotEmpty(); } + @Test + void streamChatCompletion() { + final var request = + new OpenAiChatCompletionParameters() + .setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest?"))); + + final var totalOutput = new OpenAiChatCompletionOutput(); + final var emptyDeltaCount = new AtomicInteger(0); + OpenAiClient.forModel(GPT_35_TURBO) + .streamChatCompletionDeltas(request) + .peek(totalOutput::addDelta) + // foreach consumes all elements, closing the stream at the end + .forEach( + delta -> { + final String deltaContent = delta.getDeltaContent(); + log.info("deltaContent: {}", deltaContent); + if (deltaContent.isEmpty()) { + emptyDeltaCount.incrementAndGet(); + } + }); + + // the first two and the last delta don't have any content + // see OpenAiChatCompletionDelta#getDeltaContent + assertThat(emptyDeltaCount.get()).isLessThanOrEqualTo(3); + + assertThat(totalOutput.getChoices()).isNotEmpty(); + assertThat(totalOutput.getChoices().get(0).getMessage().getContent()).isNotEmpty(); + assertThat(totalOutput.getPromptFilterResults()).isNotNull(); + assertThat(totalOutput.getChoices().get(0).getContentFilterResults()).isNotNull(); + } + @Test void chatCompletionTools() { final var completion = OpenAiController.chatCompletionTools(); final var message = completion.getChoices().get(0).getMessage(); assertThat(message.getRole()).isEqualTo("assistant"); + assertThat(message.getTool_calls()).isNotNull(); assertThat(message.getTool_calls().get(0).getFunction().getName()).isEqualTo("fibonacci"); } diff --git a/foundation-models/openai/pom.xml b/foundation-models/openai/pom.xml index b1319ac0..37721f18 100644 --- a/foundation-models/openai/pom.xml +++ b/foundation-models/openai/pom.xml @@ -97,6 +97,11 @@ junit-jupiter-api test + + org.junit.jupiter + junit-jupiter-params + test + org.wiremock wiremock @@ -107,5 +112,10 @@ assertj-core test + + org.mockito + mockito-core + test + diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index af85738a..333ffbed 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -7,14 +7,17 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.sap.ai.sdk.core.Core; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; +import com.sap.ai.sdk.foundationmodels.openai.model.StreamedDelta; import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; import java.io.IOException; +import java.util.stream.Stream; import javax.annotation.Nonnull; import lombok.AccessLevel; import lombok.RequiredArgsConstructor; @@ -30,7 +33,7 @@ @RequiredArgsConstructor(access = AccessLevel.PRIVATE) public final class OpenAiClient { private static final String DEFAULT_API_VERSION = "2024-02-01"; - private static final ObjectMapper JACKSON; + static final ObjectMapper JACKSON; static { JACKSON = @@ -105,6 +108,42 @@ public OpenAiChatCompletionOutput chatCompletion( return execute("/chat/completions", parameters, OpenAiChatCompletionOutput.class); } + /** + * Generate a completion for the given prompt. + * + * @param parameters the prompt, including messages and other parameters. + * @return A stream of message deltas + * @throws OpenAiClientException if the request fails or if the finish reason is content_filter + */ + @Nonnull + public Stream streamChatCompletion( + @Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException { + return streamChatCompletionDeltas(parameters) + .peek(OpenAiClient::throwOnContentFilter) + .map(OpenAiChatCompletionDelta::getDeltaContent); + } + + private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelta delta) { + final String finishReason = delta.getFinishReason(); + if (finishReason != null && finishReason.equals("content_filter")) { + throw new OpenAiClientException("Content filter filtered the output."); + } + } + + /** + * Generate a completion for the given prompt. + * + * @param parameters the prompt, including messages and other parameters. + * @return A stream of chat completion delta elements. + * @throws OpenAiClientException if the request fails + */ + @Nonnull + public Stream streamChatCompletionDeltas( + @Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException { + parameters.enableStreaming(); + return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class); + } + /** * Get a vector representation of a given input that can be easily consumed by machine learning * models and algorithms. @@ -129,6 +168,16 @@ private T execute( return executeRequest(request, responseType); } + @Nonnull + private Stream executeStream( + @Nonnull final String path, + @Nonnull final Object payload, + @Nonnull final Class deltaType) { + final var request = new HttpPost(path); + serializeAndSetHttpEntity(request, payload); + return streamRequest(request, deltaType); + } + private static void serializeAndSetHttpEntity( @Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) { try { @@ -145,9 +194,22 @@ private T executeRequest( try { @SuppressWarnings("UnstableApiUsage") final var client = ApacheHttpClient5Accessor.getHttpClient(destination); - return client.execute(request, new OpenAiResponseHandler<>(responseType, JACKSON)); + return client.execute(request, new OpenAiResponseHandler<>(responseType)); + } catch (final IOException e) { + throw new OpenAiClientException("Request to OpenAI model failed", e); + } + } + + @Nonnull + private Stream streamRequest( + final BasicClassicHttpRequest request, @Nonnull final Class deltaType) { + try { + @SuppressWarnings("UnstableApiUsage") + final var client = ApacheHttpClient5Accessor.getHttpClient(destination); + return new OpenAiStreamingHandler<>(deltaType) + .handleResponse(client.executeOpen(null, request, null)); } catch (final IOException e) { - throw new OpenAiClientException("Request to OpenAI model failed.", e); + throw new OpenAiClientException("Request to OpenAI model failed", e); } } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java index f91eed79..9d7dd0bd 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java @@ -1,7 +1,8 @@ package com.sap.ai.sdk.foundationmodels.openai; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiClient.JACKSON; + import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiError; import io.vavr.control.Try; import java.io.IOException; @@ -21,8 +22,14 @@ class OpenAiResponseHandler implements HttpClientResponseHandler { @Nonnull private final Class responseType; - @Nonnull private final ObjectMapper jackson; + /** + * Processes a {@link ClassicHttpResponse} and returns some value corresponding to that response. + * + * @param response The response to process + * @return A model class instantiated from the response + * @throws OpenAiClientException in case of a problem or the connection was aborted + */ @Override public T handleResponse(@Nonnull final ClassicHttpResponse response) throws OpenAiClientException { @@ -35,14 +42,15 @@ public T handleResponse(@Nonnull final ClassicHttpResponse response) // The InputStream of the HTTP entity is closed by EntityUtils.toString @SuppressWarnings("PMD.CloseResource") @Nonnull - private T parseResponse(@Nonnull final ClassicHttpResponse response) { + private T parseResponse(@Nonnull final ClassicHttpResponse response) + throws OpenAiClientException { final HttpEntity responseEntity = response.getEntity(); if (responseEntity == null) { throw new OpenAiClientException("Response from OpenAI model was empty."); } final var content = getContent(responseEntity); try { - return jackson.readValue(content, responseType); + return JACKSON.readValue(content, responseType); } catch (final JsonProcessingException e) { log.error("Failed to parse the following response from OpenAI model: {}", content); throw new OpenAiClientException("Failed to parse response from OpenAI model", e); @@ -60,10 +68,11 @@ private static String getContent(@Nonnull final HttpEntity entity) { // The InputStream of the HTTP entity is closed by EntityUtils.toString @SuppressWarnings("PMD.CloseResource") - private void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) { + static void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) + throws OpenAiClientException { final var exception = new OpenAiClientException( - "Request to OpenAI model failed with status %s %s " + "Request to OpenAI model failed with status %s %s" .formatted(response.getCode(), response.getReasonPhrase())); final var entity = response.getEntity(); if (entity == null) { @@ -85,19 +94,29 @@ private void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) throw exception; } - final var maybeError = Try.of(() -> jackson.readValue(content, OpenAiError.class)); + parseErrorAndThrow(content, exception); + } + + /** + * Parse the error response and throw an exception. + * + * @param errorResponse the error response, most likely a JSON of {@link OpenAiError}. + * @param baseException a base exception to add the error message to. + */ + static void parseErrorAndThrow( + @Nonnull final String errorResponse, @Nonnull final OpenAiClientException baseException) + throws OpenAiClientException { + final var maybeError = Try.of(() -> JACKSON.readValue(errorResponse, OpenAiError.class)); if (maybeError.isFailure()) { - exception.addSuppressed(maybeError.getCause()); - throw exception; + baseException.addSuppressed(maybeError.getCause()); + throw baseException; } final var error = maybeError.get().getError(); if (error == null) { - throw exception; + throw baseException; } - final var message = - "Request to OpenAI model failed with %s %s and error message: '%s'" - .formatted(response.getCode(), response.getReasonPhrase(), error.getMessage()); - throw new OpenAiClientException(message); + throw new OpenAiClientException( + "%s and error message: '%s'".formatted(baseException.getMessage(), error.getMessage())); } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java new file mode 100644 index 00000000..3a264705 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java @@ -0,0 +1,83 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiClient.JACKSON; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiResponseHandler.buildExceptionAndThrow; +import static com.sap.ai.sdk.foundationmodels.openai.OpenAiResponseHandler.parseErrorAndThrow; + +import com.sap.ai.sdk.foundationmodels.openai.model.StreamedDelta; +import io.vavr.control.Try; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.hc.core5.http.ClassicHttpResponse; +import org.apache.hc.core5.http.HttpEntity; + +@Slf4j +@RequiredArgsConstructor +class OpenAiStreamingHandler { + + @Nonnull private final Class deltaType; + + @Nonnull + Stream handleResponse(@Nonnull final ClassicHttpResponse response) + throws OpenAiClientException { + if (response.getCode() >= 300) { + buildExceptionAndThrow(response); + } + return parseResponse(response); + } + + /** + * @param response The response to process + * @return A {@link Stream} of a model class instantiated from the response + * @author stippi + */ + // The stream is closed by the user of the Stream + @SuppressWarnings("PMD.CloseResource") + private Stream parseResponse(@Nonnull final ClassicHttpResponse response) + throws OpenAiClientException { + final HttpEntity responseEntity = response.getEntity(); + if (responseEntity == null) { + throw new OpenAiClientException("Response from OpenAI model was empty."); + } + final InputStream inputStream; + try { + inputStream = responseEntity.getContent(); + } catch (IOException e) { + throw new OpenAiClientException("Failed to read response content.", e); + } + final var br = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); + + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + return br.lines() + // half of the lines are empty newlines, the last line is "data: [DONE]" + .filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim())) + .peek( + line -> { + if (!line.startsWith("data: ")) { + final String msg = "Failed to parse response from OpenAI model"; + parseErrorAndThrow(line, new OpenAiClientException(msg)); + } + }) + .map( + line -> { + final String data = line.substring(5); // remove "data: " + try { + return JACKSON.readValue(data, deltaType); + } catch (final IOException e) { + log.error("Failed to parse the following response from OpenAI model: {}", line); + throw new OpenAiClientException("Failed to parse delta message: " + line, e); + } + }) + .onClose( + () -> + Try.run(inputStream::close) + .onFailure(e -> log.error("Could not close HTTP input stream", e))); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java new file mode 100644 index 00000000..58ac00ea --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java @@ -0,0 +1,20 @@ +package com.sap.ai.sdk.foundationmodels.openai.model; + +import javax.annotation.Nonnull; + +/** + * Interface for model classes that can be created from aggregated streamed deltas. + * + *

      For example aggregating chat completions deltas into a single chat completion output. + * + * @param the delta type. + */ +public interface DeltaAggregatable { + + /** + * Add a streamed delta to the total output. + * + * @param delta the delta to add. + */ + void addDelta(@Nonnull final D delta); +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java index 3b18ff92..91a6ff65 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java @@ -3,18 +3,32 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatAssistantMessage; import javax.annotation.Nonnull; +import lombok.AccessLevel; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.Setter; import lombok.ToString; import lombok.experimental.Accessors; /** Result candidates for OpenAI chat completion output. */ @Accessors(chain = true) @EqualsAndHashCode(callSuper = true) -@ToString +@ToString(callSuper = true) public class OpenAiChatCompletionChoice extends OpenAiCompletionChoice { /** Completion chat message. */ @JsonProperty("message") @Getter(onMethod_ = @Nonnull) + @Setter(onMethod_ = @Nonnull, value = AccessLevel.PACKAGE) private OpenAiChatAssistantMessage message; + + void addDelta(@Nonnull final OpenAiDeltaChatCompletionChoice delta) { + super.addDelta(delta); + + if (delta.getMessage() != null) { + if (message == null) { + message = new OpenAiChatAssistantMessage(); + } + message.addDelta(delta.getMessage()); + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionDelta.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionDelta.java new file mode 100644 index 00000000..1136e066 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionDelta.java @@ -0,0 +1,60 @@ +package com.sap.ai.sdk.foundationmodels.openai.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.Accessors; + +/** OpenAI chat completion output delta for streaming. */ +@Accessors(chain = true) +@EqualsAndHashCode(callSuper = true) +@ToString(callSuper = true) +public class OpenAiChatCompletionDelta extends OpenAiCompletionOutput implements StreamedDelta { + /** List of result candidates. */ + @JsonProperty("choices") + @Getter(onMethod_ = @Nonnull) + private List choices; + + /** + * Can be used in conjunction with the seed request parameter to understand when backend changes + * have been made that might impact determinism. + */ + @JsonProperty("system_fingerprint") + @Getter(onMethod_ = @Nullable) + private String systemFingerprint; + + @Nonnull + @Override + public String getDeltaContent() { + // Avoid the first delta: "choices":[] + if (!getChoices().isEmpty() + // Multiple choices are spread out on multiple deltas + // A delta only contains one choice with a variable index + && getChoices().get(0).getIndex() == 0) { + + final var message = getChoices().get(0).getMessage(); + // Avoid the second delta: "choices":[{"delta":{"content":"","role":"assistant"}}] + if (message != null && message.getContent() != null) { + return message.getContent(); + } + } + return ""; + } + + @Nullable + @Override + public String getFinishReason() { + // Avoid the first delta: "choices":[] + if (!getChoices().isEmpty() + // Multiple choices are spread out on multiple deltas + // A delta only contains one choice with a variable index + && getChoices().get(0).getIndex() == 0) { + return getChoices().get(0).getFinishReason(); + } + return null; + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java index 0254aee2..449f3e8f 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.ArrayList; import java.util.List; import javax.annotation.Nonnull; import lombok.EqualsAndHashCode; @@ -11,8 +12,9 @@ /** OpenAI chat completion output. */ @Accessors(chain = true) @EqualsAndHashCode(callSuper = true) -@ToString -public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput { +@ToString(callSuper = true) +public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput + implements DeltaAggregatable { /** List of result candidates. */ @JsonProperty("choices") @Getter(onMethod_ = @Nonnull) @@ -25,4 +27,30 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput { @JsonProperty("system_fingerprint") @Getter(onMethod_ = @Nonnull) private String systemFingerprint; + + /** + * Add a streamed delta to the total output. + * + * @param delta the delta to add. + */ + public void addDelta(@Nonnull final OpenAiChatCompletionDelta delta) { + super.addDelta(delta); + + if (delta.getSystemFingerprint() != null) { + systemFingerprint = delta.getSystemFingerprint(); + } + + if (!delta.getChoices().isEmpty()) { + if (choices == null) { + choices = new ArrayList<>(); + } + // Multiple choices are spread out on multiple deltas + // A delta only contains one choice with a variable index + final int index = delta.getChoices().get(0).getIndex(); + for (int i = choices.size(); i < index + 1; i++) { + choices.add(new OpenAiChatCompletionChoice()); + } + choices.get(index).addDelta(delta.getChoices().get(0)); + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionParameters.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionParameters.java index 8dccfe55..19a3c425 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionParameters.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionParameters.java @@ -21,7 +21,7 @@ /** OpenAI chat completion input parameters. */ @Accessors(chain = true) @EqualsAndHashCode(callSuper = true) -@ToString +@ToString(callSuper = true) public class OpenAiChatCompletionParameters extends OpenAiCompletionParameters { /** A list of messages comprising the conversation so far. */ @JsonProperty("messages") diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatMessage.java index 181938c5..a22e5db4 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatMessage.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatMessage.java @@ -18,6 +18,7 @@ import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -28,7 +29,9 @@ import lombok.experimental.Accessors; /** OpenAI chat message types. */ -@JsonTypeInfo(use = Id.NAME, property = "role") // This is the field that determines the class type +@JsonTypeInfo(use = Id.NAME, property = "role", defaultImpl = OpenAiChatAssistantMessage.class) +// role is the field that determines the class type +// if role is missing we default to OpenAiChatAssistantMessage @JsonSubTypes({ @Type(value = OpenAiChatSystemMessage.class, name = "system"), @Type(value = OpenAiChatUserMessage.class, name = "user"), @@ -241,7 +244,8 @@ class OpenAiChatAssistantMessage implements OpenAiChatMessage { /** Message content. */ @JsonProperty("content") @Getter(onMethod_ = @Nullable) - private String content; // must be String or null + @Setter(onParam_ = @Nullable, value = AccessLevel.PACKAGE) + private String content; /** The tool calls generated by the model, such as function calls. */ @JsonProperty("tool_calls") @@ -250,6 +254,24 @@ class OpenAiChatAssistantMessage implements OpenAiChatMessage { // TODO: add context // https://github.com/Azure/azure-rest-api-specs/blob/07d286359f828bbc7901e86288a5d62b48ae2052/specification/cognitiveservices/data-plane/AzureOpenAI/inference/stable/2024-02-01/inference.json#L1599 + + void addDelta(@Nonnull final OpenAiChatAssistantMessage delta) { + + if (delta.getContent() != null) { + if (content == null) { + content = ""; + } + content += delta.getContent(); + } + + if (delta.getTool_calls() != null) { + if (tool_calls == null) { + tool_calls = new ArrayList<>(); + } + // TODO: camel case + tool_calls.addAll(delta.getTool_calls()); + } + } } /** OpenAI tool message. */ diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java index 701f86d8..80c5baa3 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -12,7 +13,20 @@ @EqualsAndHashCode @ToString public class OpenAiCompletionChoice { - /** Reason for finish. */ + /** + * Reason for finish. The possible values are: + * + *

      {@code stop}: API returned complete message, or a message terminated by one of the stop + * sequences provided via the stop parameter + * + *

      {@code length}: Incomplete model output due to max_tokens parameter or token limit + * + *

      {@code function_call}: The model decided to call a function + * + *

      {@code content_filter}: Omitted content due to a flag from our content filters + * + *

      {@code null}: API response still in progress or incomplete + */ @JsonProperty("finish_reason") @Getter(onMethod_ = @Nullable) private String finishReason; @@ -29,4 +43,22 @@ public class OpenAiCompletionChoice { @JsonProperty("content_filter_results") @Getter(onMethod_ = @Nullable) private OpenAiContentFilterPromptResults contentFilterResults; + + void addDelta(@Nonnull final OpenAiCompletionChoice delta) { + + if (delta.getFinishReason() != null) { + finishReason = delta.getFinishReason(); + } + + if (delta.getIndex() != null) { + index = delta.getIndex(); + } + + if (delta.getContentFilterResults() != null) { + if (contentFilterResults == null) { + contentFilterResults = new OpenAiContentFilterPromptResults(); + } + contentFilterResults.addDelta(delta.getContentFilterResults()); + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java index ee70e23a..366e4137 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java @@ -43,4 +43,24 @@ public class OpenAiCompletionOutput { @JsonProperty("prompt_filter_results") @Getter(onMethod_ = @Nullable) private List promptFilterResults; + + void addDelta(@Nonnull final OpenAiChatCompletionDelta delta) { + created = delta.getCreated(); + id = delta.getId(); + model = delta.getModel(); + object = delta.getObject(); + + if (delta.getUsage() != null) { + if (usage == null) { + usage = new OpenAiUsage(); + } + usage.addDelta(delta.getUsage()); + } + + if (delta.getPromptFilterResults() != null && promptFilterResults == null) { + promptFilterResults = delta.getPromptFilterResults(); + // prompt_filter_results is overriden instead of updated because it is only present once in + // the first delta + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionParameters.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionParameters.java index 76bf6417..c2bb9373 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionParameters.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionParameters.java @@ -6,8 +6,8 @@ import java.util.Map; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import lombok.AccessLevel; import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; import lombok.Setter; import lombok.ToString; import lombok.experimental.Accessors; @@ -80,7 +80,6 @@ public class OpenAiCompletionParameters { * contain the stop sequence. */ @JsonProperty("stop") - @Setter(value = AccessLevel.NONE) @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @Nullable private List stop; @@ -103,18 +102,51 @@ public class OpenAiCompletionParameters { private Double frequencyPenalty; /** - * NOTE: This method is currently not supported. Therefore, it stays protected.
      - *
      * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only * server-sent events as they become available, with the stream terminated by a `data: [DONE]` * message. Default: false. */ @JsonProperty("stream") - @Setter(value = AccessLevel.NONE) - @Nullable - private Boolean stream; // TODO for implementation details, please find + private Boolean stream; + + /** + * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only + * server-sent + * events as they become available, with the stream terminated by a {@code data: [DONE]} + * message. Only set this when you set {@code stream: true}. + */ + @JsonProperty("stream_options") + private OpenAiStreamOptions streamOptions; - // https://github.com/Azure/azure-rest-api-specs/blob/3cb1b51638616435470fc10ea00de92512186ece/specification/cognitiveservices/data-plane/AzureOpenAI/inference/stable/2024-02-01/inference.json#L1149 + /** "stream_options": { "include_usage": "true" } */ + @RequiredArgsConstructor + @Setter + @JsonFormat(shape = JsonFormat.Shape.OBJECT) + @EqualsAndHashCode + @ToString + public static class OpenAiStreamOptions { + /** + * If set, an additional chunk will be streamed before the {@code data: [DONE]} message. The + * usage field on this chunk shows the token usage statistics for the entire request, and the + * choices field will always be an empty array. All other chunks will also include a {@code + * usage} field, but with a null value. + */ + @JsonProperty("include_usage") + private Boolean include_usage; + } + + /** + * Please use {@link + * com.sap.ai.sdk.foundationmodels.openai.OpenAiClient#streamChatCompletion(OpenAiChatCompletionParameters)} + * instead. + * + *

      Enable streaming of the completion. If enabled, partial message deltas will be sent. + */ + public void enableStreaming() { + this.stream = true; + this.streamOptions = new OpenAiStreamOptions().setInclude_usage(true); + } /** * Up to four sequences where the API will stop generating further tokens. The returned text won't diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterDetectedResult.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterDetectedResult.java index 16bee568..c8a0b1d5 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterDetectedResult.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterDetectedResult.java @@ -10,7 +10,7 @@ /** OpenAI content filter detected result. */ @Accessors(chain = true) @EqualsAndHashCode(callSuper = true) -@ToString +@ToString(callSuper = true) public class OpenAiContentFilterDetectedResult extends OpenAiContentFilterResultBase { /** Whether the content was detected. */ @JsonProperty("detected") diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java index a7c07dc7..342914cb 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -10,9 +11,17 @@ /** Content filtering results for a prompt in the request. */ @Accessors(chain = true) @EqualsAndHashCode(callSuper = true) -@ToString +@ToString(callSuper = true) public class OpenAiContentFilterPromptResults extends OpenAiContentFilterResultsBase { @JsonProperty("jailbreak") @Getter(onMethod_ = @Nullable) private OpenAiContentFilterDetectedResult jailbreak; + + void addDelta(@Nonnull final OpenAiContentFilterPromptResults delta) { + super.addDelta(delta); + + if (delta.getJailbreak() != null) { + jailbreak = delta.getJailbreak(); + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java index f2cab6fd..78670475 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -40,4 +41,25 @@ public class OpenAiContentFilterResultsBase { @JsonProperty("error") @Getter(onMethod_ = @Nullable) private OpenAiErrorBase error; + + void addDelta(@Nonnull final OpenAiContentFilterPromptResults delta) { + if (delta.getSexual() != null) { + sexual = delta.getSexual(); + } + if (delta.getViolence() != null) { + violence = delta.getViolence(); + } + if (delta.getHate() != null) { + hate = delta.getHate(); + } + if (delta.getSelfHarm() != null) { + selfHarm = delta.getSelfHarm(); + } + if (delta.getProfanity() != null) { + profanity = delta.getProfanity(); + } + if (delta.getError() != null) { + error = delta.getError(); + } + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterSeverityResult.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterSeverityResult.java index 0d008e29..fd158aad 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterSeverityResult.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterSeverityResult.java @@ -12,7 +12,7 @@ /** Information about the content filtering results. */ @Accessors(chain = true) @EqualsAndHashCode(callSuper = true) -@ToString +@ToString(callSuper = true) public class OpenAiContentFilterSeverityResult extends OpenAiContentFilterResultBase { /** Severity of the content. */ @JsonProperty("severity") diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiDeltaChatCompletionChoice.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiDeltaChatCompletionChoice.java new file mode 100644 index 00000000..dc9ed2e9 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiDeltaChatCompletionChoice.java @@ -0,0 +1,23 @@ +package com.sap.ai.sdk.foundationmodels.openai.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatAssistantMessage; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import lombok.experimental.Accessors; + +/** Result candidates for OpenAI chat completion output. */ +@Accessors(chain = true) +@EqualsAndHashCode(callSuper = true) +@ToString(callSuper = true) +public class OpenAiDeltaChatCompletionChoice extends OpenAiCompletionChoice { + /** Completion chat message. */ + @JsonProperty("delta") + @Getter(onMethod_ = @Nullable) + @Setter(onMethod_ = @Nullable, value = AccessLevel.PACKAGE) + private OpenAiChatAssistantMessage message; +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java index 132f3071..7884b8c2 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java @@ -27,4 +27,12 @@ public class OpenAiUsage { @JsonProperty("total_tokens") @Getter(onMethod_ = @Nonnull) private Integer totalTokens; + + void addDelta(@Nonnull final OpenAiUsage delta) { + if (delta.getCompletionTokens() != null) { + completionTokens = delta.getCompletionTokens(); + } + promptTokens = delta.getPromptTokens(); + totalTokens = delta.getTotalTokens(); + } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/StreamedDelta.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/StreamedDelta.java new file mode 100644 index 00000000..ac790901 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/StreamedDelta.java @@ -0,0 +1,43 @@ +package com.sap.ai.sdk.foundationmodels.openai.model; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Interface for streamed delta classes. + * + *

      This interface defines a method to retrieve the content from a delta, which is a chunk in a + * stream of data. Implementations of this interface should provide the logic to extract the + * relevant content from the delta. + */ +public interface StreamedDelta { + + /** + * Get the message content from the delta. + * + *

      Note: If there are multiple choices only the first one is returned + * + *

      Note: The first two and the last delta do not contain any content + * + * @return the message content or empty string. + */ + @Nonnull + String getDeltaContent(); + + /** + * Reason for finish. The possible values are: + * + *

      {@code stop}: API returned complete message, or a message terminated by one of the stop + * sequences provided via the stop parameter + * + *

      {@code length}: Incomplete model output due to max_tokens parameter or token limit + * + *

      {@code function_call}: The model decided to call a function + * + *

      {@code content_filter}: Omitted content due to a flag from our content filters + * + *

      {@code null}: API response still in progress or incomplete + */ + @Nullable + String getFinishReason(); +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java index 35beb7b9..aebc3718 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java @@ -3,39 +3,70 @@ import static com.github.tomakehurst.wiremock.client.WireMock.*; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import com.fasterxml.jackson.core.JsonParseException; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; import com.github.tomakehurst.wiremock.stubbing.Scenario; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionChoice; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; +import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterPromptResults; import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; -import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiPromptFilterResult; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import io.vavr.control.Try; import java.io.IOException; +import java.io.InputStream; import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.InputStreamEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpResponse; import org.assertj.core.api.SoftAssertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; @WireMockTest class OpenAiClientTest { - private OpenAiClient client; + private static OpenAiClient client; + private final Function TEST_FILE_LOADER = + filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); @BeforeEach void setup(WireMockRuntimeInfo server) { final DefaultHttpDestination destination = DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); client = OpenAiClient.withCustomDestination(destination); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); } @Test - void testApiVersion() { + void apiVersion() { stubFor(post(anyUrl()).willReturn(okJson("{}"))); Try.of(() -> client.chatCompletion(new OpenAiChatCompletionParameters())); @@ -53,8 +84,20 @@ void testApiVersion() { verify(exactly(2), postRequestedFor(anyUrl()).withoutQueryParam("api-version")); } - @Test - void testErrorHandling() { + private static Runnable[] chatCompletionCalls() { + return new Runnable[] { + () -> client.chatCompletion(new OpenAiChatCompletionParameters()), + () -> + client + .streamChatCompletionDeltas(new OpenAiChatCompletionParameters()) + // the stream needs to be consumed to parse the response + .forEach(System.out::println) + }; + } + + @ParameterizedTest + @MethodSource("chatCompletionCalls") + void chatCompletionErrorHandling(@Nonnull final Runnable request) { final var errorJson = """ { "error": { "code": null, "message": "foo", "type": "invalid stuff" } } @@ -89,7 +132,6 @@ void testErrorHandling() { .willSetStateTo("4")); stubFor(post(anyUrl()).inScenario("Errors").whenScenarioStateIs("4").willReturn(noContent())); - final Runnable request = () -> client.chatCompletion(new OpenAiChatCompletionParameters()); final var softly = new SoftAssertions(); softly @@ -128,122 +170,311 @@ void testErrorHandling() { } @Test - void testChatCompletion() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("chatCompletionResponse.json") - .readAllBytes()); - stubFor(post(anyUrl()).willReturn(okJson(response))); - - final var systemMessage = - new OpenAiChatMessage.OpenAiChatSystemMessage() - .setContent("You are a helpful, friendly and sometimes slightly snarky AI assistant!"); - final var userMessage = - new OpenAiChatUserMessage().addText("Hello World! Why is this phrase so famous?"); - final var request = - new OpenAiChatCompletionParameters().setMessages(List.of(systemMessage, userMessage)); - - final var result = client.chatCompletion(request); - - assertThat(result).isNotNull(); - assertThat(result.getCreated()).isEqualTo(1719300073); - assertThat(result.getId()).isEqualTo("chatcmpl-9dumHtDEyysGFnknk17n4Lt37tg7T"); - assertThat(result.getModel()).isEqualTo("gpt-4-32k"); - assertThat(result.getObject()).isEqualTo("chat.completion"); - assertThat(result.getSystemFingerprint()).isNull(); - - assertThat(result.getUsage().getCompletionTokens()).isEqualTo(54); - assertThat(result.getUsage().getPromptTokens()).isEqualTo(13); - assertThat(result.getUsage().getTotalTokens()).isEqualTo(67); - - assertThat(result.getPromptFilterResults()).hasSize(1); - OpenAiPromptFilterResult promptFilterResults = result.getPromptFilterResults().get(0); - assertThat(promptFilterResults.getPromptIndex()).isEqualTo(0); - assertThat(promptFilterResults.getContentFilterResults().getSexual().isFiltered()).isFalse(); - assertThat(promptFilterResults.getContentFilterResults().getSexual().getSeverity()) - .isEqualTo(SAFE); - assertThat(promptFilterResults.getContentFilterResults().getViolence().isFiltered()).isFalse(); - assertThat(promptFilterResults.getContentFilterResults().getViolence().getSeverity()) - .isEqualTo(SAFE); - assertThat(promptFilterResults.getContentFilterResults().getHate().isFiltered()).isFalse(); - assertThat(promptFilterResults.getContentFilterResults().getHate().getSeverity()) - .isEqualTo(SAFE); - assertThat(promptFilterResults.getContentFilterResults().getSelfHarm()).isNull(); - assertThat(promptFilterResults.getContentFilterResults().getProfanity()).isNull(); - assertThat(promptFilterResults.getContentFilterResults().getError()).isNull(); - assertThat(promptFilterResults.getContentFilterResults().getJailbreak().isFiltered()).isFalse(); - assertThat(promptFilterResults.getContentFilterResults().getJailbreak().isDetected()).isFalse(); - - assertThat(result.getChoices()).hasSize(1); - OpenAiChatCompletionChoice choice = result.getChoices().get(0); - assertThat(choice.getFinishReason()).isEqualTo("stop"); - assertThat(choice.getIndex()).isEqualTo(0); - assertThat(choice.getMessage().getContent()) - .isEqualTo( - """ + void chatCompletion() throws IOException { + try (var inputStream = TEST_FILE_LOADER.apply("chatCompletionResponse.json")) { + + final String response = new String(inputStream.readAllBytes()); + stubFor(post("/chat/completions").willReturn(okJson(response))); + + final var systemMessage = + new OpenAiChatMessage.OpenAiChatSystemMessage() + .setContent( + "You are a helpful, friendly and sometimes slightly snarky AI assistant!"); + final var userMessage = + new OpenAiChatUserMessage().addText("Hello World! Why is this phrase so famous?"); + final var request = + new OpenAiChatCompletionParameters().setMessages(List.of(systemMessage, userMessage)); + + final var result = client.chatCompletion(request); + + assertThat(result).isNotNull(); + assertThat(result.getCreated()).isEqualTo(1719300073); + assertThat(result.getId()).isEqualTo("chatcmpl-9dumHtDEyysGFnknk17n4Lt37tg7T"); + assertThat(result.getModel()).isEqualTo("gpt-4-32k"); + assertThat(result.getObject()).isEqualTo("chat.completion"); + assertThat(result.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + + assertThat(result.getUsage()).isNotNull(); + assertThat(result.getUsage().getCompletionTokens()).isEqualTo(54); + assertThat(result.getUsage().getPromptTokens()).isEqualTo(13); + assertThat(result.getUsage().getTotalTokens()).isEqualTo(67); + + assertThat(result.getPromptFilterResults()).hasSize(1); + assertThat(result.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0); + OpenAiContentFilterPromptResults promptFilterResults = + result.getPromptFilterResults().get(0).getContentFilterResults(); + assertThat(promptFilterResults).isNotNull(); + assertThat(promptFilterResults.getSexual()).isNotNull(); + assertThat(promptFilterResults.getSexual().isFiltered()).isFalse(); + assertThat(promptFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getViolence()).isNotNull(); + assertThat(promptFilterResults.getViolence().isFiltered()).isFalse(); + assertThat(promptFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(promptFilterResults.getHate()).isNotNull(); + assertThat(promptFilterResults.getHate().isFiltered()).isFalse(); + assertThat(promptFilterResults.getHate().getSeverity()).isEqualTo(SAFE); + // TODO: update the JSON response and those assertions + assertThat(promptFilterResults.getSelfHarm()).isNull(); + assertThat(promptFilterResults.getProfanity()).isNull(); + assertThat(promptFilterResults.getError()).isNull(); + assertThat(promptFilterResults.getJailbreak()).isNotNull(); + assertThat(promptFilterResults.getJailbreak().isFiltered()).isFalse(); + assertThat(promptFilterResults.getJailbreak().isDetected()).isFalse(); + + assertThat(result.getChoices()).hasSize(1); + OpenAiChatCompletionChoice choice = result.getChoices().get(0); + assertThat(choice.getFinishReason()).isEqualTo("stop"); + assertThat(choice.getIndex()).isEqualTo(0); + assertThat(choice.getMessage().getContent()) + .isEqualTo( + """ This is a highly subjective question as the concept of beauty differs from one person to another. It's based on personal preferences and cultural standards. There are attractive people in all walks of life and industries, making it impossible to universally determine who is the "prettiest"."""); - assertThat(choice.getMessage().getRole()).isEqualTo("assistant"); - - OpenAiContentFilterPromptResults contentFilterResults = choice.getContentFilterResults(); - assertThat(contentFilterResults.getSexual().isFiltered()).isFalse(); - assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getViolence().isFiltered()).isFalse(); - assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getHate().isFiltered()).isFalse(); - assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE); - assertThat(contentFilterResults.getSelfHarm()).isNull(); - assertThat(contentFilterResults.getProfanity()).isNull(); - assertThat(contentFilterResults.getError()).isNull(); - assertThat(contentFilterResults.getJailbreak()).isNull(); - - verify( - postRequestedFor(urlPathEqualTo("/chat/completions")) - .withRequestBody( - equalToJson( - """ - {"messages":[{"role":"system","content":"You are a helpful, friendly and sometimes slightly snarky AI assistant!"},{"role":"user","content":[{"type":"text","text":"Hello World! Why is this phrase so famous?"}]}]}"""))); + assertThat(choice.getMessage().getRole()).isEqualTo("assistant"); + + OpenAiContentFilterPromptResults contentFilterResults = choice.getContentFilterResults(); + assertThat(contentFilterResults).isNotNull(); + assertThat(contentFilterResults.getSexual()).isNotNull(); + assertThat(contentFilterResults.getSexual().isFiltered()).isFalse(); + assertThat(contentFilterResults.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getViolence()).isNotNull(); + assertThat(contentFilterResults.getViolence().isFiltered()).isFalse(); + assertThat(contentFilterResults.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getHate()).isNotNull(); + assertThat(contentFilterResults.getHate().isFiltered()).isFalse(); + assertThat(contentFilterResults.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(contentFilterResults.getSelfHarm()).isNull(); + assertThat(contentFilterResults.getProfanity()).isNull(); + assertThat(contentFilterResults.getError()).isNull(); + assertThat(contentFilterResults.getJailbreak()).isNull(); + + verify( + postRequestedFor(urlPathEqualTo("/chat/completions")) + .withRequestBody( + equalToJson( + """ + {"messages":[{"role":"system","content":"You are a helpful, friendly and sometimes slightly snarky AI assistant!"},{"role":"user","content":[{"type":"text","text":"Hello World! Why is this phrase so famous?"}]}]}"""))); + } + } + + @Test + void embedding() throws IOException { + try (var inputStream = TEST_FILE_LOADER.apply("embeddingResponse.json")) { + + final String response = new String(inputStream.readAllBytes()); + stubFor(post("/embeddings").willReturn(okJson(response))); + + final var request = new OpenAiEmbeddingParameters().setInput("Hello World"); + final var result = client.embedding(request); + + assertThat(result).isNotNull(); + assertThat(result.getModel()).isEqualTo("ada"); + assertThat(result.getObject()).isEqualTo("list"); + + assertThat(result.getUsage()).isNotNull(); + assertThat(result.getUsage().getCompletionTokens()).isNull(); + assertThat(result.getUsage().getPromptTokens()).isEqualTo(2); + assertThat(result.getUsage().getTotalTokens()).isEqualTo(2); + + assertThat(result.getData()).isNotNull().hasSize(1); + var embeddingData = result.getData().get(0); + assertThat(embeddingData).isNotNull(); + assertThat(embeddingData.getObject()).isEqualTo("embedding"); + assertThat(embeddingData.getIndex()).isEqualTo(0); + assertThat(embeddingData.getEmbedding()) + .isNotNull() + .isNotEmpty() + .containsExactly(-0.0000000070958645d, 2.123e-300d, -0.0069813123d, -3.385849e-05d) + // ensure double precision + .hasToString("[-7.0958645E-9, 2.123E-300, -0.0069813123, -3.385849E-5]"); + + verify( + postRequestedFor(urlPathEqualTo("/embeddings")) + .withRequestBody( + equalToJson(""" + {"input":["Hello World"]}"""))); + } } @Test - void testEmbedding() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("embeddingResponse.json") - .readAllBytes()); - stubFor(post(anyUrl()).willReturn(okJson(response))); - - final var request = new OpenAiEmbeddingParameters().setInput("Hello World"); - final var result = client.embedding(request); - - assertThat(result).isNotNull(); - assertThat(result.getModel()).isEqualTo("ada"); - assertThat(result.getObject()).isEqualTo("list"); - - assertThat(result.getUsage()).isNotNull(); - assertThat(result.getUsage().getCompletionTokens()).isNull(); - assertThat(result.getUsage().getPromptTokens()).isEqualTo(2); - assertThat(result.getUsage().getTotalTokens()).isEqualTo(2); - - assertThat(result.getData()).isNotNull().hasSize(1); - var embeddingData = result.getData().get(0); - assertThat(embeddingData).isNotNull(); - assertThat(embeddingData.getObject()).isEqualTo("embedding"); - assertThat(embeddingData.getIndex()).isEqualTo(0); - assertThat(embeddingData.getEmbedding()) - .isNotNull() - .isNotEmpty() - .containsExactly(-0.0000000070958645d, 2.123e-300d, -0.0069813123d, -3.385849e-05d) - .hasToString( - "[-7.0958645E-9, 2.123E-300, -0.0069813123, -3.385849E-5]"); // ensure double precision - - verify( - postRequestedFor(urlPathEqualTo("/embeddings")) - .withRequestBody( - equalToJson(""" - {"input":["Hello World"]}"""))); + void streamChatCompletionDeltasErrorHandling() throws IOException { + try (var inputStream = spy(TEST_FILE_LOADER.apply("streamChatCompletionError.txt"))) { + + final var httpClient = mock(HttpClient.class); + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); + + // Create a mock response + final var mockResponse = new BasicClassicHttpResponse(200, "OK"); + final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + mockResponse.setEntity(inputStreamEntity); + mockResponse.setHeader("Content-Type", "text/event-stream"); + + // Configure the HttpClient mock to return the mock response + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + + final var request = + new OpenAiChatCompletionParameters() + .setMessages( + List.of( + new OpenAiChatUserMessage() + .addText( + "Can you give me the first 100 numbers of the Fibonacci sequence?"))); + + try (Stream stream = client.streamChatCompletionDeltas(request)) { + assertThatThrownBy(() -> stream.forEach(System.out::println)) + .isInstanceOf(OpenAiClientException.class) + .hasMessage( + "Failed to parse response from OpenAI model and error message: 'exceeded token rate limit'"); + } + + Mockito.verify(inputStream, times(1)).close(); + } + } + + @Test + void streamChatCompletionDeltas() throws IOException { + try (var inputStream = spy(TEST_FILE_LOADER.apply("streamChatCompletion.txt"))) { + + final var httpClient = mock(HttpClient.class); + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); + + // Create a mock response + final var mockResponse = new BasicClassicHttpResponse(200, "OK"); + final var inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + mockResponse.setEntity(inputStreamEntity); + mockResponse.setHeader("Content-Type", "text/event-stream"); + + // Configure the HttpClient mock to return the mock response + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + + final var request = + new OpenAiChatCompletionParameters() + .setMessages( + List.of( + new OpenAiChatUserMessage() + .addText( + "Can you give me the first 100 numbers of the Fibonacci sequence?"))); + + try (Stream stream = client.streamChatCompletionDeltas(request)) { + + OpenAiChatCompletionOutput totalOutput = new OpenAiChatCompletionOutput(); + final List deltaList = + stream.peek(totalOutput::addDelta).toList(); + + assertThat(deltaList).hasSize(5); + // the first two and the last delta don't have any content + assertThat(deltaList.get(0).getDeltaContent()).isEqualTo(""); + assertThat(deltaList.get(1).getDeltaContent()).isEqualTo(""); + assertThat(deltaList.get(2).getDeltaContent()).isEqualTo("Sure"); + assertThat(deltaList.get(3).getDeltaContent()).isEqualTo("!"); + assertThat(deltaList.get(4).getDeltaContent()).isEqualTo(""); + + assertThat(deltaList.get(0).getSystemFingerprint()).isNull(); + assertThat(deltaList.get(1).getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(deltaList.get(2).getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(deltaList.get(3).getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(deltaList.get(4).getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + + assertThat(deltaList.get(0).getUsage()).isNull(); + assertThat(deltaList.get(1).getUsage()).isNull(); + assertThat(deltaList.get(2).getUsage()).isNull(); + assertThat(deltaList.get(3).getUsage()).isNull(); + final var usage = deltaList.get(4).getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getCompletionTokens()).isEqualTo(607); + assertThat(usage.getPromptTokens()).isEqualTo(21); + assertThat(usage.getTotalTokens()).isEqualTo(628); + + assertThat(deltaList.get(0).getChoices()).isEmpty(); + assertThat(deltaList.get(1).getChoices()).hasSize(1); + assertThat(deltaList.get(2).getChoices()).hasSize(1); + assertThat(deltaList.get(3).getChoices()).hasSize(1); + assertThat(deltaList.get(4).getChoices()).hasSize(1); + + final var delta0 = deltaList.get(0); + assertThat(delta0.getId()).isEqualTo(""); + assertThat(delta0.getCreated()).isEqualTo(0); + assertThat(delta0.getModel()).isEqualTo(""); + assertThat(delta0.getObject()).isEqualTo(""); + assertThat(delta0.getUsage()).isNull(); + assertThat(delta0.getChoices()).isEmpty(); + // prompt filter results are only present in delta 0 + assertThat(delta0.getPromptFilterResults()).isNotNull(); + assertThat(delta0.getPromptFilterResults().get(0).getPromptIndex()).isEqualTo(0); + final var promptFilter0 = delta0.getPromptFilterResults().get(0).getContentFilterResults(); + assertThat(promptFilter0).isNotNull(); + assertFilter(promptFilter0); + + final var delta2 = deltaList.get(2); + assertThat(delta2.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1"); + assertThat(delta2.getCreated()).isEqualTo(1724825677); + assertThat(delta2.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(delta2.getObject()).isEqualTo("chat.completion.chunk"); + assertThat(delta2.getUsage()).isNull(); + assertThat(delta2.getPromptFilterResults()).isNull(); + final var choices2 = delta2.getChoices().get(0); + assertThat(choices2.getIndex()).isEqualTo(0); + assertThat(choices2.getFinishReason()).isNull(); + assertThat(choices2.getMessage()).isNotNull(); + // the role is only defined in delta 1, but it defaults to "assistant" for all deltas + assertThat(choices2.getMessage().getRole()).isEqualTo("assistant"); + assertThat(choices2.getMessage().getContent()).isEqualTo("Sure"); + assertThat(choices2.getMessage().getTool_calls()).isNull(); + final var filter2 = choices2.getContentFilterResults(); + assertFilter(filter2); + + final var delta3 = deltaList.get(3); + assertThat(delta3.getDeltaContent()).isEqualTo("!"); + + final var delta4Choice = deltaList.get(4).getChoices().get(0); + assertThat(delta4Choice.getFinishReason()).isEqualTo("stop"); + assertThat(delta4Choice.getMessage()).isNotNull(); + // the role is only defined in delta 1, but it defaults to "assistant" for all deltas + assertThat(delta4Choice.getMessage().getRole()).isEqualTo("assistant"); + assertThat(delta4Choice.getMessage().getContent()).isNull(); + assertThat(delta4Choice.getMessage().getTool_calls()).isNull(); + assertThat(totalOutput.getChoices()).hasSize(1); + final var choice = totalOutput.getChoices().get(0); + assertThat(choice.getFinishReason()).isEqualTo("stop"); + assertFilter(choice.getContentFilterResults()); + assertThat(choice.getFinishReason()).isEqualTo("stop"); + assertThat(choice.getMessage()).isNotNull(); + assertThat(choice.getMessage().getRole()).isEqualTo("assistant"); + assertThat(choice.getMessage().getContent()).isEqualTo("Sure!"); + assertThat(choice.getMessage().getTool_calls()).isNull(); + assertThat(totalOutput.getId()).isEqualTo("chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1"); + assertThat(totalOutput.getCreated()).isEqualTo(1724825677); + assertThat(totalOutput.getModel()).isEqualTo("gpt-35-turbo"); + assertThat(totalOutput.getObject()).isEqualTo("chat.completion.chunk"); + final var totalUsage = totalOutput.getUsage(); + assertThat(totalUsage).isNotNull(); + assertThat(totalUsage.getCompletionTokens()).isEqualTo(607); + assertThat(totalUsage.getPromptTokens()).isEqualTo(21); + assertThat(totalUsage.getTotalTokens()).isEqualTo(628); + assertThat(totalOutput.getSystemFingerprint()).isEqualTo("fp_e49e4201a9"); + assertThat(totalOutput.getPromptFilterResults()).isNotNull(); + assertFilter(totalOutput.getPromptFilterResults().get(0).getContentFilterResults()); + } + + Mockito.verify(inputStream, times(1)).close(); + } + } + + void assertFilter(OpenAiContentFilterPromptResults filter) { + assertThat(filter).isNotNull(); + assertThat(filter.getHate()).isNotNull(); + assertThat(filter.getHate().isFiltered()).isFalse(); + assertThat(filter.getHate().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getSelfHarm()).isNotNull(); + assertThat(filter.getSelfHarm().isFiltered()).isFalse(); + assertThat(filter.getSelfHarm().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getSexual()).isNotNull(); + assertThat(filter.getSexual().isFiltered()).isFalse(); + assertThat(filter.getSexual().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getViolence()).isNotNull(); + assertThat(filter.getViolence().isFiltered()).isFalse(); + assertThat(filter.getViolence().getSeverity()).isEqualTo(SAFE); + assertThat(filter.getJailbreak()).isNull(); + assertThat(filter.getProfanity()).isNull(); + assertThat(filter.getError()).isNull(); } } diff --git a/foundation-models/openai/src/test/resources/chatCompletionResponse.json b/foundation-models/openai/src/test/resources/chatCompletionResponse.json index 2bbeb89c..056576d4 100644 --- a/foundation-models/openai/src/test/resources/chatCompletionResponse.json +++ b/foundation-models/openai/src/test/resources/chatCompletionResponse.json @@ -63,5 +63,5 @@ } } ], - "system_fingerprint": null -} \ No newline at end of file + "system_fingerprint": "fp_e49e4201a9" +} diff --git a/foundation-models/openai/src/test/resources/streamChatCompletion.txt b/foundation-models/openai/src/test/resources/streamChatCompletion.txt new file mode 100644 index 00000000..6a4f849c --- /dev/null +++ b/foundation-models/openai/src/test/resources/streamChatCompletion.txt @@ -0,0 +1,6 @@ +data: {"choices":[],"created":0,"id":"","model":"","object":"","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}]} +data: {"choices":[{"content_filter_results":{},"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0}],"created":1724825677,"id":"chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1","model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":"fp_e49e4201a9","usage":null} +data: {"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"delta":{"content":"Sure"},"finish_reason":null,"index":0}],"created":1724825677,"id":"chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1","model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":"fp_e49e4201a9","usage":null} +data: {"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"delta":{"content":"!"},"finish_reason":null,"index":0}],"created":1724825677,"id":"chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1","model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":"fp_e49e4201a9","usage":null} +data: {"choices":[{"content_filter_results":{},"delta":{},"finish_reason":"stop","index":0}],"created":1724825677,"id":"chatcmpl-A16EvnkgEm6AdxY0NoOmGPjsJucQ1","model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":"fp_e49e4201a9","usage":{"completion_tokens":607,"prompt_tokens":21,"total_tokens":628}} +data: [DONE] diff --git a/foundation-models/openai/src/test/resources/streamChatCompletionError.txt b/foundation-models/openai/src/test/resources/streamChatCompletionError.txt new file mode 100644 index 00000000..174d47c0 --- /dev/null +++ b/foundation-models/openai/src/test/resources/streamChatCompletionError.txt @@ -0,0 +1,2 @@ +data: {"choices":[],"created":0,"id":"","model":"","object":"","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}]} +{"error":{"code":"429","message":"exceeded token rate limit"}} diff --git a/pom.xml b/pom.xml index 9bfa9e94..6fd85817 100644 --- a/pom.xml +++ b/pom.xml @@ -57,6 +57,7 @@ 3.4.0 2.1.3 6.1.12 + 5.12.0 false false @@ -91,6 +92,11 @@ ${junit-jupiter.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit-jupiter.version} + org.wiremock wiremock @@ -103,6 +109,12 @@ ${assertj-core.version} test + + org.mockito + mockito-core + ${mockito.version} + test + com.sap.ai.sdk