Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(provider): add martian api #717

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ export const LEMONFOX_AI: string = 'lemonfox-ai';
export const UPSTAGE: string = 'upstage';
export const LAMBDA: string = 'lambda';
export const DASHSCOPE: string = 'dashscope';
export const MARTIAN: string = 'martian';

export const VALID_PROVIDERS = [
ANTHROPIC,
Expand Down Expand Up @@ -121,6 +122,7 @@ export const VALID_PROVIDERS = [
UPSTAGE,
LAMBDA,
DASHSCOPE,
MARTIAN,
];

export const CONTENT_TYPES = {
Expand Down
13 changes: 13 additions & 0 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
CONTENT_TYPES,
HUGGING_FACE,
STABILITY_AI,
MARTIAN,
} from '../globals';
import Providers from '../providers';
import { ProviderAPIConfig, endpointStrings } from '../providers/types';
Expand Down Expand Up @@ -1053,6 +1054,10 @@ export function constructConfigFromRequestHeaders(
openaiProject: requestHeaders[`x-${POWERED_BY}-openai-project`],
};

const martianConfig = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I can see that this mapping was added in the PR. But its not used anywhere in the provider config. Can you please verify if this is required? We can remove it if its not used anywhere

openaiOrganization: requestHeaders[`x-${POWERED_BY}-martian-organization`],
openaiProject: requestHeaders[`x-${POWERED_BY}-martian-project`],
};
const huggingfaceConfig = {
huggingfaceBaseUrl: requestHeaders[`x-${POWERED_BY}-huggingface-base-url`],
};
Expand Down Expand Up @@ -1149,6 +1154,12 @@ export function constructConfigFromRequestHeaders(
...stabilityAiConfig,
};
}
if (parsedConfigJson.provider === MARTIAN) {
parsedConfigJson = {
...parsedConfigJson,
...openAiConfig,
};
}
}
return convertKeysToCamelCase(parsedConfigJson, [
'override_params',
Expand All @@ -1173,6 +1184,8 @@ export function constructConfigFromRequestHeaders(
...(requestHeaders[`x-${POWERED_BY}-provider`] === AZURE_AI_INFERENCE &&
azureAiInferenceConfig),
...(requestHeaders[`x-${POWERED_BY}-provider`] === OPEN_AI && openAiConfig),
...(requestHeaders[`x-${POWERED_BY}-provider`] === MARTIAN &&
martianConfig),
...(requestHeaders[`x-${POWERED_BY}-provider`] === ANTHROPIC &&
anthropicConfig),
...(requestHeaders[`x-${POWERED_BY}-provider`] === HUGGING_FACE &&
Expand Down
2 changes: 2 additions & 0 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import { UpstageConfig } from './upstage';
import { LAMBDA } from '../globals';
import { LambdaProviderConfig } from './lambda';
import { DashScopeConfig } from './dashscope';
import { MartianConfig } from './martian';

const Providers: { [key: string]: ProviderConfigs } = {
openai: OpenAIConfig,
Expand Down Expand Up @@ -94,6 +95,7 @@ const Providers: { [key: string]: ProviderConfigs } = {
upstage: UpstageConfig,
[LAMBDA]: LambdaProviderConfig,
dashscope: DashScopeConfig,
martian: MartianConfig,
};

export default Providers;
23 changes: 23 additions & 0 deletions src/providers/martian/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { ProviderAPIConfig } from '../types';

const MartianAPIConfig: ProviderAPIConfig = {
getBaseURL: ({ providerOptions }) => {
return (
providerOptions.martianBaseUrl || 'https://withmartian.com/api/openai/v1'
);
},
headers: ({ providerOptions }) => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${providerOptions.apiKey}`,
}),
getEndpoint: ({ fn }) => {
switch (fn) {
case 'chatComplete':
return '/chat/completions';
default:
return '';
}
},
};

export default MartianAPIConfig;
230 changes: 230 additions & 0 deletions src/providers/martian/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import { MARTIAN, ANTHROPIC } from '../../globals';
import {
ChatCompletionResponse,
ErrorResponse,
ProviderConfig,
} from '../../types';
import { MartianErrorResponseTransform } from './utils';

export const MartianChatCompleteConfig: ProviderConfig = {
model: {
param: 'model',
required: true,
default: 'router',
},
messages: {
param: 'messages',
required: true,
},
max_tokens: {
param: 'max_tokens',
default: 100,
min: 0,
},
temperature: {
param: 'temperature',
default: 1,
min: 0,
max: 2,
},
top_p: {
param: 'top_p',
default: 1,
min: 0,
max: 1,
},
n: {
param: 'n',
default: 1,
},
stream: {
param: 'stream',
default: false,
},
stop: {
param: 'stop',
},
max_cost: {
param: 'max_cost',
},
max_cost_per_million_tokens: {
param: 'max_cost_per_million_tokens',
},
models: {
param: 'models',
},
willingness_to_pay: {
param: 'willingness_to_pay',
},
extra: {
param: 'extra',
},
};

export interface MartianChatCompleteResponse extends ChatCompletionResponse {
system_fingerprint: string;
}

export const MartianChatCompleteResponseTransform: (
response: MartianChatCompleteResponse | ErrorResponse,
responseStatus: number
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200 && 'error' in response) {
return MartianErrorResponseTransform(response, MARTIAN);
}

return response;
};

/**
* Transforms an Martian-format chat completions JSON response into an array of formatted Martian compatible text/event-stream chunks.
*
* @param {Object} response - The MartianChatCompleteResponse object.
* @param {string} provider - The provider string.
* @returns {Array<string>} - An array of formatted stream chunks.
*/
export const MartianChatCompleteJSONToStreamResponseTransform: (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JSONToStream transformer is only used for OpenAI provider. It is not required to write it for each provider because it assumes that the json that is passed to the function will always be OpenAI compliant.

Its already handled in responseHandler function. So I think it would be safe to remove it from here.

response: MartianChatCompleteResponse,
provider: string
) => Array<string> = (response, provider) => {
const streamChunkArray: Array<string> = [];
const { id, model, system_fingerprint, choices } = response;

const {
prompt_tokens,
completion_tokens,
cache_read_input_tokens,
cache_creation_input_tokens,
} = response.usage || {};

let total_tokens;
if (prompt_tokens && completion_tokens)
total_tokens = prompt_tokens + completion_tokens;

const shouldSendCacheUsage =
provider === ANTHROPIC &&
(Number.isInteger(cache_read_input_tokens) ||
Number.isInteger(cache_creation_input_tokens));

const streamChunkTemplate: Record<string, any> = {
id,
object: 'chat.completion.chunk',
created: Date.now(),
model: model || '',
system_fingerprint: system_fingerprint || null,
provider,
usage: {
...(completion_tokens && { completion_tokens }),
...(prompt_tokens && { prompt_tokens }),
...(total_tokens && { total_tokens }),
...(shouldSendCacheUsage && {
cache_read_input_tokens,
cache_creation_input_tokens,
}),
},
};

for (const [index, choice] of choices.entries()) {
if (
choice.message &&
choice.message.tool_calls &&
choice.message.tool_calls.length
) {
for (const [
toolCallIndex,
toolCall,
] of choice.message.tool_calls.entries()) {
const toolCallNameChunk = {
index: toolCallIndex,
id: toolCall.id,
type: 'function',
function: {
name: toolCall.function.name,
arguments: '',
},
};

const toolCallArgumentChunk = {
index: toolCallIndex,
function: {
arguments: toolCall.function.arguments,
},
};

streamChunkArray.push(
`data: ${JSON.stringify({
...streamChunkTemplate,
choices: [
{
index: index,
delta: {
role: 'assistant',
content: null,
tool_calls: [toolCallNameChunk],
},
},
],
})}\n\n`
);

streamChunkArray.push(
`data: ${JSON.stringify({
...streamChunkTemplate,
choices: [
{
index: index,
delta: {
role: 'assistant',
tool_calls: [toolCallArgumentChunk],
},
},
],
})}\n\n`
);
}
}

if (
choice.message &&
choice.message.content &&
typeof choice.message.content === 'string'
) {
const inidividualWords: Array<string> = [];
for (let i = 0; i < choice.message.content.length; i += 4) {
inidividualWords.push(choice.message.content.slice(i, i + 4));
}
inidividualWords.forEach((word: string) => {
streamChunkArray.push(
`data: ${JSON.stringify({
...streamChunkTemplate,
choices: [
{
index: index,
delta: {
role: 'assistant',
content: word,
},
},
],
})}\n\n`
);
});
}

streamChunkArray.push(
`data: ${JSON.stringify({
...streamChunkTemplate,
choices: [
{
index: index,
delta: {},
finish_reason: choice.finish_reason,
},
],
})}\n\n`
);
}

streamChunkArray.push(`data: [DONE]\n\n`);
return streamChunkArray;
};
18 changes: 18 additions & 0 deletions src/providers/martian/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { ProviderConfigs } from '../types';
import MartianAPIConfig from './api';
import {
MartianChatCompleteConfig,
MartianChatCompleteResponseTransform,
} from './chatComplete';

const MartianConfig: ProviderConfigs = {
api: MartianAPIConfig,
chatComplete: MartianChatCompleteConfig,
responseTransforms: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For OpenAI compliant providers, you can reuse the openai-base provider's function to avoid writing redundant code.

Please check inference-net provider integration as a reference on how to use the base provider:

responseTransforms: responseTransformers(INFERENCENET, {
chatComplete: true,
}),

chatComplete: MartianChatCompleteResponseTransform,
},
};

// Explicit export for MartianConfig
export { MartianConfig };
export default MartianConfig;
14 changes: 14 additions & 0 deletions src/providers/martian/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { ErrorResponse } from '../types';
import { generateErrorResponse } from '../utils';

export const MartianErrorResponseTransform: (
response: ErrorResponse,
provider: string
) => ErrorResponse = (response, provider) => {
return generateErrorResponse(
{
...response.error,
},
provider
);
};
Loading