Skip to content

Commit

Permalink
feat: new stopFunction option to return after a function is called
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Dec 13, 2024
1 parent 103be3f commit 4a56c9c
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 63 deletions.
4 changes: 3 additions & 1 deletion src/ax/ai/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -516,12 +516,14 @@ const logResponse = (resp: Readonly<AxChatResponse>) => {
console.log(colorLog.greenBright(r.content));
}
if (r.functionCalls) {
process.stdout.write(colorLog.yellow(`Executing functions:\n`));

for (const f of r.functionCalls) {
const args =
typeof f.function.params !== 'string'
? JSON.stringify(f.function.params, null, 2)
: f.function.params;
console.log(colorLog.yellow(`${f.function.name}(${args})`));
process.stdout.write(colorLog.yellow(`${f.function.name}(${args})`));
}
}
}
Expand Down
35 changes: 20 additions & 15 deletions src/ax/ai/google-gemini/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -286,42 +286,47 @@ export class AxAIGoogleGemini extends AxBaseAI<
let tools: AxAIGoogleGeminiChatRequest['tools'] | undefined = [];

if (req.functions && req.functions.length > 0) {
tools.push({ function_declarations: req.functions });
tools.push({ functionDeclarations: req.functions });
}

if (this.options?.codeExecution) {
tools.push({ code_execution: {} });
tools.push({ codeExecution: {} });
}

if (this.options?.googleSearchRetrieval) {
tools.push({
google_search_retrieval: this.options.googleSearchRetrieval
googleSearchRetrieval: this.options.googleSearchRetrieval
});
}

if (tools.length === 0) {
tools = undefined;
}

// eslint-disable-next-line @typescript-eslint/naming-convention
let tool_config;
let toolConfig;

