Skip to content

Commit

Permalink
✨ feat: support sync with webrtc
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Mar 12, 2024
1 parent 4d7154b commit 303ec63
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 166 deletions.
6 changes: 3 additions & 3 deletions src/app/chat/(desktop)/features/SessionHeader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import { MessageSquarePlus } from 'lucide-react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import useSWR from 'swr';

import { DESKTOP_HEADER_ICON_SIZE } from '@/const/layoutTokens';
import { useEnabledDataSync } from '@/hooks/useSyncData';
import { useGlobalStore } from '@/store/global';
import { useSessionStore } from '@/store/session';

Expand All @@ -27,9 +27,9 @@ const Header = memo(() => {
const { styles } = useStyles();
const { t } = useTranslation('chat');
const [createSession] = useSessionStore((s) => [s.createSession]);
const [syncEnabled, enabledSync] = useGlobalStore((s) => [s.syncEnabled, s.enabledSync]);
const [syncEnabled] = useGlobalStore((s) => [s.syncEnabled]);

useSWR('enableSync', enabledSync, { revalidateOnFocus: false });
useEnabledDataSync();

return (
<Flexbox className={styles.top} gap={16} padding={16}>
Expand Down
4 changes: 2 additions & 2 deletions src/database/core/__tests__/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ describe('BaseModel', () => {
content: 'Hello, World!',
};

const result = await baseModel['_add'](validData);
const result = await baseModel['_addWithSync'](validData);

expect(result).toHaveProperty('id');
expect(console.error).not.toHaveBeenCalled();
Expand All @@ -49,7 +49,7 @@ describe('BaseModel', () => {
content: 'Hello, World!',
};

await expect(baseModel['_add'](invalidData)).rejects.toThrow(TypeError);
await expect(baseModel['_addWithSync'](invalidData)).rejects.toThrow(TypeError);
});
});
});
31 changes: 29 additions & 2 deletions src/database/core/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Dexie, { BulkError } from 'dexie';
import { ZodObject } from 'zod';

import { DBBaseFieldsSchema } from '@/database/core/types/db';
import { syncBus } from '@/libs/sync';
import { nanoid } from '@/utils/uuid';

import { LocalDB, LocalDBInstance, LocalDBSchema } from './db';
Expand All @@ -21,10 +22,14 @@ export class BaseModel<N extends keyof LocalDBSchema = any, T = LocalDBSchema[N]
return this.db[this._tableName] as Dexie.Table;
}

get yMap() {
return syncBus.getYMap(this._tableName);
}

