Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-oai usage meta + non-oai client types #182

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/light-chefs-clean.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

update client types to better support non oai clients + updates to allow for passing usage properties into meta from non-oai clients
Binary file modified bun.lockb
Binary file not shown.
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"zod": ">=3.22.4"
},
"devDependencies": {
"@anthropic-ai/sdk": "latest",
"@anthropic-ai/sdk": "0.22.0",
"@changesets/changelog-github": "^0.5.0",
"@changesets/cli": "^2.27.1",
"@ianvs/prettier-plugin-sort-imports": "4.1.0",
Expand All @@ -75,8 +75,8 @@
"eslint-plugin-only-warn": "^1.1.0",
"eslint-plugin-prettier": "^5.1.2",
"husky": "^8.0.3",
"llm-polyglot": "1.0.0",
"openai": "latest",
"llm-polyglot": "2.0.0",
"openai": "4.50.0",
"prettier": "latest",
"ts-inference-check": "^0.3.0",
"tsup": "^8.0.1",
Expand Down
57 changes: 50 additions & 7 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ import {
PROVIDER_SUPPORTED_MODES_BY_MODEL,
PROVIDERS
} from "./constants/providers"
import { iterableTee } from "./lib"
import { ClientTypeChatCompletionParams, CompletionMeta } from "./types"

const MAX_RETRIES_DEFAULT = 0

class Instructor<C extends GenericClient | OpenAI> {
class Instructor<C> {
readonly client: OpenAILikeClient<C>
readonly mode: Mode
readonly provider: Provider
Expand All @@ -46,7 +47,17 @@ class Instructor<C extends GenericClient | OpenAI> {
logger = undefined,
retryAllErrors = false
}: InstructorConfig<C>) {
this.client = client
if (!isGenericClient(client) && !(client instanceof OpenAI)) {
throw new Error("Client does not match the required structure")
}

if (client instanceof OpenAI) {
this.client = client as OpenAI
} else {
this.client = client as C & GenericClient
}

// this.client = client
this.mode = mode
this.debug = debug
this.retryAllErrors = retryAllErrors
Expand Down Expand Up @@ -308,7 +319,9 @@ class Instructor<C extends GenericClient | OpenAI> {
debug: this.debug ?? false
})

async function checkForUsage(reader: Stream<OpenAI.ChatCompletionChunk>) {
async function checkForUsage(
reader: Stream<OpenAI.ChatCompletionChunk> | AsyncIterable<OpenAI.ChatCompletionChunk>
) {
for await (const chunk of reader) {
if ("usage" in chunk) {
streamUsage = chunk.usage as CompletionMeta["usage"]
Expand Down Expand Up @@ -345,6 +358,24 @@ class Instructor<C extends GenericClient | OpenAI> {
})
}

//check if async iterator
if (
this.provider !== "OAI" &&
completionParams?.stream &&
completion?.[Symbol.asyncIterator]
) {
const [completion1, completion2] = await iterableTee(
completion as AsyncIterable<OpenAI.ChatCompletionChunk>,
2
)

checkForUsage(completion1)

return OAIStream({
res: completion2
})
}

return OAIStream({
res: completion as unknown as AsyncIterable<OpenAI.ChatCompletionChunk>
})
Expand Down Expand Up @@ -419,7 +450,7 @@ class Instructor<C extends GenericClient | OpenAI> {
}
}

export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> & OpenAILikeClient<C>
export type InstructorClient<C> = Instructor<C> & OpenAILikeClient<C>

/**
* Creates an instance of the `Instructor` class.
Expand All @@ -442,9 +473,7 @@ export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> &
* @param args
* @returns
*/
export default function createInstructor<C extends GenericClient | OpenAI>(
args: InstructorConfig<C>
): InstructorClient<C> {
export default function createInstructor<C>(args: InstructorConfig<C>): InstructorClient<C> {
const instructor = new Instructor<C>(args)
const instructorWithProxy = new Proxy(instructor, {
get: (target, prop, receiver) => {
Expand All @@ -458,3 +487,17 @@ export default function createInstructor<C extends GenericClient | OpenAI>(

return instructorWithProxy as InstructorClient<C>
}
//eslint-disable-next-line @typescript-eslint/no-explicit-any
function isGenericClient(client: any): client is GenericClient {
return (
typeof client === "object" &&
client !== null &&
"baseURL" in client &&
"chat" in client &&
typeof client.chat === "object" &&
"completions" in client.chat &&
typeof client.chat.completions === "object" &&
"create" in client.chat.completions &&
typeof client.chat.completions.create === "function"
)
}
42 changes: 42 additions & 0 deletions src/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,45 @@ export function omit<T extends object, K extends keyof T>(keys: K[], obj: T): Om
}
return result
}

export async function iterableTee<T>(
iterable: AsyncIterable<T>,
n: number
): Promise<AsyncGenerator<T>[]> {
const buffers: T[][] = Array.from({ length: n }, () => [])
const resolvers: (() => void)[] = []
const iterator = iterable[Symbol.asyncIterator]()
let done = false

async function* reader(index: number) {
while (true) {
if (buffers[index].length > 0) {
yield buffers[index].shift()!
} else if (done) {
break
} else {
await new Promise<void>(resolve => resolvers.push(resolve))
}
}
}

;(async () => {
for await (const item of {
[Symbol.asyncIterator]: () => iterator
}) {
for (const buffer of buffers) {
buffer.push(item)
}

while (resolvers.length > 0) {
resolvers.shift()!()
}
}
done = true
while (resolvers.length > 0) {
resolvers.shift()!()
}
})()

return Array.from({ length: n }, (_, i) => reader(i))
}
6 changes: 3 additions & 3 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export type GenericClient = {
baseURL?: string
chat?: {
completions?: {
create?: (params: GenericCreateParams) => Promise<unknown>
create?: <P extends GenericCreateParams>(params: P) => Promise<unknown>
}
}
}
Expand All @@ -55,7 +55,7 @@ export type ClientType<C> =
: C extends GenericClient ? "generic"
: never

export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient
export type OpenAILikeClient<C> = OpenAI | (C & GenericClient)
export type SupportedInstructorClient = GenericClient | OpenAI
export type LogLevel = "debug" | "info" | "warn" | "error"

Expand All @@ -68,7 +68,7 @@ export type Mode = ZMode
export type ResponseModel<T extends z.AnyZodObject> = ZResponseModel<T>

export interface InstructorConfig<C> {
client: OpenAILikeClient<C>
client: C
mode: Mode
debug?: boolean
logger?: <T extends unknown[]>(level: LogLevel, ...args: T) => void
Expand Down
33 changes: 20 additions & 13 deletions tests/anthropic.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ describe("LLMClient Anthropic Provider - mode: TOOLS", () => {
})
})

describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
describe("LLMClient Anthropic Provider - mode: TOOLS - stream", () => {
const instructor = Instructor({
client: anthropicClient,
mode: "MD_JSON"
mode: "TOOLS"
})

test("basic completion", async () => {
const completion = await instructor.chat.completions.create({
model: "claude-3-sonnet-20240229",
stream: true,
max_tokens: 1000,
messages: [
{
Expand All @@ -135,17 +136,24 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
}
],
response_model: {
name: "get_name",
name: "extract_name",
schema: z.object({
name: z.string()
})
}
})

expect(omit(["_meta"], completion)).toEqual({ name: "Dimitri Kennedy" })
let final = {}

for await (const result of completion) {
final = result
}

//@ts-expect-error ignore for testing
expect(omit(["_meta"], final)).toEqual({ name: "Dimitri Kennedy" })
})

test("complex schema - streaming", async () => {
test("complex schema", async () => {
const completion = await instructor.chat.completions.create({
model: "claude-3-sonnet-20240229",
max_tokens: 1000,
Expand Down Expand Up @@ -173,14 +181,15 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
Programming
Leadership
Communication


`
}
],
response_model: {
name: "process_user_data",
schema: z.object({
story: z
.string()
.describe("A long and mostly made up story about the user - minimum 500 words"),
userDetails: z.object({
firstName: z.string(),
lastName: z.string(),
Expand All @@ -196,21 +205,19 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
years: z.number().optional()
})
),
skills: z.array(z.string()),
summaryOfWorldWarOne: z
.string()
.describe("A detailed summary of World War One and its major events - min 500 words")
skills: z.array(z.string())
})
}
})

let final = {}

for await (const result of completion) {
final = result
}

//@ts-expect-error - lazy
expect(omit(["_meta", "summaryOfWorldWarOne"], final)).toEqual({
//@ts-expect-error ignore for testing
expect(omit(["_meta", "story"], final)).toEqual({
userDetails: {
firstName: "John",
lastName: "Doe",
Expand Down
1 change: 0 additions & 1 deletion tests/stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ async function extractUser() {
let extraction: Extraction = {}

for await (const result of extractionStream) {
console.log(result)
try {
extraction = result
expect(result).toHaveProperty("users")
Expand Down
Loading