diff --git a/kt/client.go b/kt/client.go index e7e5b82..7381eed 100644 --- a/kt/client.go +++ b/kt/client.go @@ -59,15 +59,15 @@ func (c *Client) checkDig(dig *SigDig) *clientErr { // checkVrfProof errors on fail. // TODO: if VRF pubkey is bad, does VRF.Verify still mean something? -func (c *Client) checkVrf(uid uint64, ver uint64, label []byte, proof []byte) bool { +func checkVrf(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, label []byte, proof []byte) bool { pre := &MapLabelPre{Uid: uid, Ver: ver} preByt := MapLabelPreEncode(make([]byte, 0), pre) - return !c.servVrfPk.Verify(preByt, label, proof) + return !pk.Verify(preByt, label, proof) } // checkMemb errors on fail. -func (c *Client) checkMemb(uid uint64, ver uint64, dig []byte, memb *Memb) bool { - if c.checkVrf(uid, ver, memb.Label, memb.VrfProof) { +func checkMemb(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, dig []byte, memb *Memb) bool { + if checkVrf(pk, uid, ver, memb.Label, memb.VrfProof) { return true } mapVal := compMapVal(memb.EpochAdded, memb.CommOpen) @@ -75,19 +75,19 @@ func (c *Client) checkMemb(uid uint64, ver uint64, dig []byte, memb *Memb) bool } // checkMembHide errors on fail. -func (c *Client) checkMembHide(uid uint64, ver uint64, dig []byte, memb *MembHide) bool { - if c.checkVrf(uid, ver, memb.Label, memb.VrfProof) { +func checkMembHide(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, dig []byte, memb *MembHide) bool { + if checkVrf(pk, uid, ver, memb.Label, memb.VrfProof) { return true } return merkle.CheckProof(true, memb.MerkProof, memb.Label, memb.MapVal, dig) } // checkHist errors on fail. -func (c *Client) checkHist(uid uint64, dig []byte, membs []*MembHide) bool { +func checkHist(pk *cryptoffi.VrfPublicKey, uid uint64, dig []byte, membs []*MembHide) bool { var err0 bool for ver0, memb := range membs { ver := uint64(ver0) - if c.checkMembHide(uid, ver, dig, memb) { + if checkMembHide(pk, uid, ver, dig, memb) { err0 = true } } @@ -95,8 +95,8 @@ func (c *Client) checkHist(uid uint64, dig []byte, membs []*MembHide) bool { } // checkNonMemb errors on fail. -func (c *Client) checkNonMemb(uid uint64, ver uint64, dig []byte, nonMemb *NonMemb) bool { - if c.checkVrf(uid, ver, nonMemb.Label, nonMemb.VrfProof) { +func checkNonMemb(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, dig []byte, nonMemb *NonMemb) bool { + if checkVrf(pk, uid, ver, nonMemb.Label, nonMemb.VrfProof) { return true } return merkle.CheckProof(false, nonMemb.MerkProof, nonMemb.Label, nil, dig) @@ -114,7 +114,7 @@ func (c *Client) Put(pk []byte) (uint64, *clientErr) { return 0, err1 } // check latest entry has right ver, epoch, pk. - if c.checkMemb(c.uid, c.nextVer, dig.Dig, latest) { + if checkMemb(c.servVrfPk, c.uid, c.nextVer, dig.Dig, latest) { return 0, stdErr } if dig.Epoch != latest.EpochAdded { @@ -124,7 +124,7 @@ func (c *Client) Put(pk []byte) (uint64, *clientErr) { return 0, stdErr } // check bound has right ver. - if c.checkNonMemb(c.uid, c.nextVer+1, dig.Dig, bound) { + if checkNonMemb(c.servVrfPk, c.uid, c.nextVer+1, dig.Dig, bound) { return 0, stdErr } c.nextVer += 1 @@ -146,7 +146,7 @@ func (c *Client) Get(uid uint64) (bool, []byte, uint64, *clientErr) { if err1.err { return false, nil, 0, err1 } - if c.checkHist(uid, dig.Dig, hist) { + if checkHist(c.servVrfPk, uid, dig.Dig, hist) { return false, nil, 0, stdErr } numHistVers := uint64(len(hist)) @@ -155,7 +155,7 @@ func (c *Client) Get(uid uint64) (bool, []byte, uint64, *clientErr) { return false, nil, 0, stdErr } // check latest has right ver. - if isReg && c.checkMemb(uid, numHistVers, dig.Dig, latest) { + if isReg && checkMemb(c.servVrfPk, uid, numHistVers, dig.Dig, latest) { return false, nil, 0, stdErr } // check bound has right ver. @@ -164,7 +164,7 @@ func (c *Client) Get(uid uint64) (bool, []byte, uint64, *clientErr) { if isReg { boundVer = numHistVers + 1 } - if c.checkNonMemb(uid, boundVer, dig.Dig, bound) { + if checkNonMemb(c.servVrfPk, uid, boundVer, dig.Dig, bound) { return false, nil, 0, stdErr } return isReg, latest.CommOpen.Pk, dig.Epoch, &clientErr{err: false} @@ -182,7 +182,7 @@ func (c *Client) SelfMon() (uint64, *clientErr) { if err1.err { return 0, err1 } - if c.checkNonMemb(c.uid, c.nextVer, dig.Dig, bound) { + if checkNonMemb(c.servVrfPk, c.uid, c.nextVer, dig.Dig, bound) { return 0, stdErr } return dig.Epoch, &clientErr{err: false}