Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow request option pass through #164

Merged
merged 7 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/curly-ants-tie.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

adding request option pass through + handling non validation errors a little bit better and not retrying if not validation error specifically
2 changes: 1 addition & 1 deletion docs/concepts/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ A follow-up meeting is scheduled for January 25th at 3 PM GMT to finalize the ag

const extractionStream = await client.chat.completions.create({
messages: [{ role: "user", content: textBlock }],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: {
schema: ExtractionValuesSchema,
name: "value extraction"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/action_items.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const extractActionItems = async (data: string): Promise<ActionItems | undefined
"content": `Create the action items for the following transcript: ${data}`,
},
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: ActionItemsSchema },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/query_decomposition.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ const createQueryPlan = async (question: string): Promise<QueryPlan | undefined>
"content": `Consider: ${question}\nGenerate the correct query plan.`,
},
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: QueryPlanSchema },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/action_items/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const extractActionItems = async (data: string) => {
content: `Create the action items for the following transcript: ${data}`
}
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: ActionItemsSchema, name: "ActionItems" },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/extract_user_stream/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ let extraction = {}

const extractionStream = await client.chat.completions.create({
messages: [{ role: "user", content: textBlock }],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: {
schema: ExtractionValuesSchema,
name: "value extraction"
Expand Down
7 changes: 3 additions & 4 deletions examples/llm-validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ const openAi = new OpenAI({ apiKey: process.env.OPENAI_API_KEY ?? "" })

const instructor = Instructor({
client: openAi,
mode: "TOOLS",
debug: true
mode: "TOOLS"
})

const statement = "Do not say questionable things"
Expand All @@ -17,7 +16,7 @@ const QuestionAnswer = z.object({
question: z.string(),
answer: z.string().superRefine(
LLMValidator(instructor, statement, {
model: "gpt-4"
model: "gpt-4-turbo"
})
)
})
Expand All @@ -26,7 +25,7 @@ const question = "What is the meaning of life?"

const check = async (context: string) => {
return await instructor.chat.completions.create({
model: "gpt-3.5-turbo",
model: "gpt-4-turbo",
max_retries: 2,
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
messages: [
Expand Down
2 changes: 1 addition & 1 deletion examples/query_decomposition/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const createQueryPlan = async (question: string) => {
content: `Consider: ${question}\nGenerate the correct query plan.`
}
],
model: "gpt-4-1106-preview",
model: "gpt-4-turbo",
response_model: { schema: QueryPlanSchema, name: "Query Plan Decomposition" },
max_tokens: 1000,
temperature: 0.0,
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@instructor-ai/instructor",
"version": "1.1.2",
"version": "1.1.1",
"description": "structured outputs for llms",
"publishConfig": {
"access": "public"
Expand Down
26 changes: 17 additions & 9 deletions src/constants/providers.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { omit } from "@/lib"
import OpenAI from "openai"
import { z } from "zod"
import { MODE, withResponseModel, type Mode } from "zod-stream"
import { withResponseModel, MODE as ZMODE, type Mode } from "zod-stream"

export const MODE = ZMODE
export const PROVIDERS = {
OAI: "OAI",
ANYSCALE: "ANYSCALE",
Expand All @@ -11,7 +12,6 @@ export const PROVIDERS = {
GROQ: "GROQ",
OTHER: "OTHER"
} as const

export type Provider = keyof typeof PROVIDERS

export const PROVIDER_SUPPORTED_MODES: {
Expand All @@ -34,6 +34,19 @@ export const NON_OAI_PROVIDER_URLS = {
} as const

export const PROVIDER_PARAMS_TRANSFORMERS = {
[PROVIDERS.GROQ]: {
[MODE.TOOLS]: function groqToolsParamsTransformer<
T extends z.AnyZodObject,
P extends OpenAI.ChatCompletionCreateParams
>(params: ReturnType<typeof withResponseModel<T, "TOOLS", P>>) {
if (params.tools.some(tool => tool) && params.stream) {
console.warn("Streaming may not be supported when using tools in Groq, try MD_JSON instead")
return params
}

return params
}
},
[PROVIDERS.ANYSCALE]: {
[MODE.JSON_SCHEMA]: function removeAdditionalPropertiesKeyJSONSchema<
T extends z.AnyZodObject,
Expand Down Expand Up @@ -90,12 +103,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[PROVIDERS.OAI]: {
[MODE.FUNCTIONS]: ["*"],
[MODE.TOOLS]: ["*"],
[MODE.JSON]: [
"gpt-3.5-turbo-1106",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview"
],
[MODE.JSON]: ["gpt-3.5-turbo-1106", "gpt-4-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview"],
[MODE.MD_JSON]: ["*"]
},
[PROVIDERS.TOGETHER]: {
Expand Down Expand Up @@ -124,7 +132,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[MODE.TOOLS]: ["*"]
},
[PROVIDERS.GROQ]: {
[MODE.TOOLS]: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"],
[MODE.TOOLS]: ["mixtral-8x7b-32768", "gemma-7b-it"],
[MODE.MD_JSON]: ["*"]
}
}
8 changes: 1 addition & 7 deletions src/dsl/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,9 @@ export const LLMValidator = <C extends GenericClient | OpenAI>(
}
}

export const moderationValidator = <C extends GenericClient | OpenAI>(
client: InstructorClient<C>
) => {
export const moderationValidator = (client: InstructorClient<OpenAI>) => {
return async (value: string, ctx: z.RefinementCtx) => {
try {
if (!(client instanceof OpenAI)) {
throw new Error("ModerationValidator only supports OpenAI clients")
}

const response = await client.moderations.create({ input: value })
const flaggedResults = response.results.filter(result => result.flagged)

Expand Down
89 changes: 60 additions & 29 deletions src/instructor.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
ChatCompletionCreateParamsWithModel,
ClientTypeChatCompletionRequestOptions,
GenericChatCompletion,
GenericClient,
InstructorConfig,
Expand All @@ -8,7 +9,7 @@ import {
ReturnTypeBasedOnParams
} from "@/types"
import OpenAI from "openai"
import { z } from "zod"
import { z, ZodError } from "zod"
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
import { fromZodError } from "zod-validation-error"

Expand Down Expand Up @@ -102,11 +103,14 @@ class Instructor<C extends GenericClient | OpenAI> {
}
}

private async chatCompletionStandard<T extends z.AnyZodObject>({
max_retries = MAX_RETRIES_DEFAULT,
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>): Promise<z.infer<T>> {
private async chatCompletionStandard<T extends z.AnyZodObject>(
{
max_retries = MAX_RETRIES_DEFAULT,
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<z.infer<T>> {
let attempts = 0
let validationIssues = ""
let lastMessage: OpenAI.ChatCompletionMessageParam | null = null
Expand Down Expand Up @@ -147,13 +151,17 @@ class Instructor<C extends GenericClient | OpenAI> {

try {
if (this.client.chat?.completions?.create) {
const result = await this.client.chat.completions.create({
...resolvedParams,
stream: false
})
const result = await this.client.chat.completions.create(
{
...resolvedParams,
stream: false
},
requestOptions
)

completion = result as GenericChatCompletion<typeof result>
} else {
throw new Error("Unsupported client type")
throw new Error("Unsupported client type -- no completion method found.")
}
this.log("debug", "raw standard completion response: ", completion)
} catch (error) {
Expand All @@ -176,7 +184,17 @@ class Instructor<C extends GenericClient | OpenAI> {
const data = JSON.parse(parsedCompletion) as z.infer<T> & { _meta?: CompletionMeta }
return { ...data, _meta: { usage: completion?.usage ?? undefined } }
} catch (error) {
this.log("error", "failed to parse completion", parsedCompletion, this.mode)
this.log(
"error",
"failed to parse completion",
parsedCompletion,
this.mode,
"attempt: ",
attempts,
"max attempts: ",
max_retries
)

throw error
}
}
Expand All @@ -202,26 +220,38 @@ class Instructor<C extends GenericClient | OpenAI> {
throw new Error("Validation failed.")
}
}

return validation.data
} catch (error) {
if (!(error instanceof ZodError)) {
throw error
}

if (attempts < max_retries) {
this.log(
"debug",
`response model: ${response_model.name} - Retrying, attempt: `,
attempts
)

this.log(
"warn",
`response model: ${response_model.name} - Validation issues: `,
validationIssues
validationIssues,
" - Attempt: ",
attempts,
" - Max attempts: ",
max_retries
)

attempts++
return await makeCompletionCallWithRetries()
} else {
this.log(
"debug",
`response model: ${response_model.name} - Max attempts reached: ${attempts}`
)

this.log(
"error",
`response model: ${response_model.name} - Validation issues: `,
Expand All @@ -236,13 +266,10 @@ class Instructor<C extends GenericClient | OpenAI> {
return makeCompletionCallWithRetries()
}

private async chatCompletionStream<T extends z.AnyZodObject>({
max_retries,
response_model,
...params
}: ChatCompletionCreateParamsWithModel<T>): Promise<
AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>
> {
private async chatCompletionStream<T extends z.AnyZodObject>(
{ max_retries, response_model, ...params }: ChatCompletionCreateParamsWithModel<T>,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>> {
if (max_retries) {
this.log("warn", "max_retries is not supported for streaming completions")
}
Expand All @@ -269,10 +296,13 @@ class Instructor<C extends GenericClient | OpenAI> {
return streamClient.create({
completionPromise: async () => {
if (this.client.chat?.completions?.create) {
const completion = await this.client.chat.completions.create({
...completionParams,
stream: true
})
const completion = await this.client.chat.completions.create(
{
...completionParams,
stream: true
},
requestOptions
)

this.log("debug", "raw stream completion response: ", completion)

Expand Down Expand Up @@ -306,18 +336,19 @@ class Instructor<C extends GenericClient | OpenAI> {
P extends T extends z.AnyZodObject ? ChatCompletionCreateParamsWithModel<T>
: ClientTypeChatCompletionParams<OpenAILikeClient<C>> & { response_model: never }
>(
params: P
params: P,
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
): Promise<ReturnTypeBasedOnParams<typeof this.client, P>> => {
this.validateModelModeSupport(params)

if (this.isChatCompletionCreateParamsWithModel(params)) {
if (params.stream) {
return this.chatCompletionStream(params) as ReturnTypeBasedOnParams<
return this.chatCompletionStream(params, requestOptions) as ReturnTypeBasedOnParams<
typeof this.client,
P & { stream: true }
>
} else {
return this.chatCompletionStandard(params) as ReturnTypeBasedOnParams<
return this.chatCompletionStandard(params, requestOptions) as ReturnTypeBasedOnParams<
typeof this.client,
P
>
Expand All @@ -326,8 +357,8 @@ class Instructor<C extends GenericClient | OpenAI> {
if (this.client.chat?.completions?.create) {
const result =
this.isStandardStream(params) ?
await this.client.chat.completions.create(params)
: await this.client.chat.completions.create(params)
await this.client.chat.completions.create(params, requestOptions)
: await this.client.chat.completions.create(params, requestOptions)

return result as unknown as ReturnTypeBasedOnParams<OpenAILikeClient<C>, P>
} else {
Expand Down
9 changes: 7 additions & 2 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export type GenericCreateParams<M = unknown> = Omit<
[key: string]: unknown
}

export type GenericRequestOptions = Partial<OpenAI.RequestOptions> & {
[key: string]: unknown
}

export type GenericChatCompletion<T = unknown> = Partial<OpenAI.Chat.Completions.ChatCompletion> & {
[key: string]: unknown
choices?: T
Expand All @@ -43,15 +47,16 @@ export type GenericClient = {
export type ClientTypeChatCompletionParams<C> =
C extends OpenAI ? OpenAI.ChatCompletionCreateParams : GenericCreateParams

export type ClientTypeChatCompletionRequestOptions<C> =
C extends OpenAI ? OpenAI.RequestOptions : GenericRequestOptions

export type ClientType<C> =
C extends OpenAI ? "openai"
: C extends GenericClient ? "generic"
: never

export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient

export type SupportedInstructorClient = GenericClient | OpenAI

export type LogLevel = "debug" | "info" | "warn" | "error"

export type CompletionMeta = Partial<ZCompletionMeta> & {
Expand Down
Loading
Loading