Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Streaming #25

Merged
merged 62 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d319aa9
OpenAI streaming
CharlesDuboisSAP Aug 16, 2024
69ae7eb
Added homepage and error handling todo
CharlesDuboisSAP Aug 16, 2024
7870e6d
Renamed vars
CharlesDuboisSAP Aug 19, 2024
652ec1e
Added todos
CharlesDuboisSAP Aug 19, 2024
727b3d4
Made stream generic, try-with resources, TEXT_EVENT_STREAM, exception…
CharlesDuboisSAP Aug 21, 2024
b3190a5
Formatting
bot-sdk-js Aug 21, 2024
f0fa3e6
close stream correctly
CharlesDuboisSAP Aug 21, 2024
09ca6ea
Formatting
bot-sdk-js Aug 21, 2024
d86243a
Created OpenAiStreamOutput
CharlesDuboisSAP Aug 21, 2024
2a4ce7b
Merge remote-tracking branch 'origin/streaming' into streaming
CharlesDuboisSAP Aug 21, 2024
cf6ec46
Formatting
bot-sdk-js Aug 21, 2024
a73f037
Renamed stream to streamChatCompletion, Added comments
CharlesDuboisSAP Aug 22, 2024
eb3f24a
Added total output
CharlesDuboisSAP Aug 23, 2024
fb2cdaf
Total output is printed
CharlesDuboisSAP Aug 23, 2024
fe078c7
Formatting
bot-sdk-js Aug 23, 2024
09e1be0
addDelta is propagated everywhere
CharlesDuboisSAP Aug 23, 2024
42ae946
addDelta is propagated everywhere
CharlesDuboisSAP Aug 23, 2024
e6e009a
forgotten addDeltas
CharlesDuboisSAP Aug 23, 2024
bee8fdc
Added jackson dependencies
CharlesDuboisSAP Aug 23, 2024
5f03c6f
Added Javadoc
CharlesDuboisSAP Aug 23, 2024
e79ca8e
Removed 1 TODO
CharlesDuboisSAP Aug 23, 2024
ba2c5e0
PMD
CharlesDuboisSAP Aug 27, 2024
c10eecb
PMD again
CharlesDuboisSAP Aug 27, 2024
cdae1c6
Merge branch 'refs/heads/main' into streaming
CharlesDuboisSAP Aug 27, 2024
faa3b70
Merge branch 'refs/heads/main' into streaming
CharlesDuboisSAP Aug 27, 2024
0e1a167
Added OpenAiClientTest.streamChatCompletion()
CharlesDuboisSAP Aug 28, 2024
31dbd52
Change return type of stream, added e2e test
CharlesDuboisSAP Aug 29, 2024
de7e7f0
Added documentation
CharlesDuboisSAP Aug 29, 2024
349936f
Added documentation framework-agnostic + throw if finish reason is in…
CharlesDuboisSAP Aug 29, 2024
58b0bc9
Merge branch 'refs/heads/main' into streaming
CharlesDuboisSAP Aug 30, 2024
3366c2e
Added error handling test
CharlesDuboisSAP Aug 30, 2024
c709d31
Updates from pair review / discussion
MatKuhr Aug 30, 2024
73031d1
Cleanup + streamChatCompletion doesn't throw
CharlesDuboisSAP Sep 2, 2024
6b1bfd0
PMD
CharlesDuboisSAP Sep 2, 2024
23474ba
Added errorHandling test
CharlesDuboisSAP Sep 2, 2024
769cd7d
Apply suggestions from code review
CharlesDuboisSAP Sep 3, 2024
118dc69
Dependency analyze
CharlesDuboisSAP Sep 3, 2024
acd21c0
Review comments
CharlesDuboisSAP Sep 3, 2024
28268b2
Make client static
CharlesDuboisSAP Sep 3, 2024
9a9a44b
Formatting
bot-sdk-js Sep 3, 2024
788db03
PMD
CharlesDuboisSAP Sep 3, 2024
0616f55
Fix tests
CharlesDuboisSAP Sep 3, 2024
3446bf0
Removed exception constructors no args
CharlesDuboisSAP Sep 3, 2024
45a20c6
Refactor exception message
CharlesDuboisSAP Sep 3, 2024
f843061
Readme sentences
CharlesDuboisSAP Sep 3, 2024
5edcf71
Remove superfluous call super
CharlesDuboisSAP Sep 3, 2024
7474fb1
reset httpclient-cache and -factory after each test case
newtork Sep 3, 2024
ac6f36c
Very minor code-style improvements in test
newtork Sep 3, 2024
ffa369a
Minor code-style in OpenAIController
newtork Sep 3, 2024
6cfeee9
Reduce README sample code
newtork Sep 3, 2024
6d4fd2f
Update OpenAiStreamingHandler.java (#43)
newtork Sep 3, 2024
a6c566a
Fix import
newtork Sep 3, 2024
f6a4fe6
Added stream_options to model
CharlesDuboisSAP Sep 4, 2024
05dedf9
Change Executor#submit() to #execute()
newtork Sep 4, 2024
2604969
Merge branch 'streaming' of https://github.com/SAP/ai-sdk-java into s…
newtork Sep 4, 2024
9a3bf2f
Merge remote-tracking branch 'origin/main' into streaming
newtork Sep 4, 2024
a0ae779
Added usage testing
CharlesDuboisSAP Sep 4, 2024
2c934f7
Added beautiful Javadoc to enableStreaming
CharlesDuboisSAP Sep 4, 2024
77eb464
typo
CharlesDuboisSAP Sep 4, 2024
488f060
Fix mistake
CharlesDuboisSAP Sep 4, 2024
5e4ff73
Merge branch 'main' into streaming
newtork Sep 4, 2024
59e1382
streaming readme (#48)
newtork Sep 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,78 @@ final String resultMessage = result.getChoices().get(0).getMessage().getContent(

See [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java)

### Chat completion with a model not in defined `OpenAiModel`
### Chat completion with a model not defined in `OpenAiModel`

```java
final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(new OpenAiModel("model")).chatCompletion(request);
```

### Stream chat completion

CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
It's possible to pass a stream of chat completion delta elements, e.g. from the application backend to the frontend in real-time.

#### Stream the chat completion asynchronously
This is a blocking example for streaming and printing directly to the console:
```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);

// try-with-resources on stream ensures the connection will be closed
try( Stream<String> stream = client.streamChatCompletion(request)) {
stream.forEach(deltaString -> {
System.out.print(deltaString);
System.out.flush();
});
}
```

<details>
<summary>It's also possible to aggregate the total output.</summary>

The following example is non-blocking.
Any asynchronous library can be used, e.g. classic Thread API.

```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

OpenAiChatCompletionOutput totalOutput = new OpenAiChatCompletionOutput();
OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);

// Do the request before the thread starts to handle exceptions during request initialization
Stream<OpenAiChatCompletionDelta> stream = client.streamChatCompletionDeltas(request);

Thread thread = new Thread(() -> {
// try-with-resources ensures the stream is closed
try (stream) {
stream.peek(totalOutput::addDelta).forEach(delta -> System.out.println(delta));
}
});
thread.start(); // non-blocking

thread.join(); // blocking

// access aggregated information from total output, e.g.
Integer tokens = totalOutput.getUsage().getCompletionTokens();
System.out.println("Tokens: " + tokens);
```

</details>

#### Spring Boot example

Please find [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java).
It shows the usage of Spring Boot's `ResponseBodyEmitter` to stream the chat completion delta messages to the frontend in real-time.

## Orchestration chat completion

### Prerequisites
Expand Down
13 changes: 13 additions & 0 deletions e2e-test-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
<version>${springframework.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand All @@ -95,6 +100,14 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
</dependency>
<!-- scope "runtime" -->
<dependency>
<groupId>ch.qos.logback</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_ADA_002;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionFunction;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
Expand All @@ -14,13 +16,21 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage.ImageDetailLevel;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

/** Endpoints for OpenAI operations */
@Slf4j
@RestController
class OpenAiController {
/**
Expand All @@ -38,6 +48,98 @@ public static OpenAiChatCompletionOutput chatCompletion() {
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
}

/**
* Asynchronous stream of an OpenAI chat request
*
* @return the emitter that streams the assistant message response
*/
@SuppressWarnings("unused") // The end-to-end test doesn't use this method
@GetMapping("/streamChatCompletionDeltas")
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas() {
final var msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);

final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
MatKuhr marked this conversation as resolved.
Show resolved Hide resolved
() -> {
final var totalOutput = new OpenAiChatCompletionOutput();
// try-with-resources ensures the stream is closed
try (stream) {
stream
.peek(totalOutput::addDelta)
.forEach(delta -> send(emitter, delta.getDeltaContent()));
} finally {
send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
Copy link
Contributor

@newtork newtork Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment/Minor)

It could be worthwhile to not just dump the object-to-string. But instead to extract the relevant information and write it piece-by-piece. That would also demonstrate the actual benefit of the feature. Right now the interested reader would wonder.. "what does total output even mean? what does it contain?"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are potentially 30 relevant informations. Which one in particular?
Also you could say the regular OpenAiClient.chatCompletion() is confusing because it returns the total output.
Maybe you have a name suggestion for this var?

emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

private static String objectToJson(@Nonnull final Object obj) {
try {
return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj);
} catch (final JsonProcessingException ignored) {
return "Could not parse object to JSON";
}
}

/**
* Asynchronous stream of an OpenAI chat request
*
* @return the emitter that streams the assistant message response
*/
@SuppressWarnings("unused") // The end-to-end test doesn't use this method
@GetMapping("/streamChatCompletion")
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText(
"Can you give me the first 100 numbers of the Fibonacci sequence?")));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request);

final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
() -> {
try (stream) {
stream.forEach(deltaMessage -> send(emitter, deltaMessage));
} finally {
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

private static void send(
@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) {
try {
emitter.send(chunk);
} catch (final IOException e) {
log.error(Arrays.toString(e.getStackTrace()));
emitter.completeWithError(e);
}
}

/**
* Chat request to OpenAI with an image
*
Expand Down
2 changes: 2 additions & 0 deletions e2e-test-app/src/main/resources/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ <h2>Endpoints</h2>
<li><h4>OpenAI</h4></li>
<ul>
<li><a href="/chatCompletion">/chatCompletion</a></li>
<li><a href="/streamChatCompletion">/streamChatCompletion</a></li>
<li><a href="/streamChatCompletionDeltas">/streamChatCompletionDeltas</a></li>
<li><a href="/chatCompletionTool">/chatCompletionTool</a></li>
<li><a href="/chatCompletionImage">/chatCompletionImage</a></li>
<li><a href="/embedding">/embedding</a></li>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package com.sap.ai.sdk.app.controllers;

import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO;
import static org.assertj.core.api.Assertions.assertThat;

import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;

@Slf4j
class OpenAiTest {
@Test
void chatCompletion() {
Expand All @@ -23,12 +32,44 @@ void chatCompletionImage() {
assertThat(message.getContent()).isNotEmpty();
}

@Test
void streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest?")));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) Let's try also setting stream_options and assert on the token usage in the total

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not in the azure spec

Copy link
Member

@MatKuhr MatKuhr Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But as far as I can tell from my testing: it works anyway 😉
maybe they just haven't updated it yet


final var totalOutput = new OpenAiChatCompletionOutput();
final var emptyDeltaCount = new AtomicInteger(0);
OpenAiClient.forModel(GPT_35_TURBO)
.streamChatCompletionDeltas(request)
.peek(totalOutput::addDelta)
// foreach consumes all elements, closing the stream at the end
.forEach(
delta -> {
final String deltaContent = delta.getDeltaContent();
log.info("deltaContent: {}", deltaContent);
if (deltaContent.isEmpty()) {
emptyDeltaCount.incrementAndGet();
}
});

// the first two and the last delta don't have any content
// see OpenAiChatCompletionDelta#getDeltaContent
assertThat(emptyDeltaCount.get()).isLessThanOrEqualTo(3);

assertThat(totalOutput.getChoices()).isNotEmpty();
assertThat(totalOutput.getChoices().get(0).getMessage().getContent()).isNotEmpty();
assertThat(totalOutput.getPromptFilterResults()).isNotNull();
assertThat(totalOutput.getChoices().get(0).getContentFilterResults()).isNotNull();
}

@Test
void chatCompletionTools() {
final var completion = OpenAiController.chatCompletionTools();

final var message = completion.getChoices().get(0).getMessage();
assertThat(message.getRole()).isEqualTo("assistant");
assertThat(message.getTool_calls()).isNotNull();
assertThat(message.getTool_calls().get(0).getFunction().getName()).isEqualTo("fibonacci");
}

Expand Down
10 changes: 10 additions & 0 deletions foundation-models/openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.wiremock</groupId>
<artifactId>wiremock</artifactId>
Expand All @@ -107,5 +112,10 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Loading