diff --git a/src/ai/assistant.ts b/src/ai/assistant.ts index 11aae19d..da3826d0 100644 --- a/src/ai/assistant.ts +++ b/src/ai/assistant.ts @@ -30,6 +30,7 @@ import { buildSetOp, buildSetStateRuleOp, getChecksumAddress, + getAssistant, } from '../utils/util'; import { isObjectOwner, @@ -221,27 +222,38 @@ export class Assistants extends FactoryBase { * @returns Returns a promise that resolves with the assistant. */ async get(objectId: string, tokenId: string, assistantId: string): Promise { + const appId = AinftObject.getAppId(objectId); + await validateObject(this.ain, objectId); await validateToken(this.ain, objectId, tokenId); - await validateAssistant(this.ain, objectId, tokenId, assistantId); - - const serverName = getServerName(); - await validateServerConfigurationForObject(this.ain, objectId, serverName); - const opType = OperationType.RETRIEVE_ASSISTANT; - const body = { assistantId }; - const { data } = await request(this.ainize!, { - serverName, - opType, - data: body, - }); + const data = await getAssistant(this.ain, appId, tokenId); + const assistant = { + id: data.id, + tokenId, + model: data.config.model, + name: data.config.name, + instructions: data.config.instructions, + description: data.config.description || null, + metadata: data.config.metadata || {}, + created_at: data.createdAt, + }; - return data; + return assistant; } - async list(objectId: string, { limit = 20, offset = 0, order = 'desc' }: QueryParams) { - const address = await this.ain.signer.getAddress(); - + /** + * Retrieves a list of assistants. + * @param {string} objectId - The ID of AINFT object. + * @param {string} address - The checksum address of account. + * @param {QueryParams} QueryParams - The parameters for querying items. + * @returns Returns a promise that resolves with the list of assistants. + */ + async list( + objectId: string, + address: string, + { limit = 20, offset = 0, order = 'desc' }: QueryParams = {} + ) { await validateObject(this.ain, objectId); const tokens = await this.getTokensByAddress(objectId, address); diff --git a/src/ai/message.ts b/src/ai/message.ts index 9c5a203b..31c00f41 100644 --- a/src/ai/message.ts +++ b/src/ai/message.ts @@ -10,7 +10,6 @@ import { MessageMap, MessageTransactionResult, MessageUpdateParams, - Page, } from '../types'; import { buildSetTxBody, buildSetValueOp, getAssistant, getValue, sendTx } from '../utils/util'; import { @@ -114,67 +113,48 @@ export class Messages extends FactoryBase { * @param {string} tokenId - The ID of AINFT token. * @param {string} threadId - The ID of thread. * @param {string} messageId - The ID of message. + * @param {string} address - The checksum address of account. * @returns Returns a promise that resolves with the message. */ async get( objectId: string, tokenId: string, threadId: string, - messageId: string + messageId: string, + address: string ): Promise { - const address = await this.ain.signer.getAddress(); - await validateObject(this.ain, objectId); await validateToken(this.ain, objectId, tokenId); await validateAssistant(this.ain, objectId, tokenId); await validateThread(this.ain, objectId, tokenId, address, threadId); - await validateMessage(this.ain, objectId, tokenId, address, threadId, messageId); - const serverName = getServerName(); - await validateServerConfigurationForObject(this.ain, objectId, serverName); - - const opType = OperationType.RETRIEVE_MESSAGE; - const body = { threadId, messageId }; - - const { data } = await request(this.ainize!, { - serverName, - opType, - data: body, - }); + const appId = AinftObject.getAppId(objectId); + const messagesPath = Path.app(appId).token(tokenId).ai().history(address).thread(threadId).messages().value(); + const messages: MessageMap = await getValue(this.ain, messagesPath); + const key = this.findMessageKey(messages, messageId); - return data; + return messages[key]; } - // TODO(jiyoung): fetch from blockchain db. /** * Retrieves a list of messages. * @param {string} objectId - The ID of AINFT object. * @param {string} tokenId - The ID of AINFT token. * @param {string} threadId - The ID of thread. + * @param {string} address - The checksum address of account. * @returns Returns a promise that resolves with the list of messages. */ - async list(objectId: string, tokenId: string, threadId: string): Promise { - const appId = AinftObject.getAppId(objectId); - const address = await this.ain.signer.getAddress(); - + async list(objectId: string, tokenId: string, threadId: string, address: string): Promise { await validateObject(this.ain, objectId); await validateToken(this.ain, objectId, tokenId); await validateAssistant(this.ain, objectId, tokenId); await validateThread(this.ain, objectId, tokenId, address, threadId); - const serverName = getServerName(); - await validateServerConfigurationForObject(this.ain, objectId, serverName); - - const opType = OperationType.LIST_MESSAGES; - const body = { threadId }; - - const { data } = await request>(this.ainize!, { - serverName, - opType, - data: body, - }); + const appId = AinftObject.getAppId(objectId); + const messagesPath = Path.app(appId).token(tokenId).ai().history(address).thread(threadId).messages().value(); + const messages = await this.ain.db.ref(messagesPath).getValue(); - return data?.data || {}; + return messages; } private async createMessageAndRun( @@ -330,4 +310,18 @@ export class Messages extends FactoryBase { return buildSetTxBody(buildSetValueOp(messagePath, value), address); } + + private findMessageKey = (messages: MessageMap, messageId: string) => { + let messageKey = null; + for (const key in messages) { + if (messages[key].id === messageId) { + messageKey = key; + break; + } + } + if (!messageKey) { + throw new Error('Message not found'); + } + return messageKey; + }; } diff --git a/src/ai/thread.ts b/src/ai/thread.ts index f6dca93e..4f2433f0 100644 --- a/src/ai/thread.ts +++ b/src/ai/thread.ts @@ -173,104 +173,55 @@ export class Threads extends FactoryBase { * @param {string} objectId - The ID of AINFT object. * @param {string} tokenId - The ID of AINFT token. * @param {string} threadId - The ID of thread. + * @param {string} address - The checksum address of account. * @returns Returns a promise that resolves with the thread. */ - async get(objectId: string, tokenId: string, threadId: string): Promise { - const address = await this.ain.signer.getAddress(); - + async get(objectId: string, tokenId: string, threadId: string, address: string): Promise { await validateObject(this.ain, objectId); await validateToken(this.ain, objectId, tokenId); await validateAssistant(this.ain, objectId, tokenId); await validateThread(this.ain, objectId, tokenId, address, threadId); - const serverName = getServerName(); - await validateServerConfigurationForObject(this.ain, objectId, serverName); - - const opType = OperationType.RETRIEVE_THREAD; - const body = { threadId }; - - const { data } = await request(this.ainize!, { - serverName, - opType, - data: body, - }); + const appId = AinftObject.getAppId(objectId); + const threadPath = Path.app(appId).token(tokenId).ai().history(address).thread(threadId).value(); + const data = await this.ain.db.ref(threadPath).getValue(); + const thread = { + id: data.id, + metadata: data.metadata || {}, + created_at: data.createdAt, + }; - return data; + return thread; } + /** + * Retrieves a list of threads. + * @param {string} objectId - The ID of AINFT object. + * @param {string | null} [tokenId] - The ID of AINFT token. + * @param {string | null} [address] - The checksum address of account. + * @param {QueryParams} QueryParams - The parameters for querying items. + * @returns Returns a promise that resolves with the list of threads. + */ async list( objectId: string, tokenId?: string | null, address?: string | null, - { limit = 20, order = 'desc', next }: QueryParams = {} + { limit = 20, offset = 0, order = 'desc' }: QueryParams = {} ) { - let checksum = null; - if (address) { - checksum = getChecksumAddress(address); + await validateObject(this.ain, objectId); + if (tokenId) { + await validateToken(this.ain, objectId, tokenId); } - await validateObject(this.ain, objectId); + const tokens = await this.fetchTokens(objectId); + const threads = this.flattenThreads(tokens); + const filtered = this.filterThreads(threads, tokenId, address); + const sorted = _.orderBy(filtered, ['created_at'], [order]); - const serverName = getServerName(); - const opType = OperationType.LIST_THREADS; - const body = { - objectId, - ...(tokenId && { tokenId }), - ...(checksum && { address: checksum }), - limit, - order, - ...(next && { next }), - }; - - const { data } = await request(this.ainize!, { - serverName, - opType, - data: body, - }); + const total = sorted.length; + const items = sorted.slice(offset, offset + limit); - return data; - // NOTE(jiyoung): example data - /* - return { - items: { - '0': { - id: 'thread_yjw3LcSxSxIkrk225v7kLpCA', - assistant: { - id: 'asst_IfWuJqqO5PdCF9DbgZRcFClG', - model: 'gpt-3.5-turbo', - name: 'AINA-TKAJYJF1C5', - instructions: '', - description: '일상적인 작업에 적합합니다. GPT-3.5-turbo에 의해 구동됩니다.', - metadata: { - image: 'https://picsum.photos/id/1/200/200', - }, - }, - created_at: 1711962854, - metadata: { - title: '도와드릴까요?', - }, - }, - '1': { - id: 'thread_mmzBrZeM5vllqEceRttvu1xk', - assistant: { - id: 'asst_IfWuJqqO5PdCF9DbgZRcFClG', - model: 'gpt-3.5-turbo', - name: 'AINA-TKAJYJF1C5', - instructions: '', - description: '일상적인 작업에 적합합니다. GPT-3.5-turbo에 의해 구동됩니다.', - metadata: { - image: 'https://picsum.photos/id/1/200/200', - }, - }, - created_at: 1711961028, - metadata: { - title: '영문번역', - }, - }, - }, - next: 'e49274a2-a255-4f95-b57a-68beebc6bdf7', - }; - */ + return { total, items }; } async createAndRun( @@ -441,4 +392,56 @@ export class Threads extends FactoryBase { return buildSetTxBody(buildSetValueOp(threadPath, value), address); } + + private async fetchTokens(objectId: string) { + const appId = AinftObject.getAppId(objectId); + const tokensPath = Path.app(appId).tokens().value(); + return this.ain.db.ref(tokensPath).getValue(); + } + + private flattenThreads(tokens: any) { + const flatten: any = []; + _.forEach(tokens, (token, tokenId) => { + const assistant = token.ai; + if (!assistant) { + return; + } + const histories = assistant.history; + if (typeof histories !== 'object' || histories === true) { + return; + } + _.forEach(histories, (history, address) => { + const threads = _.get(history, 'threads'); + _.forEach(threads, (thread) => { + flatten.push({ + id: thread.id, + metadata: thread.metadata || {}, + created_at: thread.createdAt, + assistant: { + id: assistant.id, + tokenId, + model: assistant.config.model, + name: assistant.config.name, + instructions: assistant.config.instructions, + description: assistant.config.description || null, + metadata: assistant.config.metadata || {}, + created_at: assistant.createdAt, + }, + author: { address }, + }); + }); + }); + }); + return flatten; + } + + private filterThreads(threads: any, tokenId?: string | null, address?: string | null) { + return _.filter(threads, (thread) => { + const threadTokenId = _.get(thread, 'assistant.tokenId'); + const threadAddress = _.get(thread, 'author.address'); + const tokenIdMatch = tokenId ? threadTokenId === tokenId : true; + const addressMatch = address ? threadAddress === address : true; + return tokenIdMatch && addressMatch; + }); + } } diff --git a/src/types.ts b/src/types.ts index f2b34bb4..51908ea3 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1290,7 +1290,7 @@ export interface Thread { * The metadata can contain up to 16 pairs, * with keys limited to 64 characters and values to 512 characters. */ - metadata: object | null; + metadata: object | {}; /** The UNIX timestamp in seconds. */ created_at: number; } @@ -1377,10 +1377,12 @@ export interface MessageMap { } export interface QueryParams { + /** The maximum number of items to return */ limit?: number; + /** The number of items to skip */ offset?: number; + /** The order of the result set */ order?: 'asc' | 'desc'; - next?: string | null; } export interface Page { diff --git a/test/ai/assistant.test.ts b/test/ai/assistant.test.ts index 403b37a7..cf269847 100644 --- a/test/ai/assistant.test.ts +++ b/test/ai/assistant.test.ts @@ -50,7 +50,7 @@ describe.skip('assistant', () => { }); it('should list assistants', async () => { - const result = await ainft.assistant.list(objectId, { limit: 20, offset: 0, order: 'desc' }); + const result = await ainft.assistant.list(objectId, address); expect(result.total).toBeDefined(); expect(result.items).toBeDefined(); diff --git a/test/ai/message.test.ts b/test/ai/message.test.ts index 83628cfd..81c568c4 100644 --- a/test/ai/message.test.ts +++ b/test/ai/message.test.ts @@ -1,5 +1,5 @@ import AinftJs from '../../src/ainft'; -import { messageId, objectId, privateKey, threadId, tokenId } from '../test_data'; +import { messageId, objectId, privateKey, threadId, tokenId, address } from '../test_data'; import { MESSAGE_REGEX, TX_HASH_REGEX } from '../constants'; jest.setTimeout(60 * 1000); // 1min @@ -39,7 +39,7 @@ describe.skip('message', () => { }); it('should get message', async () => { - const message = await ainft.message.get(objectId, tokenId, threadId, messageId); + const message = await ainft.message.get(objectId, tokenId, threadId, messageId, address); expect(message.id).toBe(messageId); expect(message.thread_id).toBe(threadId); @@ -49,7 +49,7 @@ describe.skip('message', () => { }); it('should list messages', async () => { - const messages = await ainft.message.list(objectId, tokenId, threadId); + const messages = await ainft.message.list(objectId, tokenId, threadId, address); expect(Object.keys(messages).length).toBe(2); }); diff --git a/test/ai/thread.test.ts b/test/ai/thread.test.ts index 903ef1fb..9d52b072 100644 --- a/test/ai/thread.test.ts +++ b/test/ai/thread.test.ts @@ -1,5 +1,5 @@ import AinftJs from '../../src/ainft'; -import { privateKey, objectId, tokenId, threadId } from '../test_data'; +import { privateKey, address, objectId, tokenId, threadId } from '../test_data'; import { TX_HASH_REGEX, THREAD_REGEX } from '../constants'; describe.skip('thread', () => { @@ -31,14 +31,14 @@ describe.skip('thread', () => { }); it('should get thread', async () => { - const thread = await ainft.thread.get(objectId, tokenId, threadId); + const thread = await ainft.thread.get(objectId, tokenId, threadId, address); expect(thread.id).toBe(threadId); expect(thread.metadata).toEqual({ key1: 'value1' }); }); it('should list threads', async () => { - const result = await ainft.thread.list(objectId, null, null, { limit: 20, order: 'desc' }); + const result = await ainft.thread.list(objectId, null, null, { limit: 20, offset: 0, order: 'desc' }); expect(result.items).toBeDefined(); });