Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular RAG: Query Augmentor #1644

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -32,12 +31,12 @@
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
import org.springframework.ai.rag.augmentation.QueryAugmentor;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand All @@ -60,33 +59,19 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr

public static final String DOCUMENT_CONTEXT = "rag_document_context";

public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
{query}

Context information is below. Use this information to answer the user query.

---------------------
{context}
---------------------

Given the context and provided history information and not prior knowledge,
reply to the user query. If the answer is not in the context, inform
the user that you can't answer the query.
""");

private final DocumentRetriever documentRetriever;

private final PromptTemplate promptTemplate;
private final QueryAugmentor queryAugmentor;

private final boolean protectFromBlocking;

private final int order;

public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate,
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable QueryAugmentor queryAugmentor,
@Nullable Boolean protectFromBlocking, @Nullable Integer order) {
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
this.documentRetriever = documentRetriever;
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
this.order = order != null ? order : 0;
}
Expand Down Expand Up @@ -140,21 +125,10 @@ private AdvisedRequest before(AdvisedRequest request) {
List<Document> documents = this.documentRetriever.retrieve(query);
context.put(DOCUMENT_CONTEXT, documents);

// 2. Combine retrieved documents.
String documentContext = documents.stream()
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

// 3. Define augmentation prompt parameters.
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);

// 4. Augment user prompt with the context data.
UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters);
// 2. Augment user query with the contextual data.
Query augmentedQuery = this.queryAugmentor.augment(query, documents);

return AdvisedRequest.from(request)
.withUserText(augmentedUserMessage.getContent())
.withAdviseContext(context)
.build();
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
Expand Down Expand Up @@ -185,7 +159,7 @@ public static final class Builder {

private DocumentRetriever documentRetriever;

private PromptTemplate promptTemplate;
private QueryAugmentor queryAugmentor;

private Boolean protectFromBlocking;

Expand All @@ -199,8 +173,8 @@ public Builder documentRetriever(DocumentRetriever documentRetriever) {
return this;
}

public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
public Builder queryAugmentor(QueryAugmentor queryAugmentor) {
this.queryAugmentor = queryAugmentor;
return this;
}

Expand All @@ -215,7 +189,7 @@ public Builder order(Integer order) {
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate,
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.queryAugmentor,
this.protectFromBlocking, this.order);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.rag.augmentation;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
import org.springframework.ai.rag.Query;
import org.springframework.ai.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Augments the user query with contextual data.
*
* <p>
* Example usage: <pre>{@code
* QueryAugmentor augmentor = ContextualQueryAugmentor.builder()
* .promptTemplate(promptTemplate)
* .emptyContextPromptTemplate(emptyContextPromptTemplate)
* .allowEmptyContext(allowEmptyContext)
* .build();
* Query augmentedQuery = augmentor.augment(query, documents);
* }</pre>
*
* @author Thomas Vitale
* @since 1.0.0
*/
public class ContextualQueryAugmentor implements QueryAugmentor {

private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
Context information is below.

---------------------
{context}
---------------------

Given the context information and no prior knowledge, answer the query.

Follow these rules:

1. If the answer is not in the context, just say that you don't know.
2. Avoid statements like "Based on the context..." or "The provided information...".

Query: {query}

Answer:
""");

private static final PromptTemplate DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE = new PromptTemplate("""
The user query is outside your knowledge base.
Politely inform the user that you can't answer it.
""");

private static final boolean DEFAULT_ALLOW_EMPTY_CONTEXT = true;

private final PromptTemplate promptTemplate;

private final PromptTemplate emptyContextPromptTemplate;

private final boolean allowEmptyContext;

public ContextualQueryAugmentor(@Nullable PromptTemplate promptTemplate,
@Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext) {
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.emptyContextPromptTemplate = emptyContextPromptTemplate != null ? emptyContextPromptTemplate
: DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE;
this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : DEFAULT_ALLOW_EMPTY_CONTEXT;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context");
}

@Override
public Query augment(Query query, List<Document> documents) {
Assert.notNull(query, "query cannot be null");
Assert.notNull(documents, "documents cannot be null");

if (documents.isEmpty()) {
return augmentQueryWhenEmptyContext(query);
}

// 1. Join documents.
String documentContext = documents.stream()
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

// 2. Define prompt parameters.
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);

// 3. Augment user prompt with document context.
return new Query(this.promptTemplate.render(promptParameters));
}

private Query augmentQueryWhenEmptyContext(Query query) {
if (this.allowEmptyContext) {
return query;
}
return new Query(this.emptyContextPromptTemplate.render());
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private PromptTemplate promptTemplate;

private PromptTemplate emptyContextPromptTemplate;

private Boolean allowEmptyContext;

public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}

public Builder emptyContextPromptTemplate(PromptTemplate emptyContextPromptTemplate) {
this.emptyContextPromptTemplate = emptyContextPromptTemplate;
return this;
}

public Builder allowEmptyContext(Boolean allowEmptyContext) {
this.allowEmptyContext = allowEmptyContext;
return this;
}

public ContextualQueryAugmentor build() {
return new ContextualQueryAugmentor(this.promptTemplate, this.emptyContextPromptTemplate,
this.allowEmptyContext);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.rag.augmentation;

import java.util.List;
import java.util.function.BiFunction;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;

/**
* Component for augmenting a query with contextual data based on a specific strategy.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public interface QueryAugmentor extends BiFunction<Query, List<Document>, Query> {

/**
* Augments the user query with contextual data.
* @param query The user query to augment
* @param documents The contextual data to use for augmentation
* @return The augmented query
*/
Query augment(Query query, List<Document> documents);

/**
* Augments the user query with contextual data.
* @param query The user query to augment
* @param documents The contextual data to use for augmentation
* @return The augmented query
*/
default Query apply(Query query, List<Document> documents) {
return augment(query, documents);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* RAG Module: Augmentation.
* <p>
* This package provides the functional building blocks for augmenting a user query with
* contextual data.
*/

@NonNullApi
@NonNullFields
package org.springframework.ai.rag.augmentation;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Loading