Skip to content

Commit

Permalink
add json_schema
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani committed Oct 26, 2024
1 parent b734f82 commit 7880908
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
}

static void SetJsonSchema(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_json_schema", "JSON Schema");
}

// Main Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() >= 1); // At least prompt required
Expand All @@ -142,6 +146,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
std::string json_schema = GetConfigValue(context, "openprompt_json_schema", "");

// Override model if provided as second argument
if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) {
Expand All @@ -151,7 +156,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
if (!json_schema.empty()) {
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant. Summarize and Output JSON format (without any omissions): " + json_schema + "\"},";
} else {
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
}
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";

Expand All @@ -167,11 +176,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
}

auto res = client.Post(path.c_str(), headers, request_body, "application/json");

if (!res) {
HandleHttpError(res, "POST");
}

if (res->status != 200) {
throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason);
}
Expand All @@ -181,7 +190,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0),
&duckdb_yyjson::yyjson_doc_free
);

if (!doc) {
throw std::runtime_error("Failed to parse JSON response");
}
Expand Down Expand Up @@ -246,6 +255,8 @@ static void LoadInternal(DatabaseInstance &instance) {
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_json_schema", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetJsonSchema));
}

void OpenPromptExtension::Load(DuckDB &db) {
Expand Down

0 comments on commit 7880908

Please sign in to comment.