Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add publish worker #115

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions pkg/api/publishWorker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package api

import (
"context"
"database/sql"
"time"

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/registrant"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
)

type PublishWorker struct {
ctx context.Context
log *zap.Logger
listener <-chan []queries.StagedOriginatorEnvelope
notifier chan<- bool
registrant *registrant.Registrant
store *sql.DB
subscription db.DBSubscription[queries.StagedOriginatorEnvelope]
}

func StartPublishWorker(
ctx context.Context,
log *zap.Logger,
reg *registrant.Registrant,
store *sql.DB,
) (*PublishWorker, error) {
q := queries.New(store)
query := func(ctx context.Context, lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) {
results, err := q.SelectStagedOriginatorEnvelopes(
ctx,
queries.SelectStagedOriginatorEnvelopesParams{
LastSeenID: lastSeenID,
NumRows: numRows,
},
)
if err != nil {
return nil, 0, err
}
if len(results) > 0 {
lastSeenID = results[len(results)-1].ID
}
return results, lastSeenID, nil
}
notifier := make(chan bool, 1)
subscription := db.NewDBSubscription(
ctx,
log,
query,
0, // lastSeenID
db.PollingOptions{Interval: time.Second, Notifier: notifier, NumRows: 100},
)
listener, err := subscription.Start()
if err != nil {
return nil, err
}

worker := &PublishWorker{
ctx: ctx,
log: log,
notifier: notifier,
subscription: *subscription,
listener: listener,
registrant: reg,
store: store,
}
go worker.start()

return worker, nil
}

func (p *PublishWorker) NotifyStagedPublish() {
select {
case p.notifier <- true:
default:
}
}

func (p *PublishWorker) start() {
for {
select {
case <-p.ctx.Done():
return
case new_batch := <-p.listener:
for _, stagedEnv := range new_batch {
for !p.publishStagedEnvelope(stagedEnv) {
// Infinite retry on failure to publish; we cannot
// continue to the next envelope until this one is processed
time.Sleep(time.Second)
}
}
}
}
}

func (p *PublishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool {
logger := p.log.With(zap.Int64("sequenceID", stagedEnv.ID))
originatorEnv, err := p.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
logger.Error(
"Failed to sign staged envelope",
zap.Error(err),
)
return false
}
originatorBytes, err := proto.Marshal(originatorEnv)
if err != nil {
logger.Error("Failed to marshal originator envelope", zap.Error(err))
return false
}

q := queries.New(p.store)
richardhuaaa marked this conversation as resolved.
Show resolved Hide resolved

// On unique constraint conflicts, no error is thrown, but numRows is 0
inserted, err := q.InsertGatewayEnvelope(
p.ctx,
queries.InsertGatewayEnvelopeParams{
OriginatorID: int32(p.registrant.NodeID()),
OriginatorSequenceID: stagedEnv.ID,
Topic: stagedEnv.Topic,
OriginatorEnvelope: originatorBytes,
},
)
if err != nil {
logger.Error("Failed to insert gateway envelope", zap.Error(err))
return false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a risk here that we infinitely retry on this step or one of the steps above in an unanticipated error case. I think that we should be okay here but a second pair of eyes can't hurt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why the proto marshaling would fail, but if it did we would be in an endless loop of retries here.

Would probably have to be a maliciously crafted payload

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my research, it seems like it fails if some expected fields of the proto are not set. In this case, the proto is OriginatorEnvelope with two fields that we are setting here, so I expect that we should be able to prevent this error happening.

I still think we should catch the error just in case though, firstly so that we can detect it, secondly so that we don't screw up payload ordering if there is an unanticipated error type that is super common.

} else if inserted == 0 {
// Envelope was already inserted by another worker
logger.Debug("Envelope already inserted")
}

// Try to delete the row regardless of if the gateway envelope was inserted elsewhere
deleted, err := q.DeleteStagedOriginatorEnvelope(context.Background(), stagedEnv.ID)
if err != nil {
logger.Error("Failed to delete staged envelope", zap.Error(err))
// Envelope is already inserted, so it is safe to continue
return true
} else if deleted == 0 {
// Envelope was already deleted by another worker
logger.Debug("Envelope already deleted")
}

return true
}
84 changes: 70 additions & 14 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,27 @@ type Service struct {
ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
queries *queries.Queries
store *sql.DB
worker *PublishWorker
}

func NewReplicationApiService(
ctx context.Context,
log *zap.Logger,
registrant *registrant.Registrant,
writerDB *sql.DB,
store *sql.DB,
) (*Service, error) {
return &Service{ctx: ctx, log: log, registrant: registrant, queries: queries.New(writerDB)}, nil
worker, err := StartPublishWorker(ctx, log, registrant, store)
if err != nil {
return nil, err
}
return &Service{
ctx: ctx,
log: log,
registrant: registrant,
store: store,
worker: worker,
}, nil
}

