Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
feat(assistant-function): change the function define and a test for a…
Browse files Browse the repository at this point in the history
…ssistant funxtion
  • Loading branch information
vacuityv committed Nov 20, 2023
1 parent 650d76b commit fb44232
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.theokanning.openai.assistants;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.Map;

/**
* @description:
* @author: vacuity
* @create: 2023-11-20 10:09
**/


@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class AssistantFunction {

private String description;

private String name;

private Map<String, Object> parameters;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.theokanning.openai.assistants;

import com.theokanning.openai.completion.chat.ChatFunction;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
Expand All @@ -17,5 +16,5 @@ public class Tool {
/**
* Function definition, only used if type is "function"
*/
ChatFunction function;
AssistantFunction function;
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
public class SubmitToolOutputsRequest {

@JsonProperty("tool_outputs")
private List<SubmitToolOutputRequestItem> tool_outputs;
private List<SubmitToolOutputRequestItem> toolOutputs;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package com.theokanning.openai.service;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.theokanning.openai.ListSearchParameters;
import com.theokanning.openai.OpenAiResponse;
import com.theokanning.openai.assistants.Assistant;
import com.theokanning.openai.assistants.AssistantFunction;
import com.theokanning.openai.assistants.AssistantRequest;
import com.theokanning.openai.assistants.AssistantToolsEnum;
import com.theokanning.openai.assistants.Tool;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.messages.Message;
import com.theokanning.openai.messages.MessageRequest;
import com.theokanning.openai.runs.RequiredAction;
import com.theokanning.openai.runs.Run;
import com.theokanning.openai.runs.RunCreateRequest;
import com.theokanning.openai.runs.RunStep;
import com.theokanning.openai.runs.SubmitToolOutputRequestItem;
import com.theokanning.openai.runs.SubmitToolOutputs;
import com.theokanning.openai.runs.SubmitToolOutputsRequest;
import com.theokanning.openai.runs.ToolCall;
import com.theokanning.openai.threads.Thread;
import com.theokanning.openai.threads.ThreadRequest;
import com.theokanning.openai.utils.TikTokensUtil;
import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

class AssistantFunctionTest {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token, Duration.ofMinutes(1));

@Test
void createRetrieveRun() throws JsonProcessingException {

ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class);
mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class);

String funcDef = "{\n" +
" \"type\": \"object\",\n" +
" \"properties\": {\n" +
" \"location\": {\n" +
" \"type\": \"string\",\n" +
" \"description\": \"The city and state, e.g. San Francisco, CA\"\n" +
" },\n" +
" \"unit\": {\n" +
" \"type\": \"string\",\n" +
" \"enum\": [\"celsius\", \"fahrenheit\"]\n" +
" }\n" +
" },\n" +
" \"required\": [\"location\"]\n" +
"}";
Map<String, Object> funcParameters = mapper.readValue(funcDef, new TypeReference<Map<String, Object>>() {});
AssistantFunction function = AssistantFunction.builder()
.name("weather_reporter")
.description("Get the current weather of a location")
.parameters(funcParameters)
.build();

List<Tool> toolList = new ArrayList<>();
Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function);
toolList.add(funcTool);


AssistantRequest assistantRequest = AssistantRequest.builder()
.model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName())
.name("MATH_TUTOR")
.instructions("You are a personal Math Tutor.")
.tools(toolList)
.build();
Assistant assistant = service.createAssistant(assistantRequest);

ThreadRequest threadRequest = ThreadRequest.builder()
.build();
Thread thread = service.createThread(threadRequest);

MessageRequest messageRequest = MessageRequest.builder()
.content("What's the weather of Xiamen?")
.build();

Message message = service.createMessage(thread.getId(), messageRequest);

RunCreateRequest runCreateRequest = RunCreateRequest.builder()
.assistantId(assistant.getId())
.build();

Run run = service.createRun(thread.getId(), runCreateRequest);
assertNotNull(run);

Run retrievedRun = service.retrieveRun(thread.getId(), run.getId());
while (!(retrievedRun.getStatus().equals("completed"))
&& !(retrievedRun.getStatus().equals("failed"))
&& !(retrievedRun.getStatus().equals("requires_action"))){
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
}
if (retrievedRun.getStatus().equals("requires_action")) {
RequiredAction requiredAction = retrievedRun.getRequiredAction();
System.out.println("requiredAction");
System.out.println(mapper.writeValueAsString(requiredAction));
List<ToolCall> toolCalls = requiredAction.getSubmitToolOutputs().getToolCalls();
ToolCall toolCall = toolCalls.get(0);
String toolCallId = toolCall.getId();

SubmitToolOutputRequestItem toolOutputRequestItem = SubmitToolOutputRequestItem.builder()
.toolCallId(toolCallId)
.output("sunny")
.build();
List<SubmitToolOutputRequestItem> toolOutputRequestItems = new ArrayList<>();
toolOutputRequestItems.add(toolOutputRequestItem);
SubmitToolOutputsRequest submitToolOutputsRequest = SubmitToolOutputsRequest.builder()
.toolOutputs(toolOutputRequestItems)
.build();
retrievedRun = service.submitToolOutputs(retrievedRun.getThreadId(), retrievedRun.getId(), submitToolOutputsRequest);

while (!(retrievedRun.getStatus().equals("completed"))
&& !(retrievedRun.getStatus().equals("failed"))
&& !(retrievedRun.getStatus().equals("requires_action"))){
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
}

OpenAiResponse<Message> response = service.listMessages(thread.getId());

List<Message> messages = response.getData();

System.out.println(mapper.writeValueAsString(messages));

}
}
}

0 comments on commit fb44232

Please sign in to comment.