Skip to content

Commit

Permalink
Add GPT3 prompt truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
rizerphe committed May 4, 2023
1 parent 2e41299 commit 199f851
Showing 1 changed file with 45 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import React from "react";
import { Completer, Model, Prompt } from "../../complete";
import available_models from "./models.json";
import {
Expand All @@ -6,13 +7,53 @@ import {
parse_settings,
} from "./provider_settings";
import { Configuration, OpenAIApi } from "openai";
import SettingsItem from "../../../components/SettingsItem";
import { z } from "zod";

export const model_settings_schema = z.object({
context_length: z.number().int().positive(),
});
export type ModelSettings = z.infer<typeof model_settings_schema>;
const parse_model_settings = (settings: string): ModelSettings => {
try {
return model_settings_schema.parse(JSON.parse(settings));
} catch (e) {
return {
context_length: 4000,
};
}
};

export default class OpenAIModel implements Model {
id: string;
name: string;
description: string;

provider_settings: Settings;
Settings = ({
settings,
saveSettings,
}: {
settings: string | null;
saveSettings: (settings: string) => void;
}) => (
<SettingsItem
name="Context length"
description="In characters, how much context should the model get"
>
<input
type="number"
value={parse_model_settings(settings || "").context_length}
onChange={(e) =>
saveSettings(
JSON.stringify({
context_length: parseInt(e.target.value),
})
)
}
/>
</SettingsItem>
);

constructor(
id: string,
Expand All @@ -26,15 +67,17 @@ export default class OpenAIModel implements Model {
this.provider_settings = parse_settings(provider_settings);
}

async complete(prompt: Prompt): Promise<string> {
async complete(prompt: Prompt, settings: string): Promise<string> {
const parsed_settings = parse_model_settings(settings);
const config = new Configuration({
apiKey: this.provider_settings.api_key,
});
const api = new OpenAIApi(config);

const response = await api.createCompletion({
model: this.id,
prompt: prompt.prefix,
prompt: prompt.prefix.slice(-parsed_settings.context_length),
max_tokens: 64,
});

if (response.status === 401) {
Expand Down

0 comments on commit 199f851

Please sign in to comment.