if (req.functionCall) {
if (req.functionCall === 'none') {
tool_config = { function_calling_config: { mode: 'NONE' as const } };
toolConfig = { functionCallingConfig: { mode: 'NONE' as const } };
} else if (req.functionCall === 'auto') {
tool_config = { function_calling_config: { mode: 'AUTO' as const } };
toolConfig = { functionCallingConfig: { mode: 'AUTO' as const } };
} else if (req.functionCall === 'required') {
tool_config = {
function_calling_config: { mode: 'ANY' as const }
toolConfig = {
functionCallingConfig: { mode: 'ANY' as const }
};
} else {
tool_config = {
function_calling_config: {
mode: 'ANY' as const,
allowed_function_names: [req.functionCall.function.name]
}
const allowedFunctionNames = req.functionCall.function.name
? {
allowedFunctionNames: [req.functionCall.function.name]
}
: {};
toolConfig = {
functionCallingConfig: { mode: 'ANY' as const },
...allowedFunctionNames
};
}
} else if (tools && tools.length > 0) {
toolConfig = { functionCallingConfig: { mode: 'AUTO' as const } };
}

const generationConfig = {
Expand All @@ -338,7 +343,7 @@ export class AxAIGoogleGemini extends AxBaseAI<
const reqValue: AxAIGoogleGeminiChatRequest = {
contents,
tools,
tool_config,
toolConfig,
systemInstruction,
generationConfig,
safetySettings
Expand Down
14 changes: 7 additions & 7 deletions src/ax/ai/google-gemini/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ export type AxAIGoogleGeminiToolFunctionDeclaration = {

export type AxAIGoogleGeminiToolGoogleSearchRetrieval = {
mode?: 'MODE_DYNAMIC';
dynamic_threshold?: number;
dynamicThreshold?: number;
};

export type AxAIGoogleGeminiTool = {
function_declarations?: AxAIGoogleGeminiToolFunctionDeclaration[];
code_execution?: object;
google_search_retrieval?: AxAIGoogleGeminiToolGoogleSearchRetrieval;
functionDeclarations?: AxAIGoogleGeminiToolFunctionDeclaration[];
codeExecution?: object;
googleSearchRetrieval?: AxAIGoogleGeminiToolGoogleSearchRetrieval;
};

export type AxAIGoogleGeminiToolConfig = {
function_calling_config: {
functionCallingConfig: {
mode: 'ANY' | 'NONE' | 'AUTO';
allowed_function_names?: string[];
allowedFunctionNames?: string[];
};
};

Expand All @@ -113,7 +113,7 @@ export type AxAIGoogleGeminiSafetySettings = {
export type AxAIGoogleGeminiChatRequest = {
contents: AxAIGoogleGeminiContent[];
tools?: AxAIGoogleGeminiTool[];
tool_config?: AxAIGoogleGeminiToolConfig;
toolConfig?: AxAIGoogleGeminiToolConfig;
systemInstruction?: AxAIGoogleGeminiContent;
generationConfig: AxAIGoogleGeminiGenerationConfig;
safetySettings?: AxAIGoogleGeminiSafetySettings;
Expand Down
8 changes: 6 additions & 2 deletions src/ax/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,16 @@ export type AxChatRequest = {
functionCalls?: {
id: string;
type: 'function';

function: { name: string; params?: string | object };
}[];
cache?: boolean;
}
| { role: 'function'; result: string; functionId: string; cache?: boolean }
| {
role: 'function';
result: string;
functionId: string;
cache?: boolean;
}
>[];
functions?: Readonly<{
name: string;
Expand Down
5 changes: 5 additions & 0 deletions src/ax/dsp/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,13 @@ export const processFunctions = async (
traceId?: string
) => {
const funcProc = new AxFunctionProcessor(functionList);
const functionsExecuted = new Set<string>();

// Map each function call to a promise that resolves to the function result or null
const promises = functionCalls.map((func) =>
funcProc?.execute(func, { sessionId, traceId, ai }).then((fres) => {
functionsExecuted.add(func.name.toLowerCase());

if (fres?.id) {
return {
role: 'function' as const,
Expand All @@ -152,6 +155,8 @@ export const processFunctions = async (
mem.add(result, sessionId);
}
});

return functionsExecuted;
};

// eslint-disable-next-line @typescript-eslint/naming-convention
Expand Down
56 changes: 45 additions & 11 deletions src/ax/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export interface AxGenOptions {

functions?: InputFunctionType;
functionCall?: AxChatRequest['functionCall'];
stopFunction?: string;
promptTemplate?: typeof AxPromptTemplate;
asserts?: AxAssertion[];
streamingAsserts?: AxStreamingAssertion[];
Expand Down Expand Up @@ -85,8 +86,8 @@ export class AxGen<
private asserts: AxAssertion[];
private streamingAsserts: AxStreamingAssertion[];
private options?: Omit<AxGenOptions, 'functions'>;

private functions?: AxFunction[];
private functionsExecuted: Set<string> = new Set<string>();

constructor(
signature: Readonly<AxSignature | string>,
Expand Down Expand Up @@ -156,7 +157,8 @@ export class AxGen<
stream,
model,
rateLimiter,
functions
functions,
functionCall: _functionCall
}: Readonly<
Omit<AxProgramForwardOptions, 'ai'> & { ai: AxAIService; stream: boolean }
>) {
Expand All @@ -166,7 +168,7 @@ export class AxGen<
throw new Error('No chat prompt found');
}

const functionCall = this.options?.functionCall;
const functionCall = _functionCall ?? this.options?.functionCall;

const hasJSON = this.signature
.getOutputFields()
Expand Down Expand Up @@ -208,7 +210,8 @@ export class AxGen<
model,
rateLimiter,
stream = false,
functions
functions,
functionCall
}: Readonly<
Omit<AxProgramForwardOptions, 'ai' | 'mem'> & {
sig: Readonly<AxSignature>;
Expand All @@ -230,7 +233,8 @@ export class AxGen<
modelConfig,
model,
rateLimiter,
functions
functions,
functionCall
});

if (res instanceof ReadableStream) {
Expand Down Expand Up @@ -321,7 +325,15 @@ export class AxGen<
if (!functions) {
throw new Error('Functions are not defined');
}
await processFunctions(ai, functions, funcs, mem, sessionId, traceId);
const fx = await processFunctions(
ai,
functions,
funcs,
mem,
sessionId,
traceId
);
this.functionsExecuted = new Set([...this.functionsExecuted, ...fx]);
}

streamingExtractFinalValue(values, xstate, content);
Expand Down Expand Up @@ -360,7 +372,15 @@ export class AxGen<
if (!functions) {
throw new Error('Functions are not defined');
}
await processFunctions(ai, functions, funcs, mem, sessionId, traceId);
const fx = await processFunctions(
ai,
functions,
funcs,
mem,
sessionId,
traceId
);
this.functionsExecuted = new Set([...this.functionsExecuted, ...fx]);
}
}

Expand All @@ -379,6 +399,10 @@ export class AxGen<
options?: Readonly<AxProgramForwardOptions>,
span?: AxSpan
): Promise<OUT> {
const stopFunction = (
options?.stopFunction ?? this.options?.stopFunction
)?.toLowerCase();

const maxRetries = options?.maxRetries ?? this.options?.maxRetries ?? 5;
const maxSteps = options?.maxSteps ?? this.options?.maxSteps ?? 10;
const mem = options?.mem ?? this.options?.mem ?? new AxMemory();
Expand Down Expand Up @@ -412,16 +436,25 @@ export class AxGen<
stream: canStream && options?.stream,
maxSteps: options?.maxSteps,
rateLimiter: options?.rateLimiter,
functions: options?.functions
functions: options?.functions,
functionCall: options?.functionCall
});

const lastMemItem = mem.getLast(options?.sessionId);

const stopFunctionExecuted =
stopFunction && this.functionsExecuted.has(stopFunction);

if (lastMemItem?.role === 'function') {
continue multiStepLoop;
if (!stopFunction || !stopFunctionExecuted) {
continue multiStepLoop;
}
}

if (!stopFunctionExecuted) {
assertRequiredFields(sig, output);
}

assertRequiredFields(sig, output);
this.trace = { ...values, ...output };
return output;
} catch (e) {
Expand Down Expand Up @@ -449,14 +482,15 @@ export class AxGen<
}
}
}

if (err instanceof AxAssertionError && err.getOptional()) {
return err.getValue() as OUT;
}

throw new Error(`Unable to fix validation error: ${err?.message}`);
}

throw new Error('Could not complete task within maximum allowed steps');
throw new Error(`Max steps reached: ${maxSteps}`);
}

public override async forward(
Expand Down
3 changes: 3 additions & 0 deletions src/ax/dsp/program.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type {
AxAIService,
AxChatRequest,
AxChatResponse,
AxFunction,
AxModelConfig,
Expand Down Expand Up @@ -54,6 +55,8 @@ export type AxProgramForwardOptions = {
stream?: boolean;
debug?: boolean;
functions?: AxFunction[];
functionCall?: AxChatRequest['functionCall'];
stopFunction?: string;
};

export interface AxTunable {
Expand Down
8 changes: 5 additions & 3 deletions src/ax/dsp/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class AxPromptTemplate {
this.outputFormat = {
type: 'text' as const,
text: [
'Use the following key-value output format for the output not JSON or a markdown block.',
'Use only the following key-value output format for the output.',
...this.renderOutFields(this.sig.getOutputFields()),
'---\n\n'
].join('\n\n')
Expand Down Expand Up @@ -129,7 +129,7 @@ export class AxPromptTemplate {
private renderExamples = (data: Readonly<Record<string, AxFieldValue>[]>) => {
const list: ChatRequestUserMessage = [];

for (const item of data) {
for (const [index, item] of data.entries()) {
const renderedInputItem = this.sig
.getInputFields()
.map((field) => this.renderInField(field, item, true))
Expand All @@ -143,7 +143,9 @@ export class AxPromptTemplate {
.flat();

if (renderedOutputItem.length === 0) {
throw new Error('Output fields are required for examples.');
throw new Error(
`Output fields are required in examples: index: ${index}, data: ${JSON.stringify(item)}`
);
}

const renderedItem = [...renderedInputItem, ...renderedOutputItem];
Expand Down
Loading

0 comments on commit 4a56c9c

Please sign in to comment.