Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Streaming #25

Merged
merged 62 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d319aa9
OpenAI streaming
CharlesDuboisSAP Aug 16, 2024
69ae7eb
Added homepage and error handling todo
CharlesDuboisSAP Aug 16, 2024
7870e6d
Renamed vars
CharlesDuboisSAP Aug 19, 2024
652ec1e
Added todos
CharlesDuboisSAP Aug 19, 2024
727b3d4
Made stream generic, try-with resources, TEXT_EVENT_STREAM, exception…
CharlesDuboisSAP Aug 21, 2024
b3190a5
Formatting
bot-sdk-js Aug 21, 2024
f0fa3e6
close stream correctly
CharlesDuboisSAP Aug 21, 2024
09ca6ea
Formatting
bot-sdk-js Aug 21, 2024
d86243a
Created OpenAiStreamOutput
CharlesDuboisSAP Aug 21, 2024
2a4ce7b
Merge remote-tracking branch 'origin/streaming' into streaming
CharlesDuboisSAP Aug 21, 2024
cf6ec46
Formatting
bot-sdk-js Aug 21, 2024
a73f037
Renamed stream to streamChatCompletion, Added comments
CharlesDuboisSAP Aug 22, 2024
eb3f24a
Added total output
CharlesDuboisSAP Aug 23, 2024
fb2cdaf
Total output is printed
CharlesDuboisSAP Aug 23, 2024
fe078c7
Formatting
bot-sdk-js Aug 23, 2024
09e1be0
addDelta is propagated everywhere
CharlesDuboisSAP Aug 23, 2024
42ae946
addDelta is propagated everywhere
CharlesDuboisSAP Aug 23, 2024
e6e009a
forgotten addDeltas
CharlesDuboisSAP Aug 23, 2024
bee8fdc
Added jackson dependencies
CharlesDuboisSAP Aug 23, 2024
5f03c6f
Added Javadoc
CharlesDuboisSAP Aug 23, 2024
e79ca8e
Removed 1 TODO
CharlesDuboisSAP Aug 23, 2024
ba2c5e0
PMD
CharlesDuboisSAP Aug 27, 2024
c10eecb
PMD again
CharlesDuboisSAP Aug 27, 2024
cdae1c6
Merge branch 'refs/heads/main' into streaming
CharlesDuboisSAP Aug 27, 2024
faa3b70
Merge branch 'refs/heads/main' into streaming
CharlesDuboisSAP Aug 27, 2024
0e1a167
Added OpenAiClientTest.streamChatCompletion()
CharlesDuboisSAP Aug 28, 2024
31dbd52
Change return type of stream, added e2e test
CharlesDuboisSAP Aug 29, 2024
de7e7f0
Added documentation
CharlesDuboisSAP Aug 29, 2024
349936f
Added documentation framework-agnostic + throw if finish reason is in…
CharlesDuboisSAP Aug 29, 2024
58b0bc9
Merge branch 'refs/heads/main' into streaming
CharlesDuboisSAP Aug 30, 2024
3366c2e
Added error handling test
CharlesDuboisSAP Aug 30, 2024
c709d31
Updates from pair review / discussion
MatKuhr Aug 30, 2024
73031d1
Cleanup + streamChatCompletion doesn't throw
CharlesDuboisSAP Sep 2, 2024
6b1bfd0
PMD
CharlesDuboisSAP Sep 2, 2024
23474ba
Added errorHandling test
CharlesDuboisSAP Sep 2, 2024
769cd7d
Apply suggestions from code review
CharlesDuboisSAP Sep 3, 2024
118dc69
Dependency analyze
CharlesDuboisSAP Sep 3, 2024
acd21c0
Review comments
CharlesDuboisSAP Sep 3, 2024
28268b2
Make client static
CharlesDuboisSAP Sep 3, 2024
9a9a44b
Formatting
bot-sdk-js Sep 3, 2024
788db03
PMD
CharlesDuboisSAP Sep 3, 2024
0616f55
Fix tests
CharlesDuboisSAP Sep 3, 2024
3446bf0
Removed exception constructors no args
CharlesDuboisSAP Sep 3, 2024
45a20c6
Refactor exception message
CharlesDuboisSAP Sep 3, 2024
f843061
Readme sentences
CharlesDuboisSAP Sep 3, 2024
5edcf71
Remove superfluous call super
CharlesDuboisSAP Sep 3, 2024
7474fb1
reset httpclient-cache and -factory after each test case
newtork Sep 3, 2024
ac6f36c
Very minor code-style improvements in test
newtork Sep 3, 2024
ffa369a
Minor code-style in OpenAIController
newtork Sep 3, 2024
6cfeee9
Reduce README sample code
newtork Sep 3, 2024
6d4fd2f
Update OpenAiStreamingHandler.java (#43)
newtork Sep 3, 2024
a6c566a
Fix import
newtork Sep 3, 2024
f6a4fe6
Added stream_options to model
CharlesDuboisSAP Sep 4, 2024
05dedf9
Change Executor#submit() to #execute()
newtork Sep 4, 2024
2604969
Merge branch 'streaming' of https://github.com/SAP/ai-sdk-java into s…
newtork Sep 4, 2024
9a3bf2f
Merge remote-tracking branch 'origin/main' into streaming
newtork Sep 4, 2024
a0ae779
Added usage testing
CharlesDuboisSAP Sep 4, 2024
2c934f7
Added beautiful Javadoc to enableStreaming
CharlesDuboisSAP Sep 4, 2024
77eb464
typo
CharlesDuboisSAP Sep 4, 2024
488f060
Fix mistake
CharlesDuboisSAP Sep 4, 2024
5e4ff73
Merge branch 'main' into streaming
newtork Sep 4, 2024
59e1382
streaming readme (#48)
newtork Sep 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions e2e-test-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage.ImageDetailLevel;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

/** Endpoints for OpenAI operations */
@Slf4j
@RestController
class OpenAiController {
/**
Expand All @@ -38,6 +46,47 @@ public static OpenAiChatCompletionOutput chatCompletion() {
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
}

/**
* Stream chat request to OpenAI
*
* @return the emitter that streams the assistant message response
*/
@GetMapping("/stream")
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> stream() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText(
"Can you give me the first 100 number of the Fibonacci sequence?")));

ResponseBodyEmitter emitter = new ResponseBodyEmitter();
// Start streaming the content asynchronously
ThreadContextExecutors.getExecutor()
.submit(
() -> {
try (var stream = OpenAiClient.forModel(GPT_35_TURBO).stream(request)) {
stream
.getDeltaStream()
.filter(delta -> delta.getDeltaContent() != null)
.forEach(
delta -> {
try {
emitter.send(delta.getDeltaContent());
} catch (IOException e) {
log.error(Arrays.toString(e.getStackTrace()));
emitter.completeWithError(e);
}
});
} finally {
emitter.complete();
}
});
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Chat request to OpenAI with an image
*
Expand Down
1 change: 1 addition & 0 deletions e2e-test-app/src/main/resources/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ <h2>Endpoints</h2>
<li><h4>OpenAI</h4></li>
<ul>
<li><a href="/chatCompletion">/chatCompletion</a></li>
<li><a href="/stream">/stream</a></li>
<li><a href="/chatCompletionTool">/chatCompletionTool</a></li>
<li><a href="/chatCompletionImage">/chatCompletionImage</a></li>
<li><a href="/embedding">/embedding</a></li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ void chatCompletionImage() {
assertThat(message.getContent()).isNotEmpty();
}

@Test
void stream() {
final var emitter = OpenAiController.stream();
// TODO: assert on the emitter
}

@Test
void chatCompletionTools() {
final var completion = OpenAiController.chatCompletionTools();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import com.sap.ai.sdk.core.Core;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionStream;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiStreamOutput;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
Expand All @@ -30,7 +32,7 @@
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public final class OpenAiClient {
private static final String DEFAULT_API_VERSION = "2024-02-01";
private static final ObjectMapper JACKSON;
static final ObjectMapper JACKSON;

static {
JACKSON =
Expand Down Expand Up @@ -105,6 +107,20 @@ public OpenAiChatCompletionOutput chatCompletion(
return execute("/chat/completions", parameters, OpenAiChatCompletionOutput.class);
}

/**
* Generate a completion for the given prompt.
*
* @param parameters the prompt, including messages and other parameters.
* @return the completion output
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiChatCompletionStream<OpenAiStreamOutput> stream(
newtork marked this conversation as resolved.
Show resolved Hide resolved
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
parameters.setStream(true);
return stream("/chat/completions", parameters, OpenAiStreamOutput.class);
}

/**
* Get a vector representation of a given input that can be easily consumed by machine learning
* models and algorithms.
Expand All @@ -129,6 +145,16 @@ private <T> T execute(
return executeRequest(request, responseType);
}

@Nonnull
private <T> OpenAiChatCompletionStream<T> stream(
@Nonnull final String path,
@Nonnull final Object payload,
@Nonnull final Class<T> responseType) {
final var request = new HttpPost(path);
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
serializeAndSetHttpEntity(request, payload);
return streamRequest(request, responseType);
}

private static void serializeAndSetHttpEntity(
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
try {
Expand All @@ -145,9 +171,22 @@ private <T> T executeRequest(
try {
@SuppressWarnings("UnstableApiUsage")
final var client = ApacheHttpClient5Accessor.getHttpClient(destination);
return client.execute(request, new OpenAiResponseHandler<>(responseType, JACKSON));
return client.execute(request, new OpenAiResponseHandler<>(responseType));
} catch (final IOException e) {
throw new OpenAiClientException(e);
}
}

@Nonnull
private <T> OpenAiChatCompletionStream<T> streamRequest(
final BasicClassicHttpRequest request, @Nonnull final Class<T> responseType) {
try {
@SuppressWarnings("UnstableApiUsage")
final var client = ApacheHttpClient5Accessor.getHttpClient(destination);
return new OpenAiStreamingHandler<>(responseType)
.handleResponse(client.executeOpen(null, request, null));
} catch (final IOException e) {
throw new OpenAiClientException("Request to OpenAI model failed.", e);
throw new OpenAiClientException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,23 @@

/** Generic exception for errors occurring when using OpenAI foundation models. */
public class OpenAiClientException extends RuntimeException {
static final String BASE_ERROR_MESSAGE = "Request to OpenAI model failed";
@Serial private static final long serialVersionUID = -7345541120979974432L;

/** Create a new exception with the base message: {@code Request to OpenAI model failed} */
public OpenAiClientException() {
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
super(BASE_ERROR_MESSAGE);
}

/**
* Create a new exception with the base message: {@code Request to OpenAI model failed}
*
* @param e the cause
*/
public OpenAiClientException(@Nonnull final Exception e) {
super(BASE_ERROR_MESSAGE, e);
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Create a new exception with the given message.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static com.sap.ai.sdk.foundationmodels.openai.OpenAiClient.JACKSON;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiError;
import io.vavr.control.Try;
import java.io.IOException;
Expand All @@ -21,8 +22,14 @@
class OpenAiResponseHandler<T> implements HttpClientResponseHandler<T> {

@Nonnull private final Class<T> responseType;
@Nonnull private final ObjectMapper jackson;

/**
* Processes a {@link ClassicHttpResponse} and returns some value corresponding to that response.
*
* @param response The response to process
* @return A model class instantiated from the response
* @throws OpenAiClientException in case of a problem or the connection was aborted
*/
@Override
public T handleResponse(@Nonnull final ClassicHttpResponse response)
throws OpenAiClientException {
Expand All @@ -35,14 +42,15 @@ public T handleResponse(@Nonnull final ClassicHttpResponse response)
// The InputStream of the HTTP entity is closed by EntityUtils.toString
@SuppressWarnings("PMD.CloseResource")
@Nonnull
private T parseResponse(@Nonnull final ClassicHttpResponse response) {
private T parseResponse(@Nonnull final ClassicHttpResponse response)
throws OpenAiClientException {
final HttpEntity responseEntity = response.getEntity();
if (responseEntity == null) {
throw new OpenAiClientException("Response from OpenAI model was empty.");
}
final var content = getContent(responseEntity);
try {
return jackson.readValue(content, responseType);
return JACKSON.readValue(content, responseType);
} catch (final JsonProcessingException e) {
log.error("Failed to parse the following response from OpenAI model: {}", content);
throw new OpenAiClientException("Failed to parse response from OpenAI model", e);
Expand All @@ -60,7 +68,8 @@ private static String getContent(@Nonnull final HttpEntity entity) {

// The InputStream of the HTTP entity is closed by EntityUtils.toString
@SuppressWarnings("PMD.CloseResource")
private void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) {
static void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response)
throws OpenAiClientException {
final var exception =
new OpenAiClientException(
"Request to OpenAI model failed with status %s %s "
Expand All @@ -85,19 +94,28 @@ private void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response)
throw exception;
}

final var maybeError = Try.of(() -> jackson.readValue(content, OpenAiError.class));
parseErrorAndThrow(content, exception);
}

/**
* Parse the error response and throw an exception.
*
* @param errorResponse the error response, most likely a JSON of {@link OpenAiError}.
* @param baseException a base exception to add the error message to.
*/
static void parseErrorAndThrow(String errorResponse, OpenAiClientException baseException)
throws OpenAiClientException {
final var maybeError = Try.of(() -> JACKSON.readValue(errorResponse, OpenAiError.class));
if (maybeError.isFailure()) {
exception.addSuppressed(maybeError.getCause());
throw exception;
baseException.addSuppressed(maybeError.getCause());
throw baseException;
}

final var error = maybeError.get().getError();
if (error == null) {
throw exception;
throw baseException;
}
final var message =
"Request to OpenAI model failed with %s %s and error message: '%s'"
.formatted(response.getCode(), response.getReasonPhrase(), error.getMessage());
throw new OpenAiClientException(message);
throw new OpenAiClientException(
baseException.getMessage() + "and error message: '%s'".formatted(error.getMessage()));
newtork marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static com.sap.ai.sdk.foundationmodels.openai.OpenAiClient.JACKSON;
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiResponseHandler.buildExceptionAndThrow;
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiResponseHandler.parseErrorAndThrow;

import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.hc.core5.http.ClassicHttpResponse;

@Slf4j
@RequiredArgsConstructor
class OpenAiStreamingHandler<T> {

@Nonnull private final Class<T> responseType;

/**
* Processes a {@link ClassicHttpResponse} and returns some value corresponding to that response.
*
* @param response The response to process
* @return A {@link OpenAiChatCompletionStream} of a model class instantiated from the response
* @throws OpenAiClientException in case of a problem or the connection was aborted
*/
@Nonnull
public OpenAiChatCompletionStream<T> handleResponse(@Nonnull final ClassicHttpResponse response)
throws OpenAiClientException {
if (response.getCode() >= 300) {
buildExceptionAndThrow(response);
}
return parseResponse(response);
}

/**
* @param response The response to process
* @return A {@link OpenAiChatCompletionStream} of a model class instantiated from the response
* @author stippi
*/
private OpenAiChatCompletionStream<T> parseResponse(@Nonnull final ClassicHttpResponse response)
throws OpenAiClientException {

InputStream inputStream;
try {
inputStream = response.getEntity().getContent();
CharlesDuboisSAP marked this conversation as resolved.
Show resolved Hide resolved
} catch (IOException e) {
throw new OpenAiClientException("Failed to read response content.", e);
}
var output = new OpenAiChatCompletionStream<T>();
BufferedReader br =
new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));

// TODO: set total
Stream<T> deltaStream =
br.lines()
.filter(
responseLine ->
!responseLine.isEmpty() && !"data: [DONE]".equals(responseLine.trim()))
.peek(
responseLine -> {
if (!responseLine.startsWith("data: ")) {
parseErrorAndThrow(responseLine, new OpenAiClientException());
}
})
.map(
responseLine -> {
String data = responseLine.substring(5).replace("delta", "message");
try {
return JACKSON.readValue(data, responseType);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
return output.setDeltaStream(deltaStream);
}
}
Loading