forked from pkoukk/tiktoken-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpe.go
75 lines (67 loc) · 1.52 KB
/
bpe.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package tiktoken
import (
"math"
)
func bytePairMerge[T any](piece []byte, ranks map[string]int, f func(start, end int) T) []T {
parts := make([][2]int, len(piece)+1)
for i := 0; i < len(parts); i++ {
parts[i][0], parts[i][1] = i, math.MaxInt // use max int as sentinel
}
getRank := func(startIdx, skip int) int {
if startIdx+skip+2 < len(parts) {
b := piece[parts[startIdx][0]:parts[startIdx+skip+2][0]]
rank, ok := ranks[string(b)]
if ok {
return rank
}
}
return -1 // use -1 to represent None
}
for i := 0; i < len(parts)-2; i++ {
if rank := getRank(i, 0); rank >= 0 {
parts[i][1] = rank
}
}
for len(parts) > 1 {
minRank, minIdx := math.MaxInt, -1
for i := 0; i < len(parts)-1; i++ {
if parts[i][1] < minRank {
minRank, minIdx = parts[i][1], i
}
}
if minRank < math.MaxInt {
i := minIdx
rank := getRank(i, 1)
if rank >= 0 {
parts[i][1] = rank
} else {
parts[i][1] = math.MaxInt
}
if i > 0 {
rk := getRank(i-1, 1)
if rk >= 0 {
parts[i-1][1] = rk
} else {
parts[i-1][1] = math.MaxInt
}
}
parts = append(parts[:i+1], parts[i+2:]...)
} else {
break
}
}
out := make([]T, len(parts)-1)
for i := 0; i < len(out); i++ {
out[i] = f(parts[i][0], parts[i+1][0])
}
return out
}
func bytePairEncode(piece []byte, ranks map[string]int) []int {
if len(piece) == 1 {
v := ranks[string(piece)]
return []int{v}
}
return bytePairMerge(piece, ranks, func(start, end int) int {
return ranks[string(piece[start:end])]
})
}