From 638e663503ce820379d7793a48121a5680708cb4 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Thu, 26 Sep 2024 15:10:37 +0200 Subject: [PATCH] fixup! Implement Certificate Revocation List --- .../service/grpc/getSigningRecords.go | 16 +- .../service/grpc/signCertificate.go | 18 +- .../service/grpc/signIdentityCertificate.go | 3 +- .../service/http/revocationList.go | 25 +-- .../service/http/revocationList_test.go | 5 +- .../store/cqldb/revocationList.go | 16 +- .../store/cqldb/signingRecords.go | 16 +- .../store/cqldb/signingRecords_test.go | 20 +- .../store/mongodb/revocationList.go | 205 ++++-------------- .../store/mongodb/revocationList_test.go | 182 ++++++---------- .../store/mongodb/signingRecords.go | 38 +--- .../store/mongodb/signingRecords_test.go | 14 +- certificate-authority/store/revocationList.go | 33 ++- certificate-authority/store/store.go | 40 +--- certificate-authority/test/revocationList.go | 10 +- 15 files changed, 213 insertions(+), 428 deletions(-) diff --git a/certificate-authority/service/grpc/getSigningRecords.go b/certificate-authority/service/grpc/getSigningRecords.go index 9236524a9..875954746 100644 --- a/certificate-authority/service/grpc/getSigningRecords.go +++ b/certificate-authority/service/grpc/getSigningRecords.go @@ -1,10 +1,7 @@ package grpc import ( - "context" - "github.com/plgd-dev/hub/v2/certificate-authority/pb" - "github.com/plgd-dev/hub/v2/certificate-authority/store" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -14,16 +11,11 @@ func (s *CertificateAuthorityServer) GetSigningRecords(req *pb.GetSigningRecords if err != nil { return s.logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot get signing records: %v", err)) } - err = s.store.LoadSigningRecords(srv.Context(), owner, req, func(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var sub pb.SigningRecord - if ok := iter.Next(ctx, &sub); !ok { - return iter.Err() - } - if err = srv.Send(&sub); err != nil { - return err - } + err = s.store.LoadSigningRecords(srv.Context(), owner, req, func(sr *pb.SigningRecord) (err error) { + if err = srv.Send(sr); err != nil { + return err } + return nil }) if err != nil { return s.logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot get signing records: %v", err)) diff --git a/certificate-authority/service/grpc/signCertificate.go b/certificate-authority/service/grpc/signCertificate.go index e0e9f1fa3..a23396548 100644 --- a/certificate-authority/service/grpc/signCertificate.go +++ b/certificate-authority/service/grpc/signCertificate.go @@ -31,18 +31,11 @@ func (s *CertificateAuthorityServer) updateSigningIdentityCertificateRecord(ctx now := time.Now().UnixNano() err := s.store.LoadSigningRecords(ctx, updateSigningRecord.GetOwner(), &store.SigningRecordsQuery{ CommonNameFilter: []string{updateSigningRecord.GetCommonName()}, - }, func(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var signingRecord pb.SigningRecord - ok := iter.Next(ctx, &signingRecord) - if !ok { - break - } - if updateSigningRecord.GetPublicKey() != signingRecord.GetPublicKey() && signingRecord.GetCredential().GetValidUntilDate() > now { - return fmt.Errorf("common name %v with different public key fingerprint exist", signingRecord.GetCommonName()) - } - found = true + }, func(sr *store.SigningRecord) (err error) { + if updateSigningRecord.GetPublicKey() != sr.GetPublicKey() && sr.GetCredential().GetValidUntilDate() > now { + return fmt.Errorf("common name %v with different public key fingerprint exist", sr.GetCommonName()) } + found = true return nil }) if err != nil { @@ -125,13 +118,14 @@ func (s *CertificateAuthorityServer) SignCertificate(ctx context.Context, req *p logger.With("crt", string(cert)).Debugf("CertificateAuthorityServer.SignCertificate") replacedCredential := replacedRecord.GetCredential() if replacedRecord != nil { - err = s.store.AddRevocationListCertificate(ctx, replacedCredential.GetIssuerId(), &store.RevocationListCertificate{ + err = s.store.RevokeCertificates(ctx, replacedCredential.GetIssuerId(), &store.RevocationListCertificate{ Serial: replacedCredential.GetSerial(), Expiration: replacedCredential.GetValidUntilDate(), Revocation: time.Now().UnixNano(), }) if err != nil { // TODO: what to do here? remove the new signing record? restore the original? + panic(err) } } diff --git a/certificate-authority/service/grpc/signIdentityCertificate.go b/certificate-authority/service/grpc/signIdentityCertificate.go index c7cb41bb5..f714b4e61 100644 --- a/certificate-authority/service/grpc/signIdentityCertificate.go +++ b/certificate-authority/service/grpc/signIdentityCertificate.go @@ -65,13 +65,14 @@ func (s *CertificateAuthorityServer) SignIdentityCertificate(ctx context.Context logger.With("crt", string(cert)).Debugf("CertificateAuthorityServer.SignIdentityCertificate") replacedCredential := replacedRecord.GetCredential() if replacedCredential != nil { - err = s.store.AddRevocationListCertificate(ctx, replacedCredential.GetIssuerId(), &store.RevocationListCertificate{ + err = s.store.RevokeCertificates(ctx, replacedCredential.GetIssuerId(), &store.RevocationListCertificate{ Serial: replacedCredential.GetSerial(), Expiration: replacedCredential.GetValidUntilDate(), Revocation: time.Now().UnixNano(), }) if err != nil { // TODO: what to do here? remove the new signing record? restore the original? + panic(err) } } diff --git a/certificate-authority/service/http/revocationList.go b/certificate-authority/service/http/revocationList.go index 35cfbd9e0..dabe4cfa4 100644 --- a/certificate-authority/service/http/revocationList.go +++ b/certificate-authority/service/http/revocationList.go @@ -15,23 +15,25 @@ import ( pkgTime "github.com/plgd-dev/hub/v2/pkg/time" ) -var revocationListNumber = big.NewInt(0) - func (requestHandler *RequestHandler) revocationList(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) issuerID := vars[uri.IssuerIDKey] - var rles []x509.RevocationListEntry - err := requestHandler.store.GetRevokedCertificates(r.Context(), store.CertificatesQuery{ + template := &x509.RevocationList{ + NextUpdate: time.Now().Add(time.Minute * 10), // TODO: pridat konfiguraciu, default napr. 10min + } + err := requestHandler.store.GetRevocationLists(r.Context(), store.RevocationListsQuery{ IssuerIdFilter: []string{issuerID}, }, func(rl *store.RevocationList) error { + template.Number = big.NewInt(rl.Number) + template.ThisUpdate = pkgTime.Unix(0, rl.UpdatedAt) for _, c := range rl.Certificates { var sn big.Int _, ok := sn.SetString(c.Serial, 10) if !ok { panic("invalid serial number string " + c.Serial) } - rles = append(rles, x509.RevocationListEntry{ + template.RevokedCertificateEntries = append(template.RevokedCertificateEntries, x509.RevocationListEntry{ SerialNumber: &sn, RevocationTime: pkgTime.Unix(0, c.Revocation), }) @@ -43,19 +45,14 @@ func (requestHandler *RequestHandler) revocationList(w http.ResponseWriter, r *h panic(err) } + if len(template.RevokedCertificateEntries) == 0 { + return + } + issuingCert := requestHandler.cas.GetSigner().GetCertificate() if issuingCert == nil { panic("issuer certificate not set") } - now := time.Now() - template := &x509.RevocationList{ - RevokedCertificateEntries: rles, - Number: revocationListNumber, - ThisUpdate: now, - NextUpdate: now.Add(time.Minute * 10), // TODO: pridat konfiguraciu, default napr. 10min - } - // TODO: store CRLs in DB and only increase the number if a new CRL has been generated - revocationListNumber.Add(revocationListNumber, big.NewInt(1)) signer := requestHandler.cas.GetSigner() crl, err := x509.CreateRevocationList(rand.Reader, template, issuingCert, signer.GetPrivateKey().(*ecdsa.PrivateKey)) diff --git a/certificate-authority/service/http/revocationList_test.go b/certificate-authority/service/http/revocationList_test.go index 9a8767b98..fee05e4c5 100644 --- a/certificate-authority/service/http/revocationList_test.go +++ b/certificate-authority/service/http/revocationList_test.go @@ -9,7 +9,6 @@ import ( "time" certAuthURI "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" - "github.com/plgd-dev/hub/v2/certificate-authority/store" "github.com/plgd-dev/hub/v2/certificate-authority/test" httpgwTest "github.com/plgd-dev/hub/v2/http-gateway/test" "github.com/plgd-dev/hub/v2/pkg/config/database" @@ -38,8 +37,6 @@ func TestRevocationList(t *testing.T) { ctx = pkgGrpc.CtxWithToken(ctx, token) test.AddRevocationListToStore(ctx, t, s, time.Now()) - err := s.RevokeCertificates(ctx, store.CertificatesQuery{}) - require.NoError(t, err) request := httpgwTest.NewRequest(http.MethodGet, certAuthURI.SigningRevocationList, nil).Host(config.CERTIFICATE_AUTHORITY_HTTP_HOST).AuthToken(token).AddIssuerID(test.GetIssuerID(2)).Build() httpResp := httpgwTest.HTTPDo(t, request) @@ -51,5 +48,5 @@ func TestRevocationList(t *testing.T) { _, err = x509.ParseRevocationList(respBody) require.NoError(t, err) - time.Sleep(time.Minute) + // time.Sleep(time.Minute) } diff --git a/certificate-authority/store/cqldb/revocationList.go b/certificate-authority/store/cqldb/revocationList.go index 0046bde39..1c53e0f74 100644 --- a/certificate-authority/store/cqldb/revocationList.go +++ b/certificate-authority/store/cqldb/revocationList.go @@ -10,22 +10,10 @@ func (s *Store) InsertRevocationLists(context.Context, ...*store.RevocationList) return store.ErrNotSupported } -func (s *Store) AddRevocationListCertificate(context.Context, string, *store.RevocationListCertificate) error { +func (s *Store) RevokeCertificates(context.Context, string, ...*store.RevocationListCertificate) error { return store.ErrNotSupported } -func (s *Store) RevokeCertificates(context.Context, store.CertificatesQuery) error { - return store.ErrNotSupported -} - -func (s *Store) GetExpiredCertificates(context.Context, store.ExpiredCertificatesQuery, store.Process[store.RevocationList]) error { - return store.ErrNotSupported -} - -func (s *Store) GetRevokedCertificates(context.Context, store.CertificatesQuery, store.Process[store.RevocationList]) error { - return store.ErrNotSupported -} - -func (s *Store) DeleteExpiredCertificates(context.Context) error { +func (s *Store) GetRevocationLists(context.Context, store.RevocationListsQuery, store.Process[store.RevocationList]) error { return store.ErrNotSupported } diff --git a/certificate-authority/store/cqldb/signingRecords.go b/certificate-authority/store/cqldb/signingRecords.go index 9f2345f3f..32b551a48 100644 --- a/certificate-authority/store/cqldb/signingRecords.go +++ b/certificate-authority/store/cqldb/signingRecords.go @@ -339,15 +339,25 @@ func (s *Store) DeleteNonDeviceExpiredRecords(_ context.Context, _ time.Time) (i return 0, store.ErrNotSupported } -func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, h store.LoadSigningRecordsFunc) error { +func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, p store.Process[store.SigningRecord]) error { i := SigningRecordsIterator{ ctx: ctx, s: s, queries: toSigningRecordsQueryFilter(owner, query, true), provided: make(map[string]struct{}, 32), } - err := h(ctx, &i) - + var err error + for { + var stored store.SigningRecord + if !i.Next(ctx, &stored) { + err = i.Err() + break + } + err = p(&stored) + if err != nil { + break + } + } errClose := i.close() if err == nil { return errClose diff --git a/certificate-authority/store/cqldb/signingRecords_test.go b/certificate-authority/store/cqldb/signingRecords_test.go index af56e61f1..2c2d8189d 100644 --- a/certificate-authority/store/cqldb/signingRecords_test.go +++ b/certificate-authority/store/cqldb/signingRecords_test.go @@ -177,7 +177,7 @@ func TestStoreUpdateSigningRecord(t *testing.T) { var h testSigningRecordHandler err = s.LoadSigningRecords(ctx, tt.args.sub.GetOwner(), &pb.GetSigningRecordsRequest{ IdFilter: []string{tt.args.sub.GetId()}, - }, h.Handle) + }, h.process) require.NoError(t, err) require.Len(t, h.lcs, 1) hubTest.CheckProtobufs(t, tt.args.sub, h.lcs[0], hubTest.RequireToCheckFunc(require.Equal)) @@ -345,7 +345,7 @@ func TestStoreDeleteExpiredRecords(t *testing.T) { }) require.NoError(t, err) var h testSigningRecordHandler - err = s.LoadSigningRecords(ctx, "owner", nil, h.Handle) + err = s.LoadSigningRecords(ctx, "owner", nil, h.process) require.NoError(t, err) require.Len(t, h.lcs, 1) time.Sleep(time.Second * 3) @@ -353,7 +353,7 @@ func TestStoreDeleteExpiredRecords(t *testing.T) { require.Error(t, err) require.Equal(t, store.ErrNotSupported, err) var h1 testSigningRecordHandler - err = s.LoadSigningRecords(ctx, "owner", nil, h1.Handle) + err = s.LoadSigningRecords(ctx, "owner", nil, h1.process) require.NoError(t, err) require.Empty(t, h1.lcs) } @@ -362,15 +362,9 @@ type testSigningRecordHandler struct { lcs pb.SigningRecords } -func (h *testSigningRecordHandler) Handle(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var sub store.SigningRecord - if !iter.Next(ctx, &sub) { - break - } - h.lcs = append(h.lcs, &sub) - } - return iter.Err() +func (h *testSigningRecordHandler) process(sr *store.SigningRecord) (err error) { + h.lcs = append(h.lcs, sr) + return nil } func TestStoreLoadSigningRecords(t *testing.T) { @@ -518,7 +512,7 @@ func TestStoreLoadSigningRecords(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var h testSigningRecordHandler - err := s.LoadSigningRecords(ctx, "owner", tt.args.query, h.Handle) + err := s.LoadSigningRecords(ctx, "owner", tt.args.query, h.process) if tt.wantErr { require.Error(t, err) return diff --git a/certificate-authority/store/mongodb/revocationList.go b/certificate-authority/store/mongodb/revocationList.go index dc44906f3..2f249dc60 100644 --- a/certificate-authority/store/mongodb/revocationList.go +++ b/certificate-authority/store/mongodb/revocationList.go @@ -8,8 +8,8 @@ import ( "github.com/plgd-dev/hub/v2/certificate-authority/store" "github.com/plgd-dev/hub/v2/pkg/mongodb" - pkgTime "github.com/plgd-dev/hub/v2/pkg/time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -18,19 +18,20 @@ const revocationListCol = "revocationList" func (s *Store) InsertRevocationLists(ctx context.Context, rls ...*store.RevocationList) error { documents := make([]interface{}, 0, len(rls)) for _, rl := range rls { + if err := rl.Validate(); err != nil { + return err + } documents = append(documents, rl) } _, err := s.Collection(revocationListCol).InsertMany(ctx, documents) return err } -func (s *Store) existsRevocationListCertificate(ctx context.Context, id, serial string) (bool, error) { +// check if a certificate with one of the serial numbers already exists in the array +func (s *Store) checkDuplicitCertificates(ctx context.Context, id string, serials []string) (bool, error) { filter := bson.M{ "_id": id, - store.CertificatesKey + "." + store.SerialKey: serial, - store.CertificatesKey + "." + store.ExpirationKey: bson.M{ - "$gte": pkgTime.UnixNano(time.Now()), - }, + store.CertificatesKey + "." + store.SerialKey: bson.M{mongodb.In: serials}, } c, err := s.Collection(revocationListCol).CountDocuments(ctx, filter) if err != nil { @@ -39,29 +40,37 @@ func (s *Store) existsRevocationListCertificate(ctx context.Context, id, serial return c > 0, nil } -func (s *Store) AddRevocationListCertificate(ctx context.Context, id string, rlc *store.RevocationListCertificate) error { +func (s *Store) RevokeCertificates(ctx context.Context, id string, rl ...*store.RevocationListCertificate) error { if id == "" { return errors.New("revocation list ID not set") } - if err := rlc.Validate(); err != nil { - return err + serials := make([]string, len(rl)) + for _, rlc := range rl { + if err := rlc.Validate(); err != nil { + return err + } + serials = append(serials, rlc.Serial) } - - // Check if an unexpired certificate with the serial already exists in the array - exists, err := s.existsRevocationListCertificate(ctx, id, rlc.Serial) + exists, err := s.checkDuplicitCertificates(ctx, id, serials) if err != nil { return err } if exists { - return fmt.Errorf("valid certificate with serial %s already exists", rlc.Serial) + return fmt.Errorf("duplicit serial number in %v", serials) } - filter := bson.M{ "_id": id, } + now := time.Now() update := bson.M{ "$push": bson.M{ - store.CertificatesKey: rlc, + store.CertificatesKey: bson.M{"$each": rl}, + }, + "$inc": bson.M{ + "number": 1, + }, + "$set": bson.M{ + "updatedAt": now.UnixNano(), }, "$setOnInsert": bson.M{ "_id": id, @@ -72,166 +81,44 @@ func (s *Store) AddRevocationListCertificate(ctx context.Context, id string, rlc return err } -func (s *Store) revokeCertificatesByIssuer(ctx context.Context, issuerIdFilter []string, revocation int64) error { - filter := bson.M{ - "_id": bson.M{mongodb.In: issuerIdFilter}, - } - updateOptions := options.Update().SetArrayFilters(options.ArrayFilters{ - Filters: bson.A{ - // Ensure the certificate is not expired - bson.M{"cert." + store.ExpirationKey: bson.M{"$gte": revocation}}, - }, - }) - // Update operation: set the revocation date to the current time for all certificates. - _, err := s.Collection(revocationListCol).UpdateMany(ctx, filter, bson.M{ - "$set": bson.M{store.CertificatesKey + ".$[cert]." + store.RevocationKey: revocation}, - }, updateOptions) - return err -} - -func toRevokedCertificatesArrayFilter(query store.CertificatesQuery, now int64) bson.A { - expiredFilter := bson.M{"cert." + store.ExpirationKey: bson.M{"$gte": now}} - if len(query.SerialFilter) == 0 { - return bson.A{expiredFilter} - } - filters := bson.A{expiredFilter} - filters = append(filters, bson.M{"cert." + store.SerialKey: bson.M{mongodb.In: query.SerialFilter}}) - return bson.A{bson.M{mongodb.And: filters}} -} - -func (s *Store) RevokeCertificates(ctx context.Context, query store.CertificatesQuery) error { - now := time.Now().UnixNano() - var filter interface{} = bson.D{} - if len(query.IssuerIdFilter) > 0 { - err := s.revokeCertificatesByIssuer(ctx, query.IssuerIdFilter, now) - if err != nil { - return err - } - if len(query.SerialFilter) == 0 { - return nil - } - filter = bson.M{ - "_id": bson.M{"$nin": query.IssuerIdFilter}, - } - } - updateOptions := options.Update().SetArrayFilters(options.ArrayFilters{ - Filters: toRevokedCertificatesArrayFilter(query, now), - }) - _, err := s.Collection(revocationListCol).UpdateMany(ctx, filter, bson.M{ - "$set": bson.M{store.CertificatesKey + ".$[cert]." + store.RevocationKey: now}, - }, updateOptions) - return err -} - -// get expired certificates -func (s *Store) GetExpiredCertificates(ctx context.Context, query store.ExpiredCertificatesQuery, p store.Process[store.RevocationList]) error { +func (s *Store) GetRevocationLists(ctx context.Context, query store.RevocationListsQuery, p store.Process[store.RevocationList]) error { now := time.Now().UnixNano() - filter := bson.M{ - store.CertificatesKey: bson.M{ - "$elemMatch": bson.M{ - // expired certificates - store.ExpirationKey: bson.M{"$lt": now}, - }, - }, - } + filter := bson.M{} if len(query.IssuerIdFilter) > 0 { filter["_id"] = bson.M{mongodb.In: query.IssuerIdFilter} } - projection := bson.M{ - store.CertificatesKey: bson.M{ - "$filter": bson.M{ - "input": "$" + store.CertificatesKey, - "as": "cert", - "cond": bson.M{ - // expired certificates - "$lt": []interface{}{"$$cert." + store.ExpirationKey, now}, - }, - }, - }, - } - - cur, err := s.Collection(revocationListCol).Find(ctx, filter, options.Find().SetProjection(projection)) - if err != nil { - return err - } - return processCursor(ctx, cur, p) -} - -func (s *Store) GetRevokedCertificates(ctx context.Context, query store.CertificatesQuery, p store.Process[store.RevocationList]) error { - now := time.Now().UnixNano() - filter := bson.M{ - store.CertificatesKey: bson.M{ + var opts []*options.FindOptions + if !query.IncludeExpired { + filter[store.CertificatesKey] = bson.M{ "$elemMatch": bson.M{ // non-expired certificates store.ExpirationKey: bson.M{"$gte": now}, - // revoked - store.RevocationKey: bson.M{"$gt": 0}, }, - }, - } - if len(query.IssuerIdFilter) > 0 { - filter["_id"] = bson.M{mongodb.In: query.IssuerIdFilter} - } - - projection := bson.M{ - store.CertificatesKey: bson.M{ - "$filter": bson.M{ - "input": "$" + store.CertificatesKey, - "as": "cert", - "cond": bson.M{ - mongodb.And: []bson.M{ - // non-expired certificates - {"$gte": []interface{}{"$$cert." + store.ExpirationKey, now}}, - // revoked - {"$gt": []interface{}{"$$cert." + store.RevocationKey, 0}}, + } + projection := bson.M{ + store.CertificatesKey: bson.M{ + "$filter": bson.M{ + "input": "$" + store.CertificatesKey, + "as": "cert", + "cond": bson.M{ + mongodb.And: []bson.M{ + // non-expired certificates + {"$gte": []interface{}{"$$cert." + store.ExpirationKey, now}}, + }, }, }, }, - }, + } + opts = append(opts, options.Find().SetProjection(projection)) } - cur, err := s.Collection(revocationListCol).Find(ctx, filter, options.Find().SetProjection(projection)) + cur, err := s.Collection(revocationListCol).Find(ctx, filter, opts...) if err != nil { + if errors.Is(err, mongo.ErrNilDocument) { + return nil + } return err } return processCursor(ctx, cur, p) } - -func (s *Store) DeleteExpiredCertificates(ctx context.Context) error { - // delete expired certificates - now := time.Now().UnixNano() - - filter := bson.M{ - store.CertificatesKey: bson.M{ - "$elemMatch": bson.M{ - store.ExpirationKey: bson.M{"$lt": now}, - }, - }, - } - - update := bson.M{ - "$pull": bson.M{ - store.CertificatesKey: bson.M{ - store.ExpirationKey: bson.M{"$lt": now}, - }, - }, - } - - res, err := s.Collection(revocationListCol).UpdateMany(ctx, filter, update) - if err != nil { - return err - } - - if res.ModifiedCount == 0 { - return nil - } - - // delete documents with empty revocation list - emptyCertificatesFilter := bson.M{ - // This checks that no element exists at index 0, meaning the array is empty - store.CertificatesKey + ".0": bson.M{"$exists": false}, - } - _, err = s.Collection(revocationListCol).DeleteMany(ctx, emptyCertificatesFilter) - return err -} diff --git a/certificate-authority/store/mongodb/revocationList_test.go b/certificate-authority/store/mongodb/revocationList_test.go index a28cfce94..c4623075e 100644 --- a/certificate-authority/store/mongodb/revocationList_test.go +++ b/certificate-authority/store/mongodb/revocationList_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAddRevocationListCertificate(t *testing.T) { +func TestRevokeCertificates(t *testing.T) { s, cleanUpStore := test.NewMongoStore(t) defer cleanUpStore() @@ -41,8 +41,20 @@ func TestAddRevocationListCertificate(t *testing.T) { { name: "missing serial number", args: args{ - id: id, - certificate: &store.RevocationListCertificate{}, + id: id, + certificate: &store.RevocationListCertificate{ + Revocation: time.Now().UnixNano(), + }, + }, + wantErr: true, + }, + { + name: "missing revocation time", + args: args{ + id: id, + certificate: &store.RevocationListCertificate{ + Serial: "1", + }, }, wantErr: true, }, @@ -53,6 +65,7 @@ func TestAddRevocationListCertificate(t *testing.T) { certificate: &store.RevocationListCertificate{ Serial: "1", Expiration: pkgTime.UnixNano(time.Now().Add(time.Hour)), + Revocation: time.Now().UnixNano(), }, }, }, @@ -63,6 +76,7 @@ func TestAddRevocationListCertificate(t *testing.T) { certificate: &store.RevocationListCertificate{ Serial: "2", Expiration: pkgTime.UnixNano(time.Now().Add(time.Hour)), + Revocation: time.Now().UnixNano(), }, }, }, @@ -73,6 +87,7 @@ func TestAddRevocationListCertificate(t *testing.T) { certificate: &store.RevocationListCertificate{ Serial: "2", Expiration: pkgTime.UnixNano(time.Now().Add(time.Hour)), + Revocation: time.Now().UnixNano(), }, }, wantErr: true, @@ -84,6 +99,7 @@ func TestAddRevocationListCertificate(t *testing.T) { certificate: &store.RevocationListCertificate{ Serial: "2", Expiration: pkgTime.UnixNano(time.Now().Add(time.Hour)), + Revocation: time.Now().UnixNano(), }, }, }, @@ -91,7 +107,7 @@ func TestAddRevocationListCertificate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := s.AddRevocationListCertificate(ctx, tt.args.id, tt.args.certificate) + err := s.RevokeCertificates(ctx, tt.args.id, tt.args.certificate) if tt.wantErr { require.Error(t, err) return @@ -101,146 +117,78 @@ func TestAddRevocationListCertificate(t *testing.T) { } } -func TestRevokeAllCertificates(t *testing.T) { +func TestGetRevocationLists(t *testing.T) { s, cleanUpStore := test.NewMongoStore(t) defer cleanUpStore() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - expirationStart := time.Now().Add(-time.Hour) - inserted := test.AddRevocationListToStore(ctx, t, s, expirationStart) + stored := test.AddRevocationListToStore(ctx, t, s, time.Now().Add(-2*time.Hour-time.Minute)) - // revoke certificates - err := s.RevokeCertificates(ctx, store.CertificatesQuery{}) - require.NoError(t, err) - - // get expired certificates - expired := make(map[string]*store.RevocationList) - err = s.GetExpiredCertificates(ctx, store.ExpiredCertificatesQuery{}, func(v *store.RevocationList) error { - expired[v.Id] = v + // get all + retrieved := make(map[string]*store.RevocationList) + err := s.GetRevocationLists(ctx, store.RevocationListsQuery{ + IncludeExpired: true, + }, func(v *store.RevocationList) error { + retrieved[v.Id] = v return nil }) require.NoError(t, err) + test.CheckRevocationLists(t, stored, retrieved, false) - // get revoked certificates - revoked := make(map[string]*store.RevocationList) - err = s.GetRevokedCertificates(ctx, store.CertificatesQuery{}, func(v *store.RevocationList) error { - revoked[v.Id] = v + // get all non-expired + retrieved = make(map[string]*store.RevocationList) + err = s.GetRevocationLists(ctx, store.RevocationListsQuery{}, func(v *store.RevocationList) error { + retrieved[v.Id] = v return nil }) require.NoError(t, err) - - expiredExp := make(map[string]*store.RevocationList) - revokedExp := make(map[string]*store.RevocationList) - for id, v := range inserted { - for _, cert := range v.Certificates { - if cert.Expiration > pkgTime.UnixNano(time.Now()) { - if revokedExp[id] == nil { - revokedExp[id] = &store.RevocationList{ - Id: id, - } - } - revokedExp[id].Certificates = append(revokedExp[id].Certificates, cert) - continue - } - if expiredExp[id] == nil { - expiredExp[id] = &store.RevocationList{ - Id: id, - } - } - expiredExp[id].Certificates = append(expiredExp[id].Certificates, cert) + expected := make(map[string]*store.RevocationList) + for k, v := range stored { + if v.Certificates[0].Expiration < time.Now().UnixNano() { + continue } + expected[k] = v } + test.CheckRevocationLists(t, expected, retrieved, false) - require.Len(t, expired, len(expiredExp)) - require.EqualValues(t, expiredExp, expired) - test.CheckRevocationLists(t, revokedExp, revoked, true) -} - -func TestRevokeIssuersCertificates(t *testing.T) { - s, cleanUpStore := test.NewMongoStore(t) - defer cleanUpStore() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - expirationStart := time.Now().Add(-time.Hour) - inserted := test.AddRevocationListToStore(ctx, t, s, expirationStart) - - // revoke certificates - issuerIdFilter := []string{test.GetIssuerID(0), test.GetIssuerID(7), test.GetIssuerID(9)} - err := s.RevokeCertificates(ctx, store.CertificatesQuery{ + // get selected issuers revocation lists + retrieved = make(map[string]*store.RevocationList) + issuerIdFilter := []string{test.GetIssuerID(0), test.GetIssuerID(2), test.GetIssuerID(4)} + err = s.GetRevocationLists(ctx, store.RevocationListsQuery{ IssuerIdFilter: issuerIdFilter, - }) - require.NoError(t, err) - - // get revoked certificates - revoked := make(map[string]*store.RevocationList) - err = s.GetRevokedCertificates(ctx, store.CertificatesQuery{}, func(v *store.RevocationList) error { - revoked[v.Id] = v + IncludeExpired: true, + }, func(v *store.RevocationList) error { + retrieved[v.Id] = v return nil }) require.NoError(t, err) - - revokedExp := make(map[string]*store.RevocationList) - for id, v := range inserted { - for _, cert := range v.Certificates { - if cert.Expiration <= pkgTime.UnixNano(time.Now()) || !slices.Contains(issuerIdFilter, id) { - continue - } - if revokedExp[id] == nil { - revokedExp[id] = &store.RevocationList{ - Id: id, - } - } - revokedExp[id].Certificates = append(revokedExp[id].Certificates, cert) + expected = make(map[string]*store.RevocationList) + for k, v := range stored { + if !slices.Contains(issuerIdFilter, v.Id) { + continue } + expected[k] = v } + test.CheckRevocationLists(t, expected, retrieved, false) - test.CheckRevocationLists(t, revokedExp, revoked, true) -} - -func TestRevokeCertificatesById(t *testing.T) { - s, cleanUpStore := test.NewMongoStore(t) - defer cleanUpStore() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - expirationStart := time.Now().Add(-time.Hour) - inserted := test.AddRevocationListToStore(ctx, t, s, expirationStart) - - // revoke certificates - serialFilter := []string{test.GetCertificateSerial(1), test.GetCertificateSerial(3), test.GetCertificateSerial(5)} - err := s.RevokeCertificates(ctx, store.CertificatesQuery{ - SerialFilter: serialFilter, - }) - require.NoError(t, err) - - // get revoked certificates - revoked := make(map[string]*store.RevocationList) - err = s.GetRevokedCertificates(ctx, store.CertificatesQuery{}, func(v *store.RevocationList) error { - revoked[v.Id] = v + // get selected issuers revocation lists, only non-expired certificates + retrieved = make(map[string]*store.RevocationList) + issuerIdFilter = []string{test.GetIssuerID(0), test.GetIssuerID(2), test.GetIssuerID(4)} + err = s.GetRevocationLists(ctx, store.RevocationListsQuery{ + IssuerIdFilter: issuerIdFilter, + }, func(v *store.RevocationList) error { + retrieved[v.Id] = v return nil }) require.NoError(t, err) - - revokedExp := make(map[string]*store.RevocationList) - for id, v := range inserted { - for _, cert := range v.Certificates { - if cert.Expiration <= pkgTime.UnixNano(time.Now()) || - !slices.Contains(serialFilter, cert.Serial) { - continue - } - if revokedExp[id] == nil { - revokedExp[id] = &store.RevocationList{ - Id: id, - } - } - revokedExp[id].Certificates = append(revokedExp[id].Certificates, cert) + expected = make(map[string]*store.RevocationList) + for k, v := range stored { + if v.Certificates[0].Expiration < time.Now().UnixNano() || !slices.Contains(issuerIdFilter, v.Id) { + continue } + expected[k] = v } - - test.CheckRevocationLists(t, revokedExp, revoked, true) + test.CheckRevocationLists(t, expected, retrieved, false) } diff --git a/certificate-authority/store/mongodb/signingRecords.go b/certificate-authority/store/mongodb/signingRecords.go index 5ca3076e3..6a7b0bde6 100644 --- a/certificate-authority/store/mongodb/signingRecords.go +++ b/certificate-authority/store/mongodb/signingRecords.go @@ -124,40 +124,14 @@ func (s *Store) DeleteNonDeviceExpiredRecords(ctx context.Context, now time.Time return res.DeletedCount, nil } -func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, h store.LoadSigningRecordsFunc) error { +func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, p store.Process[store.SigningRecord]) error { col := s.Collection(signingRecordsCol) - iter, err := col.Find(ctx, toSigningRecordsQueryFilter(owner, query)) - if errors.Is(err, mongo.ErrNilDocument) { - return nil - } + cur, err := col.Find(ctx, toSigningRecordsQueryFilter(owner, query)) if err != nil { + if errors.Is(err, mongo.ErrNilDocument) { + return nil + } return err } - - i := SigningRecordsIterator{ - iter: iter, - } - err = h(ctx, &i) - - errClose := iter.Close(ctx) - if err == nil { - return errClose - } - return err -} - -type SigningRecordsIterator struct { - iter *mongo.Cursor -} - -func (i *SigningRecordsIterator) Next(ctx context.Context, s *store.SigningRecord) bool { - if !i.iter.Next(ctx) { - return false - } - err := i.iter.Decode(s) - return err == nil -} - -func (i *SigningRecordsIterator) Err() error { - return i.iter.Err() + return processCursor(ctx, cur, p) } diff --git a/certificate-authority/store/mongodb/signingRecords_test.go b/certificate-authority/store/mongodb/signingRecords_test.go index 87bf805a4..e0bb3e202 100644 --- a/certificate-authority/store/mongodb/signingRecords_test.go +++ b/certificate-authority/store/mongodb/signingRecords_test.go @@ -288,15 +288,9 @@ type testSigningRecordHandler struct { lcs pb.SigningRecords } -func (h *testSigningRecordHandler) Handle(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var sub store.SigningRecord - if !iter.Next(ctx, &sub) { - break - } - h.lcs = append(h.lcs, &sub) - } - return iter.Err() +func (h *testSigningRecordHandler) process(sr *store.SigningRecord) (err error) { + h.lcs = append(h.lcs, sr) + return nil } func TestStoreLoadSigningRecords(t *testing.T) { @@ -461,7 +455,7 @@ func TestStoreLoadSigningRecords(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var h testSigningRecordHandler - err := s.LoadSigningRecords(ctx, tt.args.owner, tt.args.query, h.Handle) + err := s.LoadSigningRecords(ctx, tt.args.owner, tt.args.query, h.process) require.NoError(t, err) require.Len(t, h.lcs, len(tt.want)) h.lcs.Sort() diff --git a/certificate-authority/store/revocationList.go b/certificate-authority/store/revocationList.go index d276dbe93..aae816f58 100644 --- a/certificate-authority/store/revocationList.go +++ b/certificate-authority/store/revocationList.go @@ -13,21 +13,44 @@ type RevocationListCertificate struct { // Serial number Serial string `bson:"serial"` // Expiration date of the certificate, in unix nanoseconds timestamp format - Expiration int64 `bson:"expiration"` + Expiration int64 `bson:"expiration,omitempty"` // Revocation date of the certificate, in unix nanoseconds timestamp format. 0 means that the certificate hasn't been revoked. - Revocation int64 `bson:"revocation,omitempty"` + Revocation int64 `bson:"revocation"` +} + +func (rlc *RevocationListCertificate) Validate() error { + if rlc.Serial == "" { + return errors.New("serial number not set") + } + if rlc.Revocation == 0 { + return errors.New("revocation time not set") + } + return nil } type RevocationList struct { // The record ID is determined by applying a formula that utilizes the public key of the issuer, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw). Id string `bson:"_id"` + // Unix timestamp when the record was last updated + UpdatedAt int64 `bson:"updatedAt"` + // Number is used to populate the X.509 v2 cRLNumber extension in the CRL, which should be a monotonically increasing sequence number for a given + // CRL scope and CRL issuer. + Number int64 `bson:"number"` // List of certificates issued by the issuer Certificates []*RevocationListCertificate `bson:"certificates"` } -func (rlc *RevocationListCertificate) Validate() error { - if rlc.Serial == "" { - return errors.New("serial number not set") +func (rl *RevocationList) Validate() error { + if rl.Id == "" { + return errors.New("id not set") + } + if rl.UpdatedAt == 0 { + return errors.New("update time not set") + } + for _, c := range rl.Certificates { + if err := c.Validate(); err != nil { + return err + } } return nil } diff --git a/certificate-authority/store/store.go b/certificate-authority/store/store.go index 758934ad8..bbfa9bd6c 100644 --- a/certificate-authority/store/store.go +++ b/certificate-authority/store/store.go @@ -11,35 +11,17 @@ import ( var ErrNotSupported = errors.New("not supported") type ( + Process[T any] func(v *T) error + SigningRecordsQuery = pb.GetSigningRecordsRequest DeleteSigningRecordsQuery = pb.DeleteSigningRecordsRequest - SigningRecordIter interface { - Next(ctx context.Context, SigningRecord *SigningRecord) bool - Err() error - } - - LoadSigningRecordsFunc = func(ctx context.Context, iter SigningRecordIter) (err error) -) - -type ( - CertificatesQuery struct { + RevocationListsQuery struct { // Filter by issuer's id (must match RevocationList.Id) IssuerIdFilter []string - // Filter by serial number. - SerialFilter []string + // Include expired certificates + IncludeExpired bool } - ExpiredCertificatesQuery struct { - // Filter by issuer's id (must match RevocationList.Id) - IssuerIdFilter []string - } - - RevocationListIter interface { - Next(ctx context.Context, rl *RevocationList) bool - Err() error - } - - Process[T any] func(v *T) error ) type Store interface { @@ -48,18 +30,18 @@ type Store interface { // UpdateSigningRecord updates an existing signing record. If the record does not exist, it will create a new one. UpdateSigningRecord(ctx context.Context, record *SigningRecord) (*SigningRecord, error) DeleteSigningRecords(ctx context.Context, ownerID string, query *DeleteSigningRecordsQuery) (int64, error) - LoadSigningRecords(ctx context.Context, ownerID string, query *SigningRecordsQuery, h LoadSigningRecordsFunc) error + LoadSigningRecords(ctx context.Context, ownerID string, query *SigningRecordsQuery, p Process[SigningRecord]) error // DeleteNonDeviceExpiredRecords deletes all expired records that are not associated with a device. // For CqlDB, this is a no-op because expired records are deleted by Cassandra automatically. DeleteNonDeviceExpiredRecords(ctx context.Context, now time.Time) (int64, error) + // InsertRevocationLists adds revocations lists to the database InsertRevocationLists(ctx context.Context, rls ...*RevocationList) error - AddRevocationListCertificate(ctx context.Context, id string, rlc *RevocationListCertificate) error - RevokeCertificates(ctx context.Context, query CertificatesQuery) error - GetExpiredCertificates(ctx context.Context, query ExpiredCertificatesQuery, p Process[RevocationList]) error - GetRevokedCertificates(ctx context.Context, query CertificatesQuery, p Process[RevocationList]) error - DeleteExpiredCertificates(ctx context.Context) error + // RevokeCertificates adds certificates to revocation list + RevokeCertificates(ctx context.Context, id string, rlc ...*RevocationListCertificate) error + // Get + GetRevocationLists(ctx context.Context, query RevocationListsQuery, p Process[RevocationList]) error Close(ctx context.Context) error } diff --git a/certificate-authority/test/revocationList.go b/certificate-authority/test/revocationList.go index 5e5b0b6f5..f756a25ee 100644 --- a/certificate-authority/test/revocationList.go +++ b/certificate-authority/test/revocationList.go @@ -32,10 +32,11 @@ func GetCertificateSerial(i int) string { return id } -func getCertificate(c int, exp time.Time) *store.RevocationListCertificate { +func getCertificate(c int, rev, exp time.Time) *store.RevocationListCertificate { return &store.RevocationListCertificate{ Serial: GetCertificateSerial(c), Expiration: pkgTime.UnixNano(exp), + Revocation: pkgTime.UnixNano(rev), } } @@ -43,13 +44,16 @@ func AddRevocationListToStore(ctx context.Context, t *testing.T, s store.Store, rlm := make(map[string]*store.RevocationList) c := 0 for i := range 10 { + now := time.Now() rlID := GetIssuerID(i) rl := &store.RevocationList{ - Id: rlID, + Id: rlID, + UpdatedAt: now.UnixNano(), + Number: int64(i), } exp := expirationStart.Add(time.Duration(i) * time.Hour) for range 10 { - rlc := getCertificate(c, exp) + rlc := getCertificate(c, now, exp) rl.Certificates = append(rl.Certificates, rlc) c++ }