Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bigint backend pt 3: Move ec point conversion to TS #140

Merged
merged 9 commits into from
Sep 22, 2023
178 changes: 169 additions & 9 deletions crypto/bindings/conversion.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Implementation of Kimchi_bindings.Protocol.Gates
*/
import { MlArray } from '../../../lib/ml/base.js';
import { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js';
import { mapTuple } from './util.js';
import type {
WasmFpGate,
Expand All @@ -10,9 +10,12 @@ import type {
import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs';
import { bigIntToBytes, bytesToBigInt } from '../bigint-helpers.js';

export { createRustConversion };
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved

type Field = Uint8Array;

export { createRustConversion };
// Kimchi_types.or_infinity
type OrInfinity = MlOption<MlTuple<Field, Field>>;
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved

// ml types from kimchi_types.ml
type GateType = number;
Expand All @@ -24,6 +27,12 @@ type Gate = [
coeffs: MlArray<Field>
];

type PolyComm = [
_: 0,
unshifted: MlArray<OrInfinity>,
shifted: MlOption<OrInfinity>
];

type wasm = typeof wasmNamespace;

// TODO: Hardcoding this is a little brittle
Expand All @@ -35,9 +44,17 @@ function createRustConversion(wasm: wasm) {
return wasm.Wire.create(row, col);
}

function perField<WasmGate extends typeof WasmFpGate | typeof WasmFqGate>(
WasmGate: WasmGate
) {
function perField<WasmGate extends typeof WasmFpGate | typeof WasmFqGate>({
WasmGate,
WasmPolyComm,
CommitmentCurve,
makeCommitmentCurve,
}: {
WasmGate: WasmGate;
WasmPolyComm: WasmPolyCommClass;
CommitmentCurve: WrapperClass<WasmAffine>;
makeCommitmentCurve: MakeAffine<WasmAffine>;
}) {
return {
vectorToRust: fieldsToRustFlat,
vectorFromRust: fieldsFromRustFlat,
Expand All @@ -47,18 +64,42 @@ function createRustConversion(wasm: wasm) {
let rustCoeffs = fieldsToRustFlat(coeffs);
return new WasmGate(typ, rustWires, rustCoeffs);
},
polyCommToRust(polyComm: PolyComm): WasmPolyComm {
return polyCommToRust(polyComm, WasmPolyComm, makeCommitmentCurve);
},
polyCommFromRust(polyComm: WasmPolyComm): PolyComm {
return polyCommFromRust(polyComm, CommitmentCurve, false);
},
};
}

const fpConversion = perField(wasm.WasmFpGate);
const fqConversion = perField(wasm.WasmFqGate);
// TODO: we have to lie about types here:
// -) the WasmGVesta class doesn't declare __wrap() but our code assumes it
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
// -) WasmGVesta doesn't declare the `ptr` property but our code assumes it
dannywillems marked this conversation as resolved.
Show resolved Hide resolved

const fp = perField({
WasmGate: wasm.WasmFpGate,
WasmPolyComm: wasm.WasmFpPolyComm,
CommitmentCurve:
wasm.WasmGVesta as any as WrapperClass<wasmNamespace.WasmGVesta>,
makeCommitmentCurve:
wasm.caml_vesta_affine_one as MakeAffine<wasmNamespace.WasmGVesta>,
});
const fq = perField({
WasmGate: wasm.WasmFqGate,
WasmPolyComm: wasm.WasmFqPolyComm,
CommitmentCurve:
wasm.WasmGPallas as any as WrapperClass<wasmNamespace.WasmGPallas>,
makeCommitmentCurve:
wasm.caml_pallas_affine_one as MakeAffine<wasmNamespace.WasmGPallas>,
});

return {
wireToRust,
fieldsToRustFlat,
fieldsFromRustFlat,
fp: fpConversion,
fq: fqConversion,
fp,
fq,
gateFromRust(wasmGate: WasmFpGate | WasmFqGate) {
// note: this was never used and the old implementation was wrong
// (accessed non-existent fields on wasmGate)
Expand All @@ -67,6 +108,8 @@ function createRustConversion(wasm: wasm) {
};
}

// field, field vectors

// TODO make more performant
function fieldToRust(x: Field): Uint8Array {
return x;
Expand Down Expand Up @@ -98,3 +141,120 @@ function fieldsFromRustFlat(fieldBytes: Uint8Array): MlArray<Field> {
}
return [0, ...fields];
}

// affine

type WasmAffine = wasmNamespace.WasmGVesta | wasmNamespace.WasmGPallas;
type MakeAffine<A extends WasmAffine> = () => A & { ptr: number };

function affineFromRust(pt: WasmAffine): OrInfinity {
if (pt.infinity) {
pt.free();
return 0;
} else {
let x = fieldFromRust(pt.x);
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
let y = fieldFromRust(pt.y);
pt.free();
return [0, [0, x, y]];
}
}

function affineToRust<A extends WasmAffine>(
pt: OrInfinity,
klass: MakeAffine<A>
) {
var res = klass();
if (pt === 0) {
dannywillems marked this conversation as resolved.
Show resolved Hide resolved
res.infinity = true;
} else {
res.x = fieldToRust(pt[1][1]);
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
res.y = fieldToRust(pt[1][2]);
}
return res;
}

// polycomm

type WasmPolyComm = wasmNamespace.WasmFpPolyComm | wasmNamespace.WasmFqPolyComm;
type WasmPolyCommClass = wasm['WasmFpPolyComm'] | wasm['WasmFqPolyComm'];

function polyCommFromRust(
polyComm: WasmPolyComm,
klass: WrapperClass<WasmAffine>,
shouldFree: boolean
): PolyComm {
let rustShifted = polyComm.shifted;
let rustUnshifted = polyComm.unshifted;
let mlShifted: MlOption<OrInfinity> =
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
rustShifted === undefined ? 0 : [0, affineFromRust(rustShifted)];
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
let mlUnshifted = mlArrayFromRustVector(
rustUnshifted,
klass,
affineFromRust,
shouldFree
);
return [0, mlUnshifted, mlShifted];
}

function polyCommToRust(
[, camlUnshifted, camlShifted]: PolyComm,
PolyComm: WasmPolyCommClass,
makeAffine: MakeAffine<WasmAffine>
): WasmPolyComm {
let rustShifted =
camlShifted === 0 ? undefined : affineToRust(camlShifted[1], makeAffine);
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
let rustUnshifted = mlArrayToRustVector(
camlUnshifted,
affineToRust,
makeAffine
);
return new PolyComm(rustUnshifted, rustShifted);
}

// generic rust helpers

type Freeable = { free(): void };
type WrappedPointer = Freeable & { ptr: number };
type WrapperClass<T extends Freeable> = {
__wrap(i: number): T;
};

const registry = new FinalizationRegistry((ptr: WrappedPointer) => {
ptr.free();
});

function mlArrayFromRustVector<TRust extends Freeable, TMl>(
rustVector: Uint32Array,
klass: WrapperClass<TRust>,
convert: (c: TRust) => TMl,
shouldFree: boolean
): MlArray<TMl> {
var n = rustVector.length;
var array: TMl[] = new Array(n);
for (let i = 0; i < n; i++) {
var rustValue = klass.__wrap(rustVector[i]);
array[i] = convert(rustValue);
if (shouldFree) rustValue.free();
}
return [0, ...array];
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
}

// TODO get rid of excessive indirection here

function mlArrayToRustVector<TRust extends WrappedPointer, TMl>(
[, ...array]: MlArray<TMl>,
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
convert: (c: TMl, makeNew: () => TRust) => TRust,
makeNew: () => TRust
): Uint32Array {
let n = array.length;
let rustVector = new Uint32Array(n);
for (var i = 0, l = array.length; i < l; i++) {
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
var rustValue = convert(array[i], makeNew);
// Beware: caller may need to do finalizer things to avoid these
// pointers disappearing out from under us.
rustVector[i] = rustValue.ptr;
// Don't free when GC runs; rust will free on its end.
registry.unregister(rustValue);
}
return rustVector;
}