-
Notifications
You must be signed in to change notification settings - Fork 866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modular RAG: Orchestration and Post-Retrieval #1767
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,71 +16,81 @@ | |
|
||
package org.springframework.ai.chat.client.advisor; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.function.Predicate; | ||
|
||
import reactor.core.publisher.Flux; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.scheduler.Schedulers; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Collectors; | ||
|
||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; | ||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; | ||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; | ||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; | ||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; | ||
import org.springframework.ai.chat.model.ChatResponse; | ||
import org.springframework.ai.chat.prompt.PromptTemplate; | ||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.rag.Query; | ||
import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer; | ||
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor; | ||
import org.springframework.ai.rag.augmentation.QueryAugmentor; | ||
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; | ||
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter; | ||
import org.springframework.ai.rag.orchestration.routing.AllRetrieversQueryRouter; | ||
import org.springframework.ai.rag.orchestration.routing.QueryRouter; | ||
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander; | ||
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; | ||
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner; | ||
import org.springframework.ai.rag.retrieval.join.DocumentJoiner; | ||
import org.springframework.ai.rag.retrieval.search.DocumentRetriever; | ||
import org.springframework.core.task.TaskExecutor; | ||
import org.springframework.core.task.support.ContextPropagatingTaskDecorator; | ||
import org.springframework.lang.Nullable; | ||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; | ||
import org.springframework.util.Assert; | ||
import org.springframework.util.StringUtils; | ||
import reactor.core.scheduler.Scheduler; | ||
|
||
/** | ||
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the | ||
* building blocks defined in the {@link org.springframework.ai.rag} package and following | ||
* the Modular RAG Architecture. | ||
* <p> | ||
* It's the successor of the {@link QuestionAnswerAdvisor}. | ||
* | ||
* @author Christian Tzolov | ||
* @author Thomas Vitale | ||
* @since 1.0.0 | ||
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a> | ||
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a> | ||
*/ | ||
public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { | ||
public final class RetrievalAugmentationAdvisor implements BaseAdvisor { | ||
|
||
public static final String DOCUMENT_CONTEXT = "rag_document_context"; | ||
|
||
private final List<QueryTransformer> queryTransformers; | ||
|
||
private final DocumentRetriever documentRetriever; | ||
@Nullable | ||
private final QueryExpander queryExpander; | ||
|
||
private final QueryRouter queryRouter; | ||
|
||
private final DocumentJoiner documentJoiner; | ||
|
||
private final QueryAugmenter queryAugmenter; | ||
|
||
private final QueryAugmentor queryAugmentor; | ||
private final TaskExecutor taskExecutor; | ||
|
||
private final boolean protectFromBlocking; | ||
private final Scheduler scheduler; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of passing a The case where we don't need the protection from blocking is supported when passing |
||
|
||
private final int order; | ||
|
||
public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, DocumentRetriever documentRetriever, | ||
@Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) { | ||
Assert.notNull(queryTransformers, "queryTransformers cannot be null"); | ||
public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers, | ||
@Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentJoiner documentJoiner, | ||
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, | ||
@Nullable Integer order) { | ||
Assert.notNull(queryRouter, "queryRouter cannot be null"); | ||
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); | ||
Assert.notNull(documentRetriever, "documentRetriever cannot be null"); | ||
this.queryTransformers = queryTransformers; | ||
this.documentRetriever = documentRetriever; | ||
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build(); | ||
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true; | ||
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of(); | ||
this.queryExpander = queryExpander; | ||
this.queryRouter = queryRouter; | ||
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner(); | ||
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build(); | ||
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor(); | ||
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER; | ||
this.order = order != null ? order : 0; | ||
} | ||
|
||
|
@@ -89,41 +99,7 @@ public static Builder builder() { | |
} | ||
|
||
@Override | ||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { | ||
Assert.notNull(advisedRequest, "advisedRequest cannot be null"); | ||
Assert.notNull(chain, "chain cannot be null"); | ||
|
||
AdvisedRequest processedAdvisedRequest = before(advisedRequest); | ||
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest); | ||
return after(advisedResponse); | ||
} | ||
|
||
@Override | ||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { | ||
Assert.notNull(advisedRequest, "advisedRequest cannot be null"); | ||
Assert.notNull(chain, "chain cannot be null"); | ||
|
||
// This can be executed by both blocking and non-blocking Threads | ||
// E.g. a command line or Tomcat blocking Thread implementation | ||
// or by a WebFlux dispatch in a non-blocking manner. | ||
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ? | ||
// @formatter:off | ||
Mono.just(advisedRequest) | ||
.publishOn(Schedulers.boundedElastic()) | ||
.map(this::before) | ||
.flatMapMany(chain::nextAroundStream) | ||
: chain.nextAroundStream(before(advisedRequest)); | ||
// @formatter:on | ||
|
||
return advisedResponses.map(ar -> { | ||
if (onFinishReason().test(ar)) { | ||
ar = after(ar); | ||
} | ||
return ar; | ||
}); | ||
} | ||
|
||
private AdvisedRequest before(AdvisedRequest request) { | ||
public AdvisedRequest before(AdvisedRequest request) { | ||
Map<String, Object> context = new HashMap<>(request.adviseContext()); | ||
|
||
// 0. Create a query from the user text and parameters. | ||
|
@@ -135,17 +111,47 @@ private AdvisedRequest before(AdvisedRequest request) { | |
transformedQuery = queryTransformer.apply(transformedQuery); | ||
} | ||
|
||
// 2. Retrieve similar documents for the original query. | ||
List<Document> documents = this.documentRetriever.retrieve(transformedQuery); | ||
// 2. Expand query into one or multiple queries. | ||
List<Query> expandedQueries = queryExpander != null ? queryExpander.expand(transformedQuery) | ||
: List.of(transformedQuery); | ||
|
||
// 3. Get similar documents for each query. | ||
Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream() | ||
.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), taskExecutor)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Query expansion and/or routing can result in parallel executions that is supported by a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think is a good starting point that is generally useful. Users can implement their own RAG Advisor to customize how they do scatter gather based on other frameworks and patterns. |
||
.toList() | ||
.stream() | ||
.map(CompletableFuture::join) | ||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); | ||
|
||
// 4. Combine documents retrieved based on multiple queries and from multiple data | ||
// sources. | ||
List<Document> documents = documentJoiner.join(documentsForQuery); | ||
context.put(DOCUMENT_CONTEXT, documents); | ||
|
||
// 3. Augment user query with the document contextual data. | ||
Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents); | ||
// 5. Augment user query with the document contextual data. | ||
Query augmentedQuery = queryAugmenter.augment(originalQuery, documents); | ||
|
||
// 6. Update advised request with augmented prompt. | ||
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build(); | ||
} | ||
|
||
private AdvisedResponse after(AdvisedResponse advisedResponse) { | ||
/** | ||
* Processes a single query by routing it to document retrievers and collecting | ||
* documents. | ||
*/ | ||
private Map.Entry<Query, List<List<Document>>> getDocumentsForQuery(Query query) { | ||
List<DocumentRetriever> retrievers = queryRouter.route(query); | ||
List<List<Document>> documents = retrievers.stream() | ||
.map(retriever -> CompletableFuture.supplyAsync(() -> retriever.retrieve(query), taskExecutor)) | ||
.toList() | ||
.stream() | ||
.map(CompletableFuture::join) | ||
.toList(); | ||
return Map.entry(query, documents); | ||
} | ||
|
||
@Override | ||
public AdvisedResponse after(AdvisedResponse advisedResponse) { | ||
ChatResponse.Builder chatResponseBuilder; | ||
if (advisedResponse.response() == null) { | ||
chatResponseBuilder = ChatResponse.builder(); | ||
|
@@ -157,66 +163,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) { | |
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); | ||
} | ||
|
||
private Predicate<AdvisedResponse> onFinishReason() { | ||
return advisedResponse -> { | ||
ChatResponse chatResponse = advisedResponse.response(); | ||
return chatResponse != null && chatResponse.getResults() != null | ||
&& chatResponse.getResults() | ||
.stream() | ||
.anyMatch(result -> result != null && result.getMetadata() != null | ||
&& StringUtils.hasText(result.getMetadata().getFinishReason())); | ||
}; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return this.getClass().getSimpleName(); | ||
public Scheduler getScheduler() { | ||
return scheduler; | ||
} | ||
|
||
@Override | ||
public int getOrder() { | ||
return this.order; | ||
} | ||
|
||
private static TaskExecutor buildDefaultTaskExecutor() { | ||
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); | ||
taskExecutor.setThreadNamePrefix("ai-advisor-"); | ||
taskExecutor.setCorePoolSize(4); | ||
taskExecutor.setMaxPoolSize(16); | ||
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator()); | ||
taskExecutor.initialize(); | ||
return taskExecutor; | ||
} | ||
|
||
public static final class Builder { | ||
|
||
private final List<QueryTransformer> queryTransformers = new ArrayList<>(); | ||
private List<QueryTransformer> queryTransformers; | ||
|
||
private QueryExpander queryExpander; | ||
|
||
private QueryRouter queryRouter; | ||
|
||
private DocumentJoiner documentJoiner; | ||
|
||
private DocumentRetriever documentRetriever; | ||
private QueryAugmenter queryAugmenter; | ||
|
||
private QueryAugmentor queryAugmentor; | ||
private TaskExecutor taskExecutor; | ||
|
||
private Boolean protectFromBlocking; | ||
private Scheduler scheduler; | ||
|
||
private Integer order; | ||
|
||
private Builder() { | ||
} | ||
|
||
public Builder queryTransformers(List<QueryTransformer> queryTransformers) { | ||
Assert.notNull(queryTransformers, "queryTransformers cannot be null"); | ||
this.queryTransformers.addAll(queryTransformers); | ||
this.queryTransformers = queryTransformers; | ||
return this; | ||
} | ||
|
||
public Builder queryTransformers(QueryTransformer... queryTransformers) { | ||
Assert.notNull(queryTransformers, "queryTransformers cannot be null"); | ||
this.queryTransformers.addAll(Arrays.asList(queryTransformers)); | ||
this.queryTransformers = Arrays.asList(queryTransformers); | ||
return this; | ||
} | ||
|
||
public Builder queryExpander(QueryExpander queryExpander) { | ||
this.queryExpander = queryExpander; | ||
return this; | ||
} | ||
|
||
public Builder queryRouter(QueryRouter queryRouter) { | ||
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter"); | ||
this.queryRouter = queryRouter; | ||
return this; | ||
} | ||
|
||
public Builder documentRetriever(DocumentRetriever documentRetriever) { | ||
this.documentRetriever = documentRetriever; | ||
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter"); | ||
this.queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build(); | ||
return this; | ||
} | ||
|
||
public Builder documentJoiner(DocumentJoiner documentJoiner) { | ||
this.documentJoiner = documentJoiner; | ||
return this; | ||
} | ||
|
||
public Builder queryAugmenter(QueryAugmenter queryAugmenter) { | ||
this.queryAugmenter = queryAugmenter; | ||
return this; | ||
} | ||
|
||
public Builder queryAugmentor(QueryAugmentor queryAugmentor) { | ||
this.queryAugmentor = queryAugmentor; | ||
public Builder taskExecutor(TaskExecutor taskExecutor) { | ||
this.taskExecutor = taskExecutor; | ||
return this; | ||
} | ||
|
||
public Builder protectFromBlocking(Boolean protectFromBlocking) { | ||
this.protectFromBlocking = protectFromBlocking; | ||
public Builder scheduler(Scheduler scheduler) { | ||
this.scheduler = scheduler; | ||
return this; | ||
} | ||
|
||
|
@@ -226,8 +257,8 @@ public Builder order(Integer order) { | |
} | ||
|
||
public RetrievalAugmentationAdvisor build() { | ||
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor, | ||
this.protectFromBlocking, this.order); | ||
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.queryRouter, | ||
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order); | ||
} | ||
|
||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using the TaskExecutor to run things asynchronously, no observation data is propagated unless we add this dependency (https://docs.micrometer.io/context-propagation/reference/index.html)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