Skip to content

Commit

Permalink
Merge pull request #44 from WSE-research/threeshot_input_generative
Browse files Browse the repository at this point in the history
Added three-shot (input) and zero-shot (output) generative support
  • Loading branch information
dschiese authored Jun 17, 2024
2 parents 7e69ce7 + f62d7f8 commit 7a8e8c2
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
</parent>
<groupId>com.wse</groupId>
<artifactId>qanary-explanation-service</artifactId>
<version>3.4.1</version>
<version>3.5.0</version>
<name>Qanary explanation service</name>
<description>Webservice for rule-based explanation of QA-Systems as well as specific components</description>
<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public QanaryResponseObject executeQanaryPipeline(QanaryRequestObject qanaryRequ
MultiValueMap<String, String> multiValueMap = new LinkedMultiValueMap();
multiValueMap.add("question", qanaryRequestObject.getQuestion());
multiValueMap.addAll(qanaryRequestObject.getComponentListAsMap());

return webClient.post().uri(uriBuilder -> uriBuilder // TODO: use new endpoint for question answering
.scheme("http").host(QANARY_PIPELINE_HOST).port(QANARY_PIPELINE_PORT).path("/startquestionansweringwithtextquestion")
.queryParams(multiValueMap)
Expand All @@ -75,6 +74,7 @@ public void initConnection() {
}

public String getQuestionFromQuestionId(String questionId) {
logger.info("Get question from url: {}", questionId);
return webClient.get().uri(questionId).retrieve().bodyToMono(String.class).block();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public class GenerativeExplanations {
put("^^http://www.w3.org/2001/XMLSchema#double", "");
}};
private static final Map<Integer, String> EXAMPLE_COUNT_AND_TEMPLATE = new HashMap<>() {{
put(0, "/prompt_templates/outputdata/zeroshot");
put(1, "/prompt_templates/outputdata/oneshot");
put(2, "/prompt_templates/outputdata/twoshot");
put(3, "/prompt_templates/outputdata/threeshot");
Expand All @@ -79,6 +80,7 @@ public class GenerativeExplanations {
put(1, "/prompt_templates/inputdata/oneshot");
put(2, "/prompt_templates/inputdata/twoshot");
put(0, "/prompt_templates/inputdata/zeroshot");
put(3, "/prompt_templates/inputdata/threeshot");
}};
private static final String EXPLANATION_NAMESPACE = "urn:qanary:explanations#";
private static final String QUESTION_QUERY = "/queries/random_question_query.rq";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ public String createPrompt(int shots, GenerativeExplanationObject generativeExpl
generativeExplanations.getPromptTemplate(shots)
);

logger.info("Shots {} and Object {}", shots, generativeExplanationObject.getExampleComponents().get(0).getExplanation());

prompt = prompt.replace("<TASK_RDF_DATA_TEST>", generativeExplanationObject.getTestComponent().getDataSet());

ArrayList<TestDataObject> testDataObjects = generativeExplanationObject.getExampleComponents();
Expand Down Expand Up @@ -194,14 +192,17 @@ public String sendPrompt(String prompt, GptModel gptModel) throws Exception {
*/
public String getInputDataExplanationPrompt(QanaryComponent component, String body, int shots) throws Exception {
String prompt = getStringFromFile(generativeExplanations.getPromptTemplateInputData(shots));

prompt = prompt.replace("${QUERY}", body).replace("${COMPONENT}", component.getPrefixedComponentName());
if (shots > 0) {
InputQueryExample inputQueryExample = GenerativeExplanations.INPUT_QUERIES_AND_EXAMPLE.get(random.nextInt(GenerativeExplanations.INPUT_QUERIES_AND_EXAMPLE.size()));
prompt = prompt.replace("${ZEROSHOT_QUERY}", inputQueryExample.getQuery()).replace("${ZEROSHOT_EXPLANATION", inputQueryExample.getExplanations()); // select random pre-defined statements
if (shots > 1) {
InputQueryExample inputQueryExample2 = GenerativeExplanations.INPUT_QUERIES_AND_EXAMPLE.get(random.nextInt(GenerativeExplanations.INPUT_QUERIES_AND_EXAMPLE.size()));
prompt = prompt.replace("${ONESHOT_QUERY}", inputQueryExample2.getQuery()).replace("${ONESHOT_EXPLANATION", inputQueryExample2.getExplanations()); // select random pre-defined statements
if (shots > 2) {
InputQueryExample inputQueryExample3 = GenerativeExplanations.INPUT_QUERIES_AND_EXAMPLE.get(random.nextInt(GenerativeExplanations.INPUT_QUERIES_AND_EXAMPLE.size()));
prompt = prompt.replace("${TWOSHOT_QUERY}", inputQueryExample3.getQuery()).replace("${TWOSHOT_EXPLANATION", inputQueryExample3.getExplanations()); // select random pre-defined statements
}
}
}
return prompt;
Expand Down
33 changes: 33 additions & 0 deletions src/main/resources/prompt_templates/inputdata/threeshot
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Given the following context: Here, we consider the data of a Question Answering system. The data describes a SPARQL query. As a user I'd like to understand what the query means and does. For this purpose a (text-based) explanation has to be computed.
Here's an example explanation:
The query:
```
"${ZEROSHOT_QUERY}"
```
The example explanation:
"${ZEROSHOT_EXPLANATION}"

Another example is the following:
The query:
```
${ONESHOT_QUERY}
```
The 2nd example explanation:
"${ONESHOT_EXPLANATION}"

And lastly an 3rd example:
The query
```
"${TWOSHOT_QUERY}"
```
The 3rd explanation:
```
"${TWOSHOT_EXPLANATION}"
```

Now explain the following query, used by the component "${COMPONENT}":
```
${QUERY}
```

Don't use more than 3 sentences.
12 changes: 12 additions & 0 deletions src/main/resources/prompt_templates/outputdata/zeroshot
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Given the following context:
Here, we consider the data of a Question Answering system.
The data describes the outcome of this system.
As a user I'd like to understand what happened inside that particular component.
For this purpose a (text-based) explanation has to be computed.

Given is the following data:
<TASK_RDF_DATA_TEST>

Now, create an explanation for this data.

Don't introduce your answer and only return the result.

0 comments on commit 7a8e8c2

Please sign in to comment.