diff --git a/.vscode/extensions.json b/.vscode/extensions.json index a72345bb..449f697e 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -3,7 +3,7 @@ // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp // List of extensions which should be recommended for users of this workspace. - "recommendations": ["bradymholt.pgformatter", "golang.go"], + "recommendations": ["bradymholt.pgformatter", "golang.go", "emeraldwalk.runonsave"], // List of extensions recommended by VS Code that should not be recommended for users of this workspace. "unwantedRecommendations": [] } diff --git a/.vscode/settings.json b/.vscode/settings.json index 12fb01d4..81a48377 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,5 +5,14 @@ "pgFormatter.tabs": true, "[sql]": { "editor.defaultFormatter": "bradymholt.pgformatter" + }, + // Instructions from https://github.com/segmentio/golines + "emeraldwalk.runonsave": { + "commands": [ + { + "match": "\\.go$", + "cmd": "golines ${file} -w" + } + ] } } diff --git a/README.md b/README.md index c4b39c5d..565bff4f 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,10 @@ The `xmtpd` node build provides two options for monitoring your node. To learn how to visualize node data in Grafana, see [Prometheus Histograms with Grafana Heatmaps](https://towardsdatascience.com/prometheus-histograms-with-grafana-heatmaps-d556c28612c7) and [How to visualize Prometheus histograms in Grafana](https://grafana.com/blog/2020/06/23/how-to-visualize-prometheus-histograms-in-grafana/). +# Contributing + +Please follow the [style guide](https://google.github.io/styleguide/go/decisions). + ## Modifying the protobuf schema Submit and land a PR to https://github.com/xmtp/proto. Then run: diff --git a/cmd/replication/main.go b/cmd/replication/main.go index e1906b3d..1fc47020 100644 --- a/cmd/replication/main.go +++ b/cmd/replication/main.go @@ -69,7 +69,7 @@ func addEnvVars() { } if connStr, hasConnstr := os.LookupEnv("READER_DB_CONNECTION_STRING"); hasConnstr { - options.DB.WriterConnectionString = connStr + options.DB.ReaderConnectionString = connStr } if privKey, hasPrivKey := os.LookupEnv("PRIVATE_KEY"); hasPrivKey { diff --git a/dev/up b/dev/up index f3b186a8..ade11116 100755 --- a/dev/up +++ b/dev/up @@ -8,6 +8,7 @@ if ! which golangci-lint &>/dev/null; then brew install golangci-lint; fi if ! which shellcheck &>/dev/null; then brew install shellcheck; fi if ! which mockery &>/dev/null; then brew install mockery; fi if ! which sqlc &> /dev/null; then brew install sqlc; fi +if ! which golines &>/dev/null; then go install github.com/segmentio/golines@latest; fi dev/generate dev/docker/up diff --git a/pkg/api/server.go b/pkg/api/server.go index b60817c7..f0f0d4b6 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -2,6 +2,7 @@ package api import ( "context" + "database/sql" "fmt" "net" "strings" @@ -20,24 +21,26 @@ import ( ) type ApiServer struct { + ctx context.Context + db *sql.DB + grpcListener net.Listener log *zap.Logger + service message_api.ReplicationApiServer wg sync.WaitGroup - grpcListener net.Listener - ctx context.Context - service *message_api.ReplicationApiServer } -func NewAPIServer(ctx context.Context, log *zap.Logger, port int) (*ApiServer, error) { +func NewAPIServer(ctx context.Context, writerDB *sql.DB, log *zap.Logger, port int) (*ApiServer, error) { grpcListener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port)) if err != nil { return nil, err } s := &ApiServer{ - log: log.Named("api"), ctx: ctx, - wg: sync.WaitGroup{}, + db: writerDB, grpcListener: &proxyproto.Listener{Listener: grpcListener, ReadHeaderTimeout: 10 * time.Second}, + log: log.Named("api"), + wg: sync.WaitGroup{}, } // TODO: Add interceptors @@ -58,11 +61,11 @@ func NewAPIServer(ctx context.Context, log *zap.Logger, port int) (*ApiServer, e healthcheck := health.NewServer() healthgrpc.RegisterHealthServer(grpcServer, healthcheck) - replicationService, err := NewReplicationApiService(ctx, log) + replicationService, err := NewReplicationApiService(ctx, log, writerDB) if err != nil { return nil, err } - s.service = &replicationService + s.service = replicationService tracing.GoPanicWrap(s.ctx, &s.wg, "grpc", func(ctx context.Context) { s.log.Info("serving grpc", zap.String("address", s.grpcListener.Addr().String())) diff --git a/pkg/api/service.go b/pkg/api/service.go index 6cb9d13a..e2ccdaf1 100644 --- a/pkg/api/service.go +++ b/pkg/api/service.go @@ -2,10 +2,14 @@ package api import ( "context" + "database/sql" + "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "github.com/xmtp/xmtpd/pkg/utils" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "go.uber.org/zap" ) @@ -13,18 +17,67 @@ import ( type Service struct { message_api.UnimplementedReplicationApiServer - ctx context.Context - log *zap.Logger + ctx context.Context + log *zap.Logger + queries *queries.Queries } -func NewReplicationApiService(ctx context.Context, log *zap.Logger) (message_api.ReplicationApiServer, error) { - return &Service{ctx: ctx, log: log}, nil +func NewReplicationApiService( + ctx context.Context, + log *zap.Logger, + writerDB *sql.DB, +) (*Service, error) { + return &Service{ctx: ctx, log: log, queries: queries.New(writerDB)}, nil } -func (s *Service) BatchSubscribeEnvelopes(req *message_api.BatchSubscribeEnvelopesRequest, server message_api.ReplicationApi_BatchSubscribeEnvelopesServer) error { +func (s *Service) Close() { + s.log.Info("closed") +} + +func (s *Service) BatchSubscribeEnvelopes( + req *message_api.BatchSubscribeEnvelopesRequest, + server message_api.ReplicationApi_BatchSubscribeEnvelopesServer, +) error { return status.Errorf(codes.Unimplemented, "method BatchSubscribeEnvelopes not implemented") } -func (s *Service) QueryEnvelopes(ctx context.Context, req *message_api.QueryEnvelopesRequest) (*message_api.QueryEnvelopesResponse, error) { +func (s *Service) QueryEnvelopes( + ctx context.Context, + req *message_api.QueryEnvelopesRequest, +) (*message_api.QueryEnvelopesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method QueryEnvelopes not implemented") } + +func (s *Service) PublishEnvelope( + ctx context.Context, + req *message_api.PublishEnvelopeRequest, +) (*message_api.PublishEnvelopeResponse, error) { + payerEnv := req.GetPayerEnvelope() + clientBytes := payerEnv.GetUnsignedClientEnvelope() + sig := payerEnv.GetPayerSignature() + if (clientBytes == nil) || (sig == nil) { + return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature") + } + // TODO(rich): Verify payer signature + // TODO(rich): Verify all originators have synced past `last_originator_sids` + // TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group + // TODO(rich): Perform any payload-specific validation (e.g. identity updates) + // TODO(rich): If it is a commit, publish it to blockchain instead + + payerBytes, err := proto.Marshal(payerEnv) + if err != nil { + return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err) + } + + stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes) + if err != nil { + return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err) + } + + originatorEnv, err := utils.SignStagedEnvelope(stagedEnv) + if err != nil { + return nil, status.Errorf(codes.Internal, "could not sign envelope: %v", err) + } + + return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil +} diff --git a/pkg/api/service_test.go b/pkg/api/service_test.go new file mode 100644 index 00000000..bfb4fa34 --- /dev/null +++ b/pkg/api/service_test.go @@ -0,0 +1,53 @@ +package api + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/proto/identity/associations" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + test "github.com/xmtp/xmtpd/pkg/testing" + "google.golang.org/protobuf/proto" +) + +func newTestService(t *testing.T) (*Service, *sql.DB, func()) { + ctx := context.Background() + log := test.NewLog(t) + db, _, dbCleanup := test.NewDB(t, ctx) + + svc, err := NewReplicationApiService(ctx, log, db) + require.NoError(t, err) + + return svc, db, func() { + svc.Close() + dbCleanup() + } +} + +func TestSimplePublish(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + resp, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: &message_api.PayerEnvelope{ + UnsignedClientEnvelope: []byte{0x5}, + PayerSignature: &associations.RecoverableEcdsaSignature{}, + }, + }, + ) + require.NoError(t, err) + require.NotNil(t, resp) + + unsignedEnv := &message_api.UnsignedOriginatorEnvelope{} + require.NoError( + t, + proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv), + ) + require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0]) + + // TODO(rich) Test that the published envelope is retrievable via the query API +} diff --git a/pkg/db/pgx.go b/pkg/db/pgx.go index dbfbd384..dc78b711 100644 --- a/pkg/db/pgx.go +++ b/pkg/db/pgx.go @@ -8,15 +8,22 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" + "github.com/xmtp/xmtpd/pkg/migrations" ) -func NewDB(ctx context.Context, dsn string, waitForDB, statementTimeout time.Duration) (*sql.DB, error) { +func newPGXDB( + ctx context.Context, + dsn string, + waitForDB, statementTimeout time.Duration, +) (*sql.DB, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, err } - config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint(statementTimeout.Milliseconds()) + config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint( + statementTimeout.Milliseconds(), + ) dbPool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { @@ -35,3 +42,19 @@ func NewDB(ctx context.Context, dsn string, waitForDB, statementTimeout time.Dur return db, err } + +func NewDB( + ctx context.Context, + dsn string, + waitForDB, statementTimeout time.Duration, +) (*sql.DB, error) { + db, err := newPGXDB(ctx, dsn, waitForDB, statementTimeout) + if err != nil { + return nil, err + } + err = migrations.Migrate(ctx, db) + if err != nil { + return nil, err + } + return db, nil +} diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 40ae60e9..38d361b2 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -1,6 +1,7 @@ package migrations import ( + "context" "database/sql" "embed" @@ -12,23 +13,35 @@ import ( //go:embed *.sql var migrationFs embed.FS -func Migrate(db *sql.DB) error { +func Migrate(ctx context.Context, db *sql.DB) error { migrationFs, err := iofs.New(migrationFs, ".") if err != nil { return err } - driver, err := postgres.WithInstance(db, &postgres.Config{}) + defer migrationFs.Close() + + conn, err := db.Conn(ctx) if err != nil { return err } + defer conn.Close() + + driver, err := postgres.WithConnection(ctx, conn, &postgres.Config{}) + if err != nil { + return err + } + defer driver.Close() + migrator, err := migrate.NewWithInstance("iofs", migrationFs, "postgres", driver) if err != nil { return err } + defer migrator.Close() err = migrator.Up() if err != nil && err != migrate.ErrNoChange { return err } + return nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index 6c8bf4a7..2c169066 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -12,7 +12,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/xmtp/xmtpd/pkg/api" "github.com/xmtp/xmtpd/pkg/db" - "github.com/xmtp/xmtpd/pkg/migrations" "github.com/xmtp/xmtpd/pkg/registry" "go.uber.org/zap" ) @@ -25,7 +24,7 @@ type ReplicationServer struct { apiServer *api.ApiServer nodeRegistry registry.NodeRegistry privateKey *ecdsa.PrivateKey - writerDb *sql.DB + writerDB *sql.DB // Can add reader DB later if needed } @@ -40,13 +39,13 @@ func NewReplicationServer(ctx context.Context, log *zap.Logger, options Options, if err != nil { return nil, err } - s.writerDb, err = db.NewDB(ctx, options.DB.WriterConnectionString, options.DB.WaitForDB, options.DB.ReadTimeout) + s.writerDB, err = db.NewDB(ctx, options.DB.WriterConnectionString, options.DB.WaitForDB, options.DB.ReadTimeout) if err != nil { return nil, err } s.ctx, s.cancel = context.WithCancel(ctx) - s.apiServer, err = api.NewAPIServer(ctx, log, options.API.Port) + s.apiServer, err = api.NewAPIServer(ctx, s.writerDB, log, options.API.Port) if err != nil { return nil, err } @@ -54,10 +53,6 @@ func NewReplicationServer(ctx context.Context, log *zap.Logger, options Options, return s, nil } -func (s *ReplicationServer) Migrate() error { - return migrations.Migrate(s.writerDb) -} - func (s *ReplicationServer) Addr() net.Addr { return s.apiServer.Addr() } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 22e89079..9bf33b04 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -43,9 +43,3 @@ func TestCreateServer(t *testing.T) { server2 := NewTestServer(t, registry) require.NotEqual(t, server1.Addr(), server2.Addr()) } - -func TestMigrate(t *testing.T) { - registry := registry.NewFixedNodeRegistry([]registry.Node{}) - server := NewTestServer(t, registry) - require.NoError(t, server.Migrate()) -} diff --git a/pkg/testing/store.go b/pkg/testing/store.go new file mode 100644 index 00000000..1ff8dd35 --- /dev/null +++ b/pkg/testing/store.go @@ -0,0 +1,46 @@ +package testing + +import ( + "context" + "database/sql" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/migrations" +) + +const ( + localTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765" + localTestDBDSNSuffix = "?sslmode=disable" +) + +func newPGXDB(t testing.TB) (*sql.DB, string, func()) { + dsn := localTestDBDSNPrefix + localTestDBDSNSuffix + config, err := pgx.ParseConfig(dsn) + require.NoError(t, err) + ctlDB := stdlib.OpenDB(*config) + dbName := "test_" + RandomStringLower(12) + _, err = ctlDB.Exec("CREATE DATABASE " + dbName) + require.NoError(t, err) + + dsn = localTestDBDSNPrefix + "/" + dbName + localTestDBDSNSuffix + config2, err := pgx.ParseConfig(dsn) + require.NoError(t, err) + db := stdlib.OpenDB(*config2) + return db, dsn, func() { + err := db.Close() + require.NoError(t, err) + _, err = ctlDB.Exec("DROP DATABASE " + dbName) + require.NoError(t, err) + ctlDB.Close() + } +} + +func NewDB(t *testing.T, ctx context.Context) (*sql.DB, string, func()) { + db, dsn, cleanup := newPGXDB(t) + require.NoError(t, migrations.Migrate(ctx, db)) + + return db, dsn, cleanup +} diff --git a/pkg/utils/envelope.go b/pkg/utils/envelope.go new file mode 100644 index 00000000..398bffcb --- /dev/null +++ b/pkg/utils/envelope.go @@ -0,0 +1,31 @@ +package utils + +import ( + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "google.golang.org/protobuf/proto" +) + +func SignStagedEnvelope( + stagedEnv queries.StagedOriginatorEnvelope, +) (*message_api.OriginatorEnvelope, error) { + payerEnv := &message_api.PayerEnvelope{} + if err := proto.Unmarshal(stagedEnv.PayerEnvelope, payerEnv); err != nil { + return nil, err + } + unsignedEnv := message_api.UnsignedOriginatorEnvelope{ + OriginatorSid: SID(stagedEnv.ID), + OriginatorNs: stagedEnv.OriginatorTime.UnixNano(), + PayerEnvelope: payerEnv, + } + unsignedBytes, err := proto.Marshal(&unsignedEnv) + if err != nil { + return nil, err + } + // TODO(rich): Plumb through public key and properly sign envelope + signedEnv := message_api.OriginatorEnvelope{ + UnsignedOriginatorEnvelope: unsignedBytes, + Proof: nil, + } + return &signedEnv, nil +} diff --git a/pkg/utils/sid.go b/pkg/utils/sid.go new file mode 100644 index 00000000..d489ce84 --- /dev/null +++ b/pkg/utils/sid.go @@ -0,0 +1,18 @@ +package utils + +const ( + nodeIDMask uint64 = 0xFFFF << 48 + localIDMask uint64 = ^nodeIDMask +) + +// Converts a local serial ID from the database into a global SID with a node ID prefix +func SID(localID int64) uint64 { + nodeMask := uint64(localID) & nodeIDMask + if localID < 0 || nodeMask != 0 { + // Either indicates ID exhaustion or developer error - + // the service should not continue running either way + panic("Invalid local ID") + } + // TODO(rich): Plumb through and set node ID + return uint64(localID) +}