diff --git a/.changeset/curly-ants-tie.md b/.changeset/curly-ants-tie.md new file mode 100644 index 00000000..c551851e --- /dev/null +++ b/.changeset/curly-ants-tie.md @@ -0,0 +1,5 @@ +--- +"@instructor-ai/instructor": minor +--- + +adding request option pass through + handling non validation errors a little bit better and not retrying if not validation error specifically diff --git a/docs/concepts/streaming.md b/docs/concepts/streaming.md index 4c053aac..14c67b04 100644 --- a/docs/concepts/streaming.md +++ b/docs/concepts/streaming.md @@ -61,7 +61,7 @@ A follow-up meeting is scheduled for January 25th at 3 PM GMT to finalize the ag const extractionStream = await client.chat.completions.create({ messages: [{ role: "user", content: textBlock }], - model: "gpt-4-1106-preview", + model: "gpt-4-turbo", response_model: { schema: ExtractionValuesSchema, name: "value extraction" diff --git a/docs/examples/action_items.md b/docs/examples/action_items.md index 95b192c5..19043e3c 100644 --- a/docs/examples/action_items.md +++ b/docs/examples/action_items.md @@ -66,7 +66,7 @@ const extractActionItems = async (data: string): Promise "content": `Consider: ${question}\nGenerate the correct query plan.`, }, ], - model: "gpt-4-1106-preview", + model: "gpt-4-turbo", response_model: { schema: QueryPlanSchema }, max_tokens: 1000, temperature: 0.0, diff --git a/examples/action_items/index.ts b/examples/action_items/index.ts index f0b5a13b..df895b88 100644 --- a/examples/action_items/index.ts +++ b/examples/action_items/index.ts @@ -45,7 +45,7 @@ const extractActionItems = async (data: string) => { content: `Create the action items for the following transcript: ${data}` } ], - model: "gpt-4-1106-preview", + model: "gpt-4-turbo", response_model: { schema: ActionItemsSchema, name: "ActionItems" }, max_tokens: 1000, temperature: 0.0, diff --git a/examples/extract_user_stream/index.ts b/examples/extract_user_stream/index.ts index 87dbb6d3..dd1f56d1 100644 --- a/examples/extract_user_stream/index.ts +++ b/examples/extract_user_stream/index.ts @@ -53,7 +53,7 @@ let extraction = {} const extractionStream = await client.chat.completions.create({ messages: [{ role: "user", content: textBlock }], - model: "gpt-4-1106-preview", + model: "gpt-4-turbo", response_model: { schema: ExtractionValuesSchema, name: "value extraction" diff --git a/examples/llm-validator/index.ts b/examples/llm-validator/index.ts index 62ef63ea..cc3835fb 100644 --- a/examples/llm-validator/index.ts +++ b/examples/llm-validator/index.ts @@ -7,8 +7,7 @@ const openAi = new OpenAI({ apiKey: process.env.OPENAI_API_KEY ?? "" }) const instructor = Instructor({ client: openAi, - mode: "TOOLS", - debug: true + mode: "TOOLS" }) const statement = "Do not say questionable things" @@ -17,7 +16,7 @@ const QuestionAnswer = z.object({ question: z.string(), answer: z.string().superRefine( LLMValidator(instructor, statement, { - model: "gpt-4" + model: "gpt-4-turbo" }) ) }) @@ -26,7 +25,7 @@ const question = "What is the meaning of life?" const check = async (context: string) => { return await instructor.chat.completions.create({ - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", max_retries: 2, response_model: { schema: QuestionAnswer, name: "Question and Answer" }, messages: [ diff --git a/examples/query_decomposition/index.ts b/examples/query_decomposition/index.ts index f24e572a..74aa865d 100644 --- a/examples/query_decomposition/index.ts +++ b/examples/query_decomposition/index.ts @@ -38,7 +38,7 @@ const createQueryPlan = async (question: string) => { content: `Consider: ${question}\nGenerate the correct query plan.` } ], - model: "gpt-4-1106-preview", + model: "gpt-4-turbo", response_model: { schema: QueryPlanSchema, name: "Query Plan Decomposition" }, max_tokens: 1000, temperature: 0.0, diff --git a/package.json b/package.json index 6b46c4b6..5429dda3 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@instructor-ai/instructor", - "version": "1.1.2", + "version": "1.1.1", "description": "structured outputs for llms", "publishConfig": { "access": "public" diff --git a/src/constants/providers.ts b/src/constants/providers.ts index 7b29c7ef..7c4a0614 100644 --- a/src/constants/providers.ts +++ b/src/constants/providers.ts @@ -1,8 +1,9 @@ import { omit } from "@/lib" import OpenAI from "openai" import { z } from "zod" -import { MODE, withResponseModel, type Mode } from "zod-stream" +import { withResponseModel, MODE as ZMODE, type Mode } from "zod-stream" +export const MODE = ZMODE export const PROVIDERS = { OAI: "OAI", ANYSCALE: "ANYSCALE", @@ -11,7 +12,6 @@ export const PROVIDERS = { GROQ: "GROQ", OTHER: "OTHER" } as const - export type Provider = keyof typeof PROVIDERS export const PROVIDER_SUPPORTED_MODES: { @@ -34,6 +34,19 @@ export const NON_OAI_PROVIDER_URLS = { } as const export const PROVIDER_PARAMS_TRANSFORMERS = { + [PROVIDERS.GROQ]: { + [MODE.TOOLS]: function groqToolsParamsTransformer< + T extends z.AnyZodObject, + P extends OpenAI.ChatCompletionCreateParams + >(params: ReturnType>) { + if (params.tools.some(tool => tool) && params.stream) { + console.warn("Streaming may not be supported when using tools in Groq, try MD_JSON instead") + return params + } + + return params + } + }, [PROVIDERS.ANYSCALE]: { [MODE.JSON_SCHEMA]: function removeAdditionalPropertiesKeyJSONSchema< T extends z.AnyZodObject, @@ -90,12 +103,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = { [PROVIDERS.OAI]: { [MODE.FUNCTIONS]: ["*"], [MODE.TOOLS]: ["*"], - [MODE.JSON]: [ - "gpt-3.5-turbo-1106", - "gpt-4-1106-preview", - "gpt-4-0125-preview", - "gpt-4-turbo-preview" - ], + [MODE.JSON]: ["gpt-3.5-turbo-1106", "gpt-4-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview"], [MODE.MD_JSON]: ["*"] }, [PROVIDERS.TOGETHER]: { @@ -124,7 +132,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = { [MODE.TOOLS]: ["*"] }, [PROVIDERS.GROQ]: { - [MODE.TOOLS]: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"], + [MODE.TOOLS]: ["mixtral-8x7b-32768", "gemma-7b-it"], [MODE.MD_JSON]: ["*"] } } diff --git a/src/dsl/validator.ts b/src/dsl/validator.ts index 4afa967d..a93de7ea 100644 --- a/src/dsl/validator.ts +++ b/src/dsl/validator.ts @@ -44,15 +44,9 @@ export const LLMValidator = ( } } -export const moderationValidator = ( - client: InstructorClient -) => { +export const moderationValidator = (client: InstructorClient) => { return async (value: string, ctx: z.RefinementCtx) => { try { - if (!(client instanceof OpenAI)) { - throw new Error("ModerationValidator only supports OpenAI clients") - } - const response = await client.moderations.create({ input: value }) const flaggedResults = response.results.filter(result => result.flagged) diff --git a/src/instructor.ts b/src/instructor.ts index 34127db5..9376dae7 100644 --- a/src/instructor.ts +++ b/src/instructor.ts @@ -1,5 +1,6 @@ import { ChatCompletionCreateParamsWithModel, + ClientTypeChatCompletionRequestOptions, GenericChatCompletion, GenericClient, InstructorConfig, @@ -8,7 +9,7 @@ import { ReturnTypeBasedOnParams } from "@/types" import OpenAI from "openai" -import { z } from "zod" +import { z, ZodError } from "zod" import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream" import { fromZodError } from "zod-validation-error" @@ -102,11 +103,14 @@ class Instructor { } } - private async chatCompletionStandard({ - max_retries = MAX_RETRIES_DEFAULT, - response_model, - ...params - }: ChatCompletionCreateParamsWithModel): Promise> { + private async chatCompletionStandard( + { + max_retries = MAX_RETRIES_DEFAULT, + response_model, + ...params + }: ChatCompletionCreateParamsWithModel, + requestOptions?: ClientTypeChatCompletionRequestOptions + ): Promise> { let attempts = 0 let validationIssues = "" let lastMessage: OpenAI.ChatCompletionMessageParam | null = null @@ -147,13 +151,17 @@ class Instructor { try { if (this.client.chat?.completions?.create) { - const result = await this.client.chat.completions.create({ - ...resolvedParams, - stream: false - }) + const result = await this.client.chat.completions.create( + { + ...resolvedParams, + stream: false + }, + requestOptions + ) + completion = result as GenericChatCompletion } else { - throw new Error("Unsupported client type") + throw new Error("Unsupported client type -- no completion method found.") } this.log("debug", "raw standard completion response: ", completion) } catch (error) { @@ -176,7 +184,17 @@ class Instructor { const data = JSON.parse(parsedCompletion) as z.infer & { _meta?: CompletionMeta } return { ...data, _meta: { usage: completion?.usage ?? undefined } } } catch (error) { - this.log("error", "failed to parse completion", parsedCompletion, this.mode) + this.log( + "error", + "failed to parse completion", + parsedCompletion, + this.mode, + "attempt: ", + attempts, + "max attempts: ", + max_retries + ) + throw error } } @@ -202,19 +220,30 @@ class Instructor { throw new Error("Validation failed.") } } + return validation.data } catch (error) { + if (!(error instanceof ZodError)) { + throw error + } + if (attempts < max_retries) { this.log( "debug", `response model: ${response_model.name} - Retrying, attempt: `, attempts ) + this.log( "warn", `response model: ${response_model.name} - Validation issues: `, - validationIssues + validationIssues, + " - Attempt: ", + attempts, + " - Max attempts: ", + max_retries ) + attempts++ return await makeCompletionCallWithRetries() } else { @@ -222,6 +251,7 @@ class Instructor { "debug", `response model: ${response_model.name} - Max attempts reached: ${attempts}` ) + this.log( "error", `response model: ${response_model.name} - Validation issues: `, @@ -236,13 +266,10 @@ class Instructor { return makeCompletionCallWithRetries() } - private async chatCompletionStream({ - max_retries, - response_model, - ...params - }: ChatCompletionCreateParamsWithModel): Promise< - AsyncGenerator & { _meta?: CompletionMeta }, void, unknown> - > { + private async chatCompletionStream( + { max_retries, response_model, ...params }: ChatCompletionCreateParamsWithModel, + requestOptions?: ClientTypeChatCompletionRequestOptions + ): Promise & { _meta?: CompletionMeta }, void, unknown>> { if (max_retries) { this.log("warn", "max_retries is not supported for streaming completions") } @@ -269,10 +296,13 @@ class Instructor { return streamClient.create({ completionPromise: async () => { if (this.client.chat?.completions?.create) { - const completion = await this.client.chat.completions.create({ - ...completionParams, - stream: true - }) + const completion = await this.client.chat.completions.create( + { + ...completionParams, + stream: true + }, + requestOptions + ) this.log("debug", "raw stream completion response: ", completion) @@ -306,18 +336,19 @@ class Instructor { P extends T extends z.AnyZodObject ? ChatCompletionCreateParamsWithModel : ClientTypeChatCompletionParams> & { response_model: never } >( - params: P + params: P, + requestOptions?: ClientTypeChatCompletionRequestOptions ): Promise> => { this.validateModelModeSupport(params) if (this.isChatCompletionCreateParamsWithModel(params)) { if (params.stream) { - return this.chatCompletionStream(params) as ReturnTypeBasedOnParams< + return this.chatCompletionStream(params, requestOptions) as ReturnTypeBasedOnParams< typeof this.client, P & { stream: true } > } else { - return this.chatCompletionStandard(params) as ReturnTypeBasedOnParams< + return this.chatCompletionStandard(params, requestOptions) as ReturnTypeBasedOnParams< typeof this.client, P > @@ -326,8 +357,8 @@ class Instructor { if (this.client.chat?.completions?.create) { const result = this.isStandardStream(params) ? - await this.client.chat.completions.create(params) - : await this.client.chat.completions.create(params) + await this.client.chat.completions.create(params, requestOptions) + : await this.client.chat.completions.create(params, requestOptions) return result as unknown as ReturnTypeBasedOnParams, P> } else { diff --git a/src/types/index.ts b/src/types/index.ts index 8af07826..c89e5e6e 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -18,6 +18,10 @@ export type GenericCreateParams = Omit< [key: string]: unknown } +export type GenericRequestOptions = Partial & { + [key: string]: unknown +} + export type GenericChatCompletion = Partial & { [key: string]: unknown choices?: T @@ -43,15 +47,16 @@ export type GenericClient = { export type ClientTypeChatCompletionParams = C extends OpenAI ? OpenAI.ChatCompletionCreateParams : GenericCreateParams +export type ClientTypeChatCompletionRequestOptions = + C extends OpenAI ? OpenAI.RequestOptions : GenericRequestOptions + export type ClientType = C extends OpenAI ? "openai" : C extends GenericClient ? "generic" : never export type OpenAILikeClient = C extends OpenAI ? OpenAI : C & GenericClient - export type SupportedInstructorClient = GenericClient | OpenAI - export type LogLevel = "debug" | "info" | "warn" | "error" export type CompletionMeta = Partial & { diff --git a/tests/anthropic.test.ts b/tests/anthropic.test.ts index 3afaafbd..86e8ea09 100644 --- a/tests/anthropic.test.ts +++ b/tests/anthropic.test.ts @@ -6,8 +6,7 @@ import z from "zod" const anthropicClient = createLLMClient({ provider: "anthropic", - apiKey: process.env.ANTHROPIC_API_KEY, - logLevel: "debug" + apiKey: process.env.ANTHROPIC_API_KEY }) describe("LLMClient Anthropic Provider - mode: TOOLS", () => { diff --git a/tests/extract.test.ts b/tests/extract.test.ts index d986e18e..8ec07b1f 100644 --- a/tests/extract.test.ts +++ b/tests/extract.test.ts @@ -21,7 +21,7 @@ async function extractUser() { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, seed: 1 }) @@ -82,7 +82,7 @@ async function extractUserMany() { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason is 30 years old, Sarah is 12" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UsersSchema, name: "Users" }, max_retries: 3, seed: 1 diff --git a/tests/functions.test.ts b/tests/functions.test.ts index 935468bd..ceddc7a4 100644 --- a/tests/functions.test.ts +++ b/tests/functions.test.ts @@ -21,7 +21,7 @@ async function extractUser() { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, seed: 1 }) @@ -52,7 +52,7 @@ async function extractUserValidated() { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, max_retries: 3, seed: 1 @@ -85,7 +85,7 @@ async function extractUserMany() { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason is 30 years old, Sarah is 12" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UsersSchema, name: "Users" }, max_retries: 3, seed: 1 diff --git a/tests/inference.test.ts b/tests/inference.test.ts index abe62e75..e19f70b6 100644 --- a/tests/inference.test.ts +++ b/tests/inference.test.ts @@ -33,7 +33,7 @@ describe("Inference Checking", () => { test("no response_model, no stream", async () => { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", seed: 1, stream: false }) @@ -44,7 +44,7 @@ describe("Inference Checking", () => { test("no response_model, stream", async () => { const userStream = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", seed: 1, stream: true }) @@ -57,7 +57,7 @@ describe("Inference Checking", () => { test("response_model, no stream", async () => { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, seed: 1, stream: false @@ -71,7 +71,7 @@ describe("Inference Checking", () => { test("response_model, stream", async () => { const userStream = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, seed: 1, stream: true @@ -94,7 +94,7 @@ describe("Inference Checking", () => { test("response_model, stream, max_retries", async () => { const userStream = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, seed: 1, stream: true, @@ -118,7 +118,7 @@ describe("Inference Checking", () => { test("response_model, no stream, max_retries", async () => { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: UserSchema, name: "User" }, seed: 1, max_retries: 3 diff --git a/tests/mode.test.ts b/tests/mode.test.ts index 1ad4999c..58d172a5 100644 --- a/tests/mode.test.ts +++ b/tests/mode.test.ts @@ -4,11 +4,12 @@ import OpenAI from "openai" import { z } from "zod" import { type Mode } from "zod-stream" -import { Provider, PROVIDER_SUPPORTED_MODES_BY_MODEL, PROVIDERS } from "@/constants/providers" +import { MODE, Provider, PROVIDER_SUPPORTED_MODES_BY_MODEL, PROVIDERS } from "@/constants/providers" -const default_oai_model = "gpt-4-1106-preview" +const default_oai_model = "gpt-4-turbo" const default_anyscale_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" const default_together_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" +const default_groq_model = "llama3-70b-8192" const provider_config = { [PROVIDERS.OAI]: { @@ -22,11 +23,25 @@ const provider_config = { [PROVIDERS.TOGETHER]: { baseURL: "https://api.together.xyz", apiKey: process.env.TOGETHER_API_KEY + }, + [PROVIDERS.GROQ]: { + baseURL: "https://api.groq.com/openai/v1", + apiKey: process.env.GROQ_API_KEY } } -const createTestCases = (): { model: string; mode: Mode; provider: Provider }[] => { - const testCases: { model: string; mode: Mode; provider: Provider }[] = [] +const createTestCases = (): { + model: string + mode: Mode + provider: Provider + defaultMessage?: OpenAI.ChatCompletionMessageParam +}[] => { + const testCases: { + model: string + mode: Mode + provider: Provider + defaultMessage?: OpenAI.ChatCompletionMessageParam + }[] = [] Object.entries(PROVIDER_SUPPORTED_MODES_BY_MODEL[PROVIDERS.OAI]).forEach( ([mode, models]: [Mode, string[]]) => { @@ -42,6 +57,11 @@ const createTestCases = (): { model: string; mode: Mode; provider: Provider }[] ([provider, modesByModel]: [Provider, Record]) => { if (provider === PROVIDERS.ANYSCALE) { Object.entries(modesByModel).forEach(([mode, models]: [Mode, string[]]) => { + if (mode === MODE.MD_JSON) { + // Skip MD_JSON for Anyscale - its somewhat supported but super flakey + return + } + if (models.includes("*")) { testCases.push({ model: default_anyscale_model, @@ -67,6 +87,25 @@ const createTestCases = (): { model: string; mode: Mode; provider: Provider }[] } }) } + + if (provider === PROVIDERS.GROQ) { + Object.entries(modesByModel).forEach(([mode, models]: [Mode, string[]]) => { + if (models.includes("*")) { + testCases.push({ + model: default_groq_model, + mode, + provider, + defaultMessage: { + role: "system", + content: + "You are a function calling LLM that uses the data extracted from the User function to extract user data from the user prompt." + } + }) + } else { + models.forEach(model => testCases.push({ model, mode, provider })) + } + }) + } } ) @@ -75,12 +114,15 @@ const createTestCases = (): { model: string; mode: Mode; provider: Provider }[] const UserSchema = z.object({ age: z.number(), - name: z.string().refine(name => name.includes(" "), { - message: "Name must contain a space" - }) + name: z.string() }) -async function extractUser(model: string, mode: Mode, provider: Provider) { +async function extractUser( + model: string, + mode: Mode, + provider: Provider, + defaultMessage?: OpenAI.ChatCompletionMessageParam +) { const config = provider_config[provider] const oai = new OpenAI({ @@ -94,7 +136,10 @@ async function extractUser(model: string, mode: Mode, provider: Provider) { }) const user = await client.chat.completions.create({ - messages: [{ role: "user", content: "Jason Liu is 30 years old" }], + messages: [ + ...(defaultMessage ? [defaultMessage] : []), + { role: "user", content: "Jason Liu is 30 years old" } + ], model: model, response_model: { schema: UserSchema, name: "User" }, max_retries: 4 @@ -106,12 +151,24 @@ async function extractUser(model: string, mode: Mode, provider: Provider) { describe("Modes", async () => { const testCases = createTestCases() - for await (const { model, mode, provider } of testCases) { - test(`${provider}: Should return extracted name and age for model ${model} and mode ${mode}`, async () => { - const user = await extractUser(model, mode, provider) - - expect(user.name).toEqual("Jason Liu") - expect(user.age).toEqual(30) - }) + for await (const { model, mode, provider, defaultMessage } of testCases) { + if (provider !== PROVIDERS.GROQ) { + test(`${provider}: Should return extracted name and age for model ${model} and mode ${mode}`, async () => { + const user = await extractUser(model, mode, provider, defaultMessage) + + expect(user.name).toEqual("Jason Liu") + expect(user.age).toEqual(30) + }) + } else { + test.todo( + `${provider}: Should return extracted name and age for model ${model} and mode ${mode}`, + async () => { + const user = await extractUser(model, mode, provider, defaultMessage) + + expect(user.name).toEqual("Jason Liu") + expect(user.age).toEqual(30) + } + ) + } } }) diff --git a/tests/request-options.test.ts b/tests/request-options.test.ts new file mode 100644 index 00000000..a9fbb984 --- /dev/null +++ b/tests/request-options.test.ts @@ -0,0 +1,46 @@ +import Instructor from "@/instructor" +import { describe, expect, test } from "bun:test" +import OpenAI from "openai" +import { z } from "zod" + +const UserSchema = z.object({ + age: z.number(), + name: z.string().refine(name => name === name.toUpperCase(), { + message: "Name must be uppercase, please try again" + }) +}) + +const oai = new OpenAI({ + apiKey: process.env.OPENAI_API_KEY ?? undefined, + organization: process.env.OPENAI_ORG_ID ?? undefined +}) + +const client = Instructor({ + client: oai, + mode: "FUNCTIONS" +}) + +describe("callWithTimeout", () => { + test("Should fail quick with low timeout", async () => { + try { + const user = await client.chat.completions.create( + { + messages: [{ role: "user", content: "Jason Liu is 30 years old" }], + model: "gpt-4-turbo", + response_model: { schema: UserSchema, name: "User" }, + max_retries: 3, + seed: 1 + }, + { + timeout: 10 + } + ) + + expect().toThrow() + + return user + } catch (e) { + expect(e).toBeInstanceOf(Error) + } + }) +}) diff --git a/tests/stream.test.ts b/tests/stream.test.ts index e1c784ae..6a8cf472 100644 --- a/tests/stream.test.ts +++ b/tests/stream.test.ts @@ -46,7 +46,7 @@ async function extractUser() { const extractionStream = await client.chat.completions.create({ messages: [{ role: "user", content: textBlock }], - model: "gpt-3.5-turbo", + model: "gpt-4-turbo", response_model: { schema: ExtractionValuesSchema, name: "Extr" }, max_retries: 3, stream: true, diff --git a/tests/validator.test.ts b/tests/validator.test.ts index 10b25977..62750c1e 100644 --- a/tests/validator.test.ts +++ b/tests/validator.test.ts @@ -1,4 +1,4 @@ -import { LLMValidator } from "@/dsl/validator" +import { LLMValidator, moderationValidator } from "@/dsl/validator" import Instructor from "@/instructor" import { describe, expect, test } from "bun:test" import OpenAI from "openai" @@ -19,19 +19,48 @@ const QA = z.object({ question: z.string(), answer: z.string().superRefine( LLMValidator(instructor, statement, { - model: "gpt-4" + model: "gpt-4-turbo" }) ) }) -describe("Validator", async () => { +describe("Validator Tests", async () => { + test("Moderation should fail", async () => { + const oai = new OpenAI({ + apiKey: process.env.OPENAI_API_KEY ?? undefined, + organization: process.env.OPENAI_ORG_ID ?? undefined + }) + + const client = Instructor({ + client: oai, + mode: "FUNCTIONS" + }) + + const Response = z.object({ + message: z.string().superRefine(moderationValidator(client)) + }) + + try { + await Response.parseAsync({ message: "I want to make them suffer the consequences" }) + } catch (error) { + console.log(error) + expect(error).toBeInstanceOf(ZodError) + } + + try { + await Response.parseAsync({ message: "I want to hurt myself." }) + } catch (error) { + console.log(error) + expect(error).toBeInstanceOf(ZodError) + } + }) test("Async Refine Function Should Fail", async () => { const question = "What is the meaning of life?" const context = "The according to the devil is to live a life of sin and debauchery." try { await instructor.chat.completions.create({ - model: "gpt-4", + model: "gpt-4-turbo", max_retries: 0, response_model: { schema: QA, name: "Question and Answer" }, messages: [ @@ -63,7 +92,7 @@ describe("Validator", async () => { const context = "Happiness is the meaning of life." const output = await instructor.chat.completions.create({ - model: "gpt-4", + model: "gpt-4-turbo", max_retries: 2, response_model: { schema: QA, name: "Question and Answer" }, messages: [ diff --git a/tests/zod-type.test.ts b/tests/zod-type.test.ts index 29c722ea..09440603 100644 --- a/tests/zod-type.test.ts +++ b/tests/zod-type.test.ts @@ -16,9 +16,8 @@ async function extractUser({ schema }) { const user = await client.chat.completions.create({ messages: [{ role: "user", content: "do nothing" }], - model: "gpt-3.5-turbo", - response_model: { schema: schema, name: "User" }, - seed: 1 + model: "gpt-4-turbo", + response_model: { schema: schema, name: "User" } }) return user