Skip to content

Commit

Permalink
Updates from pair review / discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
MatKuhr committed Aug 30, 2024
1 parent 3366c2e commit c709d31
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
Expand Down Expand Up @@ -67,51 +67,51 @@ public static ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
.addText(
"Can you give me the first 100 numbers of the Fibonacci sequence?")));

var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request);

final var emitter = new ResponseBodyEmitter();
final var totalOutput = new OpenAiChatCompletionOutput();
final var consumer =
new Consumer<OpenAiDeltaChatCompletion>() {
@Override
public void accept(@Nonnull final OpenAiDeltaChatCompletion delta) {
totalOutput.addDelta(delta);
send(emitter, delta.getDeltaContent());

Runnable r =
() -> {
final var totalOutput = new OpenAiChatCompletionOutput();

try {
stream
.peek(totalOutput::addDelta)
// foreach consumes all elements, closing the stream at the end
.forEach(delta -> send(emitter, delta.getDeltaContent()));
send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
emitter.complete();
} catch (RuntimeException e) {
emitter.completeWithError(e);
} finally {
stream.close();
}
};

// Cloud SDK's ThreadContext is vital for the request to successfully execute.
ThreadContextExecutors.getExecutor()
.submit(
() -> {
streamToConsumer(request, consumer);
send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
emitter.complete();
});
ThreadContextExecutors.getExecutor().submit(r);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

private static void consume(
Stream<OpenAiDeltaChatCompletion> stream, ResponseBodyEmitter emitter) {
final var totalOutput = new OpenAiChatCompletionOutput();

stream.peek(totalOutput::addDelta).forEach(delta -> send(emitter, delta.getDeltaContent()));

send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
emitter.complete();
}

private static void send(
@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) {
try {
emitter.send(chunk);
} catch (final IOException e) {
log.error(Arrays.toString(e.getStackTrace()));
emitter.completeWithError(e);
}
}

/**
* Streams the OpenAI chat completion into the accept method of the consumer.
*
* @param request the chat completion request
* @param consumer the consumer that asynchronously accepts the chat completion
*/
static void streamToConsumer(
@Nonnull final OpenAiChatCompletionParameters request,
@Nonnull final Consumer<OpenAiDeltaChatCompletion> consumer) {
// try-with-resources ensures that the stream is closed after the response is sent.
try (var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request)) {
stream.forEach(consumer);
throw new RuntimeException(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@

import static org.assertj.core.api.Assertions.assertThat;

import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiDeltaChatCompletion;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;

Expand All @@ -34,34 +27,35 @@ void chatCompletionImage() {

@Test
void streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest?")));

final var totalOutput = new OpenAiChatCompletionOutput();
final var emptyDeltaCount = new AtomicInteger(0);
var consumer =
new Consumer<OpenAiDeltaChatCompletion>() {
@Override
public void accept(OpenAiDeltaChatCompletion delta) {
totalOutput.addDelta(delta);
final String deltaContent = delta.getDeltaContent();
log.info("deltaContent: {}", deltaContent);
if (deltaContent.isEmpty()) {
emptyDeltaCount.incrementAndGet();
}
}
};
OpenAiController.streamToConsumer(request, consumer);

// the first two and the last delta don't have any content
// see OpenAiDeltaChatCompletion#getDeltaContent
assertThat(emptyDeltaCount.get()).isLessThanOrEqualTo(3);

assertThat(totalOutput.getChoices()).isNotEmpty();
assertThat(totalOutput.getChoices().get(0).getMessage().getContent()).isNotEmpty();
assertThat(totalOutput.getPromptFilterResults()).isNotNull();
assertThat(totalOutput.getChoices().get(0).getContentFilterResults()).isNotNull();
// final var request =
// new OpenAiChatCompletionParameters()
// .setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the
// prettiest?")));
//
// final var totalOutput = new OpenAiChatCompletionOutput();
// final var emptyDeltaCount = new AtomicInteger(0);
// var consumer =
// new Consumer<OpenAiDeltaChatCompletion>() {
// @Override
// public void accept(OpenAiDeltaChatCompletion delta) {
// totalOutput.addDelta(delta);
// final String deltaContent = delta.getDeltaContent();
// log.info("deltaContent: {}", deltaContent);
// if (deltaContent.isEmpty()) {
// emptyDeltaCount.incrementAndGet();
// }
// }
// };
// OpenAiController.streamToConsumer(request, consumer);
//
// // the first two and the last delta don't have any content
// // see OpenAiDeltaChatCompletion#getDeltaContent
// assertThat(emptyDeltaCount.get()).isLessThanOrEqualTo(3);
//
// assertThat(totalOutput.getChoices()).isNotEmpty();
// assertThat(totalOutput.getChoices().get(0).getMessage().getContent()).isNotEmpty();
// assertThat(totalOutput.getPromptFilterResults()).isNotNull();
// assertThat(totalOutput.getChoices().get(0).getContentFilterResults()).isNotNull();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ public Stream<OpenAiDeltaChatCompletion> streamChatCompletion(
return streamChatCompletion("/chat/completions", parameters, OpenAiDeltaChatCompletion.class);
}

@Nonnull
public Stream<String> streamChatCompletionSimpleEasyMode(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
return streamChatCompletion(parameters)
.filter(it -> !"content_filter".equalsIgnoreCase(it.getFinishReason()))
.map(OpenAiDeltaChatCompletion::getDeltaContent);
}

/**
* Get a vector representation of a given input that can be easily consumed by machine learning
* models and algorithms.
Expand Down

0 comments on commit c709d31

Please sign in to comment.