From afa68196ac4be089ab6ae32dba6cfa09c270f9f3 Mon Sep 17 00:00:00 2001 From: mj52951 <116341045+mj52951@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:04:50 +0100 Subject: [PATCH] chore: Converting message and proposal fields to generics (#19) Co-authored-by: mj52951 --- chains/evm/chain.go | 8 ++++---- chains/substrate/chain.go | 8 ++++---- mock/message.go | 4 ++-- mock/relayer.go | 6 +++--- relayer/message/handler.go | 4 ++-- relayer/message/handler_test.go | 10 +++++----- relayer/message/message.go | 12 ++++++------ relayer/proposal/proposal.go | 8 ++++---- relayer/relayer.go | 10 +++++----- relayer/relayer_test.go | 20 ++++++++++---------- 10 files changed, 45 insertions(+), 45 deletions(-) diff --git a/chains/evm/chain.go b/chains/evm/chain.go index 3ff9e143..e0405b89 100644 --- a/chains/evm/chain.go +++ b/chains/evm/chain.go @@ -18,11 +18,11 @@ type EventListener interface { } type ProposalExecutor interface { - Execute(props []*proposal.Proposal) error + Execute(props []*proposal.Proposal[any]) error } type MessageHandler interface { - HandleMessage(m *message.Message) (*proposal.Proposal, error) + HandleMessage(m *message.Message[any]) (*proposal.Proposal[any], error) } // EVMChain is struct that aggregates all data required for @@ -55,11 +55,11 @@ func (c *EVMChain) PollEvents(ctx context.Context) { go c.listener.ListenToEvents(ctx, c.startBlock) } -func (c *EVMChain) ReceiveMessage(m *message.Message) (*proposal.Proposal, error) { +func (c *EVMChain) ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error) { return c.messageHandler.HandleMessage(m) } -func (c *EVMChain) Write(props []*proposal.Proposal) error { +func (c *EVMChain) Write(props []*proposal.Proposal[any]) error { err := c.executor.Execute(props) if err != nil { c.logger.Err(err).Msgf("error writing proposals %+v on network %d", props, c.DomainID()) diff --git a/chains/substrate/chain.go b/chains/substrate/chain.go index e4819cc5..6e9cfc4a 100644 --- a/chains/substrate/chain.go +++ b/chains/substrate/chain.go @@ -11,11 +11,11 @@ import ( ) type ProposalExecutor interface { - Execute(props []*proposal.Proposal) error + Execute(props []*proposal.Proposal[any]) error } type MessageHandler interface { - HandleMessage(m *message.Message) (*proposal.Proposal, error) + HandleMessage(m *message.Message[any]) (*proposal.Proposal[any], error) } type EventListener interface { @@ -49,11 +49,11 @@ func (c *SubstrateChain) PollEvents(ctx context.Context) { go c.listener.ListenToEvents(ctx, c.startBlock) } -func (c *SubstrateChain) ReceiveMessage(m *message.Message) (*proposal.Proposal, error) { +func (c *SubstrateChain) ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error) { return c.messageHandler.HandleMessage(m) } -func (c *SubstrateChain) Write(props []*proposal.Proposal) error { +func (c *SubstrateChain) Write(props []*proposal.Proposal[any]) error { err := c.executor.Execute(props) if err != nil { c.logger.Err(err).Msgf("error writing proposals %+v on network %d", props, c.DomainID()) diff --git a/mock/message.go b/mock/message.go index 57ff5684..9e488005 100644 --- a/mock/message.go +++ b/mock/message.go @@ -40,10 +40,10 @@ func (m *MockHandler) EXPECT() *MockHandlerMockRecorder { } // HandleMessage mocks base method. -func (m_2 *MockHandler) HandleMessage(m *message.Message) (*proposal.Proposal, error) { +func (m_2 *MockHandler) HandleMessage(m *message.Message[any]) (*proposal.Proposal[any], error) { m_2.ctrl.T.Helper() ret := m_2.ctrl.Call(m_2, "HandleMessage", m) - ret0, _ := ret[0].(*proposal.Proposal) + ret0, _ := ret[0].(*proposal.Proposal[any]) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/mock/relayer.go b/mock/relayer.go index d222a53a..9624ee18 100644 --- a/mock/relayer.go +++ b/mock/relayer.go @@ -67,10 +67,10 @@ func (mr *MockRelayedChainMockRecorder) PollEvents(ctx any) *gomock.Call { } // ReceiveMessage mocks base method. -func (m_2 *MockRelayedChain) ReceiveMessage(m *message.Message) (*proposal.Proposal, error) { +func (m_2 *MockRelayedChain) ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error) { m_2.ctrl.T.Helper() ret := m_2.ctrl.Call(m_2, "ReceiveMessage", m) - ret0, _ := ret[0].(*proposal.Proposal) + ret0, _ := ret[0].(*proposal.Proposal[any]) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -82,7 +82,7 @@ func (mr *MockRelayedChainMockRecorder) ReceiveMessage(m any) *gomock.Call { } // Write mocks base method. -func (m *MockRelayedChain) Write(proposals []*proposal.Proposal) error { +func (m *MockRelayedChain) Write(proposals []*proposal.Proposal[any]) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", proposals) ret0, _ := ret[0].(error) diff --git a/relayer/message/handler.go b/relayer/message/handler.go index 5e8aa41c..bb829d5d 100644 --- a/relayer/message/handler.go +++ b/relayer/message/handler.go @@ -7,7 +7,7 @@ import ( ) type Handler interface { - HandleMessage(m *Message) (*proposal.Proposal, error) + HandleMessage(m *Message[any]) (*proposal.Proposal[any], error) } type MessageHandler struct { @@ -21,7 +21,7 @@ func NewMessageHandler() *MessageHandler { } // HandlerMessage calls associated handler for that message type and returns a proposal to be submitted on-chain -func (h *MessageHandler) HandleMessage(m *Message) (*proposal.Proposal, error) { +func (h *MessageHandler) HandleMessage(m *Message[any]) (*proposal.Proposal[any], error) { mh, ok := h.handlers[m.Type] if !ok { return nil, fmt.Errorf("no handler found for type %s", m.Type) diff --git a/relayer/message/handler_test.go b/relayer/message/handler_test.go index 16f59724..0e6cfe63 100644 --- a/relayer/message/handler_test.go +++ b/relayer/message/handler_test.go @@ -29,7 +29,7 @@ func (s *MessageHandlerTestSuite) SetupTest() { func (s *MessageHandlerTestSuite) TestHandleMessageWithoutRegisteredHandler() { mh := message.NewMessageHandler() - _, err := mh.HandleMessage(&message.Message{Type: "invalid"}) + _, err := mh.HandleMessage(&message.Message[any]{Type: "invalid"}) s.NotNil(err) } @@ -38,7 +38,7 @@ func (s *MessageHandlerTestSuite) TestHandleMessageWithInvalidType() { mh := message.NewMessageHandler() mh.RegisterMessageHandler("invalid", s.mockHandler) - _, err := mh.HandleMessage(&message.Message{Type: "valid"}) + _, err := mh.HandleMessage(&message.Message[any]{Type: "valid"}) s.NotNil(err) } @@ -49,13 +49,13 @@ func (s *MessageHandlerTestSuite) TestHandleMessageHandlerReturnsError() { mh := message.NewMessageHandler() mh.RegisterMessageHandler("valid", s.mockHandler) - _, err := mh.HandleMessage(&message.Message{Type: "valid"}) + _, err := mh.HandleMessage(&message.Message[any]{Type: "valid"}) s.NotNil(err) } func (s *MessageHandlerTestSuite) TestHandleMessageWithValidType() { - expectedProp := &proposal.Proposal{ + expectedProp := &proposal.Proposal[any]{ Type: "prop", } s.mockHandler.EXPECT().HandleMessage(gomock.Any()).Return(expectedProp, nil) @@ -63,7 +63,7 @@ func (s *MessageHandlerTestSuite) TestHandleMessageWithValidType() { mh := message.NewMessageHandler() mh.RegisterMessageHandler("valid", s.mockHandler) - msg := message.NewMessage(1, 2, nil, "valid") + msg := message.NewMessage[any](1, 2, nil, "valid") prop, err := mh.HandleMessage(msg) s.Nil(err) diff --git a/relayer/message/message.go b/relayer/message/message.go index 7be00fc0..bf3963b7 100644 --- a/relayer/message/message.go +++ b/relayer/message/message.go @@ -1,20 +1,20 @@ package message type MessageType string -type Message struct { +type Message[T any] struct { Source uint8 // Source where message was initiated Destination uint8 // Destination chain of message - Data interface{} // Data associated with the message + Data T // Data associated with the message Type MessageType // Message type } -func NewMessage( +func NewMessage[T any]( source uint8, destination uint8, - data interface{}, + data T, msgType MessageType, -) *Message { - return &Message{ +) *Message[T] { + return &Message[T]{ source, destination, data, diff --git a/relayer/proposal/proposal.go b/relayer/proposal/proposal.go index f20cefd7..9d0813f9 100644 --- a/relayer/proposal/proposal.go +++ b/relayer/proposal/proposal.go @@ -1,15 +1,15 @@ package proposal type ProposalType string -type Proposal struct { +type Proposal[T any] struct { Source uint8 Destination uint8 - Data interface{} + Data T Type ProposalType } -func NewProposal(source, destination uint8, data []byte, propType ProposalType) *Proposal { - return &Proposal{ +func NewProposal[T any](source, destination uint8, data T, propType ProposalType) *Proposal[T] { + return &Proposal[T]{ Source: source, Destination: destination, Data: data, diff --git a/relayer/relayer.go b/relayer/relayer.go index 6f6a10f6..e6cfbc6a 100644 --- a/relayer/relayer.go +++ b/relayer/relayer.go @@ -16,10 +16,10 @@ type RelayedChain interface { PollEvents(ctx context.Context) // ReceiveMessage accepts the message from the source chain and converts it into // a Proposal to be submitted on-chain - ReceiveMessage(m *message.Message) (*proposal.Proposal, error) + ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error) // Write submits proposals on-chain. // If multiple proposals submitted they are expected to be able to be batched. - Write(proposals []*proposal.Proposal) error + Write(proposals []*proposal.Proposal[any]) error DomainID() uint8 } @@ -34,7 +34,7 @@ type Relayer struct { // Start function starts polling events for each chain and listens to cross-chain messages. // If an array of messages is sent to the channel they are expected to be to the same destination and // able to be handled in batches. -func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message) { +func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message[any]) { log.Info().Msgf("Starting relayer") for _, c := range r.relayedChains { @@ -54,14 +54,14 @@ func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message) { } // Route function routes the messages to the destination chain. -func (r *Relayer) route(msgs []*message.Message) { +func (r *Relayer) route(msgs []*message.Message[any]) { destChain, ok := r.relayedChains[msgs[0].Destination] if !ok { log.Error().Uint8("domainID", destChain.DomainID()).Msgf("No chain registered for destination domain") return } - props := make([]*proposal.Proposal, 0) + props := make([]*proposal.Proposal[any], 0) for _, m := range msgs { prop, err := destChain.ReceiveMessage(m) if err != nil { diff --git a/relayer/relayer_test.go b/relayer/relayer_test.go index ec8ab3ea..4d2f4392 100644 --- a/relayer/relayer_test.go +++ b/relayer/relayer_test.go @@ -46,8 +46,8 @@ func (s *RouteTestSuite) TestStartListensOnChannel() { chains, ) - msgChan := make(chan []*message.Message, 1) - msgChan <- []*message.Message{ + msgChan := make(chan []*message.Message[any], 1) + msgChan <- []*message.Message[any]{ {Destination: 1}, } relayer.Start(ctx, msgChan) @@ -63,7 +63,7 @@ func (s *RouteTestSuite) TestReceiveMessageFails() { chains, ) - relayer.route([]*message.Message{ + relayer.route([]*message.Message[any]{ {Destination: 1}, }) } @@ -76,14 +76,14 @@ func (s *RouteTestSuite) TestAvoidWriteWithoutProposals() { chains, ) - relayer.route([]*message.Message{ + relayer.route([]*message.Message[any]{ {Destination: 1}, }) } func (s *RouteTestSuite) TestWriteFails() { - props := make([]*proposal.Proposal, 1) - prop := &proposal.Proposal{} + props := make([]*proposal.Proposal[any], 1) + prop := &proposal.Proposal[any]{} props[0] = prop s.mockRelayedChain.EXPECT().ReceiveMessage(gomock.Any()).Return(prop, nil) s.mockRelayedChain.EXPECT().Write(props).Return(fmt.Errorf("error")) @@ -94,14 +94,14 @@ func (s *RouteTestSuite) TestWriteFails() { chains, ) - relayer.route([]*message.Message{ + relayer.route([]*message.Message[any]{ {Destination: 1}, }) } func (s *RouteTestSuite) TestWritesToChain() { - props := make([]*proposal.Proposal, 1) - prop := &proposal.Proposal{} + props := make([]*proposal.Proposal[any], 1) + prop := &proposal.Proposal[any]{} props[0] = prop s.mockRelayedChain.EXPECT().ReceiveMessage(gomock.Any()).Return(prop, nil) s.mockRelayedChain.EXPECT().Write(props).Return(nil) @@ -111,7 +111,7 @@ func (s *RouteTestSuite) TestWritesToChain() { chains, ) - relayer.route([]*message.Message{ + relayer.route([]*message.Message[any]{ {Destination: 1}, }) }