/**
* create a new record
*/
protected async _add<T = LocalDBSchema[N]['model']>(
protected async _addWithSync<T = LocalDBSchema[N]['model']>(
data: T,
id: string | number = nanoid(),
primaryKey: string = 'id',
Expand All @@ -51,6 +56,9 @@ export class BaseModel<N extends keyof LocalDBSchema = any, T = LocalDBSchema[N]

const newId = await this.db[tableName].add(record);

// sync data to yjs data map
await this.updateYMapItem(newId);

return { id: newId };
}

Expand Down Expand Up @@ -122,6 +130,10 @@ export class BaseModel<N extends keyof LocalDBSchema = any, T = LocalDBSchema[N]
// Using bulkAdd to insert validated data
try {
await this.table.bulkAdd(validatedData);
const pools = validatedData.map(async (item) => {
await this.updateYMapItem(item.id);
});
await Promise.all(pools);

return {
added: validatedData.length,
Expand All @@ -144,7 +156,7 @@ export class BaseModel<N extends keyof LocalDBSchema = any, T = LocalDBSchema[N]
}
}

protected async _update(id: string, data: Partial<T>) {
protected async _updateWithSync(id: string, data: Partial<T>) {
// we need to check whether the data is valid
// pick data related schema from the full schema
const keys = Object.keys(data);
Expand All @@ -162,6 +174,21 @@ export class BaseModel<N extends keyof LocalDBSchema = any, T = LocalDBSchema[N]

const success = await this.table.update(id, { ...data, updatedAt: Date.now() });

// sync data to yjs data map
await this.updateYMapItem(id);

return { success };
}

protected async _deleteWithSync(id: string) {
const result = await this.table.delete(id);
// sync delete data to yjs data map
this.yMap.delete(id);
return result;
}

private updateYMapItem = async (id: string) => {
const newData = await this.table.get(id);
this.yMap.set(id, newData);
};
}
2 changes: 1 addition & 1 deletion src/database/models/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class _FileModel extends BaseModel<'files'> {
async create(file: DB_File) {
const id = nanoid();

return this._add(file, `file-${id}`);
return this._addWithSync(file, `file-${id}`);
}

async findById(id: string) {
Expand Down
165 changes: 88 additions & 77 deletions src/database/models/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,8 @@ class _MessageModel extends BaseModel {
constructor() {
super('messages', DB_MessageSchema);
}
async create(data: CreateMessageParams) {
const id = nanoid();

const messageData: DB_Message = this.mapChatMessageToDBMessage(data as ChatMessage);

return this._add(messageData, id);
}

async batchCreate(messages: ChatMessage[]) {
const data: DB_Message[] = messages.map((m) => this.mapChatMessageToDBMessage(m));

return this._batchAdd(data);
}
// **************** Query *************** //

async query({
sessionId,
Expand Down Expand Up @@ -91,45 +80,79 @@ class _MessageModel extends BaseModel {
return this.table.get(id);
}

async delete(id: string) {
return this.table.delete(id);
async queryAll() {
const data: DBModel<DB_Message>[] = await this.table.orderBy('updatedAt').toArray();

return data.map((element) => this.mapToChatMessage(element));
}

async clearTable() {
return this.table.clear();
async queryBySessionId(sessionId: string) {
return this.table.where('sessionId').equals(sessionId).toArray();
}

async update(id: string, data: DeepPartial<DB_Message>) {
return super._update(id, data);
queryByTopicId = async (topicId: string) => {
const dbMessages = await this.table.where('topicId').equals(topicId).toArray();

return dbMessages.map((message) => this.mapToChatMessage(message));
};

async count() {
return this.table.count();
}

async updatePluginState(id: string, key: string, value: any) {
const item = await this.findById(id);
// **************** Create *************** //

return this.update(id, { pluginState: { ...item.pluginState, [key]: value } });
async create(data: CreateMessageParams) {
const id = nanoid();

const messageData: DB_Message = this.mapChatMessageToDBMessage(data as ChatMessage);

return this._addWithSync(messageData, id);
}

/**
* Batch updates multiple fields of the specified messages.
*
* @param {string[]} messageIds - The identifiers of the messages to be updated.
* @param {Partial<DB_Message>} updateFields - An object containing the fields to update and their new values.
* @returns {Promise<number>} - The number of updated messages.
*/
async batchUpdate(messageIds: string[], updateFields: Partial<DB_Message>): Promise<number> {
// Retrieve the messages by their IDs
const messagesToUpdate = await this.table.where(':id').anyOf(messageIds).toArray();
async batchCreate(messages: ChatMessage[]) {
const data: DB_Message[] = messages.map((m) => this.mapChatMessageToDBMessage(m));

// Update the specified fields of each message
const updatedMessages = messagesToUpdate.map((message) => ({
...message,
...updateFields,
}));
return this._batchAdd(data);
}

// Use the bulkPut method to update the messages in bulk
await this.table.bulkPut(updatedMessages);
async duplicateMessages(messages: ChatMessage[]): Promise<ChatMessage[]> {
const duplicatedMessages = await this.createDuplicateMessages(messages);
// 批量添加复制后的消息到数据库
await this.batchCreate(duplicatedMessages);
return duplicatedMessages;
}

return updatedMessages.length;
async createDuplicateMessages(messages: ChatMessage[]): Promise<ChatMessage[]> {
// 创建一个映射来存储原始消息ID和复制消息ID之间的关系
const idMapping = new Map<string, string>();

// 首先复制所有消息,并为每个复制的消息生成新的ID
const duplicatedMessages = messages.map((originalMessage) => {
const newId = nanoid();
idMapping.set(originalMessage.id, newId);

return { ...originalMessage, id: newId };
});

// 更新 parentId 为复制后的新ID
for (const duplicatedMessage of duplicatedMessages) {
if (duplicatedMessage.parentId && idMapping.has(duplicatedMessage.parentId)) {
duplicatedMessage.parentId = idMapping.get(duplicatedMessage.parentId);
}
}

return duplicatedMessages;
}

// **************** Delete *************** //

async delete(id: string) {
return super._deleteWithSync(id);
}

async clearTable() {
return this.table.clear();
}

/**
Expand Down Expand Up @@ -158,55 +181,43 @@ class _MessageModel extends BaseModel {
return this.table.bulkDelete(messageIds);
}

async queryAll() {
const data: DBModel<DB_Message>[] = await this.table.orderBy('updatedAt').toArray();
// **************** Update *************** //

return data.map((element) => this.mapToChatMessage(element));
}

async count() {
return this.table.count();
}

async queryBySessionId(sessionId: string) {
return this.table.where('sessionId').equals(sessionId).toArray();
async update(id: string, data: DeepPartial<DB_Message>) {
return super._updateWithSync(id, data);
}

queryByTopicId = async (topicId: string) => {
const dbMessages = await this.table.where('topicId').equals(topicId).toArray();

return dbMessages.map((message) => this.mapToChatMessage(message));
};
async updatePluginState(id: string, key: string, value: any) {
const item = await this.findById(id);

async duplicateMessages(messages: ChatMessage[]): Promise<ChatMessage[]> {
const duplicatedMessages = await this.createDuplicateMessages(messages);
// 批量添加复制后的消息到数据库
await this.batchCreate(duplicatedMessages);
return duplicatedMessages;
return this.update(id, { pluginState: { ...item.pluginState, [key]: value } });
}

async createDuplicateMessages(messages: ChatMessage[]): Promise<ChatMessage[]> {
// 创建一个映射来存储原始消息ID和复制消息ID之间的关系
const idMapping = new Map<string, string>();

// 首先复制所有消息,并为每个复制的消息生成新的ID
const duplicatedMessages = messages.map((originalMessage) => {
const newId = nanoid();
idMapping.set(originalMessage.id, newId);
/**
* Batch updates multiple fields of the specified messages.
*
* @param {string[]} messageIds - The identifiers of the messages to be updated.
* @param {Partial<DB_Message>} updateFields - An object containing the fields to update and their new values.
* @returns {Promise<number>} - The number of updated messages.
*/
async batchUpdate(messageIds: string[], updateFields: Partial<DB_Message>): Promise<number> {
// Retrieve the messages by their IDs
const messagesToUpdate = await this.table.where(':id').anyOf(messageIds).toArray();

return { ...originalMessage, id: newId };
});
// Update the specified fields of each message
const updatedMessages = messagesToUpdate.map((message) => ({
...message,
...updateFields,
}));

// 更新 parentId 为复制后的新ID
for (const duplicatedMessage of duplicatedMessages) {
if (duplicatedMessage.parentId && idMapping.has(duplicatedMessage.parentId)) {
duplicatedMessage.parentId = idMapping.get(duplicatedMessage.parentId);
}
}
// Use the bulkPut method to update the messages in bulk
await this.table.bulkPut(updatedMessages);

return duplicatedMessages;
return updatedMessages.length;
}

// **************** Helper *************** //

private mapChatMessageToDBMessage(message: ChatMessage): DB_Message {
const { extra, ...messageData } = message;

Expand Down
6 changes: 3 additions & 3 deletions src/database/models/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class _SessionModel extends BaseModel {
async create(type: 'agent' | 'group', defaultValue: Partial<LobeAgentSession>, id = uuid()) {
const data = merge(DEFAULT_AGENT_LOBE_SESSION, { type, ...defaultValue });
const dataDB = this.mapToDB_Session(data);
return this._add(dataDB, id);
return this._addWithSync(dataDB, id);
}

async batchCreate(sessions: LobeAgentSession[]) {
Expand Down Expand Up @@ -104,7 +104,7 @@ class _SessionModel extends BaseModel {
}

async update(id: string, data: Partial<DB_Session>) {
return super._update(id, data);
return super._updateWithSync(id, data);
}

async updatePinned(id: string, pinned: boolean) {
Expand Down Expand Up @@ -231,7 +231,7 @@ class _SessionModel extends BaseModel {

const newSession = merge(session, { meta: { title: newTitle } });

return this._add(newSession, uuid());
return this._addWithSync(newSession, uuid());
}

async getPinnedSessions(): Promise<LobeSessions> {
Expand Down
4 changes: 2 additions & 2 deletions src/database/models/sessionGroup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class _SessionGroupModel extends BaseModel {
}

async create(name: string, sort?: number, id = nanoid()) {
return this._add({ name, sort }, id);
return this._addWithSync({ name, sort }, id);
}
async batchCreate(groups: SessionGroups) {
return this._batchAdd(groups, { idGenerator: nanoid });
Expand All @@ -20,7 +20,7 @@ class _SessionGroupModel extends BaseModel {
}

async update(id: string, data: Partial<DB_SessionGroup>) {
return super._update(id, data);
return super._updateWithSync(id, data);
}

async delete(id: string, removeGroupItem: boolean = false) {
Expand Down
Loading

0 comments on commit 303ec63

Please sign in to comment.