Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular RAG: Orchestration and Post-Retrieval #1767

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion spring-ai-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
<artifactId>micrometer-core</artifactId>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>context-propagation</artifactId>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the TaskExecutor to run things asynchronously, no observation data is propagated unless we add this dependency (https://docs.micrometer.io/context-propagation/reference/index.html)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-tracing-bridge-otel</artifactId>
Expand Down Expand Up @@ -195,4 +200,4 @@
</profiles>


</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,81 @@

package org.springframework.ai.chat.client.advisor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer;
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
import org.springframework.ai.rag.augmentation.QueryAugmentor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.orchestration.routing.AllRetrieversQueryRouter;
import org.springframework.ai.rag.orchestration.routing.QueryRouter;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import reactor.core.scheduler.Scheduler;

/**
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
* building blocks defined in the {@link org.springframework.ai.rag} package and following
* the Modular RAG Architecture.
* <p>
* It's the successor of the {@link QuestionAnswerAdvisor}.
*
* @author Christian Tzolov
* @author Thomas Vitale
* @since 1.0.0
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
*/
public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

public static final String DOCUMENT_CONTEXT = "rag_document_context";

private final List<QueryTransformer> queryTransformers;

private final DocumentRetriever documentRetriever;
@Nullable
private final QueryExpander queryExpander;

private final QueryRouter queryRouter;

private final DocumentJoiner documentJoiner;

private final QueryAugmenter queryAugmenter;

private final QueryAugmentor queryAugmentor;
private final TaskExecutor taskExecutor;

private final boolean protectFromBlocking;
private final Scheduler scheduler;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing a protectFromBlocking parameter, users can now pass a Scheduler implementation if they want to customise the streaming behavior. This has two main benefits: more clarity about what is happening under the hood and also the possibility of using different schedulers than the elastic one.

The case where we don't need the protection from blocking is supported when passing Schedulers.immediate(). The default, as we discussed last time, is running with Schedulers.boundedElastic().


private final int order;

public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, DocumentRetriever documentRetriever,
@Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
@Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentJoiner documentJoiner,
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
@Nullable Integer order) {
Assert.notNull(queryRouter, "queryRouter cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
this.queryTransformers = queryTransformers;
this.documentRetriever = documentRetriever;
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true;
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
this.queryExpander = queryExpander;
this.queryRouter = queryRouter;
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
this.order = order != null ? order : 0;
}

Expand All @@ -89,41 +99,7 @@ public static Builder builder() {
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
Assert.notNull(chain, "chain cannot be null");

AdvisedRequest processedAdvisedRequest = before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
return after(advisedResponse);
}

@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
Assert.notNull(chain, "chain cannot be null");

// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(chain::nextAroundStream)
: chain.nextAroundStream(before(advisedRequest));
// @formatter:on

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}

private AdvisedRequest before(AdvisedRequest request) {
public AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());

// 0. Create a query from the user text and parameters.
Expand All @@ -135,17 +111,47 @@ private AdvisedRequest before(AdvisedRequest request) {
transformedQuery = queryTransformer.apply(transformedQuery);
}

// 2. Retrieve similar documents for the original query.
List<Document> documents = this.documentRetriever.retrieve(transformedQuery);
// 2. Expand query into one or multiple queries.
List<Query> expandedQueries = queryExpander != null ? queryExpander.expand(transformedQuery)
: List.of(transformedQuery);

// 3. Get similar documents for each query.
Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), taskExecutor))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Query expansion and/or routing can result in parallel executions that is supported by a TaskExecutor, following the common practice in Spring. By default, a ThreadPoolTaskExecutor is used, but it's possible to customise it. For example, you can pass the auto-configured TaskExecutor from Spring Boot or another implementation, such as VirtualThreadTaskExecutor. I'll describe this in detail in the separate PR where I'm writing the documentation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is a good starting point that is generally useful. Users can implement their own RAG Advisor to customize how they do scatter gather based on other frameworks and patterns.

