Skip to content

Commit

Permalink
feat: add zipzset
Browse files Browse the repository at this point in the history
  • Loading branch information
xgzlucario committed Nov 16, 2024
1 parent cbf3df7 commit d44517a
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 133 deletions.
49 changes: 31 additions & 18 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func spopCommand(writer *resp.Writer, args []resp.RESP) {
func zaddCommand(writer *resp.Writer, args []resp.RESP) {
key := args[0]
args = args[1:]
zset, err := fetchZSet(key, true)
zs, err := fetchZSet(key, true)
if err != nil {
writer.WriteError(err)
return
Expand All @@ -435,7 +435,7 @@ func zaddCommand(writer *resp.Writer, args []resp.RESP) {
return
}
key := args[i+1].ToString()
if zset.Set(key, score) {
if zs.Set(key, score) {
count++
}
}
Expand All @@ -445,12 +445,12 @@ func zaddCommand(writer *resp.Writer, args []resp.RESP) {
func zrankCommand(writer *resp.Writer, args []resp.RESP) {
key := args[0]
member := args[1].ToStringUnsafe()
zset, err := fetchZSet(key)
zs, err := fetchZSet(key)
if err != nil {
writer.WriteError(err)
return
}
rank, _ := zset.Rank(member)
rank := zs.Rank(member)
if rank < 0 {
writer.WriteNull()
} else {
Expand All @@ -460,14 +460,14 @@ func zrankCommand(writer *resp.Writer, args []resp.RESP) {

func zremCommand(writer *resp.Writer, args []resp.RESP) {
key := args[0]
zset, err := fetchZSet(key)
zs, err := fetchZSet(key)
if err != nil {
writer.WriteError(err)
return
}
var count int
for _, arg := range args[1:] {
if zset.Remove(arg.ToStringUnsafe()) {
if zs.Remove(arg.ToStringUnsafe()) {
count++
}
}
Expand All @@ -486,29 +486,37 @@ func zrangeCommand(writer *resp.Writer, args []resp.RESP) {
writer.WriteError(err)
return
}
zset, err := fetchZSet(key)
zs, err := fetchZSet(key)
if err != nil {
writer.WriteError(err)
return
}

if stop == -1 {
stop = zset.Len()
stop = zs.Len()
}
start = min(start, stop)

withScores := len(args) == 4 && equalFold(args[3].ToStringUnsafe(), WithScores)
if withScores {
writer.WriteArrayHead((stop - start) * 2)
zset.Range(start, stop, func(key string, score float64) {
writer.WriteBulkString(key)
writer.WriteFloat(score)
zs.Scan(func(key string, score float64) {
if start <= 0 && stop >= 0 {
writer.WriteBulkString(key)
writer.WriteFloat(score)
}
start--
stop--
})

} else {
writer.WriteArrayHead(stop - start)
zset.Range(start, stop, func(key string, _ float64) {
writer.WriteBulkString(key)
zs.Scan(func(key string, _ float64) {
if start <= 0 && stop >= 0 {
writer.WriteBulkString(key)
}
start--
stop--
})
}
}
Expand All @@ -524,13 +532,11 @@ func zpopminCommand(writer *resp.Writer, args []resp.RESP) {
return
}
}

zs, err := fetchZSet(key)
if err != nil {
writer.WriteError(err)
return
}

size := min(zs.Len(), count)
writer.WriteArrayHead(size * 2)
for range size {
Expand Down Expand Up @@ -578,7 +584,7 @@ func evalCommand(writer *resp.Writer, args []resp.RESP) {
writer.WriteBulkString(res.String())

case lua.LNumber:
writer.WriteInteger(int(res)) // convert to integer
writer.WriteInteger(int(res))

case *lua.LTable:
writer.WriteArrayHead(res.Len())
Expand Down Expand Up @@ -610,7 +616,7 @@ func fetchSet(key []byte, setnx ...bool) (Set, error) {
}

func fetchZSet(key []byte, setnx ...bool) (ZSet, error) {
return fetch(key, func() ZSet { return zset.New() }, setnx...)
return fetch(key, func() ZSet { return zset.NewZipZSet() }, setnx...)
}

func fetch[T any](key []byte, new func() T, setnx ...bool) (T, error) {
Expand All @@ -628,11 +634,14 @@ func fetch[T any](key []byte, new func() T, setnx ...bool) (T, error) {
if data.Len() >= 256 {
db.dict.Set(string(key), data.ToMap())
}

case *hash.ZipSet:
if data.Len() >= 512 {
db.dict.Set(string(key), data.ToSet())
}
case *zset.ZipZSet:
if data.Len() >= 256 {
db.dict.Set(string(key), data.ToZSet())
}
}
}
return v, nil
Expand Down Expand Up @@ -666,6 +675,10 @@ func getObjectType(object any) ObjectType {
return TypeList
case *zset.ZSet:
return TypeZSet
case *zset.ZipZSet:
return TypeZipZSet
default:
panic("unknown type")
}
return TypeUnknown
}
17 changes: 10 additions & 7 deletions const.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"github.com/redis/go-redis/v9"
"github.com/xgzlucario/rotom/internal/hash"
"github.com/xgzlucario/rotom/internal/iface"
"github.com/xgzlucario/rotom/internal/list"
Expand All @@ -19,10 +20,11 @@ const (
TypeZipSet
TypeList
TypeZSet
TypeZipZSet
)

const (
TTL_FOREVER = -1
KeepTTL = redis.KeepTTL
KEY_NOT_EXIST = -2
)

Expand All @@ -34,10 +36,11 @@ const (

// type2c is objectType to new encoder.
var type2c = map[ObjectType]func() iface.Encoder{
TypeMap: func() iface.Encoder { return hash.NewMap() },
TypeZipMap: func() iface.Encoder { return hash.NewZipMap() },
TypeSet: func() iface.Encoder { return hash.NewSet() },
TypeZipSet: func() iface.Encoder { return hash.NewZipSet() },
TypeList: func() iface.Encoder { return list.New() },
TypeZSet: func() iface.Encoder { return zset.New() },
TypeMap: func() iface.Encoder { return hash.NewMap() },
TypeZipMap: func() iface.Encoder { return hash.NewZipMap() },
TypeSet: func() iface.Encoder { return hash.NewSet() },
TypeZipSet: func() iface.Encoder { return hash.NewZipSet() },
TypeList: func() iface.Encoder { return list.New() },
TypeZSet: func() iface.Encoder { return zset.New() },
TypeZipZSet: func() iface.Encoder { return zset.NewZipZSet() },
}
2 changes: 1 addition & 1 deletion dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (dict *Dict) Get(key string) (any, int) {

ts, ok := dict.expire.Get(key)
if !ok {
return data, TTL_FOREVER
return data, KeepTTL
}

// key expired
Expand Down
4 changes: 2 additions & 2 deletions dict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestDict(t *testing.T) {
dict.Set("key", []byte("hello"))

data, ttl := dict.Get("key")
assert.Equal(ttl, TTL_FOREVER)
assert.Equal(ttl, KeepTTL)
assert.Equal(data, []byte("hello"))

data, ttl = dict.Get("none")
Expand All @@ -36,7 +36,7 @@ func TestDict(t *testing.T) {
res := dict.SetTTL("key", time.Now().Add(-time.Second).UnixNano())
assert.Equal(res, 1)

res = dict.SetTTL("not-exist", TTL_FOREVER)
res = dict.SetTTL("not-exist", KeepTTL)
assert.Equal(res, 0)

// get expired
Expand Down
61 changes: 39 additions & 22 deletions internal/hash/zipmap.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package hash

import (
"bytes"
"encoding/binary"
"github.com/xgzlucario/rotom/internal/iface"
"github.com/xgzlucario/rotom/internal/resp"
"unsafe"
Expand All @@ -10,46 +12,62 @@ import (

var _ iface.MapI = (*ZipMap)(nil)

// ZipMap store data as [val1, key1, val2, key2...] in listpack.
// ZipMap store data as [entryN, ..., entry1, entry0] in listpack.
type ZipMap struct {
data *list.ListPack
}

func NewZipMap() *ZipMap {
return &ZipMap{list.NewListPack()}
return &ZipMap{data: list.NewListPack()}
}

func (zm *ZipMap) seek(key string) (it *list.LpIterator, val []byte) {
func (zm *ZipMap) buildKey(key string) []byte {
entry := make([]byte, 0, 16)
entry = binary.AppendUvarint(entry, uint64(len(key)))
return append(entry, key...)
}

// entry store as [keyLen, key, val].
func (zm *ZipMap) encode(key string, val []byte) []byte {
return append(zm.buildKey(key), val...)
}

func (*ZipMap) decode(entry []byte) (string, []byte) {
klen, n := binary.Uvarint(entry)
key := entry[n : klen+uint64(n)]
val := entry[klen+uint64(n):]
return b2s(key), val
}

func (zm *ZipMap) seek(key string) (it *list.LpIterator, entry []byte) {
it = zm.data.Iterator().SeekLast()
prefix := zm.buildKey(key)
for !it.IsFirst() {
kBytes := it.Prev()
vBytes := it.Prev()
if key == b2s(kBytes) {
return it, vBytes
entry = it.Prev()
if bytes.HasPrefix(entry, prefix) {
return it, entry
}
}
return nil, nil
}

func (zm *ZipMap) Set(key string, val []byte) (newField bool) {
it, oldVal := zm.seek(key)
func (zm *ZipMap) Set(key string, val []byte) bool {
it, _ := zm.seek(key)
entry := b2s(zm.encode(key, val))
// update
if it != nil {
if len(val) == len(oldVal) {
copy(oldVal, val)
} else {
it.ReplaceNext(b2s(val))
}
it.ReplaceNext(entry)
return false
}
// insert
zm.data.RPush(b2s(val), key)
zm.data.RPush(entry)
return true
}

func (zm *ZipMap) Get(key string) ([]byte, bool) {
_, val := zm.seek(key)
if val != nil {
it, entry := zm.seek(key)
if it != nil {
_, val := zm.decode(entry)
return val, true
}
return nil, false
Expand All @@ -58,7 +76,7 @@ func (zm *ZipMap) Get(key string) ([]byte, bool) {
func (zm *ZipMap) Remove(key string) bool {
it, _ := zm.seek(key)
if it != nil {
it.RemoveNexts(2, nil)
it.RemoveNext()
return true
}
return false
Expand All @@ -67,9 +85,8 @@ func (zm *ZipMap) Remove(key string) bool {
func (zm *ZipMap) Scan(fn func(string, []byte)) {
it := zm.data.Iterator().SeekLast()
for !it.IsFirst() {
key := it.Prev()
val := it.Prev()
fn(b2s(key), val)
key, val := zm.decode(it.Prev())
fn(key, val)
}
}

Expand All @@ -81,7 +98,7 @@ func (zm *ZipMap) ToMap() *Map {
return m
}

func (zm *ZipMap) Len() int { return zm.data.Size() / 2 }
func (zm *ZipMap) Len() int { return zm.data.Size() }

func (zm *ZipMap) Encode(writer *resp.Writer) error {
return zm.data.Encode(writer)
Expand Down
2 changes: 1 addition & 1 deletion internal/hash/zipset.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (zs *ZipSet) Remove(key string) bool {
for !it.IsFirst() {
entry := it.Prev()
if key == b2s(entry) {
it.RemoveNexts(1, nil)
it.RemoveNext()
return true
}
}
Expand Down
4 changes: 4 additions & 0 deletions internal/iface/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,9 @@ type ZSetI interface {
Encoder
Get(key string) (score float64, ok bool)
Set(key string, score float64) bool
Remove(key string) bool
Len() int
PopMin() (key string, score float64)
Rank(key string) int
Scan(fn func(key string, score float64))
}
Loading

0 comments on commit d44517a

Please sign in to comment.