Skip to content

Commit

Permalink
Merge pull request #140 from o1-labs/perf/bigint-backend-3
Browse files Browse the repository at this point in the history
Bigint backend pt 3: Move ec point conversion to TS
  • Loading branch information
mitschabaude authored Sep 22, 2023
2 parents 03e1b6e + d7d58ac commit 5174c22
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 373 deletions.
2 changes: 1 addition & 1 deletion MINA_COMMIT
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
The mina commit used to generate the backends for node and web is
366a3606d8c99a807b9f398d13e9571ed5b26cb1
2892c4351234f0bb1913f3e6d5b2469d86881ff5
16 changes: 8 additions & 8 deletions compiled/node_bindings/plonk_wasm.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -8665,14 +8665,6 @@ module.exports.__wbindgen_is_object = function(arg0) {
return ret;
};

module.exports.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };

module.exports.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };

module.exports.__wbg_crypto_e1d53a1d73fb10b8 = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
Expand Down Expand Up @@ -8713,6 +8705,14 @@ module.exports.__wbindgen_is_function = function(arg0) {
return ret;
};

module.exports.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };

module.exports.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };

module.exports.__wbg_get_27fe3dac1c4d0224 = function(arg0, arg1) {
const ret = getObject(arg0)[arg1 >>> 0];
return addHeapObject(ret);
Expand Down
Binary file modified compiled/node_bindings/plonk_wasm_bg.wasm
Binary file not shown.
230 changes: 111 additions & 119 deletions compiled/node_bindings/snarky_js_node.bc.cjs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion compiled/node_bindings/snarky_js_node.bc.map

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions compiled/web_bindings/plonk_wasm.js
Original file line number Diff line number Diff line change
Expand Up @@ -8618,12 +8618,6 @@ function getImports() {
const ret = typeof(val) === 'object' && val !== null;
return ret;
};
imports.wbg.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };
imports.wbg.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };
imports.wbg.__wbg_crypto_e1d53a1d73fb10b8 = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
Expand Down Expand Up @@ -8656,6 +8650,12 @@ function getImports() {
const ret = typeof(getObject(arg0)) === 'function';
return ret;
};
imports.wbg.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };
imports.wbg.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };
imports.wbg.__wbg_get_27fe3dac1c4d0224 = function(arg0, arg1) {
const ret = getObject(arg0)[arg1 >>> 0];
return addHeapObject(ret);
Expand Down
Binary file modified compiled/web_bindings/plonk_wasm_bg.wasm
Binary file not shown.
12 changes: 6 additions & 6 deletions compiled/web_bindings/snarky_js_web.bc.js

Large diffs are not rendered by default.

198 changes: 189 additions & 9 deletions crypto/bindings/conversion.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Implementation of Kimchi_bindings.Protocol.Gates
*/
import { MlArray } from '../../../lib/ml/base.js';
import { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js';
import { mapTuple } from './util.js';
import type {
WasmFpGate,
Expand All @@ -10,9 +10,15 @@ import type {
import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs';
import { bigIntToBytes, bytesToBigInt } from '../bigint-helpers.js';

export { createRustConversion };

type Field = Uint8Array;

export { createRustConversion };
// Kimchi_types.or_infinity
type Infinity = 0;
const Infinity = 0;
type Finite<T> = [0, T];
type OrInfinity = Infinity | Finite<MlTuple<Field, Field>>;

// ml types from kimchi_types.ml
type GateType = number;
Expand All @@ -24,6 +30,12 @@ type Gate = [
coeffs: MlArray<Field>
];

type PolyComm = [
_: 0,
unshifted: MlArray<OrInfinity>,
shifted: MlOption<OrInfinity>
];

type wasm = typeof wasmNamespace;

// TODO: Hardcoding this is a little brittle
Expand All @@ -35,9 +47,17 @@ function createRustConversion(wasm: wasm) {
return wasm.Wire.create(row, col);
}

function perField<WasmGate extends typeof WasmFpGate | typeof WasmFqGate>(
WasmGate: WasmGate
) {
function perField<WasmGate extends typeof WasmFpGate | typeof WasmFqGate>({
WasmGate,
WasmPolyComm,
CommitmentCurve,
makeCommitmentCurve,
}: {
WasmGate: WasmGate;
WasmPolyComm: WasmPolyCommClass;
CommitmentCurve: WrapperClass<WasmAffine>;
makeCommitmentCurve: MakeAffine<WasmAffine>;
}) {
return {
vectorToRust: fieldsToRustFlat,
vectorFromRust: fieldsFromRustFlat,
Expand All @@ -47,18 +67,59 @@ function createRustConversion(wasm: wasm) {
let rustCoeffs = fieldsToRustFlat(coeffs);
return new WasmGate(typ, rustWires, rustCoeffs);
},
pointToRust(point: OrInfinity) {
return affineToRust(point, makeCommitmentCurve);
},
pointFromRust(point: WasmAffine) {
return affineFromRust(point);
},
pointsToRust(points: MlArray<OrInfinity>) {
return mlArrayToRustVector(points, affineToRust, makeCommitmentCurve);
},
pointsFromRust(points: Uint32Array) {
return mlArrayFromRustVector(
points,
CommitmentCurve,
affineFromRust,
false
);
},
polyCommToRust(polyComm: PolyComm): WasmPolyComm {
return polyCommToRust(polyComm, WasmPolyComm, makeCommitmentCurve);
},
polyCommFromRust(polyComm: WasmPolyComm): PolyComm {
return polyCommFromRust(polyComm, CommitmentCurve, false);
},
};
}

const fpConversion = perField(wasm.WasmFpGate);
const fqConversion = perField(wasm.WasmFqGate);
// TODO: we have to lie about types here:
// -) the WasmGVesta class doesn't declare __wrap() but our code assumes it
// -) WasmGVesta doesn't declare the `ptr` property but our code assumes it

const fp = perField({
WasmGate: wasm.WasmFpGate,
WasmPolyComm: wasm.WasmFpPolyComm,
CommitmentCurve:
wasm.WasmGVesta as any as WrapperClass<wasmNamespace.WasmGVesta>,
makeCommitmentCurve:
wasm.caml_vesta_affine_one as MakeAffine<wasmNamespace.WasmGVesta>,
});
const fq = perField({
WasmGate: wasm.WasmFqGate,
WasmPolyComm: wasm.WasmFqPolyComm,
CommitmentCurve:
wasm.WasmGPallas as any as WrapperClass<wasmNamespace.WasmGPallas>,
makeCommitmentCurve:
wasm.caml_pallas_affine_one as MakeAffine<wasmNamespace.WasmGPallas>,
});

return {
wireToRust,
fieldsToRustFlat,
fieldsFromRustFlat,
fp: fpConversion,
fq: fqConversion,
fp,
fq,
gateFromRust(wasmGate: WasmFpGate | WasmFqGate) {
// note: this was never used and the old implementation was wrong
// (accessed non-existent fields on wasmGate)
Expand All @@ -67,6 +128,8 @@ function createRustConversion(wasm: wasm) {
};
}

// field, field vectors

// TODO make more performant
function fieldToRust(x: Field): Uint8Array {
return x;
Expand Down Expand Up @@ -98,3 +161,120 @@ function fieldsFromRustFlat(fieldBytes: Uint8Array): MlArray<Field> {
}
return [0, ...fields];
}

// affine

type WasmAffine = wasmNamespace.WasmGVesta | wasmNamespace.WasmGPallas;
type MakeAffine<A extends WasmAffine> = () => A & { ptr: number };

function affineFromRust(pt: WasmAffine): OrInfinity {
if (pt.infinity) {
pt.free();
return 0;
} else {
let x = fieldFromRust(pt.x);
let y = fieldFromRust(pt.y);
pt.free();
return [0, [0, x, y]];
}
}

function affineToRust<A extends WasmAffine>(
pt: OrInfinity,
klass: MakeAffine<A>
) {
var res = klass();
if (pt === Infinity) {
res.infinity = true;
} else {
let [, [, x, y]] = pt;
res.x = fieldToRust(x);
res.y = fieldToRust(y);
}
return res;
}

// polycomm

type WasmPolyComm = wasmNamespace.WasmFpPolyComm | wasmNamespace.WasmFqPolyComm;
type WasmPolyCommClass = wasm['WasmFpPolyComm'] | wasm['WasmFqPolyComm'];

function polyCommFromRust(
polyComm: WasmPolyComm,
klass: WrapperClass<WasmAffine>,
shouldFree: boolean
): PolyComm {
let rustShifted = polyComm.shifted;
let rustUnshifted = polyComm.unshifted;
let mlShifted = MlOption.mapTo(rustShifted, affineFromRust);
let mlUnshifted = mlArrayFromRustVector(
rustUnshifted,
klass,
affineFromRust,
shouldFree
);
return [0, mlUnshifted, mlShifted];
}

function polyCommToRust(
[, camlUnshifted, camlShifted]: PolyComm,
PolyComm: WasmPolyCommClass,
makeAffine: MakeAffine<WasmAffine>
): WasmPolyComm {
let rustShifted =
camlShifted === 0 ? undefined : affineToRust(camlShifted[1], makeAffine);
let rustUnshifted = mlArrayToRustVector(
camlUnshifted,
affineToRust,
makeAffine
);
return new PolyComm(rustUnshifted, rustShifted);
}

// generic rust helpers

type Freeable = { free(): void };
type WrappedPointer = Freeable & { ptr: number };
type WrapperClass<T extends Freeable> = {
__wrap(i: number): T;
};

const registry = new FinalizationRegistry((ptr: WrappedPointer) => {
ptr.free();
});

function mlArrayFromRustVector<TRust extends Freeable, TMl>(
rustVector: Uint32Array,
klass: WrapperClass<TRust>,
convert: (c: TRust) => TMl,
shouldFree: boolean
): MlArray<TMl> {
var n = rustVector.length;
var array: TMl[] = new Array(n);
for (let i = 0; i < n; i++) {
var rustValue = klass.__wrap(rustVector[i]);
array[i] = convert(rustValue);
if (shouldFree) rustValue.free();
}
return [0, ...array];
}

// TODO get rid of excessive indirection here

function mlArrayToRustVector<TRust extends WrappedPointer, TMl>(
[, ...array]: MlArray<TMl>,
convert: (c: TMl, makeNew: () => TRust) => TRust,
makeNew: () => TRust
): Uint32Array {
let n = array.length;
let rustVector = new Uint32Array(n);
for (var i = 0; i < n; i++) {
var rustValue = convert(array[i], makeNew);
// Beware: caller may need to do finalizer things to avoid these
// pointers disappearing out from under us.
rustVector[i] = rustValue.ptr;
// Don't free when GC runs; rust will free on its end.
registry.unregister(rustValue);
}
return rustVector;
}
Loading

0 comments on commit 5174c22

Please sign in to comment.