generated from SAP/repository-template
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Orchestration streaming first version * Added unit tests * Formatting * Added documentation * Added tests * Release notes * Applied Alex's review comments --------- Co-authored-by: SAP Cloud SDK Bot <[email protected]>
- Loading branch information
1 parent
a174ba8
commit c326b7d
Showing
17 changed files
with
660 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
orchestration/src/main/java/com/sap/ai/sdk/orchestration/IterableStreamConverter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package com.sap.ai.sdk.orchestration; | ||
|
||
import static java.nio.charset.StandardCharsets.UTF_8; | ||
import static java.util.Spliterator.NONNULL; | ||
import static java.util.Spliterator.ORDERED; | ||
|
||
import io.vavr.control.Try; | ||
import java.io.BufferedReader; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.InputStreamReader; | ||
import java.util.Iterator; | ||
import java.util.NoSuchElementException; | ||
import java.util.Spliterators; | ||
import java.util.concurrent.Callable; | ||
import java.util.function.Function; | ||
import java.util.stream.Stream; | ||
import java.util.stream.StreamSupport; | ||
import javax.annotation.Nonnull; | ||
import javax.annotation.Nullable; | ||
import lombok.AccessLevel; | ||
import lombok.RequiredArgsConstructor; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.apache.hc.core5.http.HttpEntity; | ||
|
||
/** | ||
* Internal utility class to convert from a reading handler to {@link Iterable} and {@link Stream}. | ||
* | ||
* <p><strong>Note:</strong> All operations are sequential in nature. Thread safety is not | ||
* guaranteed. | ||
* | ||
* @param <T> Iterated item type. | ||
*/ | ||
@Slf4j | ||
@RequiredArgsConstructor(access = AccessLevel.PRIVATE) | ||
class IterableStreamConverter<T> implements Iterator<T> { | ||
/** see DEFAULT_CHAR_BUFFER_SIZE in {@link BufferedReader} * */ | ||
static final int BUFFER_SIZE = 8192; | ||
|
||
/** Read next entry for Stream or {@code null} when no further entry can be read. */ | ||
private final Callable<T> readHandler; | ||
|
||
/** Close handler to be called when Stream terminated. */ | ||
private final Runnable stopHandler; | ||
|
||
/** Error handler to be called when Stream is interrupted. */ | ||
private final Function<Exception, RuntimeException> errorHandler; | ||
|
||
private boolean isDone = false; | ||
private boolean isNextFetched = false; | ||
private T next = null; | ||
|
||
@SuppressWarnings("checkstyle:IllegalCatch") | ||
@Override | ||
public boolean hasNext() { | ||
if (isDone) { | ||
return false; | ||
} | ||
if (isNextFetched) { | ||
return true; | ||
} | ||
try { | ||
next = readHandler.call(); | ||
isNextFetched = true; | ||
if (next == null) { | ||
isDone = true; | ||
stopHandler.run(); | ||
} | ||
} catch (final Exception e) { | ||
isDone = true; | ||
stopHandler.run(); | ||
log.debug("Error while reading next element.", e); | ||
throw errorHandler.apply(e); | ||
} | ||
return !isDone; | ||
} | ||
|
||
@Override | ||
public T next() { | ||
if (next == null && !hasNext()) { | ||
throw new NoSuchElementException(); // normally not reached with Stream API | ||
} | ||
isNextFetched = false; | ||
return next; | ||
} | ||
|
||
/** | ||
* Create a sequential Stream of lines from an HTTP response string (UTF-8). The underlying {@link | ||
* InputStream} is closed, when the resulting Stream is closed (e.g. via try-with-resources) or | ||
* when an exception occurred. | ||
* | ||
* @param entity The HTTP entity object. | ||
* @return A sequential Stream object. | ||
* @throws OrchestrationClientException if the provided HTTP entity object is {@code null} or | ||
* empty. | ||
*/ | ||
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed | ||
@Nonnull | ||
static Stream<String> lines(@Nullable final HttpEntity entity) | ||
throws OrchestrationClientException { | ||
if (entity == null) { | ||
throw new OrchestrationClientException("Orchestration service response was empty."); | ||
} | ||
|
||
final InputStream inputStream; | ||
try { | ||
inputStream = entity.getContent(); | ||
} catch (final IOException e) { | ||
throw new OrchestrationClientException("Failed to read response content.", e); | ||
} | ||
|
||
final var reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8), BUFFER_SIZE); | ||
final Runnable closeHandler = | ||
() -> Try.run(reader::close).onFailure(e -> log.error("Could not close input stream", e)); | ||
final Function<Exception, RuntimeException> errHandler = | ||
e -> new OrchestrationClientException("Parsing response content was interrupted.", e); | ||
|
||
final var iterator = new IterableStreamConverter<>(reader::readLine, closeHandler, errHandler); | ||
final var spliterator = Spliterators.spliteratorUnknownSize(iterator, ORDERED | NONNULL); | ||
return StreamSupport.stream(spliterator, /* NOT PARALLEL */ false).onClose(closeHandler); | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
...stration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationChatCompletionDelta.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package com.sap.ai.sdk.orchestration; | ||
|
||
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse; | ||
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous; | ||
import java.util.Map; | ||
import javax.annotation.Nonnull; | ||
import javax.annotation.Nullable; | ||
import lombok.val; | ||
|
||
/** Orchestration chat completion output delta for streaming. */ | ||
public class OrchestrationChatCompletionDelta extends CompletionPostResponse | ||
implements StreamedDelta { | ||
|
||
@Nonnull | ||
@Override | ||
// will be fixed once the generated code add a discriminator which will allow this class to extend | ||
// CompletionPostResponseStreaming | ||
@SuppressWarnings("unchecked") | ||
public String getDeltaContent() { | ||
val choices = ((LLMModuleResultSynchronous) getOrchestrationResult()).getChoices(); | ||
// Avoid the first delta: "choices":[] | ||
if (!choices.isEmpty() | ||
// Multiple choices are spread out on multiple deltas | ||
// A delta only contains one choice with a variable index | ||
&& choices.get(0).getIndex() == 0) { | ||
|
||
final var message = (Map<String, Object>) choices.get(0).getCustomField("delta"); | ||
// Avoid the second delta: "choices":[{"delta":{"content":"","role":"assistant"}}] | ||
if (message != null && message.get("content") != null) { | ||
return message.get("content").toString(); | ||
} | ||
} | ||
return ""; | ||
} | ||
|
||
@Nullable | ||
@Override | ||
public String getFinishReason() { | ||
return ((LLMModuleResultSynchronous) getOrchestrationResult()) | ||
.getChoices() | ||
.get(0) | ||
.getFinishReason(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamingHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package com.sap.ai.sdk.orchestration; | ||
|
||
import static com.sap.ai.sdk.orchestration.OrchestrationClient.JACKSON; | ||
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.buildExceptionAndThrow; | ||
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.parseErrorAndThrow; | ||
|
||
import java.io.IOException; | ||
import java.util.stream.Stream; | ||
import javax.annotation.Nonnull; | ||
import lombok.RequiredArgsConstructor; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.apache.hc.core5.http.ClassicHttpResponse; | ||
|
||
@Slf4j | ||
@RequiredArgsConstructor | ||
class OrchestrationStreamingHandler<D extends StreamedDelta> { | ||
|
||
@Nonnull private final Class<D> deltaType; | ||
|
||
/** | ||
* @param response The response to process | ||
* @return A {@link Stream} of a model class instantiated from the response | ||
*/ | ||
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed | ||
@Nonnull | ||
Stream<D> handleResponse(@Nonnull final ClassicHttpResponse response) | ||
throws OrchestrationClientException { | ||
if (response.getCode() >= 300) { | ||
buildExceptionAndThrow(response); | ||
} | ||
return IterableStreamConverter.lines(response.getEntity()) | ||
// half of the lines are empty newlines, the last line is "data: [DONE]" | ||
.peek(line -> log.info("Handler: {}", line)) | ||
.filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim())) | ||
.peek( | ||
line -> { | ||
if (!line.startsWith("data: ")) { | ||
final String msg = "Failed to parse response from the Orchestration service"; | ||
parseErrorAndThrow(line, new OrchestrationClientException(msg)); | ||
} | ||
}) | ||
.map( | ||
line -> { | ||
final String data = line.substring(5); // remove "data: " | ||
try { | ||
return JACKSON.readValue(data, deltaType); | ||
} catch (final IOException e) { // exception message e gets lost | ||
log.error( | ||
"Failed to parse the following response from the Orchestration service: {}", | ||
line); | ||
throw new OrchestrationClientException("Failed to parse delta message: " + line, e); | ||
} | ||
}); | ||
} | ||
} |
Oops, something went wrong.