forked from georg-jung/FastBertTokenizer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TokenizeSpeed.cs
156 lines (133 loc) · 5.52 KB
/
TokenizeSpeed.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright (c) Georg Jung. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System.IO.Compression;
using System.Text.Json;
using System.Text.RegularExpressions;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Diagnosers;
using BERTTokenizers.Base;
using FastBertTokenizer;
using RustLibWrapper;
namespace Benchmarks;
[MemoryDiagnoser]
/*
[PerfCollectProfiler(performExtraBenchmarksRun: false)]
[EtwProfiler(performExtraBenchmarksRun: false)]
[EventPipeProfiler(EventPipeProfile.CpuSampling)] // for speedscope files
*/
public class TokenizeSpeed
{
private readonly string[] _corpus;
private readonly List<string> _otherLibCorpus;
private readonly ConcreteUncasedTokenizer _otherLibTokenizer;
private readonly BertTokenizer _tokenizer;
private readonly int _maxSequenceLength;
public TokenizeSpeed()
: this("data/wiki-simple.json.br", "data/baai-bge-small-en-vocab.txt", "data/baai-bge-small-en-tokenizer.json", 512)
{
}
public TokenizeSpeed(string corpusPath, string vocabTxtFile, string tokenizerJsonPath, int maxSequenceLength)
{
RustTokenizer.LoadTokenizer(tokenizerJsonPath, maxSequenceLength);
using var fs = File.OpenRead(corpusPath);
using var uncompress = new BrotliStream(fs, CompressionMode.Decompress);
var dict = JsonSerializer.Deserialize<Dictionary<int, string>>(uncompress)!;
_corpus = new string[dict.Count];
_otherLibCorpus = new(dict.Count);
var cnt = 0;
foreach (var tx in dict.Values)
{
_corpus[cnt] = tx;
// this preprocessing gives the other lib kind of an unfair advantage, but it throws otherwise
var otherLib = tx.Substring(0, Math.Min(tx.Length, 1250)); // other lib throw if text is too long; 1250 works with 512 tokens, 1500 doesn't; 5000 works with 2048 tokens
otherLib = Regex.Replace(otherLib, @"\s+", " "); // required due to bad whitespace processing of other lib
otherLib = Regex.Replace(otherLib, @"[^A-Za-z0-9\s\.\,;:\\/?!#$%()=+\-*\""'–_`<>&^@{}[\]\|~']+", string.Empty); // other lib doesn't handle unknown characters
_otherLibCorpus.Add(otherLib);
cnt++;
}
_otherLibTokenizer = new(vocabTxtFile);
_tokenizer = new();
using var sr = File.OpenText(vocabTxtFile);
_tokenizer.LoadVocabulary(sr, true);
_maxSequenceLength = maxSequenceLength;
}
[Benchmark]
public IReadOnlyCollection<object> OtherLib()
{
List<object> res = new(_otherLibCorpus.Count);
foreach (var text in _otherLibCorpus)
{
res.Add(_otherLibTokenizer.Encode(_maxSequenceLength, text));
}
return res;
}
[Benchmark]
public object RustHuggingfaceWrapperSinglethreadedMemReuse()
{
var inputIds = new uint[_maxSequenceLength];
var attMask = new uint[_maxSequenceLength];
foreach (var text in _otherLibCorpus)
{
RustTokenizer.TokenizeAndGetIds(text, inputIds.AsSpan(), attMask.AsSpan());
}
return (inputIds, attMask);
}
[Benchmark(Baseline = true)]
public IReadOnlyCollection<object> FastBertTokenizerSinglethreadedAllocating()
{
List<object> res = new(_corpus.Length);
foreach (var text in _corpus)
{
res.Add(_tokenizer.Tokenize(text, _maxSequenceLength));
}
return res;
}
[Benchmark]
public object FastBertTokenizerSingleThreadedMemReuse()
{
var iids = new long[_maxSequenceLength];
var attm = new long[_maxSequenceLength];
var toktyp = new long[_maxSequenceLength];
Array.Fill(toktyp, 0);
foreach (var text in _corpus)
{
_tokenizer.Tokenize(text, iids, attm);
}
return (iids, attm, toktyp);
}
[Benchmark]
public IReadOnlyCollection<(Memory<long> InputIds, Memory<long> AttentionMask, Memory<long> TokenTypeIds)> FastBertTokenizerMultithreadedAllocating()
{
// this might be interesting to benchmark but doesn't make much sense as a real world use case
List<(Memory<long> InputIds, Memory<long> AttentionMask, Memory<long> TokenTypeIds)> res = new(_corpus.Length);
var x = _corpus.AsParallel().AsOrdered().Select(x => _tokenizer.Tokenize(x, _maxSequenceLength));
res.AddRange(x);
return res;
}
[Benchmark]
public (Memory<long> InputIds, Memory<long> AttentionMask, Memory<long> TokenTypeIds) FastBertTokenizerMultithreadedMemReuse()
{
var batchSize = 1000;
var iids = new long[_maxSequenceLength * batchSize];
var attm = new long[_maxSequenceLength * batchSize];
var toktyp = new long[_maxSequenceLength * batchSize];
Array.Fill(toktyp, 0);
var corpMem = _corpus.AsMemory();
for (var i = 0; i < corpMem.Length; i += batchSize)
{
var len = Math.Min(batchSize, corpMem.Length - i);
var batchSeqLen = _maxSequenceLength * len;
var iidsM = iids.AsMemory(0, batchSeqLen);
var attmM = attm.AsMemory(0, batchSeqLen);
_tokenizer.Tokenize(corpMem.Slice(i, len), iidsM, attmM, _maxSequenceLength);
}
return (iids.AsMemory(), attm.AsMemory(), toktyp.AsMemory());
}
private sealed class ConcreteUncasedTokenizer : UncasedTokenizer
{
public ConcreteUncasedTokenizer(string vocabularyFilePath)
: base(vocabularyFilePath)
{
}
}
}