diff --git a/deployment/modules/gcp/gcs/main.tf b/deployment/modules/gcp/gcs/main.tf index e01b5f26..e5351203 100644 --- a/deployment/modules/gcp/gcs/main.tf +++ b/deployment/modules/gcp/gcs/main.tf @@ -58,7 +58,7 @@ resource "google_spanner_database" "log_db" { ddl = [ "CREATE TABLE SeqCoord (id INT64 NOT NULL, next INT64 NOT NULL,) PRIMARY KEY (id)", "CREATE TABLE Seq (id INT64 NOT NULL, seq INT64 NOT NULL, v BYTES(MAX),) PRIMARY KEY (id, seq)", - "CREATE TABLE IntCoord (id INT64 NOT NULL, seq INT64 NOT NULL,) PRIMARY KEY (id)", + "CREATE TABLE IntCoord (id INT64 NOT NULL, seq INT64 NOT NULL, rootHash BYTES(32)) PRIMARY KEY (id)", ] deletion_protection = !var.ephemeral diff --git a/storage/gcp/gcp.go b/storage/gcp/gcp.go index e3edc35c..8ec90bfe 100644 --- a/storage/gcp/gcp.go +++ b/storage/gcp/gcp.go @@ -43,6 +43,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" gcs "cloud.google.com/go/storage" "github.com/google/go-cmp/cmp" + "github.com/transparency-dev/merkle/rfc6962" tessera "github.com/transparency-dev/trillian-tessera" "github.com/transparency-dev/trillian-tessera/api" "github.com/transparency-dev/trillian-tessera/api/layout" @@ -73,12 +74,15 @@ type Storage struct { objStore objStore queue *storage.Queue + + cpUpdated chan struct{} } // objStore describes a type which can store and retrieve objects. type objStore interface { getObject(ctx context.Context, obj string) ([]byte, int64, error) setObject(ctx context.Context, obj string, data []byte, cond *gcs.Conditions, contType string) error + lastModified(ctx context.Context, obj string) (time.Time, error) } // sequencer describes a type which knows how to sequence entries. @@ -92,10 +96,14 @@ type sequencer interface { // If forceUpdate is true, then the consumeFunc should be called, with an empty slice of entries if // necessary. This allows the log self-initialise in a transactionally safe manner. consumeEntries(ctx context.Context, limit uint64, f consumeFunc, forceUpdate bool) (bool, error) + // currentTree returns the sequencer's view of the current tree state. + currentTree(ctx context.Context) (uint64, []byte, error) } -// consumeFunc is the signature of a function which can consume entries from the sequencer. -type consumeFunc func(ctx context.Context, from uint64, entries []storage.SequencedEntry) error +// consumeFunc is the signature of a function which can consume entries from the sequencer and integrate +// them into the log. +// Returns the new rootHash once all passed entries have been integrated. +type consumeFunc func(ctx context.Context, from uint64, entries []storage.SequencedEntry) ([]byte, error) // Config holds GCP project and resource configuration for a storage instance. type Config struct { @@ -130,6 +138,7 @@ func New(ctx context.Context, cfg Config, opts ...func(*options.StorageOptions)) sequencer: seq, newCP: opt.NewCP, entriesPath: opt.EntriesPath, + cpUpdated: make(chan struct{}), } r.queue = storage.NewQueue(ctx, opt.BatchMaxAge, opt.BatchMaxSize, r.sequencer.assignEntries) @@ -154,11 +163,29 @@ func New(ctx context.Context, cfg Config, opts ...func(*options.StorageOptions)) if _, err := r.sequencer.consumeEntries(cctx, DefaultIntegrationSizeLimit, r.integrate, false); err != nil { klog.Errorf("integrate: %v", err) + select { + case r.cpUpdated <- struct{}{}: + default: + } } }() } }() + go func(ctx context.Context, i time.Duration) { + for { + select { + case <-ctx.Done(): + return + case <-r.cpUpdated: + case <-time.After(i): + } + if err := r.publishCheckpoint(ctx, i); err != nil { + klog.Warningf("publishCheckpoint: %v", err) + } + } + }(ctx, opt.CheckpointInterval) + return r, nil } @@ -193,13 +220,17 @@ func (s *Storage) init(ctx context.Context) error { if err != nil { if errors.Is(err, gcs.ErrObjectNotExist) { // No checkpoint exists, do a forced (possibly empty) integration to create one in a safe - // way (calling updateCP directly here would not be safe as it's outside the transactional + // way (setting the checkpoint directly here would not be safe as it's outside the transactional // framework which prevents the tree from rolling backwards or otherwise forking). cctx, c := context.WithTimeout(ctx, 10*time.Second) defer c() if _, err := s.sequencer.consumeEntries(cctx, DefaultIntegrationSizeLimit, s.integrate, true); err != nil { return fmt.Errorf("forced integrate: %v", err) } + select { + case s.cpUpdated <- struct{}{}: + default: + } return nil } return fmt.Errorf("failed to read checkpoint: %v", err) @@ -208,11 +239,24 @@ func (s *Storage) init(ctx context.Context) error { return nil } -func (s *Storage) updateCP(ctx context.Context, newSize uint64, newRoot []byte) error { - cpRaw, err := s.newCP(newSize, newRoot) +func (s *Storage) publishCheckpoint(ctx context.Context, minStaleness time.Duration) error { + m, err := s.objStore.lastModified(ctx, layout.CheckpointPath) + if err != nil && !errors.Is(err, gcs.ErrObjectNotExist) { + return fmt.Errorf("lastModified(%q): %v", layout.CheckpointPath, err) + } + if time.Since(m) < minStaleness { + return nil + } + + size, root, err := s.sequencer.currentTree(ctx) + if err != nil { + return fmt.Errorf("currentTree: %v", err) + } + cpRaw, err := s.newCP(size, root) if err != nil { return fmt.Errorf("newCP: %v", err) } + if err := s.objStore.setObject(ctx, layout.CheckpointPath, cpRaw, nil, ckptContType); err != nil { return fmt.Errorf("writeCheckpoint: %v", err) } @@ -301,7 +345,9 @@ func (s *Storage) setEntryBundle(ctx context.Context, bundleIndex uint64, logSiz } // integrate incorporates the provided entries into the log starting at fromSeq. -func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []storage.SequencedEntry) error { +func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []storage.SequencedEntry) ([]byte, error) { + var newRoot []byte + errG := errgroup.Group{} errG.Go(func() error { @@ -319,11 +365,16 @@ func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []stora } return n, nil } - newSize, newRoot, tiles, err := storage.Integrate(ctx, getTiles, fromSeq, entries) + newSize, root, tiles, err := storage.Integrate(ctx, getTiles, fromSeq, entries) if err != nil { return fmt.Errorf("Integrate: %v", err) } + if newSize > 0 { + newRoot = root + } else { + newRoot = rfc6962.DefaultHasher.EmptyRoot() + } for k, v := range tiles { func(ctx context.Context, k storage.TileID, v *api.HashTile) { errG.Go(func() error { @@ -331,18 +382,12 @@ func (s *Storage) integrate(ctx context.Context, fromSeq uint64, entries []stora }) }(ctx, k, v) } - errG.Go(func() error { - klog.Infof("New CP: %d, %x", newSize, newRoot) - if s.newCP != nil { - return s.updateCP(ctx, newSize, newRoot) - } - return nil - }) + klog.Infof("New tree: %d, %x", newSize, newRoot) return nil }) - return errG.Wait() + return newRoot, errG.Wait() } // updateEntryBundles adds the entries being integrated into the entry bundles. @@ -466,6 +511,7 @@ func (s *spannerSequencer) initDB(ctx context.Context) error { CREATE TABLE IntCoord ( id INT64 NOT NULL, seq INT64 NOT NULL, + rootHash BYTES(32) NOT NULL, ) PRIMARY KEY (id); */ @@ -476,7 +522,7 @@ func (s *spannerSequencer) initDB(ctx context.Context) error { if _, err := s.dbPool.Apply(ctx, []*spanner.Mutation{spanner.Insert("SeqCoord", []string{"id", "next"}, []interface{}{0, 0})}); err != nil && spanner.ErrCode(err) != codes.AlreadyExists { return err } - if _, err := s.dbPool.Apply(ctx, []*spanner.Mutation{spanner.Insert("IntCoord", []string{"id", "seq"}, []interface{}{0, 0})}); err != nil && spanner.ErrCode(err) != codes.AlreadyExists { + if _, err := s.dbPool.Apply(ctx, []*spanner.Mutation{spanner.Insert("IntCoord", []string{"id", "seq", "rootHash"}, []interface{}{0, 0, rfc6962.DefaultHasher.EmptyRoot()})}); err != nil && spanner.ErrCode(err) != codes.AlreadyExists { return err } return nil @@ -569,12 +615,13 @@ func (s *spannerSequencer) consumeEntries(ctx context.Context, limit uint64, f c didWork := false _, err := s.dbPool.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { // Figure out which is the starting index of sequenced entries to start consuming from. - row, err := txn.ReadRowWithOptions(ctx, "IntCoord", spanner.Key{0}, []string{"seq"}, &spanner.ReadOptions{LockHint: spannerpb.ReadRequest_LOCK_HINT_EXCLUSIVE}) + row, err := txn.ReadRowWithOptions(ctx, "IntCoord", spanner.Key{0}, []string{"seq", "rootHash"}, &spanner.ReadOptions{LockHint: spannerpb.ReadRequest_LOCK_HINT_EXCLUSIVE}) if err != nil { return err } var fromSeq int64 // Spanner doesn't support uint64 - if err := row.Column(0, &fromSeq); err != nil { + var rootHash []byte + if err := row.Columns(&fromSeq, &rootHash); err != nil { return fmt.Errorf("failed to read integration coordination info: %v", err) } klog.V(1).Infof("Consuming from %d", fromSeq) @@ -620,14 +667,15 @@ func (s *spannerSequencer) consumeEntries(ctx context.Context, limit uint64, f c } // Call consumeFunc with the entries we've found - if err := f(ctx, uint64(fromSeq), entries); err != nil { + newRoot, err := f(ctx, uint64(fromSeq), entries) + if err != nil { return err } // consumeFunc was successful, so we can update our coordination row, and delete the row(s) for // the then consumed entries. m := make([]*spanner.Mutation, 0) - m = append(m, spanner.Update("IntCoord", []string{"id", "seq"}, []interface{}{0, int64(orderCheck)})) + m = append(m, spanner.Update("IntCoord", []string{"id", "seq", "rootHash"}, []interface{}{0, int64(orderCheck), newRoot})) for _, c := range seqsConsumed { m = append(m, spanner.Delete("Seq", spanner.Key{0, c})) } @@ -647,6 +695,21 @@ func (s *spannerSequencer) consumeEntries(ctx context.Context, limit uint64, f c return didWork, nil } +// currentTree returns the size and root hash of the currently integrated tree. +func (s *spannerSequencer) currentTree(ctx context.Context) (uint64, []byte, error) { + row, err := s.dbPool.Single().ReadRow(ctx, "IntCoord", spanner.Key{0}, []string{"seq", "rootHash"}) + if err != nil { + return 0, nil, fmt.Errorf("failed to read IntCoord: %v", err) + } + var fromSeq int64 // Spanner doesn't support uint64 + var rootHash []byte + if err := row.Columns(&fromSeq, &rootHash); err != nil { + return 0, nil, fmt.Errorf("failed to read integration coordination info: %v", err) + } + + return uint64(fromSeq), rootHash, nil +} + // gcsStorage knows how to store and retrieve objects from GCS. type gcsStorage struct { bucket string @@ -713,3 +776,11 @@ func (s *gcsStorage) setObject(ctx context.Context, objName string, data []byte, } return nil } + +func (s *gcsStorage) lastModified(ctx context.Context, obj string) (time.Time, error) { + r, err := s.gcsClient.Bucket(s.bucket).Object(obj).NewReader(ctx) + if err != nil { + return time.Time{}, fmt.Errorf("failed to create reader for object %q in bucket %q: %w", obj, s.bucket, err) + } + return r.Attrs.LastModified, r.Close() +} diff --git a/storage/gcp/gcp_test.go b/storage/gcp/gcp_test.go index fce99317..cdfe62a2 100644 --- a/storage/gcp/gcp_test.go +++ b/storage/gcp/gcp_test.go @@ -24,6 +24,7 @@ import ( "reflect" "sync" "testing" + "time" "cloud.google.com/go/spanner/spannertest" "cloud.google.com/go/spanner/spansql" @@ -32,7 +33,7 @@ import ( tessera "github.com/transparency-dev/trillian-tessera" "github.com/transparency-dev/trillian-tessera/api" "github.com/transparency-dev/trillian-tessera/api/layout" - "github.com/transparency-dev/trillian-tessera/storage/internal" + storage "github.com/transparency-dev/trillian-tessera/storage/internal" ) func newSpannerDB(t *testing.T) func() { @@ -45,7 +46,7 @@ func newSpannerDB(t *testing.T) func() { dml, err := spansql.ParseDDL("", ` CREATE TABLE SeqCoord (id INT64 NOT NULL, next INT64 NOT NULL,) PRIMARY KEY (id); CREATE TABLE Seq (id INT64 NOT NULL, seq INT64 NOT NULL, v BYTES(MAX),) PRIMARY KEY (id, seq); - CREATE TABLE IntCoord (id INT64 NOT NULL, seq INT64 NOT NULL,) PRIMARY KEY (id); + CREATE TABLE IntCoord (id INT64 NOT NULL, seq INT64 NOT NULL, rootHash BYTES(32) NOT NULL,) PRIMARY KEY (id); `) if err != nil { t.Fatalf("Invalid DDL: %v", err) @@ -170,18 +171,18 @@ func TestSpannerSequencerRoundTrip(t *testing.T) { } seenIdx := uint64(0) - f := func(_ context.Context, fromSeq uint64, entries []storage.SequencedEntry) error { + f := func(_ context.Context, fromSeq uint64, entries []storage.SequencedEntry) ([]byte, error) { if fromSeq != seenIdx { - return fmt.Errorf("f called with fromSeq %d, want %d", fromSeq, seenIdx) + return nil, fmt.Errorf("f called with fromSeq %d, want %d", fromSeq, seenIdx) } for i, e := range entries { if got, want := e, wantEntries[i]; !reflect.DeepEqual(got, want) { - return fmt.Errorf("entry %d+%d != %d", fromSeq, i, seenIdx) + return nil, fmt.Errorf("entry %d+%d != %d", fromSeq, i, seenIdx) } seenIdx++ } - return nil + return []byte(fmt.Sprintf("root<%d>", seenIdx)), nil } more, err := s.consumeEntries(ctx, 7, f, false) @@ -304,9 +305,72 @@ func TestBundleRoundtrip(t *testing.T) { } } +func TestPublishCheckpoint(t *testing.T) { + ctx := context.Background() + + close := newSpannerDB(t) + defer close() + + s, err := newSpannerSequencer(ctx, "projects/p/instances/i/databases/d", 1000) + if err != nil { + t.Fatalf("newSpannerSequencer: %v", err) + } + + for _, test := range []struct { + name string + cpModifiedAt time.Time + publishInterval time.Duration + wantUpdate bool + }{ + { + name: "works ok", + cpModifiedAt: time.Now().Add(-15 * time.Second), + publishInterval: 10 * time.Second, + wantUpdate: true, + }, { + name: "too soon, skip update", + cpModifiedAt: time.Now().Add(-5 * time.Second), + publishInterval: 10 * time.Second, + wantUpdate: false, + }, + } { + t.Run(test.name, func(t *testing.T) { + m := newMemObjStore() + storage := &Storage{ + objStore: m, + sequencer: s, + entriesPath: layout.EntriesPath, + newCP: func(size uint64, hash []byte) ([]byte, error) { return []byte(fmt.Sprintf("%d/%x,", size, hash)), nil }, + } + // Call init so we've got a zero-sized checkpoint to work with. + if err := storage.init(ctx); err != nil { + t.Fatalf("storage.init: %v", err) + } + m.delete(layout.CheckpointPath) + m.lMod = test.cpModifiedAt + if err := storage.publishCheckpoint(ctx, test.publishInterval); err != nil { + t.Fatalf("publishCheckpoint: %v", err) + } + _, _, err := m.getObject(ctx, layout.CheckpointPath) + cpUpdated := true + if err != nil { + if !errors.Is(err, gcs.ErrObjectNotExist) { + t.Fatalf("getObject: %v", err) + } + cpUpdated = false + } + if test.wantUpdate != cpUpdated { + t.Fatalf("got cpUpdated=%t, want %t", cpUpdated, test.wantUpdate) + } + }) + } + +} + type memObjStore struct { sync.RWMutex - mem map[string][]byte + mem map[string][]byte + lMod time.Time } func newMemObjStore() *memObjStore { @@ -343,3 +407,13 @@ func (m *memObjStore) setObject(_ context.Context, obj string, data []byte, cond m.mem[obj] = data return nil } + +func (m *memObjStore) delete(obj string) { + m.Lock() + defer m.Unlock() + delete(m.mem, obj) +} + +func (m *memObjStore) lastModified(_ context.Context, obj string) (time.Time, error) { + return m.lMod, nil +}