Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bigint backend pt 3: Move ec point conversion to TS #140

Merged
merged 9 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MINA_COMMIT
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
The mina commit used to generate the backends for node and web is
366a3606d8c99a807b9f398d13e9571ed5b26cb1
2892c4351234f0bb1913f3e6d5b2469d86881ff5
16 changes: 8 additions & 8 deletions compiled/node_bindings/plonk_wasm.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -8665,14 +8665,6 @@ module.exports.__wbindgen_is_object = function(arg0) {
return ret;
};

module.exports.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };

module.exports.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };

module.exports.__wbg_crypto_e1d53a1d73fb10b8 = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
Expand Down Expand Up @@ -8713,6 +8705,14 @@ module.exports.__wbindgen_is_function = function(arg0) {
return ret;
};

module.exports.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };

module.exports.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };

module.exports.__wbg_get_27fe3dac1c4d0224 = function(arg0, arg1) {
const ret = getObject(arg0)[arg1 >>> 0];
return addHeapObject(ret);
Expand Down
Binary file modified compiled/node_bindings/plonk_wasm_bg.wasm
Binary file not shown.
230 changes: 111 additions & 119 deletions compiled/node_bindings/snarky_js_node.bc.cjs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion compiled/node_bindings/snarky_js_node.bc.map

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions compiled/web_bindings/plonk_wasm.js
Original file line number Diff line number Diff line change
Expand Up @@ -8618,12 +8618,6 @@ function getImports() {
const ret = typeof(val) === 'object' && val !== null;
return ret;
};
imports.wbg.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };
imports.wbg.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };
imports.wbg.__wbg_crypto_e1d53a1d73fb10b8 = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
Expand Down Expand Up @@ -8656,6 +8650,12 @@ function getImports() {
const ret = typeof(getObject(arg0)) === 'function';
return ret;
};
imports.wbg.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };
imports.wbg.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) {
getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2));
}, arguments) };
imports.wbg.__wbg_get_27fe3dac1c4d0224 = function(arg0, arg1) {
const ret = getObject(arg0)[arg1 >>> 0];
return addHeapObject(ret);
Expand Down
Binary file modified compiled/web_bindings/plonk_wasm_bg.wasm
Binary file not shown.
12 changes: 6 additions & 6 deletions compiled/web_bindings/snarky_js_web.bc.js

Large diffs are not rendered by default.

198 changes: 189 additions & 9 deletions crypto/bindings/conversion.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Implementation of Kimchi_bindings.Protocol.Gates
*/
import { MlArray } from '../../../lib/ml/base.js';
import { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js';
import { mapTuple } from './util.js';
import type {
WasmFpGate,
Expand All @@ -10,9 +10,15 @@ import type {
import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs';
import { bigIntToBytes, bytesToBigInt } from '../bigint-helpers.js';

export { createRustConversion };
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved

type Field = Uint8Array;

export { createRustConversion };
// Kimchi_types.or_infinity
type Infinity = 0;
const Infinity = 0;
type Finite<T> = [0, T];
type OrInfinity = Infinity | Finite<MlTuple<Field, Field>>;

// ml types from kimchi_types.ml
type GateType = number;
Expand All @@ -24,6 +30,12 @@ type Gate = [
coeffs: MlArray<Field>
];

type PolyComm = [
_: 0,
unshifted: MlArray<OrInfinity>,
shifted: MlOption<OrInfinity>
];

type wasm = typeof wasmNamespace;

// TODO: Hardcoding this is a little brittle
Expand All @@ -35,9 +47,17 @@ function createRustConversion(wasm: wasm) {
return wasm.Wire.create(row, col);
}

function perField<WasmGate extends typeof WasmFpGate | typeof WasmFqGate>(
WasmGate: WasmGate
) {
function perField<WasmGate extends typeof WasmFpGate | typeof WasmFqGate>({
WasmGate,
WasmPolyComm,
CommitmentCurve,
makeCommitmentCurve,
}: {
WasmGate: WasmGate;
WasmPolyComm: WasmPolyCommClass;
CommitmentCurve: WrapperClass<WasmAffine>;
makeCommitmentCurve: MakeAffine<WasmAffine>;
}) {
return {
vectorToRust: fieldsToRustFlat,
vectorFromRust: fieldsFromRustFlat,
Expand All @@ -47,18 +67,59 @@ function createRustConversion(wasm: wasm) {
let rustCoeffs = fieldsToRustFlat(coeffs);
return new WasmGate(typ, rustWires, rustCoeffs);
},
pointToRust(point: OrInfinity) {
return affineToRust(point, makeCommitmentCurve);
},
pointFromRust(point: WasmAffine) {
return affineFromRust(point);
},
pointsToRust(points: MlArray<OrInfinity>) {
return mlArrayToRustVector(points, affineToRust, makeCommitmentCurve);
},
pointsFromRust(points: Uint32Array) {
return mlArrayFromRustVector(
points,
CommitmentCurve,
affineFromRust,
false
);
},
polyCommToRust(polyComm: PolyComm): WasmPolyComm {
return polyCommToRust(polyComm, WasmPolyComm, makeCommitmentCurve);
},
polyCommFromRust(polyComm: WasmPolyComm): PolyComm {
return polyCommFromRust(polyComm, CommitmentCurve, false);
},
};
}

const fpConversion = perField(wasm.WasmFpGate);
const fqConversion = perField(wasm.WasmFqGate);
// TODO: we have to lie about types here:
// -) the WasmGVesta class doesn't declare __wrap() but our code assumes it
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
// -) WasmGVesta doesn't declare the `ptr` property but our code assumes it
dannywillems marked this conversation as resolved.
Show resolved Hide resolved

const fp = perField({
WasmGate: wasm.WasmFpGate,
WasmPolyComm: wasm.WasmFpPolyComm,
CommitmentCurve:
wasm.WasmGVesta as any as WrapperClass<wasmNamespace.WasmGVesta>,
makeCommitmentCurve:
wasm.caml_vesta_affine_one as MakeAffine<wasmNamespace.WasmGVesta>,
});
const fq = perField({
WasmGate: wasm.WasmFqGate,
WasmPolyComm: wasm.WasmFqPolyComm,
CommitmentCurve:
wasm.WasmGPallas as any as WrapperClass<wasmNamespace.WasmGPallas>,
makeCommitmentCurve:
wasm.caml_pallas_affine_one as MakeAffine<wasmNamespace.WasmGPallas>,
});

return {
wireToRust,
fieldsToRustFlat,
fieldsFromRustFlat,
fp: fpConversion,
fq: fqConversion,
fp,
fq,
gateFromRust(wasmGate: WasmFpGate | WasmFqGate) {
// note: this was never used and the old implementation was wrong
// (accessed non-existent fields on wasmGate)
Expand All @@ -67,6 +128,8 @@ function createRustConversion(wasm: wasm) {
};
}

// field, field vectors

// TODO make more performant
function fieldToRust(x: Field): Uint8Array {
return x;
Expand Down Expand Up @@ -98,3 +161,120 @@ function fieldsFromRustFlat(fieldBytes: Uint8Array): MlArray<Field> {
}
return [0, ...fields];
}

// affine

type WasmAffine = wasmNamespace.WasmGVesta | wasmNamespace.WasmGPallas;
type MakeAffine<A extends WasmAffine> = () => A & { ptr: number };

function affineFromRust(pt: WasmAffine): OrInfinity {
if (pt.infinity) {
pt.free();
return 0;
} else {
let x = fieldFromRust(pt.x);
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
let y = fieldFromRust(pt.y);
pt.free();
return [0, [0, x, y]];
}
}

function affineToRust<A extends WasmAffine>(
pt: OrInfinity,
klass: MakeAffine<A>
) {
var res = klass();
if (pt === Infinity) {
res.infinity = true;
} else {
let [, [, x, y]] = pt;
res.x = fieldToRust(x);
res.y = fieldToRust(y);
}
return res;
}

// polycomm

type WasmPolyComm = wasmNamespace.WasmFpPolyComm | wasmNamespace.WasmFqPolyComm;
type WasmPolyCommClass = wasm['WasmFpPolyComm'] | wasm['WasmFqPolyComm'];

function polyCommFromRust(
polyComm: WasmPolyComm,
klass: WrapperClass<WasmAffine>,
shouldFree: boolean
): PolyComm {
let rustShifted = polyComm.shifted;
let rustUnshifted = polyComm.unshifted;
let mlShifted = MlOption.mapTo(rustShifted, affineFromRust);
let mlUnshifted = mlArrayFromRustVector(
rustUnshifted,
klass,
affineFromRust,
shouldFree
);
return [0, mlUnshifted, mlShifted];
}

function polyCommToRust(
[, camlUnshifted, camlShifted]: PolyComm,
PolyComm: WasmPolyCommClass,
makeAffine: MakeAffine<WasmAffine>
): WasmPolyComm {
let rustShifted =
camlShifted === 0 ? undefined : affineToRust(camlShifted[1], makeAffine);
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
let rustUnshifted = mlArrayToRustVector(
camlUnshifted,
affineToRust,
makeAffine
);
return new PolyComm(rustUnshifted, rustShifted);
}

// generic rust helpers

type Freeable = { free(): void };
type WrappedPointer = Freeable & { ptr: number };
type WrapperClass<T extends Freeable> = {
__wrap(i: number): T;
};

const registry = new FinalizationRegistry((ptr: WrappedPointer) => {
ptr.free();
});

function mlArrayFromRustVector<TRust extends Freeable, TMl>(
rustVector: Uint32Array,
klass: WrapperClass<TRust>,
convert: (c: TRust) => TMl,
shouldFree: boolean
): MlArray<TMl> {
var n = rustVector.length;
var array: TMl[] = new Array(n);
for (let i = 0; i < n; i++) {
var rustValue = klass.__wrap(rustVector[i]);
array[i] = convert(rustValue);
if (shouldFree) rustValue.free();
}
return [0, ...array];
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
}

// TODO get rid of excessive indirection here

function mlArrayToRustVector<TRust extends WrappedPointer, TMl>(
[, ...array]: MlArray<TMl>,
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
convert: (c: TMl, makeNew: () => TRust) => TRust,
makeNew: () => TRust
): Uint32Array {
let n = array.length;
let rustVector = new Uint32Array(n);
for (var i = 0; i < n; i++) {
var rustValue = convert(array[i], makeNew);
// Beware: caller may need to do finalizer things to avoid these
// pointers disappearing out from under us.
rustVector[i] = rustValue.ptr;
// Don't free when GC runs; rust will free on its end.
registry.unregister(rustValue);
}
return rustVector;
}
Loading