Skip to content

Commit

Permalink
chore: Converting message and proposal fields to generics (#19)
Browse files Browse the repository at this point in the history
Co-authored-by: mj52951 <[email protected]>
  • Loading branch information
mj52951 and mj52951 authored Jan 4, 2024
1 parent 62219e0 commit afa6819
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 45 deletions.
8 changes: 4 additions & 4 deletions chains/evm/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ type EventListener interface {
}

type ProposalExecutor interface {
Execute(props []*proposal.Proposal) error
Execute(props []*proposal.Proposal[any]) error
}

type MessageHandler interface {
HandleMessage(m *message.Message) (*proposal.Proposal, error)
HandleMessage(m *message.Message[any]) (*proposal.Proposal[any], error)
}

// EVMChain is struct that aggregates all data required for
Expand Down Expand Up @@ -55,11 +55,11 @@ func (c *EVMChain) PollEvents(ctx context.Context) {
go c.listener.ListenToEvents(ctx, c.startBlock)
}

func (c *EVMChain) ReceiveMessage(m *message.Message) (*proposal.Proposal, error) {
func (c *EVMChain) ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error) {
return c.messageHandler.HandleMessage(m)
}

func (c *EVMChain) Write(props []*proposal.Proposal) error {
func (c *EVMChain) Write(props []*proposal.Proposal[any]) error {
err := c.executor.Execute(props)
if err != nil {
c.logger.Err(err).Msgf("error writing proposals %+v on network %d", props, c.DomainID())
Expand Down
8 changes: 4 additions & 4 deletions chains/substrate/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
)

type ProposalExecutor interface {
Execute(props []*proposal.Proposal) error
Execute(props []*proposal.Proposal[any]) error
}

type MessageHandler interface {
HandleMessage(m *message.Message) (*proposal.Proposal, error)
HandleMessage(m *message.Message[any]) (*proposal.Proposal[any], error)
}

type EventListener interface {
Expand Down Expand Up @@ -49,11 +49,11 @@ func (c *SubstrateChain) PollEvents(ctx context.Context) {
go c.listener.ListenToEvents(ctx, c.startBlock)
}

func (c *SubstrateChain) ReceiveMessage(m *message.Message) (*proposal.Proposal, error) {
func (c *SubstrateChain) ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error) {
return c.messageHandler.HandleMessage(m)
}

func (c *SubstrateChain) Write(props []*proposal.Proposal) error {
func (c *SubstrateChain) Write(props []*proposal.Proposal[any]) error {
err := c.executor.Execute(props)
if err != nil {
c.logger.Err(err).Msgf("error writing proposals %+v on network %d", props, c.DomainID())
Expand Down
4 changes: 2 additions & 2 deletions mock/message.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions mock/relayer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions relayer/message/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

type Handler interface {
HandleMessage(m *Message) (*proposal.Proposal, error)
HandleMessage(m *Message[any]) (*proposal.Proposal[any], error)
}

type MessageHandler struct {
Expand All @@ -21,7 +21,7 @@ func NewMessageHandler() *MessageHandler {
}

// HandlerMessage calls associated handler for that message type and returns a proposal to be submitted on-chain
func (h *MessageHandler) HandleMessage(m *Message) (*proposal.Proposal, error) {
func (h *MessageHandler) HandleMessage(m *Message[any]) (*proposal.Proposal[any], error) {
mh, ok := h.handlers[m.Type]
if !ok {
return nil, fmt.Errorf("no handler found for type %s", m.Type)
Expand Down
10 changes: 5 additions & 5 deletions relayer/message/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (s *MessageHandlerTestSuite) SetupTest() {
func (s *MessageHandlerTestSuite) TestHandleMessageWithoutRegisteredHandler() {
mh := message.NewMessageHandler()

_, err := mh.HandleMessage(&message.Message{Type: "invalid"})
_, err := mh.HandleMessage(&message.Message[any]{Type: "invalid"})

s.NotNil(err)
}
Expand All @@ -38,7 +38,7 @@ func (s *MessageHandlerTestSuite) TestHandleMessageWithInvalidType() {
mh := message.NewMessageHandler()
mh.RegisterMessageHandler("invalid", s.mockHandler)

_, err := mh.HandleMessage(&message.Message{Type: "valid"})
_, err := mh.HandleMessage(&message.Message[any]{Type: "valid"})

s.NotNil(err)
}
Expand All @@ -49,21 +49,21 @@ func (s *MessageHandlerTestSuite) TestHandleMessageHandlerReturnsError() {
mh := message.NewMessageHandler()
mh.RegisterMessageHandler("valid", s.mockHandler)

_, err := mh.HandleMessage(&message.Message{Type: "valid"})
_, err := mh.HandleMessage(&message.Message[any]{Type: "valid"})

s.NotNil(err)
}

func (s *MessageHandlerTestSuite) TestHandleMessageWithValidType() {
expectedProp := &proposal.Proposal{
expectedProp := &proposal.Proposal[any]{
Type: "prop",
}
s.mockHandler.EXPECT().HandleMessage(gomock.Any()).Return(expectedProp, nil)

mh := message.NewMessageHandler()
mh.RegisterMessageHandler("valid", s.mockHandler)

msg := message.NewMessage(1, 2, nil, "valid")
msg := message.NewMessage[any](1, 2, nil, "valid")
prop, err := mh.HandleMessage(msg)

s.Nil(err)
Expand Down
12 changes: 6 additions & 6 deletions relayer/message/message.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
package message

type MessageType string
type Message struct {
type Message[T any] struct {
Source uint8 // Source where message was initiated
Destination uint8 // Destination chain of message
Data interface{} // Data associated with the message
Data T // Data associated with the message
Type MessageType // Message type
}

func NewMessage(
func NewMessage[T any](
source uint8,
destination uint8,
data interface{},
data T,
msgType MessageType,
) *Message {
return &Message{
) *Message[T] {
return &Message[T]{
source,
destination,
data,
Expand Down
8 changes: 4 additions & 4 deletions relayer/proposal/proposal.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package proposal

type ProposalType string
type Proposal struct {
type Proposal[T any] struct {
Source uint8
Destination uint8
Data interface{}
Data T
Type ProposalType
}

func NewProposal(source, destination uint8, data []byte, propType ProposalType) *Proposal {
return &Proposal{
func NewProposal[T any](source, destination uint8, data T, propType ProposalType) *Proposal[T] {
return &Proposal[T]{
Source: source,
Destination: destination,
Data: data,
Expand Down
10 changes: 5 additions & 5 deletions relayer/relayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type RelayedChain interface {
PollEvents(ctx context.Context)
// ReceiveMessage accepts the message from the source chain and converts it into
// a Proposal to be submitted on-chain
ReceiveMessage(m *message.Message) (*proposal.Proposal, error)
ReceiveMessage(m *message.Message[any]) (*proposal.Proposal[any], error)
// Write submits proposals on-chain.
// If multiple proposals submitted they are expected to be able to be batched.
Write(proposals []*proposal.Proposal) error
Write(proposals []*proposal.Proposal[any]) error
DomainID() uint8
}

Expand All @@ -34,7 +34,7 @@ type Relayer struct {
// Start function starts polling events for each chain and listens to cross-chain messages.
// If an array of messages is sent to the channel they are expected to be to the same destination and
// able to be handled in batches.
func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message) {
func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message[any]) {
log.Info().Msgf("Starting relayer")

for _, c := range r.relayedChains {
Expand All @@ -54,14 +54,14 @@ func (r *Relayer) Start(ctx context.Context, msgChan chan []*message.Message) {
}

// Route function routes the messages to the destination chain.
func (r *Relayer) route(msgs []*message.Message) {
func (r *Relayer) route(msgs []*message.Message[any]) {
destChain, ok := r.relayedChains[msgs[0].Destination]
if !ok {
log.Error().Uint8("domainID", destChain.DomainID()).Msgf("No chain registered for destination domain")
return
}

props := make([]*proposal.Proposal, 0)
props := make([]*proposal.Proposal[any], 0)
for _, m := range msgs {
prop, err := destChain.ReceiveMessage(m)
if err != nil {
Expand Down
20 changes: 10 additions & 10 deletions relayer/relayer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func (s *RouteTestSuite) TestStartListensOnChannel() {
chains,
)

msgChan := make(chan []*message.Message, 1)
msgChan <- []*message.Message{
msgChan := make(chan []*message.Message[any], 1)
msgChan <- []*message.Message[any]{
{Destination: 1},
}
relayer.Start(ctx, msgChan)
Expand All @@ -63,7 +63,7 @@ func (s *RouteTestSuite) TestReceiveMessageFails() {
chains,
)

relayer.route([]*message.Message{
relayer.route([]*message.Message[any]{
{Destination: 1},
})
}
Expand All @@ -76,14 +76,14 @@ func (s *RouteTestSuite) TestAvoidWriteWithoutProposals() {
chains,
)

relayer.route([]*message.Message{
relayer.route([]*message.Message[any]{
{Destination: 1},
})
}

func (s *RouteTestSuite) TestWriteFails() {
props := make([]*proposal.Proposal, 1)
prop := &proposal.Proposal{}
props := make([]*proposal.Proposal[any], 1)
prop := &proposal.Proposal[any]{}
props[0] = prop
s.mockRelayedChain.EXPECT().ReceiveMessage(gomock.Any()).Return(prop, nil)
s.mockRelayedChain.EXPECT().Write(props).Return(fmt.Errorf("error"))
Expand All @@ -94,14 +94,14 @@ func (s *RouteTestSuite) TestWriteFails() {
chains,
)

relayer.route([]*message.Message{
relayer.route([]*message.Message[any]{
{Destination: 1},
})
}

func (s *RouteTestSuite) TestWritesToChain() {
props := make([]*proposal.Proposal, 1)
prop := &proposal.Proposal{}
props := make([]*proposal.Proposal[any], 1)
prop := &proposal.Proposal[any]{}
props[0] = prop
s.mockRelayedChain.EXPECT().ReceiveMessage(gomock.Any()).Return(prop, nil)
s.mockRelayedChain.EXPECT().Write(props).Return(nil)
Expand All @@ -111,7 +111,7 @@ func (s *RouteTestSuite) TestWritesToChain() {
chains,
)

relayer.route([]*message.Message{
relayer.route([]*message.Message[any]{
{Destination: 1},
})
}

0 comments on commit afa6819

Please sign in to comment.