Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: perpetual memory WIP #267

Closed
wants to merge 13 commits into from
3 changes: 3 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ nlp:
server_url: "http://localhost:5557"
memory:
message_window: 12
perpetual:
# The number of messages to return alongside the summaries
last_n: 4
extractors:
documents:
embeddings:
Expand Down
8 changes: 7 additions & 1 deletion config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ type NLP struct {
}

type MemoryConfig struct {
MessageWindow int `mapstructure:"message_window"`
MessageWindow int `mapstructure:"message_window"`
Perpetual PerpetualMemoryConfig `mapstructure:"perpetual"`
}

type PerpetualMemoryConfig struct {
// LastN is the number of messages to return alongside the summaries
LastN int `mapstructure:"last_n"`
}

type PostgresConfig struct {
Expand Down
2 changes: 1 addition & 1 deletion docs/docs.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/swagger.json

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ definitions:
metadata:
additionalProperties: true
type: object
min_score:
description: 'TODO: implement for documents'
type: number
mmr_lambda:
type: number
search_type:
Expand Down Expand Up @@ -186,6 +189,10 @@ definitions:
metadata:
additionalProperties: true
type: object
summaries:
items:
$ref: '#/definitions/models.Summary'
type: array
summary:
$ref: '#/definitions/models.Summary'
type: object
Expand All @@ -194,6 +201,8 @@ definitions:
metadata:
additionalProperties: true
type: object
min_score:
type: number
mmr_lambda:
type: number
search_scope:
Expand Down Expand Up @@ -1155,6 +1164,13 @@ paths:
in: query
name: lastn
type: integer
- description: Memory type. Default is 'simple'
enum:
- simple
- perpetual
in: query
name: type
type: string
produces:
- application/json
responses:
Expand All @@ -1174,8 +1190,7 @@ paths:
$ref: '#/definitions/apihandlers.APIError'
security:
- Bearer: []
summary: Returns a memory (latest summary and list of messages) for a given
session
summary: Returns a memory for a given session
tags:
- memory
post:
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ require (
github.com/segmentio/asm v1.2.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/sony/gobreaker v0.5.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.10.0 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg=
github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.10.0 h1:EaGW2JJh15aKOejeuJ+wpFSHnbd7GE6Wvp3TsNhb6LY=
github.com/spf13/afero v1.10.0/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
Expand Down
7 changes: 5 additions & 2 deletions internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package internal
import (
"bytes"
"reflect"
"strings"
"text/template"

"github.com/getzep/sprig/v3"
)

func ParsePrompt(promptTemplate string, data any) (string, error) {
tmpl, err := template.New("prompt").Parse(promptTemplate)
tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(promptTemplate)
if err != nil {
return "", err
}
Expand All @@ -18,7 +21,7 @@ func ParsePrompt(promptTemplate string, data any) (string, error) {
return "", err
}

return buf.String(), nil
return strings.TrimSpace(buf.String()), nil
}

func ReverseSlice[T any](slice []T) {
Expand Down
13 changes: 13 additions & 0 deletions pkg/models/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,16 @@ type Memory struct {
Summary *Summary `json:"summary,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}

type MemoryConfig struct {
SessionID string `json:"session_id"`
LastNMessages int `json:"last_n"`
Type MemoryType `json:"type"`
}

type MemoryType string

const (
SimpleMemoryType MemoryType = "simple"
PerpetualMemoryType MemoryType = "perpetual"
)
14 changes: 5 additions & 9 deletions pkg/models/memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,13 @@ type MessageStorer interface {
}

type MemoryStorer interface {
// GetMemory returns the most recent Summary and a list of messages for a given sessionID.
// GetMemory returns:
// - the most recent Summary, if one exists
// - the lastNMessages messages, if lastNMessages > 0
// - all messages since the last SummaryPoint, if lastNMessages == 0
// - if no Summary (and no SummaryPoint) exists and lastNMessages == 0, returns
// all undeleted messages
// GetMemory returns memory for a given sessionID.
// If config.Type is SimpleMemoryType, returns the most recent Summary and a list of messages.
// If config.Type is PerpetualMemoryType, returns the last X messages, optionally the most recent summary
// and a list of summaries semantically similar to the most recent messages.
GetMemory(ctx context.Context,
appState *AppState,
sessionID string,
lastNMessages int) (*Memory, error)
config *MemoryConfig) (*Memory, error)
// PutMemory stores a Memory for a given sessionID. If the SessionID doesn't exist, a new one is created.
PutMemory(ctx context.Context,
appState *AppState,
Expand Down
4 changes: 4 additions & 0 deletions pkg/models/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ type MemorySearchResult struct {

type MemorySearchPayload struct {
Text string `json:"text"`
Embedding []float32 `json:"embedding,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
SearchScope SearchScope `json:"search_scope,omitempty"`
SearchType SearchType `json:"search_type,omitempty"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
MinScore float32 `json:"min_score,omitempty"`
}

type DocumentSearchPayload struct {
Expand All @@ -37,6 +39,8 @@ type DocumentSearchPayload struct {
Metadata map[string]interface{} `json:"metadata,omitempty"`
SearchType SearchType `json:"search_type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
// TODO: implement for documents
MinScore float32 `json:"min_score,omitempty"`
}

type DocumentSearchResult struct {
Expand Down
Loading
Loading