func (s *Service) Close() {
Expand All @@ -54,27 +65,32 @@ func (s *Service) PublishEnvelope(
ctx context.Context,
req *message_api.PublishEnvelopeRequest,
) (*message_api.PublishEnvelopeResponse, error) {
payerEnv := req.GetPayerEnvelope()
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelope())
if err != nil {
return nil, err
}
// TODO(rich): Verify payer signature
// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

topic, err := s.validateClientInfo(clientEnv)
if err != nil {
return nil, err
}

// TODO(rich): If it is a commit, publish it to blockchain instead

payerBytes, err := proto.Marshal(payerEnv)
payerBytes, err := proto.Marshal(req.GetPayerEnvelope())
if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err)
}

stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes)
stagedEnv, err := queries.New(s.store).
InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{
Topic: topic,
PayerEnvelope: payerBytes,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err)
}
s.worker.NotifyStagedPublish()

originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
Expand All @@ -83,3 +99,43 @@ func (s *Service) PublishEnvelope(

return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil
}

func (s *Service) validatePayerInfo(
payerEnv *message_api.PayerEnvelope,
) (*message_api.ClientEnvelope, error) {
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
}
// TODO(rich): Verify payer signature

clientEnv := &message_api.ClientEnvelope{}
err := proto.Unmarshal(clientBytes, clientEnv)
if err != nil {
return nil, status.Errorf(
codes.InvalidArgument,
"could not unmarshal client envelope: %v",
err,
)
}

return clientEnv, nil
}

func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) {
if clientEnv.GetAad().GetTargetOriginator() != uint32(s.registrant.NodeID()) {
return nil, status.Errorf(codes.InvalidArgument, "invalid target originator")
}

topic := clientEnv.GetAad().GetTargetTopic()
if len(topic) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing target topic")
}

// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

return topic, nil
}
102 changes: 95 additions & 7 deletions pkg/api/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"testing"
"time"

"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -40,17 +41,41 @@ func newTestService(t *testing.T) (*Service, *sql.DB, func()) {
}
}

func createClientEnvelope() *message_api.ClientEnvelope {
return &message_api.ClientEnvelope{
Payload: nil,
Aad: &message_api.AuthenticatedData{
TargetOriginator: 1,
TargetTopic: []byte{0x5},
LastOriginatorSids: []uint64{},
},
}
}

func createPayerEnvelope(
t *testing.T,
clientEnv ...*message_api.ClientEnvelope,
) *message_api.PayerEnvelope {
if len(clientEnv) == 0 {
clientEnv = append(clientEnv, createClientEnvelope())
}
clientEnvBytes, err := proto.Marshal(clientEnv[0])
require.NoError(t, err)

return &message_api.PayerEnvelope{
UnsignedClientEnvelope: clientEnvBytes,
PayerSignature: &associations.RecoverableEcdsaSignature{},
}
}

func TestSimplePublish(t *testing.T) {
svc, _, cleanup := newTestService(t)
svc, db, cleanup := newTestService(t)
defer cleanup()

resp, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: &message_api.PayerEnvelope{
UnsignedClientEnvelope: []byte{0x5},
PayerSignature: &associations.RecoverableEcdsaSignature{},
},
PayerEnvelope: createPayerEnvelope(t),
},
)
require.NoError(t, err)
Expand All @@ -61,7 +86,70 @@ func TestSimplePublish(t *testing.T) {
t,
proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv),
)
require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0])
clientEnv := &message_api.ClientEnvelope{}
require.NoError(
t,
proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv),
)
require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0])

// TODO(rich) Test that the published envelope is retrievable via the query API
// Check that the envelope was published to the database after a delay
require.Eventually(t, func() bool {
envs, err := queries.New(db).
SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{})
require.NoError(t, err)

if len(envs) != 1 {
return false
}

originatorEnv := &message_api.OriginatorEnvelope{}
require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv))
return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope())
}, 500*time.Millisecond, 50*time.Millisecond)
}

func TestUnmarshalError(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

envelope := createPayerEnvelope(t)
envelope.UnsignedClientEnvelope = []byte("invalidbytes")
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: envelope,
},
)
require.ErrorContains(t, err, "unmarshal")
}

func TestMismatchingOriginator(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

clientEnv := createClientEnvelope()
clientEnv.Aad.TargetOriginator = 2
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: createPayerEnvelope(t, clientEnv),
},
)
require.ErrorContains(t, err, "originator")
}

func TestMissingTopic(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

clientEnv := createClientEnvelope()
clientEnv.Aad.TargetTopic = nil
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: createPayerEnvelope(t, clientEnv),
},
)
require.ErrorContains(t, err, "topic")
}
Loading
Loading