Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement cache on getPrompt and getPromptById methods #86

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
115 changes: 67 additions & 48 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { ReadStream } from 'fs';
import { v4 as uuidv4 } from 'uuid';

import { LiteralClient } from '.';
import { sharedCache } from './cache/sharedcache';
import { getPromptCacheKey } from './cache/utils';
import {
Dataset,
DatasetExperiment,
Expand Down Expand Up @@ -340,6 +342,8 @@ type CreateAttachmentParams = {
* Then you can use the `api` object to make calls to the Literal service.
*/
export class API {
/** @ignore */
private cache: typeof sharedCache;
/** @ignore */
public client: LiteralClient;
/** @ignore */
Expand Down Expand Up @@ -372,6 +376,8 @@ export class API {
throw new Error('LITERAL_API_URL not set');
}

this.cache = sharedCache;

this.apiKey = apiKey;
this.url = url;
this.environment = environment;
Expand Down Expand Up @@ -399,7 +405,7 @@ export class API {
* @returns The data part of the response from the GraphQL endpoint.
* @throws Will throw an error if the GraphQL call returns errors or if the request fails.
*/
private async makeGqlCall(query: string, variables: any) {
private async makeGqlCall(query: string, variables: any, timeout?: number) {
try {
const response = await axios({
url: this.graphqlEndpoint,
Expand All @@ -408,7 +414,8 @@ export class API {
data: {
query: query,
variables: variables
}
},
timeout
});
if (response.data.errors) {
throw new Error(JSON.stringify(response.data.errors));
Expand Down Expand Up @@ -2110,41 +2117,75 @@ export class API {
}

/**
* Retrieves a prompt by its id.
*
* @param id ID of the prompt to retrieve.
* @returns The prompt with given ID.
* Retrieves a prompt by its id. If the request fails, it will try to get the prompt from the cache.
*/
public async getPromptById(id: string) {
const query = `
query GetPrompt($id: String!) {
promptVersion(id: $id) {
createdAt
id
label
settings
status
tags
templateMessages
tools
type
updatedAt
url
variables
variablesDefaultValues
version
lineage {
name
query GetPrompt($id: String!) {
promptVersion(id: $id) {
createdAt
id
label
settings
status
tags
templateMessages
tools
type
updatedAt
url
variables
variablesDefaultValues
version
lineage {
name
}
}
}
}
`;

return await this.getPromptWithQuery(query, { id });
}

/**
* Retrieves a prompt by its name and optionally by its version.
* Private helper method to execute prompt queries with error handling and caching
*/
private async getPromptWithQuery(
query: string,
variables: { id?: string; name?: string; version?: number }
) {
const cachedPrompt = sharedCache.get(getPromptCacheKey(variables));
const timeout = cachedPrompt ? 1000 : undefined;

try {
const result = await this.makeGqlCall(query, variables, timeout);

if (!result.data || !result.data.promptVersion) {
return cachedPrompt;
}

const promptData = result.data.promptVersion;
promptData.provider = promptData.settings?.provider;
promptData.name = promptData.lineage?.name;
delete promptData.lineage;
if (promptData.settings) {
delete promptData.settings.provider;
}

const prompt = new Prompt(this, promptData);

sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);

return prompt;
} catch (error) {
return cachedPrompt;
}
}

/**
* Retrieves a prompt by its name and optionally by its version. If the request fails, it will try to get the prompt from the cache.
*
* @param name - The name of the prompt to retrieve.
* @param version - The version number of the prompt (optional).
Expand All @@ -2171,31 +2212,9 @@ export class API {
}
}
`;

return await this.getPromptWithQuery(query, { name, version });
}

private async getPromptWithQuery(
query: string,
variables: Record<string, any>
) {
const result = await this.makeGqlCall(query, variables);

if (!result.data || !result.data.promptVersion) {
return null;
}

const promptData = result.data.promptVersion;
promptData.provider = promptData.settings?.provider;
promptData.name = promptData.lineage?.name;
delete promptData.lineage;
if (promptData.settings) {
delete promptData.settings.provider;
}

return new Prompt(this, promptData);
}

/**
* Retrieves a prompt A/B testing rollout by its name.
*
Expand Down
36 changes: 36 additions & 0 deletions src/cache/sharedcache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
const cache: Map<string, any> = new Map();

class SharedCache {
private static instance: SharedCache;

public constructor() {
if (SharedCache.instance) {
throw new Error('SharedCache can only be created once');
}
SharedCache.instance = this;
}

public getInstance(): SharedCache {
return this;
}

public getCache(): Map<string, any> {
return cache;
}

public get(key: string): any {
return cache.get(key);
}

public put(key: string, value: any): void {
cache.set(key, value);
}

public clear(): void {
cache.clear();
}
}

export const sharedCache = new SharedCache();

export default sharedCache;
18 changes: 18 additions & 0 deletions src/cache/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export function getPromptCacheKey({
id,
name,
version
}: {
id?: string;
name?: string;
version?: number;
}): string {
if (id) {
return id;
} else if (name && typeof version === 'number') {
return `${name}:${version}`;
} else if (name) {
return name;
}
throw new Error('Either id or name must be provided');
}
77 changes: 77 additions & 0 deletions tests/api.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import axios from 'axios';
import 'dotenv/config';
import { v4 as uuidv4 } from 'uuid';

import { ChatGeneration, IGenerationMessage, LiteralClient } from '../src';
import { sharedCache } from '../src/cache/sharedcache';
import { Dataset } from '../src/evaluation/dataset';
import { Score } from '../src/evaluation/score';
import { Prompt, PromptConstructor } from '../src/prompt-engineering/prompt';
import { sleep } from './utils';

describe('End to end tests for the SDK', function () {
Expand Down Expand Up @@ -597,6 +600,30 @@ describe('End to end tests for the SDK', function () {
});

describe('Prompt api', () => {
const mockPromptData: PromptConstructor = {
id: 'test-id',
name: 'test-prompt',
version: 1,
createdAt: new Date().toISOString(),
type: 'CHAT',
templateMessages: [],
tools: [],
settings: {
provider: 'test',
model: 'test',
frequency_penalty: 0,
presence_penalty: 0,
temperature: 0,
top_p: 0,
max_tokens: 0
},
variables: [],
variablesDefaultValues: {},
metadata: {},
items: [],
provider: 'test'
};

it('should get a prompt by name', async () => {
const prompt = await client.api.getPrompt('Default');

Expand Down Expand Up @@ -657,6 +684,56 @@ is a templated list.`;
expect(formatted[0].content).toBe(expected);
});

it('should fallback to cache when getPromptById DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);

jest
.spyOn(client.api as any, 'makeGqlCall')
.mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPromptById(prompt.id);
expect(result).toEqual(prompt);
});

it('should fallback to cache when getPrompt DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);

sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);

jest.spyOn(axios, 'post').mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPrompt(prompt.id);
expect(result).toEqual(prompt);
});

it('should update cache with fresh data on successful DB call', async () => {
const prompt = new Prompt(client.api, mockPromptData);

jest.spyOn(client.api as any, 'makeGqlCall').mockResolvedValueOnce({
data: { promptVersion: prompt }
});

await client.api.getPromptById(prompt.id);

const cachedPrompt = sharedCache.get(prompt.id);
expect(cachedPrompt).toBeDefined();
expect(cachedPrompt?.id).toBe(prompt.id);
});

it('should return null when both DB and cache fail', async () => {
jest
.spyOn(client.api as any, 'makeGqlCall')
.mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPromptById('non-existent-id');
expect(result).toBeUndefined();
});

it('should get a prompt A/B testing configuration', async () => {
const promptName = 'TypeScript SDK E2E Tests';

Expand Down
Loading
Loading