Skip to content

Commit

Permalink
Initial publish implementation (#101)
Browse files Browse the repository at this point in the history
Initial implementation of the publish endpoint and unit test setup. Simply writes to the staged envelopes table and does nothing else.

Also fixed some issues around DB and connection lifetimes - the migration step now closes the connection it uses (but not the whole DB instance).
  • Loading branch information
richardhuaaa authored Aug 9, 2024
2 parents 2ac15f1 + 21cfd44 commit c5f1b14
Show file tree
Hide file tree
Showing 15 changed files with 277 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Extension identifier format: ${publisher}.${name}. Example: vscode.csharp

// List of extensions which should be recommended for users of this workspace.
"recommendations": ["bradymholt.pgformatter", "golang.go"],
"recommendations": ["bradymholt.pgformatter", "golang.go", "emeraldwalk.runonsave"],
// List of extensions recommended by VS Code that should not be recommended for users of this workspace.
"unwantedRecommendations": []
}
9 changes: 9 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,14 @@
"pgFormatter.tabs": true,
"[sql]": {
"editor.defaultFormatter": "bradymholt.pgformatter"
},
// Instructions from https://github.com/segmentio/golines
"emeraldwalk.runonsave": {
"commands": [
{
"match": "\\.go$",
"cmd": "golines ${file} -w"
}
]
}
}
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ The `xmtpd` node build provides two options for monitoring your node.

