Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace js-tiktoken BPE merge algorithm with faster heap based algorithm #101

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 178 additions & 17 deletions js/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,198 @@ import base64 from "base64-js";
import type { TiktokenModel } from "./ranks/ranks";
import { never } from "./utils";

type BPEMergeNode = {
listNext: BPEMergeNode | null;
listPrev: BPEMergeNode | null;

deleted: boolean;
updated: boolean;
updatedRank: number;
removed: boolean;

rank: number;
start: number;
end: number;
};

function compareNode(a: BPEMergeNode, b: BPEMergeNode) {
return a.rank - b.rank || a.start - b.start;
}

// Helper function to swap elements at two indices
function swap(heap: BPEMergeNode[], i: number, j: number) {
const temp = heap[i];
heap[i] = heap[j];
heap[j] = temp;
}

// standard binary heap push, generated by gpt4
function heapPush(heap: BPEMergeNode[], part: BPEMergeNode) {
heap.push(part); // Add the new element to the end
let currentIndex = heap.length - 1;
let parentIndex = Math.floor((currentIndex - 1) / 2);

// Bubble the new element up to its correct position
while (
currentIndex > 0 &&
compareNode(heap[currentIndex], heap[parentIndex]) < 0
) {
swap(heap, currentIndex, parentIndex);
currentIndex = parentIndex;
parentIndex = Math.floor((currentIndex - 1) / 2);
}
}

// standard heap pop, also ai generated
function heapPop(heap: BPEMergeNode[]) {
if (heap.length === 0) {
return undefined; // Return undefined if the heap is empty
}

const rootValue = heap[0]; // The root element to return
const lastValue = heap.pop(); // Remove the last element

if (heap.length > 0 && lastValue) {
heap[0] = lastValue; // Move the last element to the root
let currentIndex = 0;

// Bubble down the new root element to its correct position
while (true) {
let leftChildIndex = 2 * currentIndex + 1;
let rightChildIndex = 2 * currentIndex + 2;
let smallestIndex = currentIndex;

if (
leftChildIndex < heap.length &&
compareNode(heap[leftChildIndex], heap[smallestIndex]) < 0
) {
smallestIndex = leftChildIndex;
}

if (
rightChildIndex < heap.length &&
compareNode(heap[rightChildIndex], heap[smallestIndex]) < 0
) {
smallestIndex = rightChildIndex;
}

if (smallestIndex !== currentIndex) {
swap(heap, currentIndex, smallestIndex);
currentIndex = smallestIndex;
} else {
break;
}
}
}

return rootValue;
}

function bytePairMerge(
piece: Uint8Array,
ranks: Map<string, number>
): Array<{ start: number; end: number }> {
let parts: Array<{ start: number; end: number }> = Array.from(
const parts: BPEMergeNode[] = Array.from(
{ length: piece.length },
(_, i) => ({ start: i, end: i + 1 })
(_, i) => ({
start: i,
end: i + 1,
rank: 0,
deleted: false,
updated: false,
updatedRank: 0,
removed: true,
listNext: null,
listPrev: null,
})
);

while (parts.length > 1) {
let minRank: [number, number] | null = null;
if (parts.length === 0) {
return [];
}

for (let i = 0; i < parts.length - 1; i++) {
const slice = piece.slice(parts[i].start, parts[i + 1].end);
const rank = ranks.get(slice.join(","));
if (rank == null) continue;
const head = parts[0];
for (let i = 0; i < parts.length; ++i) {
parts[i].listPrev = parts[i - 1] ?? null;
parts[i].listNext = parts[i + 1] ?? null;
}

const heap: BPEMergeNode[] = [];
for (let i = 0; i < parts.length - 1; ++i) {
const slice = piece.slice(parts[i].start, parts[i + 1].end);
const rank = ranks.get(slice.join(","));
if (rank == null) continue;
const part = parts[i];
part.removed = false;
part.rank = rank;
heapPush(heap, part);
}

while (heap.length > 0) {
const part = heapPop(heap);
if (!part) break;

// remove deleted nodes from heap
if (part.deleted) {
part.deleted = false;
part.removed = true;
continue;
}

// reinsert updated nodes
if (part.updated) {
part.rank = part.updatedRank;
part.updated = false;
heapPush(heap, part);
continue;
}

// mark node as removed from heap
part.removed = true;

// delete next part and collapse node
part.end = part.listNext?.end ?? piece.length;
if (part.listNext) part.listNext.deleted = true;
part.listNext = part.listNext?.listNext ?? null;

if (minRank == null || rank < minRank[0]) {
minRank = [rank, i];
// update rank
if (part.listNext) {
part.listNext.listPrev = part;
const slice = piece.slice(part.start, part.listNext.end);
const rank = ranks.get(slice.join(","));
if (rank != null) {
part.removed = false;
part.rank = rank;
heapPush(heap, part);
}
}

if (minRank != null) {
const i = minRank[1];
parts[i] = { start: parts[i].start, end: parts[i + 1].end };
parts.splice(i + 1, 1);
} else {
break;
// update previous part rank
if (part.listPrev) {
const prevSlice = piece.slice(part.listPrev.start, part.end);
const prevRank = ranks.get(prevSlice.join(","));
if (prevRank != null) {
if (prevRank !== part.listPrev.rank) {
if (part.listPrev.removed) {
part.listPrev.removed = false;
part.listPrev.rank = prevRank;
heapPush(heap, part);
} else {
part.listPrev.updated = true;
part.listPrev.updatedRank = prevRank;
}
}
} else {
part.listPrev.deleted = true;
}
}
}
return parts;

const result: Array<{ start: number; end: number }> = [];
for (let node: BPEMergeNode | null = head; !!node; node = node.listNext) {
result.push({ start: node.start, end: node.end });
}
return result;
}

function bytePairEncode(piece: Uint8Array, ranks: Map<string, number>) {
Expand Down
1 change: 1 addition & 0 deletions js/test/fixtures/evil-string.ts

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions js/test/perf.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { test, expect, describe, afterAll } from "vitest";
import { get_encoding } from "../../wasm/dist";
import { getEncoding } from "../src/index";
import { EVIL_STRING } from "./fixtures/evil-string";

describe("Tokenizer resolves in acceptable time", () => {
const full = get_encoding("cl100k_base");

afterAll(() => full.free());

test("Test wasm performance", () => {
const start = Date.now();
const result = full.encode(EVIL_STRING);
const end = Date.now();
expect(end - start).toBeLessThanOrEqual(5000);
});

const lite = getEncoding("cl100k_base");

test("Test wasm performance", () => {
const start = Date.now();
const result = lite.encode(EVIL_STRING);
const end = Date.now();
expect(end - start).toBeLessThanOrEqual(5000);
});
});
Loading