Skip to content

Commit

Permalink
feat: added support for o1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Oct 28, 2024
1 parent 1348ce4 commit 8f06b16
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/ax/ai/balance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ export class AxBalancer implements AxAIService {
return this.currentService.getModelConfig();
}

getFeatures() {
return this.currentService.getFeatures();
getFeatures(model?: string) {
return this.currentService.getFeatures(model);
}

async chat(
Expand Down
16 changes: 10 additions & 6 deletions src/ax/ai/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export interface AxBaseAIArgs {
modelInfo: Readonly<AxModelInfo[]>;
models: Readonly<{ model: string; embedModel?: string }>;
options?: Readonly<AxAIServiceOptions>;
supportFor: AxBaseAIFeatures;
supportFor: AxBaseAIFeatures | ((model: string) => AxBaseAIFeatures);
modelMap?: AxAIModelMap;
}

Expand Down Expand Up @@ -95,7 +95,9 @@ export class AxBaseAI<
protected apiURL: string;
protected name: string;
protected headers: Record<string, string>;
protected supportFor: AxBaseAIFeatures;
protected supportFor:
| AxBaseAIFeatures
| ((model: string) => AxBaseAIFeatures);

constructor({
name,
Expand Down Expand Up @@ -197,8 +199,10 @@ export class AxBaseAI<
return this.name;
}

getFeatures(): AxBaseAIFeatures {
return this.supportFor;
getFeatures(model?: string): AxBaseAIFeatures {
return typeof this.supportFor === 'function'
? this.supportFor(model ?? this.models.model)
: this.supportFor;
}

getModelConfig(): AxModelConfig {
Expand All @@ -211,7 +215,7 @@ export class AxBaseAI<
): Promise<AxChatResponse | ReadableStream<AxChatResponse>> {
const model = req.model
? this.modelMap?.[req.model] ?? req.model
: this.models.model;
: this.modelMap?.[this.models.model] ?? this.models.model;

if (this.tracer) {
const mc = this.getModelConfig();
Expand Down Expand Up @@ -366,7 +370,7 @@ export class AxBaseAI<
): Promise<AxEmbedResponse> {
const embedModel = req.embedModel
? this.modelMap?.[req.embedModel] ?? req.embedModel
: this.models.embedModel;
: this.modelMap?.[this.models.embedModel ?? ''] ?? this.models.embedModel;

if (!embedModel) {
throw new Error('No embed model defined');
Expand Down
28 changes: 25 additions & 3 deletions src/ax/ai/openai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import {
import type {
AxAIPromptConfig,
AxAIServiceOptions,
AxChatRequest,
AxChatResponse,
AxChatResponseResult,
AxEmbedResponse,
Expand Down Expand Up @@ -101,7 +100,11 @@ export class AxAIOpenAI extends AxBaseAI<
embedModel: _config.embedModel as string
},
options,
supportFor: { functions: true, streaming: true },
supportFor: (model: string) => {
return isO1Model(model)
? { functions: false, streaming: false }
: { functions: true, streaming: true };
},
modelMap
});
this.config = _config;
Expand Down Expand Up @@ -147,6 +150,10 @@ export class AxAIOpenAI extends AxBaseAI<
}
}));

if (tools && isO1Model(model)) {
throw new Error('Functions are not supported for O1 models');
}

const toolsChoice =
!req.functionCall && req.functions && req.functions.length > 0
? 'auto'
Expand All @@ -159,6 +166,10 @@ export class AxAIOpenAI extends AxBaseAI<

const stream = req.modelConfig?.stream ?? this.config.stream;

if (stream && isO1Model(model)) {
throw new Error('Streaming is not supported for O1 models');
}

const reqValue: AxAIOpenAIChatRequest = {
model,
messages,
Expand Down Expand Up @@ -355,9 +366,15 @@ const mapFinishReason = (
};

function createMessages(
req: Readonly<AxChatRequest>
req: Readonly<AxInternalChatRequest>
): AxAIOpenAIChatRequest['messages'] {
return req.chatPrompt.map((msg) => {
if (msg.role === 'system' && isO1Model(req.model)) {
msg = {
role: 'user',
content: msg.content
};
}
switch (msg.role) {
case 'system':
return { role: 'system' as const, content: msg.content };
Expand Down Expand Up @@ -412,3 +429,8 @@ function createMessages(
}
});
}

const isO1Model = (model: string): boolean =>
[AxAIOpenAIModel.O1Mini, AxAIOpenAIModel.O1Preview].includes(
model as AxAIOpenAIModel
);
2 changes: 1 addition & 1 deletion src/ax/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ export interface AxAIService {
getModelInfo(): Readonly<AxModelInfoWithProvider>;
getEmbedModelInfo(): Readonly<AxModelInfoWithProvider> | undefined;
getModelConfig(): Readonly<AxModelConfig>;
getFeatures(): { functions: boolean; streaming: boolean };
getFeatures(model?: string): { functions: boolean; streaming: boolean };
getModelMap(): AxAIModelMap | undefined;

chat(
Expand Down
4 changes: 2 additions & 2 deletions src/ax/ai/wrap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ export class AxAI implements AxAIService {
return this.ai.getModelConfig();
}

getFeatures(): { functions: boolean; streaming: boolean } {
return this.ai.getFeatures();
getFeatures(model?: string): { functions: boolean; streaming: boolean } {
return this.ai.getFeatures(model);
}

getModelMap(): AxAIModelMap | undefined {
Expand Down
20 changes: 13 additions & 7 deletions src/ax/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export type AxGenerateResult<OUT extends AxGenOut> = OUT & {

export interface AxResponseHandlerArgs<T> {
ai: Readonly<AxAIService>;
model?: string;
sig: Readonly<AxSignature>;
res: T;
usageInfo: { ai: string; model: string };
Expand Down Expand Up @@ -110,10 +111,10 @@ export class AxGen<
}
}

private updateSigForFunctions = (ai: AxAIService) => {
private updateSigForFunctions = (ai: AxAIService, model?: string) => {
// AI supports function calling natively so
// no need to add fields for function call
if (ai.getFeatures().functions) {
if (ai.getFeatures(model).functions) {
return;
}

Expand Down Expand Up @@ -239,6 +240,7 @@ export class AxGen<
if (res instanceof ReadableStream) {
return (await this.processSteamingResponse({
ai,
model,
sig,
res,
usageInfo,
Expand All @@ -250,6 +252,7 @@ export class AxGen<

return (await this.processResponse({
ai,
model,
sig,
res,
usageInfo,
Expand All @@ -262,6 +265,7 @@ export class AxGen<
private async processSteamingResponse({
ai,
sig,
model,
res,
usageInfo,
mem,
Expand Down Expand Up @@ -313,7 +317,7 @@ export class AxGen<
}
}

const funcs = parseFunctions(ai, functionCalls, values);
const funcs = parseFunctions(ai, functionCalls, values, model);
if (funcs) {
await this.processFunctions(ai, funcs, mem, sessionId, traceId);
}
Expand Down Expand Up @@ -372,7 +376,7 @@ export class AxGen<
const maxRetries = options?.maxRetries ?? this.options?.maxRetries ?? 5;
const maxSteps = options?.maxSteps ?? this.options?.maxSteps ?? 10;
const mem = options?.mem ?? this.options?.mem ?? new AxMemory();
const canStream = ai.getFeatures().streaming;
const canStream = ai.getFeatures(options?.model).streaming;

let err: ValidationError | AxAssertionError | undefined;

Expand Down Expand Up @@ -453,7 +457,8 @@ export class AxGen<
values: IN,
options?: Readonly<AxProgramForwardOptions>
): Promise<OUT> {
const sig = this.updateSigForFunctions(ai) ?? this.signature;
const sig =
this.updateSigForFunctions(ai, options?.model) ?? this.signature;

const tracer = this.options?.tracer ?? options?.tracer;

Expand Down Expand Up @@ -515,12 +520,13 @@ export class AxGen<
function parseFunctions(
ai: Readonly<AxAIService>,
functionCalls: Readonly<AxChatResponseResult['functionCalls']>,
values: Record<string, unknown>
values: Record<string, unknown>,
model?: string
): AxChatResponseFunctionCall[] | undefined {
if (!functionCalls || functionCalls.length === 0) {
return;
}
if (ai.getFeatures().functions) {
if (ai.getFeatures(model).functions) {
const funcs: AxChatResponseFunctionCall[] = functionCalls.map((f) => ({
id: f.id,
name: f.function.name,
Expand Down

0 comments on commit 8f06b16

Please sign in to comment.