To learn how to visualize node data in Grafana, see [Prometheus Histograms with Grafana Heatmaps](https://towardsdatascience.com/prometheus-histograms-with-grafana-heatmaps-d556c28612c7) and [How to visualize Prometheus histograms in Grafana](https://grafana.com/blog/2020/06/23/how-to-visualize-prometheus-histograms-in-grafana/).

# Contributing

Please follow the [style guide](https://google.github.io/styleguide/go/decisions).

## Modifying the protobuf schema

Submit and land a PR to https://github.com/xmtp/proto. Then run:
Expand Down
2 changes: 1 addition & 1 deletion cmd/replication/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func addEnvVars() {
}

if connStr, hasConnstr := os.LookupEnv("READER_DB_CONNECTION_STRING"); hasConnstr {
options.DB.WriterConnectionString = connStr
options.DB.ReaderConnectionString = connStr
}

if privKey, hasPrivKey := os.LookupEnv("PRIVATE_KEY"); hasPrivKey {
Expand Down
1 change: 1 addition & 0 deletions dev/up
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ if ! which golangci-lint &>/dev/null; then brew install golangci-lint; fi
if ! which shellcheck &>/dev/null; then brew install shellcheck; fi
if ! which mockery &>/dev/null; then brew install mockery; fi
if ! which sqlc &> /dev/null; then brew install sqlc; fi
if ! which golines &>/dev/null; then go install github.com/segmentio/golines@latest; fi

dev/generate
dev/docker/up
19 changes: 11 additions & 8 deletions pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"database/sql"
"fmt"
"net"
"strings"
Expand All @@ -20,24 +21,26 @@ import (
)

type ApiServer struct {
ctx context.Context
db *sql.DB
grpcListener net.Listener
log *zap.Logger
service message_api.ReplicationApiServer
wg sync.WaitGroup
grpcListener net.Listener
ctx context.Context
service *message_api.ReplicationApiServer
}

func NewAPIServer(ctx context.Context, log *zap.Logger, port int) (*ApiServer, error) {
func NewAPIServer(ctx context.Context, writerDB *sql.DB, log *zap.Logger, port int) (*ApiServer, error) {
grpcListener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port))

if err != nil {
return nil, err
}
s := &ApiServer{
log: log.Named("api"),
ctx: ctx,
wg: sync.WaitGroup{},
db: writerDB,
grpcListener: &proxyproto.Listener{Listener: grpcListener, ReadHeaderTimeout: 10 * time.Second},
log: log.Named("api"),
wg: sync.WaitGroup{},
}

// TODO: Add interceptors
Expand All @@ -58,11 +61,11 @@ func NewAPIServer(ctx context.Context, log *zap.Logger, port int) (*ApiServer, e
healthcheck := health.NewServer()
healthgrpc.RegisterHealthServer(grpcServer, healthcheck)

replicationService, err := NewReplicationApiService(ctx, log)
replicationService, err := NewReplicationApiService(ctx, log, writerDB)
if err != nil {
return nil, err
}
s.service = &replicationService
s.service = replicationService

tracing.GoPanicWrap(s.ctx, &s.wg, "grpc", func(ctx context.Context) {
s.log.Info("serving grpc", zap.String("address", s.grpcListener.Addr().String()))
Expand Down
65 changes: 59 additions & 6 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,82 @@ package api

import (
"context"
"database/sql"

"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/utils"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"

"go.uber.org/zap"
)

type Service struct {
message_api.UnimplementedReplicationApiServer

ctx context.Context
log *zap.Logger
ctx context.Context
log *zap.Logger
queries *queries.Queries
}

func NewReplicationApiService(ctx context.Context, log *zap.Logger) (message_api.ReplicationApiServer, error) {
return &Service{ctx: ctx, log: log}, nil
func NewReplicationApiService(
ctx context.Context,
log *zap.Logger,
writerDB *sql.DB,
) (*Service, error) {
return &Service{ctx: ctx, log: log, queries: queries.New(writerDB)}, nil
}

func (s *Service) BatchSubscribeEnvelopes(req *message_api.BatchSubscribeEnvelopesRequest, server message_api.ReplicationApi_BatchSubscribeEnvelopesServer) error {
func (s *Service) Close() {
s.log.Info("closed")
}

func (s *Service) BatchSubscribeEnvelopes(
req *message_api.BatchSubscribeEnvelopesRequest,
server message_api.ReplicationApi_BatchSubscribeEnvelopesServer,
) error {
return status.Errorf(codes.Unimplemented, "method BatchSubscribeEnvelopes not implemented")
}

func (s *Service) QueryEnvelopes(ctx context.Context, req *message_api.QueryEnvelopesRequest) (*message_api.QueryEnvelopesResponse, error) {
func (s *Service) QueryEnvelopes(
ctx context.Context,
req *message_api.QueryEnvelopesRequest,
) (*message_api.QueryEnvelopesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method QueryEnvelopes not implemented")
}

func (s *Service) PublishEnvelope(
ctx context.Context,
req *message_api.PublishEnvelopeRequest,
) (*message_api.PublishEnvelopeResponse, error) {
payerEnv := req.GetPayerEnvelope()
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
}
// TODO(rich): Verify payer signature
// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)
// TODO(rich): If it is a commit, publish it to blockchain instead

payerBytes, err := proto.Marshal(payerEnv)
if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err)
}

stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes)
if err != nil {
return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err)
}

originatorEnv, err := utils.SignStagedEnvelope(stagedEnv)
if err != nil {
return nil, status.Errorf(codes.Internal, "could not sign envelope: %v", err)
}

return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil
}
53 changes: 53 additions & 0 deletions pkg/api/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package api

import (
"context"
"database/sql"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
test "github.com/xmtp/xmtpd/pkg/testing"
"google.golang.org/protobuf/proto"
)

func newTestService(t *testing.T) (*Service, *sql.DB, func()) {
ctx := context.Background()
log := test.NewLog(t)
db, _, dbCleanup := test.NewDB(t, ctx)

svc, err := NewReplicationApiService(ctx, log, db)
require.NoError(t, err)

return svc, db, func() {
svc.Close()
dbCleanup()
}
}

func TestSimplePublish(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

resp, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: &message_api.PayerEnvelope{
UnsignedClientEnvelope: []byte{0x5},
PayerSignature: &associations.RecoverableEcdsaSignature{},
},
},
)
require.NoError(t, err)
require.NotNil(t, resp)

unsignedEnv := &message_api.UnsignedOriginatorEnvelope{}
require.NoError(
t,
proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv),
)
require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0])

