diff --git a/app/encoding.go b/app/encoding.go index ebcbb753..f35ce7fb 100644 --- a/app/encoding.go +++ b/app/encoding.go @@ -8,12 +8,6 @@ import ( "github.com/milkyway-labs/milkyway/app/params" ) -var encodingConfig params.EncodingConfig = MakeEncodingConfig() - -func GetEncodingConfig() params.EncodingConfig { - return encodingConfig -} - // MakeEncodingConfig creates an EncodingConfig. func MakeEncodingConfig() params.EncodingConfig { encodingConfig := params.MakeEncodingConfig() diff --git a/testutils/codec.go b/testutils/codec.go new file mode 100644 index 00000000..8283d733 --- /dev/null +++ b/testutils/codec.go @@ -0,0 +1,12 @@ +package testutils + +import ( + "github.com/cosmos/cosmos-sdk/codec" + codectestutil "github.com/cosmos/cosmos-sdk/codec/testutil" +) + +// MakeCodecs constructs the *codec.Codec and *codec.LegacyAmino instances that can be used inside tests +func MakeCodecs() (codec.Codec, *codec.LegacyAmino) { + interfaceRegistry := codectestutil.CodecOptions{AccAddressPrefix: "cosmos", ValAddressPrefix: "cosmosvaloper"}.NewInterfaceRegistry() + return codec.NewProtoCodec(interfaceRegistry), codec.NewLegacyAmino() +} diff --git a/x/rewards/client/cli/tx.go b/x/rewards/client/cli/tx.go index a95cfa1c..85b0460e 100644 --- a/x/rewards/client/cli/tx.go +++ b/x/rewards/client/cli/tx.go @@ -86,6 +86,7 @@ Where rewards_plan.json contains: if err != nil { return fmt.Errorf("parsing rewards plan json: %w", err) } + err = rewardsPlan.Validate(clientCtx.Codec) if err != nil { return fmt.Errorf("invalid rewards plan json: %w", err) @@ -104,6 +105,12 @@ Where rewards_plan.json contains: creator, ) + // Validate the message + err = msg.ValidateBasic() + if err != nil { + return fmt.Errorf("invalid message: %w", err) + } + return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg) }, } @@ -197,6 +204,12 @@ Where rewards_plan.json contains: sender, ) + // Validate the message + err = msg.ValidateBasic() + if err != nil { + return err + } + return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg) }, } diff --git a/x/rewards/client/cli/utils_test.go b/x/rewards/client/cli/utils_test.go index 1bc1e26d..cb6fc29b 100644 --- a/x/rewards/client/cli/utils_test.go +++ b/x/rewards/client/cli/utils_test.go @@ -10,14 +10,13 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" - milkyway "github.com/milkyway-labs/milkyway/app" + milkywayapp "github.com/milkyway-labs/milkyway/app" "github.com/milkyway-labs/milkyway/x/rewards/client/cli" "github.com/milkyway-labs/milkyway/x/rewards/types" ) -func TestCliUtils_parseRewardsPlan(t *testing.T) { - encodingConfig := milkyway.MakeEncodingConfig() - codec := encodingConfig.Marshaler +func TestCLIUtils_parseRewardsPlan(t *testing.T) { + cdc, _ := milkywayapp.MakeCodecs() testCases := []struct { name string @@ -191,7 +190,7 @@ func TestCliUtils_parseRewardsPlan(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { require.NotNil(t, tc.jsonFile) - plan, err := cli.ParseRewardsPlan(codec, tc.jsonFile.Name()) + plan, err := cli.ParseRewardsPlan(cdc, tc.jsonFile.Name()) if tc.shouldErr { require.Error(t, err) } else { diff --git a/x/rewards/types/codec.go b/x/rewards/types/codec.go index 506dd71c..d219c1bf 100644 --- a/x/rewards/types/codec.go +++ b/x/rewards/types/codec.go @@ -4,6 +4,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec/legacy" "github.com/cosmos/cosmos-sdk/codec/types" + cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/msgservice" ) @@ -15,6 +16,11 @@ func RegisterLegacyAminoCodec(cdc *codec.LegacyAmino) { legacy.RegisterAminoMsg(cdc, &MsgWithdrawDelegatorReward{}, "milkyway/MsgWithdrawDelegatorReward") legacy.RegisterAminoMsg(cdc, &MsgWithdrawOperatorCommission{}, "milkyway/MsgWithdrawOperatorCommission") legacy.RegisterAminoMsg(cdc, &MsgUpdateParams{}, "milkyway/rewards/MsgUpdateParams") + + cdc.RegisterInterface((*DistributionType)(nil), nil) + cdc.RegisterConcrete(&DistributionTypeBasic{}, "milkyway/DistributionTypeBasic", nil) + cdc.RegisterConcrete(&DistributionTypeWeighted{}, "milkyway/DistributionTypeWeighted", nil) + cdc.RegisterConcrete(&DistributionTypeEgalitarian{}, "milkyway/DistributionTypeEgalitarian", nil) } func RegisterInterfaces(registry types.InterfaceRegistry) { @@ -41,3 +47,17 @@ func RegisterInterfaces(registry types.InterfaceRegistry) { ) msgservice.RegisterMsgServiceDesc(registry, &_Msg_serviceDesc) } + +// AminoCdc references the global x/rewards module codec. Note, the codec should +// ONLY be used in certain instances of tests and for JSON encoding as Amino is +// still used for that purpose. +// +// The actual codec used for serialization should be provided to x/rewards and +// defined at the application level. +var AminoCdc = codec.NewLegacyAmino() + +func init() { + RegisterLegacyAminoCodec(AminoCdc) + cryptocodec.RegisterCrypto(AminoCdc) + sdk.RegisterLegacyAminoCodec(AminoCdc) +} diff --git a/x/rewards/types/genesis_test.go b/x/rewards/types/genesis_test.go index f67f283a..74c0e4a0 100644 --- a/x/rewards/types/genesis_test.go +++ b/x/rewards/types/genesis_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" - milkyway "github.com/milkyway-labs/milkyway/app" + "github.com/milkyway-labs/milkyway/testutils" "github.com/milkyway-labs/milkyway/utils" "github.com/milkyway-labs/milkyway/x/rewards/types" ) @@ -69,7 +69,8 @@ func TestGenesisState_Validate(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - cdc, _ := milkyway.MakeCodecs() + cdc, _ := testutils.MakeCodecs() + err := tc.genesis.Validate(cdc) if tc.shouldErr { require.Error(t, err) diff --git a/x/rewards/types/messages.go b/x/rewards/types/messages.go index 75ad3296..68b63bdb 100644 --- a/x/rewards/types/messages.go +++ b/x/rewards/types/messages.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" @@ -34,36 +35,81 @@ func NewMsgCreateRewardsPlan( } } -// NewMsgSetWithdrawAddress creates a new NewMsgSetWithdrawAddress instance -func NewMsgSetWithdrawAddress(withdrawAddress string, userAddress string) *MsgSetWithdrawAddress { - return &MsgSetWithdrawAddress{ - Sender: userAddress, - WithdrawAddress: withdrawAddress, +// ValidateBasic implements sdk.Msg +func (m *MsgCreateRewardsPlan) ValidateBasic() error { + if len(m.Description) > MaxRewardsPlanDescriptionLength { + return fmt.Errorf("too long description") + } + + if m.ServiceID == 0 { + return fmt.Errorf("invalid service ID: %d", m.ServiceID) + } + + err := m.Amount.Validate() + if err != nil { + return fmt.Errorf("invalid amount per day: %w", err) + } + + if !m.EndTime.After(m.StartTime) { + return fmt.Errorf( + "end time must be after start time: %s <= %s", + m.EndTime.Format(time.RFC3339), + m.StartTime.Format(time.RFC3339), + ) + } + + if m.PoolsDistribution.DelegationType != restakingtypes.DELEGATION_TYPE_POOL { + return fmt.Errorf("pools distribution has invalid delegation type: %v", m.PoolsDistribution.DelegationType) + } + + if m.OperatorsDistribution.DelegationType != restakingtypes.DELEGATION_TYPE_OPERATOR { + return fmt.Errorf("operators distribution has invalid delegation type: %v", m.OperatorsDistribution.DelegationType) } + + _, err = sdk.AccAddressFromBech32(m.Sender) + if err != nil { + return fmt.Errorf("invalid sender address: %s", m.Sender) + } + + return nil } -// NewMsgWithdrawDelegatorReward creates a new MsgWithdrawDelegatorReward instance -func NewMsgWithdrawDelegatorReward( - delegationType restakingtypes.DelegationType, - targetID uint32, - delegatorAddress string, -) *MsgWithdrawDelegatorReward { - return &MsgWithdrawDelegatorReward{ - DelegatorAddress: delegatorAddress, - DelegationType: delegationType, - DelegationTargetID: targetID, +// GetSignBytes implements sdk.Msg +func (m *MsgCreateRewardsPlan) GetSignBytes() []byte { + return sdk.MustSortJSON(AminoCdc.MustMarshalJSON(m)) +} + +// GetSigners implements sdk.Msg +func (m *MsgCreateRewardsPlan) GetSigners() []sdk.AccAddress { + addr, err := sdk.AccAddressFromBech32(m.Sender) + if err != nil { + panic(err) } + + return []sdk.AccAddress{addr} } -// NewMsgWithdrawOperatorCommission creates a new MsgWithdrawOperatorCommission instance -func NewMsgWithdrawOperatorCommission(operatorID uint32, senderAddress string) *MsgWithdrawOperatorCommission { - return &MsgWithdrawOperatorCommission{ - Sender: senderAddress, - OperatorID: operatorID, +// UnpackInterfaces implements codectypes.UnpackInterfacesMessage +func (m *MsgCreateRewardsPlan) UnpackInterfaces(unpacker codectypes.AnyUnpacker) error { + err := m.PoolsDistribution.UnpackInterfaces(unpacker) + if err != nil { + return nil } + + err = m.OperatorsDistribution.UnpackInterfaces(unpacker) + if err != nil { + return nil + } + + err = m.UsersDistribution.UnpackInterfaces(unpacker) + if err != nil { + return nil + } + + return nil } -// ------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- // NewMsgEditRewardsPlan creates a new MsgEditRewardsPlan instance. func NewMsgEditRewardsPlan( @@ -96,7 +142,28 @@ func (m *MsgEditRewardsPlan) ValidateBasic() error { return fmt.Errorf("invalid ID: %d", m.ID) } - _, err := sdk.AccAddressFromBech32(m.Sender) + err := m.Amount.Validate() + if err != nil { + return fmt.Errorf("invalid amount: %w", err) + } + + if !m.EndTime.After(m.StartTime) { + return fmt.Errorf( + "end time must be after start time: %s <= %s", + m.EndTime.Format(time.RFC3339), + m.StartTime.Format(time.RFC3339), + ) + } + + if m.PoolsDistribution.DelegationType != restakingtypes.DELEGATION_TYPE_POOL { + return fmt.Errorf("pools distribution has invalid delegation type: %v", m.PoolsDistribution.DelegationType) + } + + if m.OperatorsDistribution.DelegationType != restakingtypes.DELEGATION_TYPE_OPERATOR { + return fmt.Errorf("operators distribution has invalid delegation type: %v", m.OperatorsDistribution.DelegationType) + } + + _, err = sdk.AccAddressFromBech32(m.Sender) if err != nil { return fmt.Errorf("invalid sender address: %s, %w", m.Sender, err) } @@ -106,3 +173,165 @@ func (m *MsgEditRewardsPlan) ValidateBasic() error { return nil } + +// GetSignBytes implements sdk.Msg +func (m *MsgEditRewardsPlan) GetSignBytes() []byte { + return sdk.MustSortJSON(AminoCdc.MustMarshalJSON(m)) +} + +// GetSigners implements sdk.Msg +func (m *MsgEditRewardsPlan) GetSigners() []sdk.AccAddress { + addr, err := sdk.AccAddressFromBech32(m.Sender) + if err != nil { + panic(err) + } + + return []sdk.AccAddress{addr} +} + +// UnpackInterfaces implements codectypes.UnpackInterfacesMessage +func (m *MsgEditRewardsPlan) UnpackInterfaces(unpacker codectypes.AnyUnpacker) error { + err := m.PoolsDistribution.UnpackInterfaces(unpacker) + if err != nil { + return nil + } + + err = m.OperatorsDistribution.UnpackInterfaces(unpacker) + if err != nil { + return nil + } + + err = m.UsersDistribution.UnpackInterfaces(unpacker) + if err != nil { + return nil + } + + return nil +} + +// -------------------------------------------------------------------------------------------------------------------- + +// NewMsgSetWithdrawAddress creates a new NewMsgSetWithdrawAddress instance +func NewMsgSetWithdrawAddress(withdrawAddress string, sender string) *MsgSetWithdrawAddress { + return &MsgSetWithdrawAddress{ + Sender: sender, + WithdrawAddress: withdrawAddress, + } +} + +// ValidateBasic implements sdk.Msg +func (m *MsgSetWithdrawAddress) ValidateBasic() error { + _, err := sdk.AccAddressFromBech32(m.Sender) + if err != nil { + return fmt.Errorf("invalid sender address: %s", m.Sender) + } + + _, err = sdk.AccAddressFromBech32(m.WithdrawAddress) + if err != nil { + return fmt.Errorf("invalid withdraw address: %s", m.WithdrawAddress) + } + + return nil +} + +// GetSignBytes implements sdk.Msg +func (m *MsgSetWithdrawAddress) GetSignBytes() []byte { + return sdk.MustSortJSON(AminoCdc.MustMarshalJSON(m)) +} + +// GetSigners implements sdk.Msg +func (m *MsgSetWithdrawAddress) GetSigners() []sdk.AccAddress { + addr, err := sdk.AccAddressFromBech32(m.Sender) + if err != nil { + panic(err) + } + + return []sdk.AccAddress{addr} +} + +// -------------------------------------------------------------------------------------------------------------------- + +// NewMsgWithdrawDelegatorReward creates a new MsgWithdrawDelegatorReward instance +func NewMsgWithdrawDelegatorReward( + delegationType restakingtypes.DelegationType, + targetID uint32, + delegatorAddress string, +) *MsgWithdrawDelegatorReward { + return &MsgWithdrawDelegatorReward{ + DelegatorAddress: delegatorAddress, + DelegationType: delegationType, + DelegationTargetID: targetID, + } +} + +// ValidateBasic implements sdk.Msg +func (m *MsgWithdrawDelegatorReward) ValidateBasic() error { + if m.DelegationType == restakingtypes.DELEGATION_TYPE_UNSPECIFIED { + return fmt.Errorf("invalid delegation type: %v", m.DelegationType) + } + + if m.DelegationTargetID == 0 { + return fmt.Errorf("invalid delegation target ID: %d", m.DelegationTargetID) + } + + _, err := sdk.AccAddressFromBech32(m.DelegatorAddress) + if err != nil { + return fmt.Errorf("invalid delegator address: %s", m.DelegatorAddress) + } + + return nil +} + +// GetSignBytes implements sdk.Msg +func (m *MsgWithdrawDelegatorReward) GetSignBytes() []byte { + return sdk.MustSortJSON(AminoCdc.MustMarshalJSON(m)) +} + +// GetSigners implements sdk.Msg +func (m *MsgWithdrawDelegatorReward) GetSigners() []sdk.AccAddress { + addr, err := sdk.AccAddressFromBech32(m.DelegatorAddress) + if err != nil { + panic(err) + } + + return []sdk.AccAddress{addr} +} + +// -------------------------------------------------------------------------------------------------------------------- + +// NewMsgWithdrawOperatorCommission creates a new MsgWithdrawOperatorCommission instance +func NewMsgWithdrawOperatorCommission(operatorID uint32, senderAddress string) *MsgWithdrawOperatorCommission { + return &MsgWithdrawOperatorCommission{ + Sender: senderAddress, + OperatorID: operatorID, + } +} + +// ValidateBasic implements sdk.Msg +func (m *MsgWithdrawOperatorCommission) ValidateBasic() error { + if m.OperatorID == 0 { + return fmt.Errorf("invalid operator ID: %d", m.OperatorID) + } + + _, err := sdk.AccAddressFromBech32(m.Sender) + if err != nil { + return fmt.Errorf("invalid sender address: %s", m.Sender) + } + + return nil +} + +// GetSignBytes implements sdk.Msg +func (m *MsgWithdrawOperatorCommission) GetSignBytes() []byte { + return sdk.MustSortJSON(AminoCdc.MustMarshalJSON(m)) +} + +// GetSigners implements sdk.Msg +func (m *MsgWithdrawOperatorCommission) GetSigners() []sdk.AccAddress { + addr, err := sdk.AccAddressFromBech32(m.Sender) + if err != nil { + panic(err) + } + + return []sdk.AccAddress{addr} +} diff --git a/x/rewards/types/messages_test.go b/x/rewards/types/messages_test.go new file mode 100644 index 00000000..18111a62 --- /dev/null +++ b/x/rewards/types/messages_test.go @@ -0,0 +1,420 @@ +package types_test + +import ( + "testing" + "time" + + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" + "github.com/milkyway-labs/milkyway/x/rewards/types" +) + +var msgCreateRewardsPlan = types.NewMsgCreateRewardsPlan( + 1, + "Test rewards plan", + sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1000))), + time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), + types.NewBasicPoolsDistribution(1), + types.NewBasicOperatorsDistribution(1), + types.NewBasicUsersDistribution(1), + "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn", +) + +func TestMsgCreateRewardsPlan_ValidateBasic(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgCreateRewardsPlan + shouldErr bool + }{ + { + name: "invalid service id returns error", + msg: types.NewMsgCreateRewardsPlan( + 0, + msgCreateRewardsPlan.Description, + msgCreateRewardsPlan.Amount, + msgCreateRewardsPlan.StartTime, + msgCreateRewardsPlan.EndTime, + msgCreateRewardsPlan.PoolsDistribution, + msgCreateRewardsPlan.OperatorsDistribution, + msgCreateRewardsPlan.UsersDistribution, + msgCreateRewardsPlan.Sender, + ), + shouldErr: true, + }, + { + name: "invalid amount", + msg: types.NewMsgCreateRewardsPlan( + msgCreateRewardsPlan.ServiceID, + msgCreateRewardsPlan.Description, + sdk.Coins{sdk.Coin{Denom: "invalid", Amount: sdkmath.NewInt(-100)}}, + msgCreateRewardsPlan.StartTime, + msgCreateRewardsPlan.EndTime, + msgCreateRewardsPlan.PoolsDistribution, + msgCreateRewardsPlan.OperatorsDistribution, + msgCreateRewardsPlan.UsersDistribution, + msgCreateRewardsPlan.Sender, + ), + shouldErr: true, + }, + { + name: "invalid end time returns error", + msg: types.NewMsgCreateRewardsPlan( + msgCreateRewardsPlan.ServiceID, + msgCreateRewardsPlan.Description, + msgCreateRewardsPlan.Amount, + time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + msgCreateRewardsPlan.PoolsDistribution, + msgCreateRewardsPlan.OperatorsDistribution, + msgCreateRewardsPlan.UsersDistribution, + msgCreateRewardsPlan.Sender, + ), + shouldErr: true, + }, + { + name: "invalid sender returns error", + msg: types.NewMsgCreateRewardsPlan( + msgCreateRewardsPlan.ServiceID, + msgCreateRewardsPlan.Description, + msgCreateRewardsPlan.Amount, + msgCreateRewardsPlan.StartTime, + msgCreateRewardsPlan.EndTime, + msgCreateRewardsPlan.PoolsDistribution, + msgCreateRewardsPlan.OperatorsDistribution, + msgCreateRewardsPlan.UsersDistribution, + "invalid", + ), + shouldErr: true, + }, + { + name: "valid message returns no error", + msg: msgCreateRewardsPlan, + shouldErr: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMsgCreateRewardsPlan_GetSignBytes(t *testing.T) { + expected := `{"type":"milkyway/MsgCreateRewardsPlan","value":{"amount":[{"amount":"1000","denom":"stake"}],"description":"Test rewards plan","end_time":"2024-12-31T23:59:59Z","operators_distribution":{"delegation_type":2,"type":{"type":"milkyway/DistributionTypeBasic","value":{}},"weight":1},"pools_distribution":{"delegation_type":1,"type":{"type":"milkyway/DistributionTypeBasic","value":{}},"weight":1},"sender":"cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn","service_id":1,"start_time":"2024-01-01T00:00:00Z","users_distribution":{"type":{},"weight":1}}}` + require.Equal(t, expected, string(msgCreateRewardsPlan.GetSignBytes())) +} + +func TestMsgCreateRewardsPlan_GetSigners(t *testing.T) { + addr, _ := sdk.AccAddressFromBech32(msgCreateRewardsPlan.Sender) + require.Equal(t, []sdk.AccAddress{addr}, msgCreateRewardsPlan.GetSigners()) +} + +// -------------------------------------------------------------------------------------------------------------------- + +var msgEditRewardsPlan = types.NewMsgEditRewardsPlan( + 1, + "Test rewards plan", + sdk.NewCoins(sdk.NewCoin("stake", sdkmath.NewInt(1000))), + time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), + types.NewBasicPoolsDistribution(1), + types.NewBasicOperatorsDistribution(1), + types.NewBasicUsersDistribution(1), + "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn", +) + +func TestMsgEditRewardsPlan_ValidateBasic(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgEditRewardsPlan + shouldErr bool + }{ + { + name: "invalid id returns error", + msg: types.NewMsgEditRewardsPlan( + 0, + msgEditRewardsPlan.Description, + msgEditRewardsPlan.Amount, + msgEditRewardsPlan.StartTime, + msgEditRewardsPlan.EndTime, + msgEditRewardsPlan.PoolsDistribution, + msgEditRewardsPlan.OperatorsDistribution, + msgEditRewardsPlan.UsersDistribution, + msgEditRewardsPlan.Sender, + ), + shouldErr: true, + }, + { + name: "invalid amount", + msg: types.NewMsgEditRewardsPlan( + msgEditRewardsPlan.ID, + msgEditRewardsPlan.Description, + sdk.Coins{sdk.Coin{Denom: "invalid", Amount: sdkmath.NewInt(-100)}}, + msgEditRewardsPlan.StartTime, + msgEditRewardsPlan.EndTime, + msgEditRewardsPlan.PoolsDistribution, + msgEditRewardsPlan.OperatorsDistribution, + msgEditRewardsPlan.UsersDistribution, + msgEditRewardsPlan.Sender, + ), + shouldErr: true, + }, + { + name: "invalid end time returns error", + msg: types.NewMsgEditRewardsPlan( + msgEditRewardsPlan.ID, + msgEditRewardsPlan.Description, + msgEditRewardsPlan.Amount, + time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + msgEditRewardsPlan.PoolsDistribution, + msgEditRewardsPlan.OperatorsDistribution, + msgEditRewardsPlan.UsersDistribution, + msgEditRewardsPlan.Sender, + ), + shouldErr: true, + }, + { + name: "invalid sender returns error", + msg: types.NewMsgEditRewardsPlan( + msgEditRewardsPlan.ID, + msgEditRewardsPlan.Description, + msgEditRewardsPlan.Amount, + msgEditRewardsPlan.StartTime, + msgEditRewardsPlan.EndTime, + msgEditRewardsPlan.PoolsDistribution, + msgEditRewardsPlan.OperatorsDistribution, + msgEditRewardsPlan.UsersDistribution, + "invalid", + ), + shouldErr: true, + }, + { + name: "valid message returns no error", + msg: msgEditRewardsPlan, + shouldErr: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMsgEditRewardsPlan_GetSignBytes(t *testing.T) { + expected := `{"type":"milkyway/MsgEditRewardsPlan","value":{"amount":[{"amount":"1000","denom":"stake"}],"description":"Test rewards plan","end_time":"2024-12-31T23:59:59Z","id":"1","operators_distribution":{"delegation_type":2,"type":{"type":"milkyway/DistributionTypeBasic","value":{}},"weight":1},"pools_distribution":{"delegation_type":1,"type":{"type":"milkyway/DistributionTypeBasic","value":{}},"weight":1},"sender":"cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn","start_time":"2024-01-01T00:00:00Z","users_distribution":{"type":{},"weight":1}}}` + require.Equal(t, expected, string(msgEditRewardsPlan.GetSignBytes())) +} + +func TestMsgEditRewardsPlan_GetSigners(t *testing.T) { + addr, _ := sdk.AccAddressFromBech32(msgEditRewardsPlan.Sender) + require.Equal(t, []sdk.AccAddress{addr}, msgCreateRewardsPlan.GetSigners()) +} + +// -------------------------------------------------------------------------------------------------------------------- + +var msgSetWithdrawAddress = types.NewMsgSetWithdrawAddress( + "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn", + "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4", +) + +func TestMsgSetWithdrawAddress_ValidateBasic(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgSetWithdrawAddress + shouldErr bool + }{ + { + name: "invalid withdraw address returns error", + msg: types.NewMsgSetWithdrawAddress( + "invalid", + msgSetWithdrawAddress.WithdrawAddress, + ), + shouldErr: true, + }, + { + name: "invalid sender address returns error", + msg: types.NewMsgSetWithdrawAddress( + msgSetWithdrawAddress.Sender, + "invalid", + ), + shouldErr: true, + }, + { + name: "valid message returns no error", + msg: msgSetWithdrawAddress, + shouldErr: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMsgSetWithdrawAddress_GetSignBytes(t *testing.T) { + expected := `{"type":"milkyway/MsgSetWithdrawAddress","value":{"sender":"cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4","withdraw_address":"cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn"}}` + require.Equal(t, expected, string(msgSetWithdrawAddress.GetSignBytes())) +} + +func TestMsgSetWithdrawAddress_GetSigners(t *testing.T) { + addr, _ := sdk.AccAddressFromBech32(msgSetWithdrawAddress.Sender) + require.Equal(t, []sdk.AccAddress{addr}, msgSetWithdrawAddress.GetSigners()) +} + +// -------------------------------------------------------------------------------------------------------------------- + +var msgWithdrawDelegatorReward = types.NewMsgWithdrawDelegatorReward( + restakingtypes.DELEGATION_TYPE_SERVICE, + 1, + "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn", +) + +func TestMsgWithdrawDelegatorReward_ValidateBasic(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgWithdrawDelegatorReward + shouldErr bool + }{ + { + name: "invalid delegation type returns error", + msg: types.NewMsgWithdrawDelegatorReward( + restakingtypes.DELEGATION_TYPE_UNSPECIFIED, + msgWithdrawDelegatorReward.DelegationTargetID, + msgWithdrawDelegatorReward.DelegatorAddress, + ), + shouldErr: true, + }, + { + name: "invalid delegation target ID returns error", + msg: types.NewMsgWithdrawDelegatorReward( + msgWithdrawDelegatorReward.DelegationType, + 0, + msgWithdrawDelegatorReward.DelegatorAddress, + ), + shouldErr: true, + }, + { + name: "invalid delegator address returns error", + msg: types.NewMsgWithdrawDelegatorReward( + msgWithdrawDelegatorReward.DelegationType, + msgWithdrawDelegatorReward.DelegationTargetID, + "invalid", + ), + shouldErr: true, + }, + { + name: "valid message returns no error", + msg: msgWithdrawDelegatorReward, + shouldErr: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMsgWithdrawDelegatorReward_GetSignBytes(t *testing.T) { + expected := `{"type":"milkyway/MsgWithdrawDelegatorReward","value":{"delegation_target_id":1,"delegation_type":3,"delegator_address":"cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn"}}` + require.Equal(t, expected, string(msgWithdrawDelegatorReward.GetSignBytes())) +} + +func TestMsgWithdrawDelegatorReward_GetSigners(t *testing.T) { + addr, _ := sdk.AccAddressFromBech32(msgWithdrawDelegatorReward.DelegatorAddress) + require.Equal(t, []sdk.AccAddress{addr}, msgWithdrawDelegatorReward.GetSigners()) +} + +// -------------------------------------------------------------------------------------------------------------------- + +var msgWithdrawOperatorCommission = types.NewMsgWithdrawOperatorCommission( + 1, + "cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn", +) + +func TestMsgWithdrawOperatorCommission_ValidateBasic(t *testing.T) { + testCases := []struct { + name string + msg *types.MsgWithdrawOperatorCommission + shouldErr bool + }{ + { + name: "invalid operator ID returns error", + msg: types.NewMsgWithdrawOperatorCommission( + 0, + msgWithdrawOperatorCommission.Sender, + ), + shouldErr: true, + }, + { + name: "invalid sender address returns error", + msg: types.NewMsgWithdrawOperatorCommission( + msgWithdrawOperatorCommission.OperatorID, + "invalid", + ), + shouldErr: true, + }, + { + name: "valid message returns no error", + msg: msgWithdrawOperatorCommission, + shouldErr: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMsgWithdrawOperatorCommission_GetSignBytes(t *testing.T) { + expected := `{"type":"milkyway/MsgWithdrawOperatorCommission","value":{"operator_id":1,"sender":"cosmos10d07y265gmmuvt4z0w9aw880jnsr700j6zn9kn"}}` + require.Equal(t, expected, string(msgWithdrawOperatorCommission.GetSignBytes())) +} + +func TestMsgWithdrawOperatorCommission_GetSigners(t *testing.T) { + addr, _ := sdk.AccAddressFromBech32(msgWithdrawOperatorCommission.Sender) + require.Equal(t, []sdk.AccAddress{addr}, msgWithdrawOperatorCommission.GetSigners()) +} diff --git a/x/rewards/types/models.go b/x/rewards/types/models.go index b14aee62..db2d4ff9 100644 --- a/x/rewards/types/models.go +++ b/x/rewards/types/models.go @@ -174,6 +174,8 @@ func GetDistributionType(unpacker codectypes.AnyUnpacker, distr Distribution) (D return distrType, nil } +// -------------------------------------------------------------------------------------------------------------------- + // NewDistributionWeight creates a new distribution weight func NewDistributionWeight(targetID, weight uint32) DistributionWeight { return DistributionWeight{ @@ -184,19 +186,33 @@ func NewDistributionWeight(targetID, weight uint32) DistributionWeight { // -------------------------------------------------------------------------------------------------------------------- -// newBasicDistribution creates a new basic distribution -func newBasicDistribution(delType restakingtypes.DelegationType, weight uint32) Distribution { - a, err := codectypes.NewAnyWithValue(&DistributionTypeBasic{}) +// NewDistribution creates a new distribution instance +func NewDistribution(delegationType restakingtypes.DelegationType, weight uint32, distributionType DistributionType) Distribution { + distrTypeAny, err := codectypes.NewAnyWithValue(distributionType) if err != nil { panic(err) } + return Distribution{ - DelegationType: delType, + DelegationType: delegationType, Weight: weight, - Type: a, + Type: distrTypeAny, } } +// UnpackInterfaces implements codectypes.UnpackInterfacesMessage +func (d *Distribution) UnpackInterfaces(unpacker codectypes.AnyUnpacker) error { + var target DistributionType + return unpacker.UnpackAny(d.Type, &target) +} + +// -------------------------------------------------------------------------------------------------------------------- + +// newBasicDistribution creates a new basic distribution +func newBasicDistribution(delType restakingtypes.DelegationType, weight uint32) Distribution { + return NewDistribution(delType, weight, &DistributionTypeBasic{}) +} + // NewBasicPoolsDistribution creates a new basic pools distribution func NewBasicPoolsDistribution(weight uint32) Distribution { return newBasicDistribution(restakingtypes.DELEGATION_TYPE_POOL, weight) @@ -219,15 +235,7 @@ func (t DistributionTypeBasic) isDistributionType() {} // newWeightedDistribution creates a new weighted distribution func newWeightedDistribution(delType restakingtypes.DelegationType, weight uint32, weights []DistributionWeight) Distribution { - a, err := codectypes.NewAnyWithValue(&DistributionTypeWeighted{Weights: weights}) - if err != nil { - panic(err) - } - return Distribution{ - DelegationType: delType, - Weight: weight, - Type: a, - } + return NewDistribution(delType, weight, &DistributionTypeWeighted{Weights: weights}) } // NewWeightedPoolsDistribution creates a new weighted pools distribution @@ -267,15 +275,7 @@ func (t DistributionTypeWeighted) isDistributionType() {} // newEgalitarianDistribution creates a new egalitarian distribution func newEgalitarianDistribution(delType restakingtypes.DelegationType, weight uint32) Distribution { - a, err := codectypes.NewAnyWithValue(&DistributionTypeEgalitarian{}) - if err != nil { - panic(err) - } - return Distribution{ - DelegationType: delType, - Weight: weight, - Type: a, - } + return NewDistribution(delType, weight, &DistributionTypeEgalitarian{}) } // NewEgalitarianPoolsDistribution creates a new egalitarian pools distribution @@ -315,18 +315,33 @@ func GetUsersDistributionType(unpacker codectypes.AnyUnpacker, distr UsersDistri return distrType, nil } -// NewBasicUsersDistribution creates a new basic users distribution -func NewBasicUsersDistribution(weight uint32) UsersDistribution { - a, err := codectypes.NewAnyWithValue(&UsersDistributionTypeBasic{}) +// -------------------------------------------------------------------------------------------------------------------- + +func NewUsersDistribution(weight uint32, distributionType UsersDistributionType) UsersDistribution { + distrTypeAny, err := codectypes.NewAnyWithValue(distributionType) if err != nil { panic(err) } + return UsersDistribution{ Weight: weight, - Type: a, + Type: distrTypeAny, } } +// UnpackInterfaces implements codectypes.UnpackInterfacesMessage +func (u *UsersDistribution) UnpackInterfaces(unpacker codectypes.AnyUnpacker) error { + var target UsersDistributionType + return unpacker.UnpackAny(u.Type, &target) +} + +// -------------------------------------------------------------------------------------------------------------------- + +// NewBasicUsersDistribution creates a new basic users distribution +func NewBasicUsersDistribution(weight uint32) UsersDistribution { + return NewUsersDistribution(weight, &UsersDistributionTypeBasic{}) +} + // Validate checks the users distribution for validity func (t UsersDistributionTypeBasic) Validate() error { return nil diff --git a/x/rewards/types/models_test.go b/x/rewards/types/models_test.go index fc32c6b6..eb49daea 100644 --- a/x/rewards/types/models_test.go +++ b/x/rewards/types/models_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + "github.com/cosmos/cosmos-sdk/codec" + codectestutil "github.com/cosmos/cosmos-sdk/codec/testutil" "github.com/stretchr/testify/require" "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" - milkyway "github.com/milkyway-labs/milkyway/app" "github.com/milkyway-labs/milkyway/utils" "github.com/milkyway-labs/milkyway/x/rewards/types" ) @@ -131,7 +132,9 @@ func TestRewardsPlan_Validate(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - cdc, _ := milkyway.MakeCodecs() + interfaceRegistry := codectestutil.CodecOptions{AccAddressPrefix: "cosmo", ValAddressPrefix: "cosmovaloper"}.NewInterfaceRegistry() + cdc := codec.NewProtoCodec(interfaceRegistry) + err := tc.plan.Validate(cdc) if tc.shouldErr { require.Error(t, err)