.toList()
.stream()
.map(CompletableFuture::join)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

// 4. Combine documents retrieved based on multiple queries and from multiple data
// sources.
List<Document> documents = documentJoiner.join(documentsForQuery);
context.put(DOCUMENT_CONTEXT, documents);

// 3. Augment user query with the document contextual data.
Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents);
// 5. Augment user query with the document contextual data.
Query augmentedQuery = queryAugmenter.augment(originalQuery, documents);

// 6. Update advised request with augmented prompt.
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
/**
* Processes a single query by routing it to document retrievers and collecting
* documents.
*/
private Map.Entry<Query, List<List<Document>>> getDocumentsForQuery(Query query) {
List<DocumentRetriever> retrievers = queryRouter.route(query);
List<List<Document>> documents = retrievers.stream()
.map(retriever -> CompletableFuture.supplyAsync(() -> retriever.retrieve(query), taskExecutor))
.toList()
.stream()
.map(CompletableFuture::join)
.toList();
return Map.entry(query, documents);
}

@Override
public AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder;
if (advisedResponse.response() == null) {
chatResponseBuilder = ChatResponse.builder();
Expand All @@ -157,66 +163,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) {
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

private Predicate<AdvisedResponse> onFinishReason() {
return advisedResponse -> {
ChatResponse chatResponse = advisedResponse.response();
return chatResponse != null && chatResponse.getResults() != null
&& chatResponse.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
};
}

@Override
public String getName() {
return this.getClass().getSimpleName();
public Scheduler getScheduler() {
return scheduler;
}

@Override
public int getOrder() {
return this.order;
}

private static TaskExecutor buildDefaultTaskExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.setThreadNamePrefix("ai-advisor-");
taskExecutor.setCorePoolSize(4);
taskExecutor.setMaxPoolSize(16);
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
taskExecutor.initialize();
return taskExecutor;
}

public static final class Builder {

private final List<QueryTransformer> queryTransformers = new ArrayList<>();
private List<QueryTransformer> queryTransformers;

private QueryExpander queryExpander;

private QueryRouter queryRouter;

private DocumentJoiner documentJoiner;

private DocumentRetriever documentRetriever;
private QueryAugmenter queryAugmenter;

private QueryAugmentor queryAugmentor;
private TaskExecutor taskExecutor;

private Boolean protectFromBlocking;
private Scheduler scheduler;

private Integer order;

private Builder() {
}

public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
this.queryTransformers.addAll(queryTransformers);
this.queryTransformers = queryTransformers;
return this;
}

public Builder queryTransformers(QueryTransformer... queryTransformers) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
this.queryTransformers.addAll(Arrays.asList(queryTransformers));
this.queryTransformers = Arrays.asList(queryTransformers);
return this;
}

public Builder queryExpander(QueryExpander queryExpander) {
this.queryExpander = queryExpander;
return this;
}

public Builder queryRouter(QueryRouter queryRouter) {
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
this.queryRouter = queryRouter;
return this;
}

public Builder documentRetriever(DocumentRetriever documentRetriever) {
this.documentRetriever = documentRetriever;
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
this.queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build();
return this;
}

public Builder documentJoiner(DocumentJoiner documentJoiner) {
this.documentJoiner = documentJoiner;
return this;
}

public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
this.queryAugmenter = queryAugmenter;
return this;
}

public Builder queryAugmentor(QueryAugmentor queryAugmentor) {
this.queryAugmentor = queryAugmentor;
public Builder taskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
return this;
}

public Builder protectFromBlocking(Boolean protectFromBlocking) {
this.protectFromBlocking = protectFromBlocking;
public Builder scheduler(Scheduler scheduler) {
this.scheduler = scheduler;
return this;
}

Expand All @@ -226,8 +257,8 @@ public Builder order(Integer order) {
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor,
this.protectFromBlocking, this.order);
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.queryRouter,
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
}

}
Expand Down
Loading