diff --git a/relayer/message/message.go b/relayer/message/message.go index 1859c247..361d17bf 100644 --- a/relayer/message/message.go +++ b/relayer/message/message.go @@ -7,14 +7,16 @@ type Message struct { Data interface{} // Data associated with the message ID string // ID is used to track and identify message across networks Type MessageType // Message type + ErrChn chan error // ErrChn is used to share errors that happen on the destination handler } -func NewMessage(source, destination uint8, data interface{}, id string, msgType MessageType) *Message { +func NewMessage(source, destination uint8, data interface{}, id string, msgType MessageType, errChn chan error) *Message { return &Message{ Source: source, Destination: destination, Data: data, Type: msgType, ID: id, + ErrChn: errChn, } } diff --git a/relayer/relayer.go b/relayer/relayer.go index 82583f46..082b81e4 100644 --- a/relayer/relayer.go +++ b/relayer/relayer.go @@ -5,10 +5,12 @@ package relayer import ( "context" + "fmt" "github.com/rs/zerolog/log" "github.com/sygmaprotocol/sygma-core/relayer/message" "github.com/sygmaprotocol/sygma-core/relayer/proposal" + "github.com/sygmaprotocol/sygma-core/utils" ) type RelayedChain interface { @@ -55,9 +57,11 @@ func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message) { // Route function routes the messages to the destination chain. func (r *Relayer) route(msgs []*message.Message) { + errChn := msgs[0].ErrChn destChain, ok := r.relayedChains[msgs[0].Destination] if !ok { log.Error().Uint8("domainID", msgs[0].Destination).Msgf("No chain registered for destination domain") + utils.TrySendError(errChn, fmt.Errorf("no chain registered")) return } @@ -69,6 +73,7 @@ func (r *Relayer) route(msgs []*message.Message) { prop, err := destChain.ReceiveMessage(m) if err != nil { log.Err(err).Msgf("Failed receiving message %+v", m) + utils.TrySendError(errChn, err) continue } @@ -79,13 +84,16 @@ func (r *Relayer) route(msgs []*message.Message) { } } if len(props) == 0 { + utils.TrySendError(errChn, nil) return } log.Debug().Msgf("Writing message") err := destChain.Write(props) if err != nil { + utils.TrySendError(errChn, err) log.Err(err).Msgf("Failed writing message") return } + utils.TrySendError(errChn, nil) } diff --git a/relayer/relayer_test.go b/relayer/relayer_test.go index 9d911a42..64bac857 100644 --- a/relayer/relayer_test.go +++ b/relayer/relayer_test.go @@ -95,9 +95,16 @@ func (s *RouteTestSuite) TestWriteFails() { chains, ) + errChn := make(chan error, 1) relayer.route([]*message.Message{ - {Destination: 1}, + { + Destination: 1, + ErrChn: errChn, + }, }) + + err := <-errChn + s.NotNil(err) } func (s *RouteTestSuite) TestWritesToChain() { @@ -113,8 +120,35 @@ func (s *RouteTestSuite) TestWritesToChain() { chains, ) + errChn := make(chan error, 1) relayer.route([]*message.Message{ - {Destination: 1}, + { + Destination: 1, + ErrChn: errChn, + }, + }) + + err := <-errChn + s.Nil(err) +} + +func (s *RouteTestSuite) TestWritesToChain_BlockingErrChn() { + props := make([]*proposal.Proposal, 1) + prop := &proposal.Proposal{} + props[0] = prop + s.mockRelayedChain.EXPECT().ReceiveMessage(gomock.Any()).Return(prop, nil) + s.mockRelayedChain.EXPECT().Write(props).Return(nil) + s.mockRelayedChain.EXPECT().DomainID().Return(uint8(1)).Times(1) + chains := make(map[uint8]RelayedChain) + chains[1] = s.mockRelayedChain + relayer := NewRelayer( + chains, + ) + + relayer.route([]*message.Message{ + { + Destination: 1, + }, }) } diff --git a/utils/channel_test.go b/utils/channel_test.go new file mode 100644 index 00000000..54faa0ca --- /dev/null +++ b/utils/channel_test.go @@ -0,0 +1,30 @@ +package utils_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/sygmaprotocol/sygma-core/utils" +) + +type ChannelTestSuite struct { + suite.Suite +} + +func TestRunChannelTestSuite(t *testing.T) { + suite.Run(t, new(ChannelTestSuite)) +} + +func (s *ChannelTestSuite) Test_TrySendError_NonBlocking() { + utils.TrySendError(nil, fmt.Errorf("error")) + + errChn := make(chan error) + utils.TrySendError(errChn, fmt.Errorf("error")) + + bufErrChn := make(chan error, 1) + utils.TrySendError(bufErrChn, fmt.Errorf("error")) + + err := <-bufErrChn + s.NotNil(err) +}