-
Notifications
You must be signed in to change notification settings - Fork 1
/
db.go
134 lines (116 loc) · 3.14 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package main
import (
"database/sql"
"encoding/json"
"log"
"os"
"github.com/google/generative-ai-go/genai"
)
const (
DB_PATH = "/.local/share/gemini/"
DB_NAME = "gemini.db"
)
type GeminiChatHistory struct {
ID int64 `db:"id"`
ChatID int64 `db:"chat_id"`
Prompt string `db:"prompt"`
Role string `db:"role"`
CreateTime string `db:"create_time"`
}
type GeminiChatList struct {
ID int64 `db:"id"`
ChatID int64 `db:"chat_id"`
ChatTitle string `db:"chat_title"`
CreateTime string `db:"create_time"`
}
type DB struct {
SqliteDB *sql.DB
}
func initDB() *DB {
FULL_DB_PATH := HOME_PATH + DB_PATH
if _, err := os.Stat(FULL_DB_PATH); os.IsNotExist(err) {
os.MkdirAll(FULL_DB_PATH, os.ModePerm)
}
sqliteDB, err := sql.Open("sqlite3", FULL_DB_PATH+DB_NAME)
if err != nil {
log.Fatal(err)
}
_, err = sqliteDB.Exec(`CREATE TABLE IF NOT EXISTS gemini_chat_history (
id INTEGER PRIMARY KEY,
chat_id INTEGER,
prompt TEXT,
role TEXT,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)`)
if err != nil {
log.Fatal(err)
}
_, err = sqliteDB.Exec(`CREATE TABLE IF NOT EXISTS gemini_chat_list (
id INTEGER PRIMARY KEY,
chat_id INTEGER,
chat_title TEXT,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
)`)
if err != nil {
log.Fatal(err)
}
return &DB{
SqliteDB: sqliteDB,
}
}
func (db *DB) InsertHistory(chat GeminiChatHistory) error {
_, err := db.SqliteDB.Exec(`INSERT INTO gemini_chat_history (chat_id, prompt, role) VALUES (?, ?, ?)`, chat.ChatID, chat.Prompt, chat.Role)
return err
}
func (db *DB) InsertHistoryWithTX(tx *sql.Tx, chat GeminiChatHistory) error {
_, err := tx.Exec(`INSERT INTO gemini_chat_history (chat_id, prompt, role) VALUES (?, ?, ?)`, chat.ChatID, chat.Prompt, chat.Role)
return err
}
func (db *DB) GetLatestChatID() (int, error) {
var chatID int
err := db.SqliteDB.QueryRow(`SELECT chat_id FROM gemini_chat_list ORDER BY id DESC LIMIT 1`).Scan(&chatID)
if err != nil && err.Error() == "sql: no rows in result set" {
return 0, nil
} else if err != nil {
return 0, err
}
return chatID, nil
}
func (db *DB) InsertChat(chat GeminiChatList) error {
_, err := db.SqliteDB.Exec(`INSERT INTO gemini_chat_list (chat_id, chat_title) VALUES (?, ?)`, chat.ChatID, chat.ChatTitle)
return err
}
func (db *DB) GetByChatID(chatId int) ([]*genai.Content, error) {
chatHistoryList := make([]*genai.Content, 0)
rows, err := db.SqliteDB.Query(`SELECT prompt,role FROM gemini_chat_history WHERE chat_id = ?`, chatId)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var prompt string
var role string
err := rows.Scan(&prompt, &role)
if err != nil {
return nil, err
}
chatHistoryList = append(chatHistoryList, &genai.Content{
Parts: parsePrompt(prompt),
Role: role,
})
}
return chatHistoryList, nil
}
func parsePrompt(prompt string) []genai.Part {
// 解析prompt数组
var promptList []string
err := json.Unmarshal([]byte(prompt), &promptList)
if err != nil {
log.Fatal(err)
}
promptPart := make([]genai.Part, 0)
for _, v := range promptList {
promptPart = append(promptPart, genai.Text(v))
}
return promptPart
}