// TODO(rich) Test that the published envelope is retrievable via the query API
}
27 changes: 25 additions & 2 deletions pkg/db/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@ import (

"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/xmtp/xmtpd/pkg/migrations"
)

func NewDB(ctx context.Context, dsn string, waitForDB, statementTimeout time.Duration) (*sql.DB, error) {
func newPGXDB(
ctx context.Context,
dsn string,
waitForDB, statementTimeout time.Duration,
) (*sql.DB, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, err
}

config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint(statementTimeout.Milliseconds())
config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint(
statementTimeout.Milliseconds(),
)

dbPool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
Expand All @@ -35,3 +42,19 @@ func NewDB(ctx context.Context, dsn string, waitForDB, statementTimeout time.Dur

return db, err
}

func NewDB(
ctx context.Context,
dsn string,
waitForDB, statementTimeout time.Duration,
) (*sql.DB, error) {
db, err := newPGXDB(ctx, dsn, waitForDB, statementTimeout)
if err != nil {
return nil, err
}
err = migrations.Migrate(ctx, db)
if err != nil {
return nil, err
}
return db, nil
}
17 changes: 15 additions & 2 deletions pkg/migrations/migrations.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package migrations

import (
"context"
"database/sql"
"embed"

Expand All @@ -12,23 +13,35 @@ import (
//go:embed *.sql
var migrationFs embed.FS

func Migrate(db *sql.DB) error {
func Migrate(ctx context.Context, db *sql.DB) error {
migrationFs, err := iofs.New(migrationFs, ".")
if err != nil {
return err
}
driver, err := postgres.WithInstance(db, &postgres.Config{})
defer migrationFs.Close()

conn, err := db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()

driver, err := postgres.WithConnection(ctx, conn, &postgres.Config{})
if err != nil {
return err
}
defer driver.Close()

migrator, err := migrate.NewWithInstance("iofs", migrationFs, "postgres", driver)
if err != nil {
return err
}
defer migrator.Close()

err = migrator.Up()
if err != nil && err != migrate.ErrNoChange {
return err
}

return nil
}
11 changes: 3 additions & 8 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/xmtp/xmtpd/pkg/api"
"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/migrations"
"github.com/xmtp/xmtpd/pkg/registry"
"go.uber.org/zap"
)
Expand All @@ -25,7 +24,7 @@ type ReplicationServer struct {
apiServer *api.ApiServer
nodeRegistry registry.NodeRegistry
privateKey *ecdsa.PrivateKey
writerDb *sql.DB
writerDB *sql.DB
// Can add reader DB later if needed
}

Expand All @@ -40,24 +39,20 @@ func NewReplicationServer(ctx context.Context, log *zap.Logger, options Options,
if err != nil {
return nil, err
}
s.writerDb, err = db.NewDB(ctx, options.DB.WriterConnectionString, options.DB.WaitForDB, options.DB.ReadTimeout)
s.writerDB, err = db.NewDB(ctx, options.DB.WriterConnectionString, options.DB.WaitForDB, options.DB.ReadTimeout)
if err != nil {
return nil, err
}

s.ctx, s.cancel = context.WithCancel(ctx)
s.apiServer, err = api.NewAPIServer(ctx, log, options.API.Port)
s.apiServer, err = api.NewAPIServer(ctx, s.writerDB, log, options.API.Port)
if err != nil {
return nil, err
}
log.Info("Replication server started", zap.Int("port", options.API.Port))
return s, nil
}

func (s *ReplicationServer) Migrate() error {
return migrations.Migrate(s.writerDb)
}

func (s *ReplicationServer) Addr() net.Addr {
return s.apiServer.Addr()
}
Expand Down
6 changes: 0 additions & 6 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,3 @@ func TestCreateServer(t *testing.T) {
server2 := NewTestServer(t, registry)
require.NotEqual(t, server1.Addr(), server2.Addr())
}

func TestMigrate(t *testing.T) {
registry := registry.NewFixedNodeRegistry([]registry.Node{})
server := NewTestServer(t, registry)
require.NoError(t, server.Migrate())
}
Loading

0 comments on commit c5f1b14

Please sign in to comment.