diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index 82188b88fa..70c49ff6a7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.Map; import java.util.function.Predicate; -import java.util.stream.Collectors; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -32,12 +31,12 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; -import org.springframework.ai.model.Content; import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor; +import org.springframework.ai.rag.augmentation.QueryAugmentor; import org.springframework.ai.rag.retrieval.source.DocumentRetriever; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -60,33 +59,19 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr public static final String DOCUMENT_CONTEXT = "rag_document_context"; - public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" - {query} - - Context information is below. Use this information to answer the user query. - - --------------------- - {context} - --------------------- - - Given the context and provided history information and not prior knowledge, - reply to the user query. If the answer is not in the context, inform - the user that you can't answer the query. - """); - private final DocumentRetriever documentRetriever; - private final PromptTemplate promptTemplate; + private final QueryAugmentor queryAugmentor; private final boolean protectFromBlocking; private final int order; - public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate, + public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) { Assert.notNull(documentRetriever, "documentRetriever cannot be null"); this.documentRetriever = documentRetriever; - this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build(); this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false; this.order = order != null ? order : 0; } @@ -140,21 +125,10 @@ private AdvisedRequest before(AdvisedRequest request) { List documents = this.documentRetriever.retrieve(query); context.put(DOCUMENT_CONTEXT, documents); - // 2. Combine retrieved documents. - String documentContext = documents.stream() - .map(Content::getContent) - .collect(Collectors.joining(System.lineSeparator())); - - // 3. Define augmentation prompt parameters. - Map promptParameters = Map.of("query", query.text(), "context", documentContext); - - // 4. Augment user prompt with the context data. - UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters); + // 2. Augment user query with the contextual data. + Query augmentedQuery = this.queryAugmentor.augment(query, documents); - return AdvisedRequest.from(request) - .withUserText(augmentedUserMessage.getContent()) - .withAdviseContext(context) - .build(); + return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build(); } private AdvisedResponse after(AdvisedResponse advisedResponse) { @@ -185,7 +159,7 @@ public static final class Builder { private DocumentRetriever documentRetriever; - private PromptTemplate promptTemplate; + private QueryAugmentor queryAugmentor; private Boolean protectFromBlocking; @@ -199,8 +173,8 @@ public Builder documentRetriever(DocumentRetriever documentRetriever) { return this; } - public Builder promptTemplate(PromptTemplate promptTemplate) { - this.promptTemplate = promptTemplate; + public Builder queryAugmentor(QueryAugmentor queryAugmentor) { + this.queryAugmentor = queryAugmentor; return this; } @@ -215,7 +189,7 @@ public Builder order(Integer order) { } public RetrievalAugmentationAdvisor build() { - return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate, + return new RetrievalAugmentationAdvisor(this.documentRetriever, this.queryAugmentor, this.protectFromBlocking, this.order); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java new file mode 100644 index 0000000000..b4c6bdf3f6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java @@ -0,0 +1,152 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.augmentation; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.model.Content; +import org.springframework.ai.rag.Query; +import org.springframework.ai.util.PromptAssert; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Augments the user query with contextual data. + * + *

+ * Example usage:

{@code
+ * QueryAugmentor augmentor = ContextualQueryAugmentor.builder()
+ *    .promptTemplate(promptTemplate)
+ *    .emptyContextPromptTemplate(emptyContextPromptTemplate)
+ *    .allowEmptyContext(allowEmptyContext)
+ *    .build();
+ * Query augmentedQuery = augmentor.augment(query, documents);
+ * }
+ * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class ContextualQueryAugmentor implements QueryAugmentor { + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + Context information is below. + + --------------------- + {context} + --------------------- + + Given the context information and no prior knowledge, answer the query. + + Follow these rules: + + 1. If the answer is not in the context, just say that you don't know. + 2. Avoid statements like "Based on the context..." or "The provided information...". + + Query: {query} + + Answer: + """); + + private static final PromptTemplate DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE = new PromptTemplate(""" + The user query is outside your knowledge base. + Politely inform the user that you can't answer it. + """); + + private static final boolean DEFAULT_ALLOW_EMPTY_CONTEXT = true; + + private final PromptTemplate promptTemplate; + + private final PromptTemplate emptyContextPromptTemplate; + + private final boolean allowEmptyContext; + + public ContextualQueryAugmentor(@Nullable PromptTemplate promptTemplate, + @Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext) { + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.emptyContextPromptTemplate = emptyContextPromptTemplate != null ? emptyContextPromptTemplate + : DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE; + this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : DEFAULT_ALLOW_EMPTY_CONTEXT; + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context"); + } + + @Override + public Query augment(Query query, List documents) { + Assert.notNull(query, "query cannot be null"); + Assert.notNull(documents, "documents cannot be null"); + + if (documents.isEmpty()) { + return augmentQueryWhenEmptyContext(query); + } + + // 1. Join documents. + String documentContext = documents.stream() + .map(Content::getContent) + .collect(Collectors.joining(System.lineSeparator())); + + // 2. Define prompt parameters. + Map promptParameters = Map.of("query", query.text(), "context", documentContext); + + // 3. Augment user prompt with document context. + return new Query(this.promptTemplate.render(promptParameters)); + } + + private Query augmentQueryWhenEmptyContext(Query query) { + if (this.allowEmptyContext) { + return query; + } + return new Query(this.emptyContextPromptTemplate.render()); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private PromptTemplate promptTemplate; + + private PromptTemplate emptyContextPromptTemplate; + + private Boolean allowEmptyContext; + + public Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder emptyContextPromptTemplate(PromptTemplate emptyContextPromptTemplate) { + this.emptyContextPromptTemplate = emptyContextPromptTemplate; + return this; + } + + public Builder allowEmptyContext(Boolean allowEmptyContext) { + this.allowEmptyContext = allowEmptyContext; + return this; + } + + public ContextualQueryAugmentor build() { + return new ContextualQueryAugmentor(this.promptTemplate, this.emptyContextPromptTemplate, + this.allowEmptyContext); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java new file mode 100644 index 0000000000..97b759e5c1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.augmentation; + +import java.util.List; +import java.util.function.BiFunction; + +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; + +/** + * Component for augmenting a query with contextual data based on a specific strategy. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface QueryAugmentor extends BiFunction, Query> { + + /** + * Augments the user query with contextual data. + * @param query The user query to augment + * @param documents The contextual data to use for augmentation + * @return The augmented query + */ + Query augment(Query query, List documents); + + /** + * Augments the user query with contextual data. + * @param query The user query to augment + * @param documents The contextual data to use for augmentation + * @return The augmented query + */ + default Query apply(Query query, List documents) { + return augment(query, documents); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java new file mode 100644 index 0000000000..a82b662c6f --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RAG Module: Augmentation. + *

+ * This package provides the functional building blocks for augmenting a user query with + * contextual data. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.rag.augmentation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/PromptAssert.java b/spring-ai-core/src/main/java/org/springframework/ai/util/PromptAssert.java new file mode 100644 index 0000000000..6ebe362c12 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/PromptAssert.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.util.Assert; + +/** + * Assertion utility class that assists in validating arguments for prompt-related + * operations. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class PromptAssert { + + private PromptAssert() { + } + + /** + * Assert that the given prompt template contains the required placeholders. + * @param promptTemplate the prompt template to check + * @param placeholders the placeholders that must be present in the prompt template + */ + public static void templateHasRequiredPlaceholders(PromptTemplate promptTemplate, String... placeholders) { + Assert.notNull(promptTemplate, "promptTemplate cannot be null"); + Assert.notEmpty(placeholders, "placeholders cannot be null or empty"); + + List missingPlaceholders = new ArrayList<>(); + for (String placeholder : placeholders) { + if (!promptTemplate.getTemplate().contains(placeholder)) { + missingPlaceholders.add(placeholder); + } + } + + if (!missingPlaceholders.isEmpty()) { + throw new IllegalArgumentException("The following placeholders must be present in the prompt template: %s" + .formatted(String.join(",", missingPlaceholders))); + } + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java index 74a18cd31c..56c5ea81cc 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java @@ -94,18 +94,23 @@ void theOneWithTheDocumentRetriever() { var prompt = promptCaptor.getValue(); assertThat(prompt.getContents()).contains(""" - What would I get if I added a pinch of Moonstone to a dash of powdered Gold? - - Context information is below. Use this information to answer the user query. + Context information is below. --------------------- doc1 doc2 --------------------- - Given the context and provided history information and not prior knowledge, - reply to the user query. If the answer is not in the context, inform - the user that you can't answer the query. + Given the context information and no prior knowledge, answer the query. + + Follow these rules: + + 1. If the answer is not in the context, just say that you don't know. + 2. Avoid statements like "Based on the context..." or "The provided information...". + + Query: What would I get if I added a pinch of Moonstone to a dash of powdered Gold? + + Answer: """); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java new file mode 100644 index 0000000000..7d9cb112f3 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java @@ -0,0 +1,96 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.augmentation; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ContextualQueryAugmentor}. + * + * @author Thomas Vitale + */ +class ContextualQueryAugmentorTests { + + @Test + void whenQueryIsNullThenThrow() { + QueryAugmentor augmenter = ContextualQueryAugmentor.builder().build(); + assertThatThrownBy(() -> augmenter.augment(null, List.of())).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("query cannot be null"); + } + + @Test + void whenDocumentsIsNullThenThrow() { + QueryAugmentor augmentor = ContextualQueryAugmentor.builder().build(); + Query query = new Query("test query"); + assertThatThrownBy(() -> augmentor.augment(query, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("documents cannot be null"); + } + + @Test + void whenDocumentsIsEmptyAndAllowEmptyContextThenReturnOriginalQuery() { + QueryAugmentor augmentor = ContextualQueryAugmentor.builder().build(); + Query query = new Query("test query"); + Query augmentedQuery = augmentor.augment(query, List.of()); + assertThat(augmentedQuery).isEqualTo(query); + } + + @Test + void whenDocumentsIsEmptyAndNotAllowEmptyContextThenReturnAugmentedQueryWithCustomTemplate() { + PromptTemplate emptyContextPromptTemplate = new PromptTemplate("No context available."); + QueryAugmentor augmentor = ContextualQueryAugmentor.builder() + .allowEmptyContext(false) + .emptyContextPromptTemplate(emptyContextPromptTemplate) + .build(); + Query query = new Query("test query"); + Query augmentedQuery = augmentor.augment(query, List.of()); + assertThat(augmentedQuery.text()).isEqualTo(emptyContextPromptTemplate.getTemplate()); + } + + @Test + void whenDocumentsAreProvidedThenReturnAugmentedQueryWithCustomTemplate() { + PromptTemplate promptTemplate = new PromptTemplate(""" + Context: + {context} + + Query: + {query} + """); + QueryAugmentor augmentor = ContextualQueryAugmentor.builder().promptTemplate(promptTemplate).build(); + Query query = new Query("test query"); + List documents = List.of(new Document("content1", Map.of()), new Document("content2", Map.of())); + Query augmentedQuery = augmentor.augment(query, documents); + assertThat(augmentedQuery.text()).isEqualTo(""" + Context: + content1 + content2 + + Query: + test query + """); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/util/PromptAssertTests.java b/spring-ai-core/src/test/java/org/springframework/ai/util/PromptAssertTests.java new file mode 100644 index 0000000000..54b4b3bd1d --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/util/PromptAssertTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.prompt.PromptTemplate; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link PromptAssert}. + * + * @author Thomas Vitale + */ +class PromptAssertTests { + + @Test + void whenPlaceholderIsPresentThenOk() { + var promptTemplate = new PromptTemplate("Hello, {name}!"); + PromptAssert.templateHasRequiredPlaceholders(promptTemplate, "{name}"); + } + + @Test + void whenPlaceholderIsPresentThenThrow() { + PromptTemplate promptTemplate = new PromptTemplate("Hello, {name}!"); + assertThatThrownBy(() -> PromptAssert.templateHasRequiredPlaceholders(promptTemplate, "{name}", "{age}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("age"); + } + + @Test + void whenPromptTemplateIsNullThenThrow() { + assertThatThrownBy(() -> PromptAssert.templateHasRequiredPlaceholders(null, "{name}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("promptTemplate cannot be null"); + } + + @Test + void whenPlaceholdersIsNullThenThrow() { + assertThatThrownBy(() -> PromptAssert.templateHasRequiredPlaceholders(new PromptTemplate("{query}"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("placeholders cannot be null or empty"); + } + + @Test + void whenPlaceholdersIsEmptyThenThrow() { + assertThatThrownBy(() -> PromptAssert.templateHasRequiredPlaceholders(new PromptTemplate("{query}"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("placeholders cannot be null or empty"); + } + +}