diff --git a/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java b/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java index cb97d700..88e82cd0 100644 --- a/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java +++ b/e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java @@ -66,42 +66,44 @@ public static ResponseEntity streamChatCompletion() { .addText( "Can you give me the first 100 numbers of the Fibonacci sequence?"))); - var emitter = new ResponseBodyEmitter(); - var stream = OpenAiClient.forModel(GPT_35_TURBO).stream(request); + final var emitter = new ResponseBodyEmitter(); // Cloud SDK's ThreadContext is vital for the request to successfully execute. ThreadContextExecutors.getExecutor() .submit( () -> { - stream - .getDeltaStream() - .map(OpenAiDeltaChatCompletion::getDeltaContent) - // The first two and the last delta do not contain any message content - .filter(Objects::nonNull) - .forEach(content -> send(emitter, content)); - - String indentedJson = objectToJson(stream.getTotalOutput()); - send(emitter, "\n\n-----Total Output-----\n\n" + indentedJson); - emitter.complete(); - stream.close(); + // try-with-resources ensures that the stream is closed after the response is sent. + try (final var result = OpenAiClient.forModel(GPT_35_TURBO).stream(request)) { + result + .getDeltaStream() + .map(OpenAiDeltaChatCompletion::getDeltaContent) + // The first two and the last deltaStream do not contain any message content + .filter(Objects::nonNull) + .forEach(content -> send(emitter, content)); + + final String indentedJson = objectToJson(result.getTotalOutput()); + send(emitter, "\n\n-----Total Output-----\n\n" + indentedJson); + emitter.complete(); + } }); // MediaType.TEXT_EVENT_STREAM allows the browser to display the content as it is streamed return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter); } - private static void send(ResponseBodyEmitter emitter, String chunk) { + private static void send( + @Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) { try { emitter.send(chunk); - } catch (IOException e) { + } catch (final IOException e) { log.error(Arrays.toString(e.getStackTrace())); emitter.completeWithError(e); } } - private static String objectToJson(Object obj) { + private static String objectToJson(@Nonnull final Object obj) { try { return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj); - } catch (JsonProcessingException ignored) { + } catch (final JsonProcessingException ignored) { return "Could not parse object to JSON"; } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java index 6ceb5859..0946c405 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiResponseHandler.java @@ -103,7 +103,8 @@ static void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) * @param errorResponse the error response, most likely a JSON of {@link OpenAiError}. * @param baseException a base exception to add the error message to. */ - static void parseErrorAndThrow(String errorResponse, OpenAiClientException baseException) + static void parseErrorAndThrow( + @Nonnull final String errorResponse, @Nonnull final OpenAiClientException baseException) throws OpenAiClientException { final var maybeError = Try.of(() -> JACKSON.readValue(errorResponse, OpenAiError.class)); if (maybeError.isFailure()) { diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStream.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStream.java index 1b43f4c2..aa684f14 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStream.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStream.java @@ -25,14 +25,21 @@ public class OpenAiStream> implements AutoCloseable { - @Getter @Nonnull private Stream deltaStream; + @Getter(onMethod_ = @Nonnull) + private Stream deltaStream; + @Nonnull private T totalOutput; void addDelta(D delta) { totalOutput.addDelta(delta); } - /** Get the total aggregated output. */ + /** + * Get the total aggregated output from all deltas. Closes the delta stream. + * + * @return the total output until now. + */ + @Nonnull public T getTotalOutput() { deltaStream.close(); return totalOutput; diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java index 117401a3..3b227f5a 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiStreamingHandler.java @@ -47,13 +47,15 @@ public OpenAiStream handleResponse(@Nonnull final ClassicHttpResponse resp * @return A {@link OpenAiStream} of a model class instantiated from the response * @author stippi */ + // The stream is closed by the user of the OpenAiStream + @SuppressWarnings("PMD.CloseResource") private OpenAiStream parseResponse(@Nonnull final ClassicHttpResponse response) throws OpenAiClientException { final HttpEntity responseEntity = response.getEntity(); if (responseEntity == null) { throw new OpenAiClientException("Response from OpenAI model was empty."); } - InputStream inputStream; + final InputStream inputStream; try { inputStream = responseEntity.getContent(); } catch (IOException e) { @@ -61,7 +63,7 @@ private OpenAiStream parseResponse(@Nonnull final ClassicHttpResponse resp } final var br = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); - OpenAiStream output = new OpenAiStream(); + final OpenAiStream output = new OpenAiStream<>(); try { output.setTotalOutput(totalType.getDeclaredConstructor().newInstance()); } catch (InstantiationException @@ -72,7 +74,7 @@ private OpenAiStream parseResponse(@Nonnull final ClassicHttpResponse resp } // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format - Stream deltaStream = + final Stream deltaStream = br.lines() .filter( responseLine -> @@ -88,11 +90,11 @@ private OpenAiStream parseResponse(@Nonnull final ClassicHttpResponse resp responseLine -> { String data = responseLine.substring(5).replace("delta", "message"); try { - D delta = JACKSON.readValue(data, deltaType); + final D delta = JACKSON.readValue(data, deltaType); output.addDelta(delta); return delta; - } catch (IOException e) { - throw new RuntimeException(e); + } catch (final IOException e) { + throw new OpenAiClientException("Failed to parse delta message: " + data, e); } }); return output.setDeltaStream(deltaStream); diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java index 8159bd9c..58ac00ea 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/DeltaAggregatable.java @@ -1,5 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; +import javax.annotation.Nonnull; + /** * Interface for model classes that can be created from aggregated streamed deltas. * @@ -14,5 +16,5 @@ public interface DeltaAggregatable { * * @param delta the delta to add. */ - void addDelta(D delta); + void addDelta(@Nonnull final D delta); } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java index 2c1f125a..91a6ff65 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionChoice.java @@ -21,7 +21,7 @@ public class OpenAiChatCompletionChoice extends OpenAiCompletionChoice { @Setter(onMethod_ = @Nonnull, value = AccessLevel.PACKAGE) private OpenAiChatAssistantMessage message; - void addDelta(OpenAiDeltaChatCompletionChoice delta) { + void addDelta(@Nonnull final OpenAiDeltaChatCompletionChoice delta) { super.addDelta(delta); if (delta.getMessage() != null) { diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java index 165f3434..2469d0cf 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiChatCompletionOutput.java @@ -33,7 +33,7 @@ public class OpenAiChatCompletionOutput extends OpenAiCompletionOutput * * @param delta the delta to add. */ - public void addDelta(OpenAiDeltaChatCompletion delta) { + public void addDelta(@Nonnull final OpenAiDeltaChatCompletion delta) { super.addDelta(delta); if (delta.getSystemFingerprint() != null) { @@ -46,7 +46,7 @@ public void addDelta(OpenAiDeltaChatCompletion delta) { } // Multiple choices are spread out on multiple deltas // A delta only contains one choice with a variable index - int index = delta.getChoices().get(0).getIndex(); + final int index = delta.getChoices().get(0).getIndex(); for (int i = choices.size(); i < index + 1; i++) { choices.add(new OpenAiChatCompletionChoice()); } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java index e0ca972e..2b2b4f3b 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionChoice.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -30,7 +31,7 @@ public class OpenAiCompletionChoice { @Getter(onMethod_ = @Nullable) private OpenAiContentFilterPromptResults contentFilterResults; - void addDelta(OpenAiCompletionChoice delta) { + void addDelta(@Nonnull final OpenAiCompletionChoice delta) { if (delta.getFinishReason() != null) { finishReason = delta.getFinishReason(); diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java index 5cca1543..35e401d9 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiCompletionOutput.java @@ -44,7 +44,7 @@ public class OpenAiCompletionOutput { @Getter(onMethod_ = @Nullable) private List promptFilterResults; - void addDelta(OpenAiDeltaChatCompletion delta) { + void addDelta(@Nonnull final OpenAiDeltaChatCompletion delta) { created = delta.getCreated(); id = delta.getId(); model = delta.getModel(); @@ -57,11 +57,10 @@ void addDelta(OpenAiDeltaChatCompletion delta) { usage.addDelta(delta.getUsage()); } - if (delta.getPromptFilterResults() != null) { - if (promptFilterResults == null) { - promptFilterResults = delta.getPromptFilterResults(); - } - // prompt_filter_results is only present once in the first delta + if (delta.getPromptFilterResults() != null && promptFilterResults == null) { + promptFilterResults = delta.getPromptFilterResults(); + // prompt_filter_results is overriden instead of updated because it is only present once in + // the first delta } } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java index 57e5a79a..342914cb 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterPromptResults.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -16,7 +17,7 @@ public class OpenAiContentFilterPromptResults extends OpenAiContentFilterResults @Getter(onMethod_ = @Nullable) private OpenAiContentFilterDetectedResult jailbreak; - void addDelta(OpenAiContentFilterPromptResults delta) { + void addDelta(@Nonnull final OpenAiContentFilterPromptResults delta) { super.addDelta(delta); if (delta.getJailbreak() != null) { diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java index 99b1e70a..2c454db3 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiContentFilterResultsBase.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai.model; import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -41,13 +42,9 @@ public class OpenAiContentFilterResultsBase { @Getter(onMethod_ = @Nullable) private OpenAiErrorBase error; - void addDelta(OpenAiContentFilterPromptResults delta) { + void addDelta(@Nonnull final OpenAiContentFilterPromptResults delta) { if (delta.getSexual() != null) { sexual = delta.getSexual(); - System.out.println(sexual.getSeverity()); - System.out.println(sexual.isFiltered()); - } else { - System.out.println("Sexual is null"); } if (delta.getViolence() != null) { violence = delta.getViolence(); diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java index 65bf22eb..7691caed 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/model/OpenAiUsage.java @@ -28,7 +28,7 @@ public class OpenAiUsage { @Getter(onMethod_ = @Nonnull) private Integer totalTokens; - void addDelta(OpenAiUsage delta) { + void addDelta(@Nonnull final OpenAiUsage delta) { if (delta.getCompletionTokens() != null) { completionTokens = delta.getCompletionTokens(); }