From f7147cd8f200bfb2f9bb24ede2088294827e57aa Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 12 Nov 2024 18:51:16 +0100 Subject: [PATCH 1/8] wip: services simulation tests --- app/modules.go | 5 +- testutils/simtesting/utils.go | 53 +++++++++++++++++++ x/services/keeper/keeper.go | 4 +- x/services/module.go | 44 ++++++++++++++-- x/services/simulation/decoder.go | 47 +++++++++++++++++ x/services/simulation/genesis.go | 76 ++++++++++++++++++++++++++++ x/services/simulation/msg_factory.go | 73 ++++++++++++++++++++++++++ x/services/simulation/utils.go | 26 ++++++++++ 8 files changed, 322 insertions(+), 6 deletions(-) create mode 100644 testutils/simtesting/utils.go create mode 100644 x/services/simulation/decoder.go create mode 100644 x/services/simulation/genesis.go create mode 100644 x/services/simulation/msg_factory.go create mode 100644 x/services/simulation/utils.go diff --git a/app/modules.go b/app/modules.go index 531341cb..f12c75cd 100644 --- a/app/modules.go +++ b/app/modules.go @@ -176,7 +176,7 @@ func appModules( icacallbacks.NewAppModule(appCodec, *app.ICACallbacksKeeper, app.AccountKeeper, app.BankKeeper), // MilkyWay modules - services.NewAppModule(appCodec, app.ServicesKeeper, app.PoolsKeeper), + services.NewAppModule(appCodec, app.ServicesKeeper, app.PoolsKeeper, app.AccountKeeper, app.BankKeeper), operators.NewAppModule(appCodec, app.OperatorsKeeper), pools.NewAppModule(appCodec, app.PoolsKeeper), restaking.NewAppModule(appCodec, app.RestakingKeeper), @@ -229,6 +229,9 @@ func simulationModules( ibc.NewAppModule(app.IBCKeeper), app.TransferModule, app.ICAModule, + + // MilkyWay modules + services.NewAppModule(appCodec, app.ServicesKeeper, app.PoolsKeeper, app.AccountKeeper, app.BankKeeper), } } diff --git a/testutils/simtesting/utils.go b/testutils/simtesting/utils.go new file mode 100644 index 00000000..d447a91a --- /dev/null +++ b/testutils/simtesting/utils.go @@ -0,0 +1,53 @@ +package simtesting + +import ( + "math/rand" + + "github.com/cosmos/cosmos-sdk/baseapp" + "github.com/cosmos/cosmos-sdk/codec" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + simtypes "github.com/cosmos/cosmos-sdk/types/simulation" + authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" + "github.com/cosmos/cosmos-sdk/x/auth/tx" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" + "github.com/cosmos/cosmos-sdk/x/simulation" +) + +// SendMsg sends a transaction with the specified message. +func SendMsg( + r *rand.Rand, moduleName string, app *baseapp.BaseApp, ak authkeeper.AccountKeeper, bk bankkeeper.Keeper, + msg sdk.Msg, ctx sdk.Context, + simAccount simtypes.Account, +) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + deposit := sdk.Coins{} + spendableCoins := bk.SpendableCoins(ctx, simAccount.Address) + for _, v := range spendableCoins { + if bk.IsSendEnabledCoin(ctx, v) { + deposit = deposit.Add(simtypes.RandSubsetCoins(r, sdk.NewCoins(v))...) + } + } + + if deposit.IsZero() { + msgType := sdk.MsgTypeURL(msg) + return simtypes.NoOpMsg(moduleName, msgType, "skip because of broke account"), nil, nil + } + + interfaceRegistry := codectypes.NewInterfaceRegistry() + txConfig := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), tx.DefaultSignModes) + txCtx := simulation.OperationInput{ + R: r, + App: app, + TxGen: txConfig, + Cdc: nil, + Msg: msg, + Context: ctx, + SimAccount: simAccount, + AccountKeeper: ak, + Bankkeeper: bk, + ModuleName: moduleName, + CoinsSpentInMsg: deposit, + } + return simulation.GenAndDeliverTxWithRandFees(txCtx) +} diff --git a/x/services/keeper/keeper.go b/x/services/keeper/keeper.go index 2e43f414..01108a26 100644 --- a/x/services/keeper/keeper.go +++ b/x/services/keeper/keeper.go @@ -21,7 +21,7 @@ type Keeper struct { // Data storeService corestoretypes.KVStoreService - schema collections.Schema + Schema collections.Schema serviceAddressSet collections.KeySet[string] // serviceParams associated a service ID with its parameters serviceParams collections.Map[uint32, types.ServiceParams] @@ -68,7 +68,7 @@ func NewKeeper( if err != nil { panic(err) } - k.schema = schema + k.Schema = schema return k } diff --git a/x/services/module.go b/x/services/module.go index be722348..7b959093 100644 --- a/x/services/module.go +++ b/x/services/module.go @@ -16,9 +16,13 @@ import ( cdctypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/module" + simtypes "github.com/cosmos/cosmos-sdk/types/simulation" + authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" "github.com/milkyway-labs/milkyway/x/services/client/cli" "github.com/milkyway-labs/milkyway/x/services/keeper" + "github.com/milkyway-labs/milkyway/x/services/simulation" "github.com/milkyway-labs/milkyway/x/services/types" ) @@ -37,10 +41,10 @@ var ( // AppModuleBasic implements the AppModuleBasic interface for the capability module. type AppModuleBasic struct { - cdc codec.BinaryCodec + cdc codec.Codec } -func NewAppModuleBasic(cdc codec.BinaryCodec) AppModuleBasic { +func NewAppModuleBasic(cdc codec.Codec) AppModuleBasic { return AppModuleBasic{cdc: cdc} } @@ -97,14 +101,23 @@ type AppModule struct { keeper *keeper.Keeper pk types.PoolsKeeper + bk bankkeeper.Keeper + ak authkeeper.AccountKeeper } -func NewAppModule(cdc codec.Codec, keeper *keeper.Keeper, pk types.PoolsKeeper) AppModule { +func NewAppModule(cdc codec.Codec, + keeper *keeper.Keeper, + pk types.PoolsKeeper, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, +) AppModule { return AppModule{ AppModuleBasic: NewAppModuleBasic(cdc), keeper: keeper, pk: pk, + ak: ak, + bk: bk, } } @@ -152,3 +165,28 @@ func (AppModule) ConsensusVersion() uint64 { return consensusVersion } func (am AppModule) IsOnePerModuleType() {} func (am AppModule) IsAppModule() {} + +// AppModuleSimulation functions + +// GenerateGenesisState creates a randomized GenState of the staking module. +func (AppModule) GenerateGenesisState(simState *module.SimulationState) { + simulation.RandomizedGenState(simState) +} + +// ProposalMsgs returns msgs used for governance proposals for simulations. +// func (AppModule) ProposalMsgs(simState module.SimulationState) []simtypes.WeightedProposalMsg { +// return simulation.ProposalMsgs() +// } + +// RegisterStoreDecoder registers a decoder for staking module's types +func (am AppModule) RegisterStoreDecoder(sdr simtypes.StoreDecoderRegistry) { + sdr[types.StoreKey] = simulation.NewDecodeStore(am.cdc, am.keeper) +} + +// WeightedOperations returns the all the staking module operations with their respective weights. +func (am AppModule) WeightedOperations(simState module.SimulationState) []simtypes.WeightedOperation { + return simulation.WeightedOperations( + simState.AppParams, simState.Cdc, simState.TxConfig, + am.ak, am.bk, am.keeper, + ) +} diff --git a/x/services/simulation/decoder.go b/x/services/simulation/decoder.go new file mode 100644 index 00000000..23680a0b --- /dev/null +++ b/x/services/simulation/decoder.go @@ -0,0 +1,47 @@ +package simulation + +import ( + "bytes" + "fmt" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/types/kv" + simtypes "github.com/cosmos/cosmos-sdk/types/simulation" + + "github.com/milkyway-labs/milkyway/x/services/keeper" + "github.com/milkyway-labs/milkyway/x/services/types" +) + +// NewDecodeStore returns a decoder function closure that unmarshals the KVPair's +// Value to the corresponding services type. +func NewDecodeStore(cdc codec.Codec, keeper *keeper.Keeper) func(kvA, kvB kv.Pair) string { + collectionsDecoder := simtypes.NewStoreDecoderFuncFromCollectionsSchema(keeper.Schema) + + return func(kvA, kvB kv.Pair) string { + switch { + case bytes.Equal(kvA.Key[:1], types.ServicePrefix): + var serviceA, serviceB types.Service + if err := cdc.Unmarshal(kvA.Value, &serviceA); err != nil { + panic(err) + } + if err := cdc.Unmarshal(kvB.Value, &serviceB); err != nil { + panic(err) + } + return fmt.Sprintf("%v\n%v", serviceA, serviceB) + + case bytes.Equal(kvA.Key[:1], types.NextServiceIDKey): + idA := types.GetServiceIDFromBytes(kvA.Value) + idB := types.GetServiceIDFromBytes(kvB.Value) + return fmt.Sprintf("%v\n%v", idA, idB) + + case bytes.Equal(kvA.Key[:1], types.ServiceAddressSetPrefix): + return collectionsDecoder(kvA, kvB) + + case bytes.Equal(kvA.Key[:1], types.ServiceParamsPrefix): + return collectionsDecoder(kvA, kvB) + + default: + panic(fmt.Sprintf("invalid services key prefix %X", kvA.Key[:1])) + } + } +} diff --git a/x/services/simulation/genesis.go b/x/services/simulation/genesis.go new file mode 100644 index 00000000..04c735b1 --- /dev/null +++ b/x/services/simulation/genesis.go @@ -0,0 +1,76 @@ +package simulation + +import ( + "math/rand" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/module" + "github.com/cosmos/cosmos-sdk/types/simulation" + + "github.com/milkyway-labs/milkyway/x/services/types" +) + +// Simulation parameter constants +const ( + keyServiceRegistrationFees = "operators_count" + keyServices = "services" + keyServicesParams = "services_params" +) + +func getServiceRegistrationFees(r *rand.Rand) sdk.Coins { + amount := int64(r.Intn(10_000_000) + 1) + return sdk.NewCoins(sdk.NewInt64Coin("umilk", amount)) +} + +func getServices(r *rand.Rand) []types.Service { + count := r.Intn(10) + 1 + var services []types.Service + for i := 0; i < count; i++ { + adminAccount := simulation.RandomAccounts(r, 1)[0] + service := RandomService(r, uint32(i), adminAccount.Address.String()) + services = append(services, service) + } + + return services +} + +func getServiceParams(r *rand.Rand, services []types.Service) []types.ServiceParamsRecord { + var params []types.ServiceParamsRecord + for _, service := range services { + generate := (r.Uint64() % 2) == 0 + if !generate { + continue + } + + serviceParams := types.NewServiceParams([]string{"umilk"}) + params = append(params, types.NewServiceParamsRecord(service.ID, serviceParams)) + } + return params +} + +// RandomizedGenState generates a random GenesisState for the services module +func RandomizedGenState(simState *module.SimulationState) { + var ( + serviceRegistrationFees sdk.Coins + services []types.Service + servicesParams []types.ServiceParamsRecord + ) + + simState.AppParams.GetOrGenerate(keyServiceRegistrationFees, &serviceRegistrationFees, simState.Rand, func(r *rand.Rand) { + serviceRegistrationFees = getServiceRegistrationFees(r) + }) + + simState.AppParams.GetOrGenerate(keyServices, &services, simState.Rand, func(r *rand.Rand) { + services = getServices(r) + }) + + simState.AppParams.GetOrGenerate(keyServicesParams, &servicesParams, simState.Rand, func(r *rand.Rand) { + servicesParams = getServiceParams(r, services) + }) + + params := types.NewParams(serviceRegistrationFees) + nextServiceId := uint32(len(services)) + 1 + + servicesGenesis := types.NewGenesisState(nextServiceId, services, servicesParams, params) + simState.GenState[types.ModuleName] = simState.Cdc.MustMarshalJSON(servicesGenesis) +} diff --git a/x/services/simulation/msg_factory.go b/x/services/simulation/msg_factory.go new file mode 100644 index 00000000..37741672 --- /dev/null +++ b/x/services/simulation/msg_factory.go @@ -0,0 +1,73 @@ +package simulation + +import ( + "math/rand" + + "github.com/cosmos/cosmos-sdk/baseapp" + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/codec" + sdk "github.com/cosmos/cosmos-sdk/types" + simtypes "github.com/cosmos/cosmos-sdk/types/simulation" + authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" + "github.com/cosmos/cosmos-sdk/x/simulation" + + "github.com/milkyway-labs/milkyway/testutils/simtesting" + "github.com/milkyway-labs/milkyway/x/services/keeper" + "github.com/milkyway-labs/milkyway/x/services/types" +) + +// Simulation operation weights constants +const ( + DefaultWeightMsgCreateService int = 100 + + OpWeightMsgCreateService = "op_weight_msg_create_service" +) + +// WeightedOperations returns all the operations from the module with their respective weights +func WeightedOperations( + appParams simtypes.AppParams, + cdc codec.JSONCodec, + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simulation.WeightedOperations { + var weightMsgCreateService int + + appParams.GetOrGenerate(OpWeightMsgCreateService, &weightMsgCreateService, nil, func(_ *rand.Rand) { + weightMsgCreateService = DefaultWeightMsgCreateService + }) + + return simulation.WeightedOperations{ + simulation.NewWeightedOperation(weightMsgCreateService, SimulateMsgCreateService(txGen, ak, bk, k)), + } +} + +func SimulateMsgCreateService( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgCreateService", "skip"), nil, nil + } + + signer, _ := simtypes.RandomAcc(r, accs) + service := RandomService(r, 1, signer.Address.String()) + msg := types.NewMsgCreateService( + service.Name, + service.Description, + service.Website, + service.PictureURL, + service.Admin, + ) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, signer) + } +} diff --git a/x/services/simulation/utils.go b/x/services/simulation/utils.go new file mode 100644 index 00000000..d9b65c11 --- /dev/null +++ b/x/services/simulation/utils.go @@ -0,0 +1,26 @@ +package simulation + +import ( + "math/rand" + + "github.com/cosmos/cosmos-sdk/types/simulation" + "github.com/milkyway-labs/milkyway/x/services/types" +) + +func RandomServiceStatus(r *rand.Rand) types.ServiceStatus { + value := (int32(r.Uint64())) % 3 + // Here we add 1 since 0 is SERVICE_STATUS_UNSPECIFIED. + return types.ServiceStatus(value + 1) +} + +func RandomService(r *rand.Rand, id uint32, admin string) types.Service { + return types.NewService(id, + RandomServiceStatus(r), + simulation.RandStringOfLength(r, 24), + simulation.RandStringOfLength(r, 24), + simulation.RandStringOfLength(r, 24), + simulation.RandStringOfLength(r, 24), + admin, + (r.Uint64()%2) == 0, + ) +} From 5c8b1482922f3fb485d54ec1454ac3e797807d44 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 13 Nov 2024 15:04:23 +0100 Subject: [PATCH 2/8] test: completed x/services simulation tests --- testutils/simtesting/utils.go | 10 + utils/slices.go | 11 ++ x/services/module.go | 6 +- x/services/simulation/msg_factory.go | 281 ++++++++++++++++++++++++++- x/services/simulation/proposals.go | 100 ++++++++++ x/services/simulation/utils.go | 20 ++ 6 files changed, 422 insertions(+), 6 deletions(-) create mode 100644 x/services/simulation/proposals.go diff --git a/testutils/simtesting/utils.go b/testutils/simtesting/utils.go index d447a91a..a46affb4 100644 --- a/testutils/simtesting/utils.go +++ b/testutils/simtesting/utils.go @@ -51,3 +51,13 @@ func SendMsg( } return simulation.GenAndDeliverTxWithRandFees(txCtx) } + +// GetSimAccount gets the Account with the given address +func GetSimAccount(address sdk.Address, accs []simtypes.Account) (simtypes.Account, bool) { + for _, acc := range accs { + if acc.Address.Equals(address) { + return acc, true + } + } + return simtypes.Account{}, false +} diff --git a/utils/slices.go b/utils/slices.go index 1573b070..ba8e8323 100644 --- a/utils/slices.go +++ b/utils/slices.go @@ -80,3 +80,14 @@ func Intersect[T comparable](a, b []T) []T { return result } + +// Filter returns the elements of the slice that satisfy the given predicate. +func Filter[T any](slice []T, f func(T) bool) []T { + var result []T + for _, v := range slice { + if f(v) { + result = append(result, v) + } + } + return result +} diff --git a/x/services/module.go b/x/services/module.go index 7b959093..a5f986fb 100644 --- a/x/services/module.go +++ b/x/services/module.go @@ -174,9 +174,9 @@ func (AppModule) GenerateGenesisState(simState *module.SimulationState) { } // ProposalMsgs returns msgs used for governance proposals for simulations. -// func (AppModule) ProposalMsgs(simState module.SimulationState) []simtypes.WeightedProposalMsg { -// return simulation.ProposalMsgs() -// } +func (am AppModule) ProposalMsgs(simState module.SimulationState) []simtypes.WeightedProposalMsg { + return simulation.ProposalMsgs(am.keeper) +} // RegisterStoreDecoder registers a decoder for staking module's types func (am AppModule) RegisterStoreDecoder(sdr simtypes.StoreDecoderRegistry) { diff --git a/x/services/simulation/msg_factory.go b/x/services/simulation/msg_factory.go index 37741672..01ea1908 100644 --- a/x/services/simulation/msg_factory.go +++ b/x/services/simulation/msg_factory.go @@ -19,9 +19,21 @@ import ( // Simulation operation weights constants const ( - DefaultWeightMsgCreateService int = 100 + DefaultWeightMsgCreateService int = 100 + DefaultWeightMsgUpdateService int = 100 + DefaultWeightMsgActivateService int = 100 + DefaultWeightMsgDeactivateService int = 100 + DefaultWeightMsgTransferServiceOwnership int = 100 + DefaultWeightMsgDeleteService int = 100 + DefaultWeightMsgSetServiceParams int = 100 - OpWeightMsgCreateService = "op_weight_msg_create_service" + OpWeightMsgCreateService = "op_weight_msg_create_service" + OpWeightMsgUpdateService = "op_weight_msg_update_service" + OpWeightMsgActivateService = "op_weight_msg_activate_service" + OpWeightMsgDeactivateService = "op_weight_msg_deactivate_service" + OpWeightMsgTransferServiceOwnership = "op_weight_msg_transfer_service_ownership" + OpWeightMsgDeleteService = "op_weight_msg_delete_service" + OpWeightMsgSetServiceParams = "op_weight_msg_set_service_params" ) // WeightedOperations returns all the operations from the module with their respective weights @@ -33,14 +45,53 @@ func WeightedOperations( bk bankkeeper.Keeper, k *keeper.Keeper, ) simulation.WeightedOperations { - var weightMsgCreateService int + var ( + weightMsgCreateService int + weightMsgUpdateService int + weightMsgActivateService int + weightMsgDeactivateService int + weightMsgTransferServiceOwnership int + weightMsgDeleteService int + weightMsgSetServiceParams int + ) + // Generate the weights appParams.GetOrGenerate(OpWeightMsgCreateService, &weightMsgCreateService, nil, func(_ *rand.Rand) { weightMsgCreateService = DefaultWeightMsgCreateService }) + appParams.GetOrGenerate(OpWeightMsgUpdateService, &weightMsgUpdateService, nil, func(_ *rand.Rand) { + weightMsgUpdateService = DefaultWeightMsgUpdateService + }) + + appParams.GetOrGenerate(OpWeightMsgActivateService, &weightMsgActivateService, nil, func(_ *rand.Rand) { + weightMsgActivateService = DefaultWeightMsgActivateService + }) + + appParams.GetOrGenerate(OpWeightMsgDeactivateService, &weightMsgDeactivateService, nil, func(_ *rand.Rand) { + weightMsgDeactivateService = DefaultWeightMsgDeactivateService + }) + + appParams.GetOrGenerate(OpWeightMsgTransferServiceOwnership, &weightMsgTransferServiceOwnership, nil, func(_ *rand.Rand) { + weightMsgTransferServiceOwnership = DefaultWeightMsgTransferServiceOwnership + }) + + appParams.GetOrGenerate(OpWeightMsgDeleteService, &weightMsgDeleteService, nil, func(_ *rand.Rand) { + weightMsgDeleteService = DefaultWeightMsgDeleteService + }) + + appParams.GetOrGenerate(OpWeightMsgSetServiceParams, &weightMsgSetServiceParams, nil, func(_ *rand.Rand) { + weightMsgSetServiceParams = DefaultWeightMsgSetServiceParams + }) + return simulation.WeightedOperations{ simulation.NewWeightedOperation(weightMsgCreateService, SimulateMsgCreateService(txGen, ak, bk, k)), + simulation.NewWeightedOperation(weightMsgUpdateService, SimulateMsgUpdateService(txGen, ak, bk, k)), + simulation.NewWeightedOperation(weightMsgActivateService, SimulateMsgActivateService(txGen, ak, bk, k)), + simulation.NewWeightedOperation(weightMsgDeactivateService, SimulateMsgDeactivateService(txGen, ak, bk, k)), + simulation.NewWeightedOperation(weightMsgTransferServiceOwnership, SimulateMsgTransferServiceOwnership(txGen, ak, bk, k)), + simulation.NewWeightedOperation(weightMsgDeleteService, SimulateMsgDeleteService(txGen, ak, bk, k)), + simulation.NewWeightedOperation(weightMsgSetServiceParams, SimulateMsgSetServiceParams(txGen, ak, bk, k)), } } @@ -71,3 +122,227 @@ func SimulateMsgCreateService( return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, signer) } } + +func SimulateMsgUpdateService( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgUpdateService", "skip"), nil, nil + } + + // Get a random service to update + service, found := GetRandomExistingService(r, ctx, k, nil) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "MsgUpdateService", "skip"), nil, nil + } + + // Get the service admin sim account + adminAddr := sdk.MustAccAddressFromBech32(service.Admin) + simAccount, found := simtesting.GetSimAccount(adminAddr, accs) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "service admin not found", "skip"), nil, nil + } + + // Generate the new service fields + newService := RandomService(r, service.ID, service.Admin) + // Create the msg + msg := types.NewMsgUpdateService( + service.ID, + newService.Name, + newService.Description, + newService.Website, + newService.PictureURL, + simAccount.Address.String(), + ) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, simAccount) + } +} + +func SimulateMsgActivateService( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgActivateService", "skip"), nil, nil + } + + // Get a random service to activate + service, found := GetRandomExistingService(r, ctx, k, func(s types.Service) bool { + return s.Status == types.SERVICE_STATUS_CREATED || s.Status == types.SERVICE_STATUS_INACTIVE + }) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "MsgActivateService", "skip"), nil, nil + } + + // Get the service admin sim account + adminAddr := sdk.MustAccAddressFromBech32(service.Admin) + simAccount, found := simtesting.GetSimAccount(adminAddr, accs) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "service admin not found", "skip"), nil, nil + } + + // Create the msg + msg := types.NewMsgActivateService(service.ID, simAccount.Address.String()) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, simAccount) + } +} + +func SimulateMsgDeactivateService( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgDeactivateService", "skip"), nil, nil + } + + // Get a random service + service, found := GetRandomExistingService(r, ctx, k, func(s types.Service) bool { + return s.Status == types.SERVICE_STATUS_ACTIVE + }) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "MsgDeactivateService", "skip"), nil, nil + } + + // Get the service admin sim account + adminAddr := sdk.MustAccAddressFromBech32(service.Admin) + simAccount, found := simtesting.GetSimAccount(adminAddr, accs) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "service admin not found", "skip"), nil, nil + } + + // Create the msg + msg := types.NewMsgDeactivateService(service.ID, simAccount.Address.String()) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, simAccount) + } +} + +func SimulateMsgTransferServiceOwnership( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgTransferServiceOwnership", "skip"), nil, nil + } + + // Get a random service + service, found := GetRandomExistingService(r, ctx, k, nil) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "MsgTransferServiceOwnership", "skip"), nil, nil + } + + // Get the service admin sim account + adminAddr := sdk.MustAccAddressFromBech32(service.Admin) + simAccount, found := simtesting.GetSimAccount(adminAddr, accs) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "service admin not found", "skip"), nil, nil + } + + // Get a new admin + newAdminAccount, _ := simtypes.RandomAcc(r, accs) + + // Create the msg + msg := types.NewMsgTransferServiceOwnership(service.ID, newAdminAccount.Address.String(), simAccount.Address.String()) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, simAccount) + } +} + +func SimulateMsgDeleteService( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgDeleteService", "skip"), nil, nil + } + + // Get a random service + service, found := GetRandomExistingService(r, ctx, k, func(s types.Service) bool { + return s.Status == types.SERVICE_STATUS_INACTIVE + }) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "MsgDeleteService", "skip"), nil, nil + } + + // Get the service admin sim account + adminAddr := sdk.MustAccAddressFromBech32(service.Admin) + simAccount, found := simtesting.GetSimAccount(adminAddr, accs) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "service admin not found", "skip"), nil, nil + } + + // Create the msg + msg := types.NewMsgDeleteService(service.ID, simAccount.Address.String()) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, simAccount) + } +} + +func SimulateMsgSetServiceParams( + txGen client.TxConfig, + ak authkeeper.AccountKeeper, + bk bankkeeper.Keeper, + k *keeper.Keeper, +) simtypes.Operation { + return func( + r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context, accs []simtypes.Account, chainID string, + ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { + // No account skipping + if len(accs) == 0 { + return simtypes.NoOpMsg(types.ModuleName, "MsgSetServiceParams", "skip"), nil, nil + } + + // Get a random service + service, found := GetRandomExistingService(r, ctx, k, nil) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "MsgSetServiceParams", "skip"), nil, nil + } + + // Get the service admin sim account + adminAddr := sdk.MustAccAddressFromBech32(service.Admin) + simAccount, found := simtesting.GetSimAccount(adminAddr, accs) + if !found { + return simtypes.NoOpMsg(types.ModuleName, "service admin not found", "skip"), nil, nil + } + + serviceParams := types.DefaultServiceParams() + + // Create the msg + msg := types.NewMsgSetServiceParams(service.ID, serviceParams, service.Admin) + + return simtesting.SendMsg(r, types.ModuleName, app, ak, bk, msg, ctx, simAccount) + } +} diff --git a/x/services/simulation/proposals.go b/x/services/simulation/proposals.go new file mode 100644 index 00000000..1be5fd82 --- /dev/null +++ b/x/services/simulation/proposals.go @@ -0,0 +1,100 @@ +package simulation + +import ( + "math/rand" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/address" + simtypes "github.com/cosmos/cosmos-sdk/types/simulation" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + "github.com/cosmos/cosmos-sdk/x/simulation" + + "github.com/milkyway-labs/milkyway/x/services/keeper" + "github.com/milkyway-labs/milkyway/x/services/types" +) + +// Simulation operation weights constants +const ( + DefaultWeightMsgUpdateParams int = 50 + DefaultWeightMsgAccreditService int = 50 + DefaultWeightMsgRevokeServiceAccreditation int = 50 + + OpWeightMsgUpdateParams = "op_weight_msg_update_params" + OpWeightMsgAccreditService = "op_weight_msg_accredit_service" + OpWeightMsgRevokeServiceAccreditation = "op_weight_msg_revoke_service_accreditation" +) + +// ProposalMsgs defines the module weighted proposals' contents +func ProposalMsgs(keeper *keeper.Keeper) []simtypes.WeightedProposalMsg { + return []simtypes.WeightedProposalMsg{ + simulation.NewWeightedProposalMsg( + OpWeightMsgUpdateParams, + DefaultWeightMsgUpdateParams, + SimulateMsgUpdateParams, + ), + simulation.NewWeightedProposalMsg( + OpWeightMsgAccreditService, + DefaultWeightMsgAccreditService, + SimulateMsgAccreditService(keeper), + ), + simulation.NewWeightedProposalMsg( + OpWeightMsgRevokeServiceAccreditation, + DefaultWeightMsgRevokeServiceAccreditation, + SimulateMsgRevokeServiceAccreditation(keeper), + ), + } +} + +// SimulateMsgUpdateParams returns a random MsgUpdateParams +func SimulateMsgUpdateParams(r *rand.Rand, _ sdk.Context, _ []simtypes.Account) sdk.Msg { + // use the default gov module account address as authority + var authority sdk.AccAddress = address.Module("gov") + + params := types.DefaultParams() + params.ServiceRegistrationFee = sdk.NewCoins(sdk.NewInt64Coin("umilk", int64(r.Intn(10_000_000)+1))) + + return &types.MsgUpdateParams{ + Authority: authority.String(), + Params: params, + } +} + +func SimulateMsgAccreditService( + k *keeper.Keeper, +) simtypes.MsgSimulatorFn { + return func(r *rand.Rand, ctx sdk.Context, accs []simtypes.Account) sdk.Msg { + // Get a random service + service, found := GetRandomExistingService(r, ctx, k, func(s types.Service) bool { + return !s.Accredited + }) + if !found { + return nil + } + + // use the default gov module account address as authority + var authority sdk.AccAddress = address.Module(govtypes.ModuleName) + + // Create the msg + return types.NewMsgAccreditService(service.ID, authority.String()) + } +} + +func SimulateMsgRevokeServiceAccreditation( + k *keeper.Keeper, +) simtypes.MsgSimulatorFn { + return func(r *rand.Rand, ctx sdk.Context, accs []simtypes.Account) sdk.Msg { + // Get a random service + service, found := GetRandomExistingService(r, ctx, k, func(s types.Service) bool { + return s.Accredited + }) + if !found { + return nil + } + + // use the default gov module account address as authority + var authority sdk.AccAddress = address.Module(govtypes.ModuleName) + + // Create the msg + return types.NewMsgRevokeServiceAccreditation(service.ID, authority.String()) + } +} diff --git a/x/services/simulation/utils.go b/x/services/simulation/utils.go index d9b65c11..7aa75adf 100644 --- a/x/services/simulation/utils.go +++ b/x/services/simulation/utils.go @@ -3,7 +3,11 @@ package simulation import ( "math/rand" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/simulation" + "github.com/milkyway-labs/milkyway/utils" + "github.com/milkyway-labs/milkyway/x/services/keeper" "github.com/milkyway-labs/milkyway/x/services/types" ) @@ -24,3 +28,19 @@ func RandomService(r *rand.Rand, id uint32, admin string) types.Service { (r.Uint64()%2) == 0, ) } + +func GetRandomExistingService(r *rand.Rand, ctx sdk.Context, k *keeper.Keeper, filter func(s types.Service) bool) (types.Service, bool) { + services := k.GetServices(ctx) + if len(services) == 0 { + return types.Service{}, false + } + if filter != nil { + services = utils.Filter(services, filter) + if len(services) == 0 { + return types.Service{}, false + } + } + + randomServiceIndex := r.Intn(len(services)) + return services[randomServiceIndex], true +} From fbdc34ad685bdd21a93c82cce203c2c32ae9e31b Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 13 Nov 2024 16:21:06 +0100 Subject: [PATCH 3/8] test: fix simulation tests --- x/services/simulation/genesis.go | 19 ++++--------------- x/services/simulation/utils.go | 11 ++++++++--- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/x/services/simulation/genesis.go b/x/services/simulation/genesis.go index 04c735b1..87bf3510 100644 --- a/x/services/simulation/genesis.go +++ b/x/services/simulation/genesis.go @@ -3,7 +3,6 @@ package simulation import ( "math/rand" - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/module" "github.com/cosmos/cosmos-sdk/types/simulation" @@ -17,17 +16,12 @@ const ( keyServicesParams = "services_params" ) -func getServiceRegistrationFees(r *rand.Rand) sdk.Coins { - amount := int64(r.Intn(10_000_000) + 1) - return sdk.NewCoins(sdk.NewInt64Coin("umilk", amount)) -} - func getServices(r *rand.Rand) []types.Service { count := r.Intn(10) + 1 var services []types.Service for i := 0; i < count; i++ { adminAccount := simulation.RandomAccounts(r, 1)[0] - service := RandomService(r, uint32(i), adminAccount.Address.String()) + service := RandomService(r, uint32(i)+1, adminAccount.Address.String()) services = append(services, service) } @@ -51,15 +45,10 @@ func getServiceParams(r *rand.Rand, services []types.Service) []types.ServicePar // RandomizedGenState generates a random GenesisState for the services module func RandomizedGenState(simState *module.SimulationState) { var ( - serviceRegistrationFees sdk.Coins - services []types.Service - servicesParams []types.ServiceParamsRecord + services []types.Service + servicesParams []types.ServiceParamsRecord ) - simState.AppParams.GetOrGenerate(keyServiceRegistrationFees, &serviceRegistrationFees, simState.Rand, func(r *rand.Rand) { - serviceRegistrationFees = getServiceRegistrationFees(r) - }) - simState.AppParams.GetOrGenerate(keyServices, &services, simState.Rand, func(r *rand.Rand) { services = getServices(r) }) @@ -68,7 +57,7 @@ func RandomizedGenState(simState *module.SimulationState) { servicesParams = getServiceParams(r, services) }) - params := types.NewParams(serviceRegistrationFees) + params := types.DefaultParams() nextServiceId := uint32(len(services)) + 1 servicesGenesis := types.NewGenesisState(nextServiceId, services, servicesParams, params) diff --git a/x/services/simulation/utils.go b/x/services/simulation/utils.go index 7aa75adf..66db5947 100644 --- a/x/services/simulation/utils.go +++ b/x/services/simulation/utils.go @@ -12,9 +12,14 @@ import ( ) func RandomServiceStatus(r *rand.Rand) types.ServiceStatus { - value := (int32(r.Uint64())) % 3 - // Here we add 1 since 0 is SERVICE_STATUS_UNSPECIFIED. - return types.ServiceStatus(value + 1) + switch r.Intn(2) { + case 0: + return types.SERVICE_STATUS_INACTIVE + case 1: + return types.SERVICE_STATUS_CREATED + default: + return types.SERVICE_STATUS_ACTIVE + } } func RandomService(r *rand.Rand, id uint32, admin string) types.Service { From b64bfc6ae6a668cad4a8d4428f5d335117787be1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 13 Nov 2024 16:21:31 +0100 Subject: [PATCH 4/8] fix: rewards invariant check --- x/rewards/keeper/invariants.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x/rewards/keeper/invariants.go b/x/rewards/keeper/invariants.go index 101b324c..ce45b2f6 100644 --- a/x/rewards/keeper/invariants.go +++ b/x/rewards/keeper/invariants.go @@ -252,9 +252,11 @@ func ReferenceCountInvariant(k *Keeper) sdk.Invariant { // TODO: handle slash events expected := targetCount + delCount count := uint64(0) + elements := 0 err := k.PoolHistoricalRewards.Walk( ctx, nil, func(key collections.Pair[uint32, uint64], rewards types.HistoricalRewards) (stop bool, err error) { count += uint64(rewards.ReferenceCount) + elements += 1 return false, nil }, ) @@ -262,7 +264,7 @@ func ReferenceCountInvariant(k *Keeper) sdk.Invariant { panic(err) } - broken := count != expected + broken := elements > 0 && count != expected return sdk.FormatInvariant(types.ModuleName, "reference count", fmt.Sprintf("expected historical reference count: %d = %v delegation targets + %v delegations\n"+ From 854e2418399d6088a673e9a150b8e9ecb53cc5d0 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 13 Nov 2024 16:41:48 +0100 Subject: [PATCH 5/8] chore: self review --- x/services/module.go | 6 +++--- x/services/simulation/genesis.go | 4 ++-- x/services/simulation/msg_factory.go | 28 ++++++++++++++-------------- x/services/simulation/proposals.go | 14 ++++++++------ x/services/simulation/utils.go | 1 + 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/x/services/module.go b/x/services/module.go index a5f986fb..b2c66e39 100644 --- a/x/services/module.go +++ b/x/services/module.go @@ -168,7 +168,7 @@ func (am AppModule) IsAppModule() {} // AppModuleSimulation functions -// GenerateGenesisState creates a randomized GenState of the staking module. +// GenerateGenesisState creates a randomized GenState of the services module. func (AppModule) GenerateGenesisState(simState *module.SimulationState) { simulation.RandomizedGenState(simState) } @@ -178,12 +178,12 @@ func (am AppModule) ProposalMsgs(simState module.SimulationState) []simtypes.Wei return simulation.ProposalMsgs(am.keeper) } -// RegisterStoreDecoder registers a decoder for staking module's types +// RegisterStoreDecoder registers a decoder for services module's types. func (am AppModule) RegisterStoreDecoder(sdr simtypes.StoreDecoderRegistry) { sdr[types.StoreKey] = simulation.NewDecodeStore(am.cdc, am.keeper) } -// WeightedOperations returns the all the staking module operations with their respective weights. +// WeightedOperations returns the all the services module operations with their respective weights. func (am AppModule) WeightedOperations(simState module.SimulationState) []simtypes.WeightedOperation { return simulation.WeightedOperations( simState.AppParams, simState.Cdc, simState.TxConfig, diff --git a/x/services/simulation/genesis.go b/x/services/simulation/genesis.go index 87bf3510..c8f5c48e 100644 --- a/x/services/simulation/genesis.go +++ b/x/services/simulation/genesis.go @@ -58,8 +58,8 @@ func RandomizedGenState(simState *module.SimulationState) { }) params := types.DefaultParams() - nextServiceId := uint32(len(services)) + 1 + nextServiceID := uint32(len(services)) + 1 - servicesGenesis := types.NewGenesisState(nextServiceId, services, servicesParams, params) + servicesGenesis := types.NewGenesisState(nextServiceID, services, servicesParams, params) simState.GenState[types.ModuleName] = simState.Cdc.MustMarshalJSON(servicesGenesis) } diff --git a/x/services/simulation/msg_factory.go b/x/services/simulation/msg_factory.go index 01ea1908..39757fac 100644 --- a/x/services/simulation/msg_factory.go +++ b/x/services/simulation/msg_factory.go @@ -27,13 +27,13 @@ const ( DefaultWeightMsgDeleteService int = 100 DefaultWeightMsgSetServiceParams int = 100 - OpWeightMsgCreateService = "op_weight_msg_create_service" - OpWeightMsgUpdateService = "op_weight_msg_update_service" - OpWeightMsgActivateService = "op_weight_msg_activate_service" - OpWeightMsgDeactivateService = "op_weight_msg_deactivate_service" - OpWeightMsgTransferServiceOwnership = "op_weight_msg_transfer_service_ownership" - OpWeightMsgDeleteService = "op_weight_msg_delete_service" - OpWeightMsgSetServiceParams = "op_weight_msg_set_service_params" + OperationWeightMsgCreateService = "op_weight_msg_create_service" + OperationWeightMsgUpdateService = "op_weight_msg_update_service" + OperationWeightMsgActivateService = "op_weight_msg_activate_service" + OperationWeightMsgDeactivateService = "op_weight_msg_deactivate_service" + OperationWeightMsgTransferServiceOwnership = "op_weight_msg_transfer_service_ownership" + OperationWeightMsgDeleteService = "op_weight_msg_delete_service" + OperationWeightMsgSetServiceParams = "op_weight_msg_set_service_params" ) // WeightedOperations returns all the operations from the module with their respective weights @@ -56,31 +56,31 @@ func WeightedOperations( ) // Generate the weights - appParams.GetOrGenerate(OpWeightMsgCreateService, &weightMsgCreateService, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgCreateService, &weightMsgCreateService, nil, func(_ *rand.Rand) { weightMsgCreateService = DefaultWeightMsgCreateService }) - appParams.GetOrGenerate(OpWeightMsgUpdateService, &weightMsgUpdateService, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgUpdateService, &weightMsgUpdateService, nil, func(_ *rand.Rand) { weightMsgUpdateService = DefaultWeightMsgUpdateService }) - appParams.GetOrGenerate(OpWeightMsgActivateService, &weightMsgActivateService, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgActivateService, &weightMsgActivateService, nil, func(_ *rand.Rand) { weightMsgActivateService = DefaultWeightMsgActivateService }) - appParams.GetOrGenerate(OpWeightMsgDeactivateService, &weightMsgDeactivateService, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgDeactivateService, &weightMsgDeactivateService, nil, func(_ *rand.Rand) { weightMsgDeactivateService = DefaultWeightMsgDeactivateService }) - appParams.GetOrGenerate(OpWeightMsgTransferServiceOwnership, &weightMsgTransferServiceOwnership, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgTransferServiceOwnership, &weightMsgTransferServiceOwnership, nil, func(_ *rand.Rand) { weightMsgTransferServiceOwnership = DefaultWeightMsgTransferServiceOwnership }) - appParams.GetOrGenerate(OpWeightMsgDeleteService, &weightMsgDeleteService, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgDeleteService, &weightMsgDeleteService, nil, func(_ *rand.Rand) { weightMsgDeleteService = DefaultWeightMsgDeleteService }) - appParams.GetOrGenerate(OpWeightMsgSetServiceParams, &weightMsgSetServiceParams, nil, func(_ *rand.Rand) { + appParams.GetOrGenerate(OperationWeightMsgSetServiceParams, &weightMsgSetServiceParams, nil, func(_ *rand.Rand) { weightMsgSetServiceParams = DefaultWeightMsgSetServiceParams }) diff --git a/x/services/simulation/proposals.go b/x/services/simulation/proposals.go index 1be5fd82..1c7e3951 100644 --- a/x/services/simulation/proposals.go +++ b/x/services/simulation/proposals.go @@ -19,26 +19,28 @@ const ( DefaultWeightMsgAccreditService int = 50 DefaultWeightMsgRevokeServiceAccreditation int = 50 - OpWeightMsgUpdateParams = "op_weight_msg_update_params" - OpWeightMsgAccreditService = "op_weight_msg_accredit_service" - OpWeightMsgRevokeServiceAccreditation = "op_weight_msg_revoke_service_accreditation" + OperationWeightMsgUpdateParams = "op_weight_msg_update_params" + //nolint:gosec + OperationWeightMsgAccreditService = "op_weight_msg_accredit_service" + //nolint:gosec + OperationWeightMsgRevokeServiceAccreditation = "op_weight_msg_revoke_service_accreditation" ) // ProposalMsgs defines the module weighted proposals' contents func ProposalMsgs(keeper *keeper.Keeper) []simtypes.WeightedProposalMsg { return []simtypes.WeightedProposalMsg{ simulation.NewWeightedProposalMsg( - OpWeightMsgUpdateParams, + OperationWeightMsgUpdateParams, DefaultWeightMsgUpdateParams, SimulateMsgUpdateParams, ), simulation.NewWeightedProposalMsg( - OpWeightMsgAccreditService, + OperationWeightMsgAccreditService, DefaultWeightMsgAccreditService, SimulateMsgAccreditService(keeper), ), simulation.NewWeightedProposalMsg( - OpWeightMsgRevokeServiceAccreditation, + OperationWeightMsgRevokeServiceAccreditation, DefaultWeightMsgRevokeServiceAccreditation, SimulateMsgRevokeServiceAccreditation(keeper), ), diff --git a/x/services/simulation/utils.go b/x/services/simulation/utils.go index 66db5947..652c7ac9 100644 --- a/x/services/simulation/utils.go +++ b/x/services/simulation/utils.go @@ -6,6 +6,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/simulation" + "github.com/milkyway-labs/milkyway/utils" "github.com/milkyway-labs/milkyway/x/services/keeper" "github.com/milkyway-labs/milkyway/x/services/types" From 6f07a00cdf345208fd288b62807ba8491b9a5e71 Mon Sep 17 00:00:00 2001 From: Riccardo Montagnin Date: Thu, 14 Nov 2024 14:58:40 +0900 Subject: [PATCH 6/8] refactor: use hooks to setup the rewards module state --- app/keepers/community_pool.go | 27 ----- app/keepers/keepers.go | 35 ++++--- testutils/storetesting/keeper.go | 23 ++++- x/liquidvesting/keeper/common_test.go | 2 +- x/liquidvesting/testutils/keeper.go | 9 +- x/operators/keeper/genesis.go | 2 +- x/operators/keeper/genesis_test.go | 12 +-- x/operators/keeper/grpc_query_test.go | 14 +-- x/operators/keeper/invariants_test.go | 14 +-- x/operators/keeper/keeper.go | 2 +- x/operators/keeper/msg_server.go | 18 +++- x/operators/keeper/msg_server_test.go | 5 +- x/operators/keeper/operators.go | 18 +--- x/operators/keeper/operators_test.go | 47 ++------- x/operators/testutils/keeper.go | 3 +- x/pools/keeper/alias_functions.go | 6 ++ x/pools/keeper/genesis.go | 6 ++ x/pools/keeper/hooks.go | 16 +++ x/pools/keeper/keeper.go | 11 ++ x/pools/types/hooks.go | 9 ++ x/restaking/keeper/alias_functions_test.go | 6 +- x/restaking/keeper/grpc_query_test.go | 6 +- x/restaking/testutils/keeper.go | 7 +- x/rewards/keeper/alias_functions.go | 96 +++++++++++++++++ x/rewards/keeper/hooks.go | 97 +++++------------ x/rewards/keeper/opertors_hooks.go | 43 ++++++++ x/rewards/keeper/pools_hooks.go | 25 +++++ x/rewards/keeper/restaking_keeper.go | 82 +++++++++++++++ x/rewards/keeper/services_hooks.go | 38 +++++++ x/rewards/keeper/target.go | 115 ++++++++++++++++++++- x/rewards/testutils/keeper.go | 22 ++-- x/rewards/types/dec_pool.go | 9 ++ x/services/keeper/genesis.go | 2 +- x/services/keeper/msg_server.go | 16 +++ x/services/keeper/msg_server_test.go | 3 +- x/services/keeper/services.go | 14 --- x/services/keeper/services_test.go | 36 ------- x/services/testutils/keeper.go | 3 +- x/tokenfactory/keeper/common_test.go | 27 ++++- 39 files changed, 636 insertions(+), 290 deletions(-) delete mode 100644 app/keepers/community_pool.go create mode 100644 x/pools/keeper/hooks.go create mode 100644 x/pools/types/hooks.go create mode 100644 x/rewards/keeper/opertors_hooks.go create mode 100644 x/rewards/keeper/pools_hooks.go create mode 100644 x/rewards/keeper/restaking_keeper.go create mode 100644 x/rewards/keeper/services_hooks.go diff --git a/app/keepers/community_pool.go b/app/keepers/community_pool.go deleted file mode 100644 index 4a66c193..00000000 --- a/app/keepers/community_pool.go +++ /dev/null @@ -1,27 +0,0 @@ -package keepers - -import ( - "context" - - sdk "github.com/cosmos/cosmos-sdk/types" -) - -type bankKeeperForCommunityPoolKeeper interface { - SendCoinsFromAccountToModule(ctx context.Context, senderAddr sdk.AccAddress, recipientModule string, amt sdk.Coins) error -} - -type CommunityPoolKeeper struct { - bk bankKeeperForCommunityPoolKeeper - feeCollectorName string -} - -func NewCommunityPoolKeeper(bk bankKeeperForCommunityPoolKeeper, feeCollectorName string) CommunityPoolKeeper { - return CommunityPoolKeeper{ - bk: bk, - feeCollectorName: feeCollectorName, - } -} - -func (k CommunityPoolKeeper) FundCommunityPool(ctx context.Context, amount sdk.Coins, sender sdk.AccAddress) error { - return k.bk.SendCoinsFromAccountToModule(ctx, sender, k.feeCollectorName, amount) -} diff --git a/app/keepers/keepers.go b/app/keepers/keepers.go index 0996b5ec..af1848ad 100644 --- a/app/keepers/keepers.go +++ b/app/keepers/keepers.go @@ -275,8 +275,6 @@ func NewAppKeeper( logger, ) - communityPoolKeeper := NewCommunityPoolKeeper(appKeepers.BankKeeper, authtypes.FeeCollectorName) - appKeepers.CrisisKeeper = crisiskeeper.NewKeeper( appCodec, runtime.NewKVStoreService(appKeepers.keys[crisistypes.StoreKey]), @@ -421,7 +419,7 @@ func NewAppKeeper( runtime.NewKVStoreService(appKeepers.keys[tokenfactorytypes.StoreKey]), appKeepers.AccountKeeper, appKeepers.BankKeeper, - communityPoolKeeper, + appKeepers.DistrKeeper, authtypes.NewModuleAddress(govtypes.ModuleName).String(), ) appKeepers.TokenFactoryKeeper.SetContractKeeper(contractKeeper) @@ -580,7 +578,7 @@ func NewAppKeeper( appKeepers.keys[servicestypes.StoreKey], runtime.NewKVStoreService(appKeepers.keys[servicestypes.StoreKey]), appKeepers.AccountKeeper, - communityPoolKeeper, + appKeepers.DistrKeeper, govAuthority, ) appKeepers.OperatorsKeeper = operatorskeeper.NewKeeper( @@ -588,7 +586,7 @@ func NewAppKeeper( appKeepers.keys[operatorstypes.StoreKey], runtime.NewKVStoreService(appKeepers.keys[operatorstypes.StoreKey]), appKeepers.AccountKeeper, - communityPoolKeeper, + appKeepers.DistrKeeper, govAuthority, ) appKeepers.PoolsKeeper = poolskeeper.NewKeeper( @@ -608,11 +606,6 @@ func NewAppKeeper( appKeepers.ServicesKeeper, govAuthority, ) - - // Set hooks based on the restaking keeper - appKeepers.OperatorsKeeper.SetHooks(appKeepers.RestakingKeeper.OperatorsHooks()) - appKeepers.ServicesKeeper.SetHooks(appKeepers.RestakingKeeper.ServicesHooks()) - appKeepers.AssetsKeeper = assetskeeper.NewKeeper( appCodec, runtime.NewKVStoreService(appKeepers.keys[assetstypes.StoreKey]), @@ -623,7 +616,7 @@ func NewAppKeeper( runtime.NewKVStoreService(appKeepers.keys[rewardstypes.StoreKey]), appKeepers.AccountKeeper, appKeepers.BankKeeper, - communityPoolKeeper, + appKeepers.DistrKeeper, appKeepers.OracleKeeper, appKeepers.PoolsKeeper, appKeepers.OperatorsKeeper, @@ -634,8 +627,6 @@ func NewAppKeeper( ) // Set hooks based on the rewards keeper - appKeepers.RestakingKeeper.SetHooks(appKeepers.RewardsKeeper.Hooks()) - appKeepers.LiquidVestingKeeper = liquidvestingkeeper.NewKeeper( appCodec, appKeepers.keys[liquidvestingtypes.StoreKey], @@ -650,9 +641,25 @@ func NewAppKeeper( govAuthority, ) - // Set hooks based on the liquid vesting keeper + // Set the restrictions on sending tokens appKeepers.BankKeeper.AppendSendRestriction(appKeepers.LiquidVestingKeeper.SendRestrictionFn) + // Set the hooks up to this point + appKeepers.PoolsKeeper.SetHooks( + appKeepers.RewardsKeeper.PoolsHooks(), + ) + appKeepers.OperatorsKeeper.SetHooks(operatorstypes.NewMultiOperatorsHooks( + appKeepers.RestakingKeeper.OperatorsHooks(), + appKeepers.RewardsKeeper.OperatorsHooks(), + )) + appKeepers.ServicesKeeper.SetHooks(servicestypes.NewMultiServicesHooks( + appKeepers.RestakingKeeper.ServicesHooks(), + appKeepers.RewardsKeeper.ServicesHooks(), + )) + appKeepers.RestakingKeeper.SetHooks( + appKeepers.RewardsKeeper.RestakingHooks(), + ) + // ---------------------- // // --- Stride Keepers --- // // ---------------------- // diff --git a/testutils/storetesting/keeper.go b/testutils/storetesting/keeper.go index cd8dde0e..acbae1ea 100644 --- a/testutils/storetesting/keeper.go +++ b/testutils/storetesting/keeper.go @@ -13,6 +13,8 @@ import ( authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + distrkeeper "github.com/cosmos/cosmos-sdk/x/distribution/keeper" + distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types" govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" milkyway "github.com/milkyway-labs/milkyway/app" @@ -28,8 +30,9 @@ type BaseKeeperTestData struct { AuthorityAddress string - AccountKeeper authkeeper.AccountKeeper - BankKeeper bankkeeper.Keeper + AccountKeeper authkeeper.AccountKeeper + BankKeeper bankkeeper.Keeper + DistributionKeeper distrkeeper.Keeper } // NewBaseKeeperTestData returns a new BaseKeeperTestData @@ -45,7 +48,7 @@ func NewBaseKeeperTestData(t *testing.T, keys []string) BaseKeeperTestData { var data BaseKeeperTestData // Define store keys - keys = append(keys, []string{authtypes.StoreKey, banktypes.StoreKey}...) + keys = append(keys, []string{authtypes.StoreKey, banktypes.StoreKey, distrtypes.StoreKey}...) slices.Sort(keys) keys = slices.Compact(keys) data.Keys = storetypes.NewKVStoreKeys(keys...) @@ -77,6 +80,20 @@ func NewBaseKeeperTestData(t *testing.T, keys []string) BaseKeeperTestData { data.AuthorityAddress, log.NewNopLogger(), ) + data.DistributionKeeper = distrkeeper.NewKeeper( + data.Cdc, + runtime.NewKVStoreService(data.Keys[distrtypes.StoreKey]), + data.AccountKeeper, + data.BankKeeper, + nil, + authtypes.FeeCollectorName, + data.AuthorityAddress, + ) + + // Init the module's genesis state as the default ones + data.AccountKeeper.InitGenesis(data.Context, *authtypes.DefaultGenesisState()) + data.BankKeeper.InitGenesis(data.Context, banktypes.DefaultGenesisState()) + data.DistributionKeeper.InitGenesis(data.Context, *distrtypes.DefaultGenesisState()) return data } diff --git a/x/liquidvesting/keeper/common_test.go b/x/liquidvesting/keeper/common_test.go index 4e7b05f7..ffe66642 100644 --- a/x/liquidvesting/keeper/common_test.go +++ b/x/liquidvesting/keeper/common_test.go @@ -143,7 +143,7 @@ func (suite *KeeperTestSuite) createService(ctx sdk.Context, id uint32) { } func (suite *KeeperTestSuite) createOperator(ctx sdk.Context, id uint32) { - err := suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err := suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( id, operatorstypes.OPERATOR_STATUS_ACTIVE, fmt.Sprintf("operator-%d", id), diff --git a/x/liquidvesting/testutils/keeper.go b/x/liquidvesting/testutils/keeper.go index 92f1a127..4b9ed55c 100644 --- a/x/liquidvesting/testutils/keeper.go +++ b/x/liquidvesting/testutils/keeper.go @@ -15,7 +15,6 @@ import ( stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" - appkeepers "github.com/milkyway-labs/milkyway/app/keepers" "github.com/milkyway-labs/milkyway/testutils/storetesting" "github.com/milkyway-labs/milkyway/x/liquidvesting" "github.com/milkyway-labs/milkyway/x/liquidvesting/keeper" @@ -68,16 +67,12 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { runtime.NewKVStoreService(data.Keys[poolstypes.StoreKey]), data.AccountKeeper, ) - communityPoolKeeper := appkeepers.NewCommunityPoolKeeper( - data.BankKeeper, - authtypes.FeeCollectorName, - ) data.OperatorsKeeper = operatorskeeper.NewKeeper( data.Cdc, data.Keys[operatorstypes.StoreKey], runtime.NewKVStoreService(data.Keys[operatorstypes.StoreKey]), data.AccountKeeper, - communityPoolKeeper, + data.DistributionKeeper, data.AuthorityAddress, ) data.ServicesKeeper = serviceskeeper.NewKeeper( @@ -85,7 +80,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.Keys[servicestypes.StoreKey], runtime.NewKVStoreService(data.Keys[servicestypes.StoreKey]), data.AccountKeeper, - communityPoolKeeper, + data.DistributionKeeper, data.AuthorityAddress, ) data.RestakingKeeper = restakingkeeper.NewKeeper( diff --git a/x/operators/keeper/genesis.go b/x/operators/keeper/genesis.go index 192a5e7c..f4372b54 100644 --- a/x/operators/keeper/genesis.go +++ b/x/operators/keeper/genesis.go @@ -47,7 +47,7 @@ func (k *Keeper) InitGenesis(ctx sdk.Context, state types.GenesisState) error { // Store the operators for _, operator := range state.Operators { - if err := k.SaveOperator(ctx, operator); err != nil { + if err := k.CreateOperator(ctx, operator); err != nil { return err } } diff --git a/x/operators/keeper/genesis_test.go b/x/operators/keeper/genesis_test.go index fc40e626..399fdcea 100644 --- a/x/operators/keeper/genesis_test.go +++ b/x/operators/keeper/genesis_test.go @@ -35,7 +35,7 @@ func (suite *KeeperTestSuite) TestKeeper_ExportGenesis() { suite.k.SetNextOperatorID(ctx, 10) suite.k.SetParams(ctx, types.DefaultParams()) - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -45,7 +45,7 @@ func (suite *KeeperTestSuite) TestKeeper_ExportGenesis() { )) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, "Inertia", @@ -95,13 +95,13 @@ func (suite *KeeperTestSuite) TestKeeper_ExportGenesis() { "https://milkyway.com/picture", "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4", ) - err := suite.k.RegisterOperator(ctx, activeValidator) + err := suite.k.CreateOperator(ctx, activeValidator) suite.Require().NoError(err) err = suite.k.StartOperatorInactivation(ctx, activeValidator) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_ACTIVE, "Inertia", @@ -146,7 +146,7 @@ func (suite *KeeperTestSuite) TestKeeper_ExportGenesis() { suite.k.SetNextOperatorID(ctx, 10) suite.k.SetParams(ctx, types.DefaultParams()) - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -162,7 +162,7 @@ func (suite *KeeperTestSuite) TestKeeper_ExportGenesis() { )) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, "Inertia", diff --git a/x/operators/keeper/grpc_query_test.go b/x/operators/keeper/grpc_query_test.go index 877c66f1..233b29bf 100644 --- a/x/operators/keeper/grpc_query_test.go +++ b/x/operators/keeper/grpc_query_test.go @@ -24,7 +24,7 @@ func (suite *KeeperTestSuite) TestQueryServer_Operator() { { name: "existing operator is returned properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -83,7 +83,7 @@ func (suite *KeeperTestSuite) TestQueryServer_OperatorParams() { { name: "default operator params are returned properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -102,7 +102,7 @@ func (suite *KeeperTestSuite) TestQueryServer_OperatorParams() { { name: "updated operator params are returned properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -155,7 +155,7 @@ func (suite *KeeperTestSuite) TestQueryServer_Operators() { { name: "query without pagination returns data properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -165,7 +165,7 @@ func (suite *KeeperTestSuite) TestQueryServer_Operators() { )) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, "Inertia", @@ -199,7 +199,7 @@ func (suite *KeeperTestSuite) TestQueryServer_Operators() { { name: "query with pagination returns data properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -209,7 +209,7 @@ func (suite *KeeperTestSuite) TestQueryServer_Operators() { )) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, "Inertia", diff --git a/x/operators/keeper/invariants_test.go b/x/operators/keeper/invariants_test.go index 533b32ff..72323384 100644 --- a/x/operators/keeper/invariants_test.go +++ b/x/operators/keeper/invariants_test.go @@ -17,7 +17,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { { name: "not found next operator id breaks invariant", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -27,7 +27,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { )) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, "Inertia", @@ -43,7 +43,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { name: "operator with id equals to next operator id breaks invariant", store: func(ctx sdk.Context) { suite.k.SetNextOperatorID(ctx, 1) - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -59,7 +59,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { name: "operator with id higher than next operator id breaks invariant", store: func(ctx sdk.Context) { suite.k.SetNextOperatorID(ctx, 1) - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -75,7 +75,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { name: "invalid operator breaks invariant", store: func(ctx sdk.Context) { suite.k.SetNextOperatorID(ctx, 1) - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_UNSPECIFIED, "MilkyWay Operator", @@ -91,7 +91,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { name: "valid data does not break invariant", store: func(ctx sdk.Context) { suite.k.SetNextOperatorID(ctx, 3) - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -101,7 +101,7 @@ func (suite *KeeperTestSuite) TestValidOperatorsInvariant() { )) suite.Require().NoError(err) - err = suite.k.RegisterOperator(ctx, types.NewOperator( + err = suite.k.CreateOperator(ctx, types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, "Inertia", diff --git a/x/operators/keeper/keeper.go b/x/operators/keeper/keeper.go index 1adb7a16..d186c009 100644 --- a/x/operators/keeper/keeper.go +++ b/x/operators/keeper/keeper.go @@ -76,7 +76,7 @@ func (k *Keeper) Logger(ctx sdk.Context) log.Logger { // SetHooks allows to set the operators hooks func (k *Keeper) SetHooks(rs types.OperatorsHooks) *Keeper { if k.hooks != nil { - panic("cannot set avs hooks twice") + panic("cannot set operators hooks twice") } k.hooks = rs diff --git a/x/operators/keeper/msg_server.go b/x/operators/keeper/msg_server.go index d0787b3d..e3de5cb2 100644 --- a/x/operators/keeper/msg_server.go +++ b/x/operators/keeper/msg_server.go @@ -22,7 +22,7 @@ func NewMsgServer(k *Keeper) types.MsgServer { return &msgServer{Keeper: k} } -// RegisterOperator defines the rpc method for Msg/RegisterOperator +// RegisterOperator defines the rpc method for Msg/CreateOperator func (k msgServer) RegisterOperator(goCtx context.Context, msg *types.MsgRegisterOperator) (*types.MsgRegisterOperatorResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) @@ -48,8 +48,22 @@ func (k msgServer) RegisterOperator(goCtx context.Context, msg *types.MsgRegiste return nil, errors.Wrap(sdkerrors.ErrInvalidRequest, err.Error()) } + // Charge for the creation + registrationFees := k.GetParams(ctx).OperatorRegistrationFee + if !registrationFees.IsZero() { + userAddress, err := sdk.AccAddressFromBech32(operator.Admin) + if err != nil { + return nil, errors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid operator admin address: %s", operator.Admin) + } + + err = k.poolKeeper.FundCommunityPool(ctx, registrationFees, userAddress) + if err != nil { + return nil, err + } + } + // Store the operator - err = k.Keeper.RegisterOperator(ctx, operator) + err = k.Keeper.CreateOperator(ctx, operator) if err != nil { return nil, err } diff --git a/x/operators/keeper/msg_server_test.go b/x/operators/keeper/msg_server_test.go index da9da1aa..f4dfbff0 100644 --- a/x/operators/keeper/msg_server_test.go +++ b/x/operators/keeper/msg_server_test.go @@ -6,6 +6,7 @@ import ( sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types" "github.com/milkyway-labs/milkyway/x/operators/keeper" "github.com/milkyway-labs/milkyway/x/operators/types" @@ -112,7 +113,7 @@ func (suite *KeeperTestSuite) TestMsgServer_RegisterOperator() { suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000)), balance) // Make sure the community pool was funded - poolBalance := suite.bk.GetBalance(ctx, authtypes.NewModuleAddress(authtypes.FeeCollectorName), "uatom") + poolBalance := suite.bk.GetBalance(ctx, authtypes.NewModuleAddress(distrtypes.ModuleName), "uatom") suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000)), poolBalance) }, }, @@ -935,7 +936,7 @@ func (suite *KeeperTestSuite) TestMsgServer_SetOperatorParams() { } // Register a test operator - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( testOperatorId, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", diff --git a/x/operators/keeper/operators.go b/x/operators/keeper/operators.go index e414edad..7d850914 100644 --- a/x/operators/keeper/operators.go +++ b/x/operators/keeper/operators.go @@ -32,22 +32,8 @@ func (k *Keeper) GetNextOperatorID(ctx sdk.Context) (operatorID uint32, err erro // -------------------------------------------------------------------------------------------------------------------- -// RegisterOperator creates a new Operator and stores it in the KVStore -func (k *Keeper) RegisterOperator(ctx sdk.Context, operator types.Operator) error { - // Charge for the creation - registrationFees := k.GetParams(ctx).OperatorRegistrationFee - if !registrationFees.IsZero() { - userAddress, err := sdk.AccAddressFromBech32(operator.Admin) - if err != nil { - return errors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid operator admin address: %s", operator.Admin) - } - - err = k.poolKeeper.FundCommunityPool(ctx, registrationFees, userAddress) - if err != nil { - return err - } - } - +// CreateOperator creates a new Operator and stores it in the KVStore +func (k *Keeper) CreateOperator(ctx sdk.Context, operator types.Operator) error { // Create the operator account if it does not exist operatorAddress, err := sdk.AccAddressFromBech32(operator.Address) if err != nil { diff --git a/x/operators/keeper/operators_test.go b/x/operators/keeper/operators_test.go index e85666db..36aad57f 100644 --- a/x/operators/keeper/operators_test.go +++ b/x/operators/keeper/operators_test.go @@ -5,7 +5,6 @@ import ( sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" - authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/milkyway-labs/milkyway/x/operators/types" ) @@ -97,7 +96,7 @@ func (suite *KeeperTestSuite) TestKeeper_GetNextOperatorID() { // -------------------------------------------------------------------------------------------------------------------- -func (suite *KeeperTestSuite) TestKeeper_RegisterOperator() { +func (suite *KeeperTestSuite) TestKeeper_CreateOperator() { testCases := []struct { name string setup func() @@ -106,28 +105,6 @@ func (suite *KeeperTestSuite) TestKeeper_RegisterOperator() { shouldErr bool check func(ctx sdk.Context) }{ - { - name: "user without enough funds to pay for registration fee returns erorr", - store: func(ctx sdk.Context) { - // Set the registration fee - suite.k.SetParams(ctx, types.NewParams( - sdk.NewCoins(sdk.NewCoin("uatom", sdkmath.NewInt(200_000_000))), - 24*time.Hour, - )) - - // Fund the user account - suite.fundAccount(ctx, "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4", sdk.NewCoins(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000)))) - }, - operator: types.NewOperator( - 1, - types.OPERATOR_STATUS_ACTIVE, - "MilkyWay Operator", - "https://milkyway.com", - "https://milkyway.com/picture", - "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4", - ), - shouldErr: true, - }, { name: "operator is registered correctly", store: func(ctx sdk.Context) { @@ -162,16 +139,6 @@ func (suite *KeeperTestSuite) TestKeeper_RegisterOperator() { "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4", ), stored) - // Make sure the user has been charged - userAddress, err := sdk.AccAddressFromBech32("cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4") - suite.Require().NoError(err) - userBalance := suite.bk.GetBalance(ctx, userAddress, "uatom") - suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000)), userBalance) - - // Make sure the community pool has been funded - poolBalance := suite.bk.GetBalance(ctx, authtypes.NewModuleAddress(authtypes.FeeCollectorName), "uatom") - suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000)), poolBalance) - // Make sure the hook has been called suite.Require().True(suite.hooks.CalledMap["AfterOperatorRegistered"]) }, @@ -189,7 +156,7 @@ func (suite *KeeperTestSuite) TestKeeper_RegisterOperator() { tc.store(ctx) } - err := suite.k.RegisterOperator(ctx, tc.operator) + err := suite.k.CreateOperator(ctx, tc.operator) if tc.shouldErr { suite.Require().Error(err) } else { @@ -220,7 +187,7 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { { name: "existing operator is returned properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -297,7 +264,7 @@ func (suite *KeeperTestSuite) TestKeeper_SaveOperator() { { name: "existing operator is returned properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -577,7 +544,7 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { { name: "reactivate active operator fails", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, "MilkyWay Operator", @@ -593,7 +560,7 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { { name: "reactivate inactivating operator fails", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, "MilkyWay Operator", @@ -609,7 +576,7 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { { name: "reactivate inactive operator works properly", store: func(ctx sdk.Context) { - err := suite.k.RegisterOperator(ctx, types.NewOperator( + err := suite.k.CreateOperator(ctx, types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVE, "MilkyWay Operator", diff --git a/x/operators/testutils/keeper.go b/x/operators/testutils/keeper.go index b3c91b6f..7c41ee28 100644 --- a/x/operators/testutils/keeper.go +++ b/x/operators/testutils/keeper.go @@ -8,7 +8,6 @@ import ( authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" - "github.com/milkyway-labs/milkyway/app/keepers" "github.com/milkyway-labs/milkyway/testutils/storetesting" "github.com/milkyway-labs/milkyway/x/operators/keeper" "github.com/milkyway-labs/milkyway/x/operators/types" @@ -41,7 +40,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.StoreKey, runtime.NewKVStoreService(data.Keys[types.StoreKey]), data.AccountKeeper, - keepers.NewCommunityPoolKeeper(data.BankKeeper, authtypes.FeeCollectorName), + data.DistributionKeeper, data.AuthorityAddress, ) diff --git a/x/pools/keeper/alias_functions.go b/x/pools/keeper/alias_functions.go index 4bc8e68d..65f961e3 100644 --- a/x/pools/keeper/alias_functions.go +++ b/x/pools/keeper/alias_functions.go @@ -95,6 +95,12 @@ func (k *Keeper) CreateOrGetPoolByDenom(ctx sdk.Context, denom string) (types.Po // Log the event k.Logger(ctx).Debug("created pool", "id", poolID, "denom", denom) + // Call the hook + err = k.AfterPoolCreated(ctx, pool.ID) + if err != nil { + return pool, err + } + return pool, nil } diff --git a/x/pools/keeper/genesis.go b/x/pools/keeper/genesis.go index 3c16dab1..6ff8c705 100644 --- a/x/pools/keeper/genesis.go +++ b/x/pools/keeper/genesis.go @@ -39,5 +39,11 @@ func (k *Keeper) InitGenesis(ctx sdk.Context, data *types.GenesisState) { if err != nil { panic(err) } + + // Call the hook + err = k.AfterPoolCreated(ctx, pool.ID) + if err != nil { + panic(err) + } } } diff --git a/x/pools/keeper/hooks.go b/x/pools/keeper/hooks.go new file mode 100644 index 00000000..2873e7f7 --- /dev/null +++ b/x/pools/keeper/hooks.go @@ -0,0 +1,16 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/milkyway-labs/milkyway/x/pools/types" +) + +var _ types.PoolsHooks = &Keeper{} + +func (k *Keeper) AfterPoolCreated(ctx sdk.Context, poolID uint32) error { + if k.hooks != nil { + return k.hooks.AfterPoolCreated(ctx, poolID) + } + return nil +} diff --git a/x/pools/keeper/keeper.go b/x/pools/keeper/keeper.go index ee1095a8..ada23e10 100644 --- a/x/pools/keeper/keeper.go +++ b/x/pools/keeper/keeper.go @@ -15,6 +15,7 @@ type Keeper struct { storeKey storetypes.StoreKey cdc codec.Codec storeService corestoretypes.KVStoreService + hooks types.PoolsHooks accountKeeper types.AccountKeeper @@ -56,3 +57,13 @@ func NewKeeper(cdc codec.Codec, func (k *Keeper) Logger(ctx sdk.Context) log.Logger { return ctx.Logger().With("module", "x/"+types.ModuleName) } + +// SetHooks allows to set the pools hooks +func (k *Keeper) SetHooks(rs types.PoolsHooks) *Keeper { + if k.hooks != nil { + panic("cannot set pools hooks twice") + } + + k.hooks = rs + return k +} diff --git a/x/pools/types/hooks.go b/x/pools/types/hooks.go new file mode 100644 index 00000000..319bdd6d --- /dev/null +++ b/x/pools/types/hooks.go @@ -0,0 +1,9 @@ +package types + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type PoolsHooks interface { + AfterPoolCreated(ctx sdk.Context, poolID uint32) error +} diff --git a/x/restaking/keeper/alias_functions_test.go b/x/restaking/keeper/alias_functions_test.go index c29534a4..89cb7701 100644 --- a/x/restaking/keeper/alias_functions_test.go +++ b/x/restaking/keeper/alias_functions_test.go @@ -692,7 +692,7 @@ func (suite *KeeperTestSuite) TestKeeper_UnbondRestakedAssets() { suite.Assert().NoError(err) // Delegate to operator - err = suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err = suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( 1, operatorstypes.OPERATOR_STATUS_ACTIVE, "", "", "", "", )) suite.Assert().NoError(err) @@ -756,7 +756,7 @@ func (suite *KeeperTestSuite) TestKeeper_UnbondRestakedAssets() { suite.Assert().NoError(err) // Delegate to operator - err = suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err = suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( 1, operatorstypes.OPERATOR_STATUS_ACTIVE, "", "", "", "", )) suite.Assert().NoError(err) @@ -791,7 +791,7 @@ func (suite *KeeperTestSuite) TestKeeper_UnbondRestakedAssets() { "", false, )) - err = suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err = suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( 1, operatorstypes.OPERATOR_STATUS_ACTIVE, "", "", "", "", )) diff --git a/x/restaking/keeper/grpc_query_test.go b/x/restaking/keeper/grpc_query_test.go index ba5fec21..a0ed4ef2 100644 --- a/x/restaking/keeper/grpc_query_test.go +++ b/x/restaking/keeper/grpc_query_test.go @@ -40,7 +40,7 @@ func (suite *KeeperTestSuite) TestQuerier_OperatorJoinedServices() { { name: "operator without joined services returns empty serviceIDs", store: func(ctx sdk.Context) { - err := suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err := suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( 1, operatorstypes.OPERATOR_STATUS_ACTIVE, "", "", "", "", )) suite.Require().NoError(err) @@ -52,7 +52,7 @@ func (suite *KeeperTestSuite) TestQuerier_OperatorJoinedServices() { { name: "configured joined services are returned properly", store: func(ctx sdk.Context) { - err := suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err := suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( 1, operatorstypes.OPERATOR_STATUS_ACTIVE, "", "", "", "", )) suite.Require().NoError(err) @@ -69,7 +69,7 @@ func (suite *KeeperTestSuite) TestQuerier_OperatorJoinedServices() { { name: "pagination is handled properly", store: func(ctx sdk.Context) { - err := suite.ok.RegisterOperator(ctx, operatorstypes.NewOperator( + err := suite.ok.CreateOperator(ctx, operatorstypes.NewOperator( 1, operatorstypes.OPERATOR_STATUS_ACTIVE, "", "", "", "", )) suite.Require().NoError(err) diff --git a/x/restaking/testutils/keeper.go b/x/restaking/testutils/keeper.go index 59276b4c..a4703d76 100644 --- a/x/restaking/testutils/keeper.go +++ b/x/restaking/testutils/keeper.go @@ -8,7 +8,6 @@ import ( authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" - appkeepers "github.com/milkyway-labs/milkyway/app/keepers" "github.com/milkyway-labs/milkyway/testutils/storetesting" operatorskeeper "github.com/milkyway-labs/milkyway/x/operators/keeper" operatorstypes "github.com/milkyway-labs/milkyway/x/operators/types" @@ -45,8 +44,6 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.StoreKey = data.Keys[types.StoreKey] // Build the keepers - communityPoolKeeper := appkeepers.NewCommunityPoolKeeper(data.BankKeeper, authtypes.FeeCollectorName) - data.PoolsKeeper = poolskeeper.NewKeeper( data.Cdc, data.Keys[poolstypes.StoreKey], @@ -58,7 +55,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.Keys[operatorstypes.StoreKey], runtime.NewKVStoreService(data.Keys[operatorstypes.StoreKey]), data.AccountKeeper, - communityPoolKeeper, + data.DistributionKeeper, data.AuthorityAddress, ) data.ServicesKeeper = serviceskeeper.NewKeeper( @@ -66,7 +63,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.Keys[servicestypes.StoreKey], runtime.NewKVStoreService(data.Keys[servicestypes.StoreKey]), data.AccountKeeper, - communityPoolKeeper, + data.DistributionKeeper, data.AuthorityAddress, ) data.Keeper = keeper.NewKeeper( diff --git a/x/rewards/keeper/alias_functions.go b/x/rewards/keeper/alias_functions.go index 7d62a71c..a3151f06 100644 --- a/x/rewards/keeper/alias_functions.go +++ b/x/rewards/keeper/alias_functions.go @@ -163,6 +163,20 @@ func (k *Keeper) GetCurrentRewards(ctx context.Context, target restakingtypes.De } } +// DeleteCurrentRewards deletes the current rewards for a target +func (k *Keeper) DeleteCurrentRewards(ctx context.Context, target restakingtypes.DelegationTarget) error { + switch target.(type) { + case *poolstypes.Pool: + return k.PoolCurrentRewards.Remove(ctx, target.GetID()) + case *operatorstypes.Operator: + return k.OperatorCurrentRewards.Remove(ctx, target.GetID()) + case *servicestypes.Service: + return k.ServiceCurrentRewards.Remove(ctx, target.GetID()) + default: + return errors.Wrapf(restakingtypes.ErrInvalidDelegationType, "invalid delegation target type %T", target) + } +} + // -------------------------------------------------------------------------------------------------------------------- // SetHistoricalRewards sets the historical rewards for a target and period @@ -216,6 +230,43 @@ func (k *Keeper) RemoveHistoricalRewards( } } +// DeleteHistoricalRewards deletes all historical rewards for a target +func (k *Keeper) DeleteHistoricalRewards(ctx context.Context, target restakingtypes.DelegationTarget) error { + var collection collections.Map[collections.Pair[uint32, uint64], types.HistoricalRewards] + switch target.(type) { + case *poolstypes.Pool: + collection = k.PoolHistoricalRewards + case *operatorstypes.Operator: + collection = k.OperatorHistoricalRewards + case *servicestypes.Service: + collection = k.ServiceHistoricalRewards + default: + return errors.Wrapf(restakingtypes.ErrInvalidDelegationType, "invalid delegation target type %T", target) + } + + // Walk over the collection and get the list of keys to be deleted + // TODO: Find a more efficient way of doing this by using rangers + var keys []collections.Pair[uint32, uint64] + err := collection.Walk(ctx, nil, func(key collections.Pair[uint32, uint64], value types.HistoricalRewards) (stop bool, err error) { + if key.K1() == target.GetID() { + keys = append(keys, key) + } + return false, nil + }) + if err != nil { + return err + } + + // Delete all the keys from the collection + for _, key := range keys { + if err := collection.Remove(ctx, key); err != nil { + return err + } + } + + return nil +} + // -------------------------------------------------------------------------------------------------------------------- // GetOutstandingRewardsCoins returns the outstanding rewards coins for a target @@ -240,6 +291,20 @@ func (k *Keeper) GetOutstandingRewardsCoins(ctx context.Context, target restakin return rewards.Rewards, nil } +// DeleteOutstandingRewards deletes the outstanding rewards for a target +func (k *Keeper) DeleteOutstandingRewards(ctx context.Context, target restakingtypes.DelegationTarget) error { + switch target.(type) { + case *poolstypes.Pool: + return k.PoolOutstandingRewards.Remove(ctx, target.GetID()) + case *operatorstypes.Operator: + return k.OperatorOutstandingRewards.Remove(ctx, target.GetID()) + case *servicestypes.Service: + return k.ServiceOutstandingRewards.Remove(ctx, target.GetID()) + default: + return errors.Wrapf(restakingtypes.ErrInvalidDelegationType, "invalid delegation target type %T", target) + } +} + // GetOperatorAccumulatedCommission returns the accumulated commission for an operator. func (k *Keeper) GetOperatorAccumulatedCommission(ctx context.Context, operatorID uint32) (commission types.AccumulatedCommission, err error) { commission, err = k.OperatorAccumulatedCommissions.Get(ctx, operatorID) @@ -252,6 +317,37 @@ func (k *Keeper) GetOperatorAccumulatedCommission(ctx context.Context, operatorI return } +// DeleteOperatorAccumulatedCommission deletes the accumulated commission for an operator. +func (k *Keeper) DeleteOperatorAccumulatedCommission(ctx context.Context, operatorID uint32) error { + return k.OperatorAccumulatedCommissions.Remove(ctx, operatorID) +} + +// GetOperatorWithdrawAddr returns the outstanding rewards coins for an operator +func (k *Keeper) GetOperatorWithdrawAddr(ctx context.Context, operator *operatorstypes.Operator) (sdk.AccAddress, error) { + // Try getting a custom withdraw address + operatorAddr, err := sdk.AccAddressFromBech32(operator.Address) + if err != nil { + return nil, err + } + + withdrawAddr, err := k.GetDelegatorWithdrawAddr(ctx, operatorAddr) + if err != nil { + return nil, err + } + + if withdrawAddr != nil { + return withdrawAddr, nil + } + + // By default, use the operator admin address as the withdraw address + adminAddress, err := sdk.AccAddressFromBech32(operator.Admin) + if err != nil { + return nil, err + } + + return adminAddress, nil +} + // GetDelegatorWithdrawAddr returns the delegator's withdraw address if set, otherwise the delegator address is returned. func (k *Keeper) GetDelegatorWithdrawAddr(ctx context.Context, delegator sdk.AccAddress) (sdk.AccAddress, error) { addr, err := k.DelegatorWithdrawAddrs.Get(ctx, delegator) diff --git a/x/rewards/keeper/hooks.go b/x/rewards/keeper/hooks.go index 6a733313..7bb5267f 100644 --- a/x/rewards/keeper/hooks.go +++ b/x/rewards/keeper/hooks.go @@ -7,73 +7,38 @@ import ( restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" ) -var _ restakingtypes.RestakingHooks = Hooks{} - -type Hooks struct { - k *Keeper -} - -func (k *Keeper) Hooks() Hooks { - return Hooks{k} -} - -func (h Hooks) BeforePoolDelegationCreated(ctx sdk.Context, poolID uint32, delegator string) error { - return h.k.BeforeDelegationCreated(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID) -} - -func (h Hooks) BeforePoolDelegationSharesModified(ctx sdk.Context, poolID uint32, delegator string) error { - return h.k.BeforeDelegationSharesModified(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID, delegator) -} - -func (h Hooks) AfterPoolDelegationModified(ctx sdk.Context, poolID uint32, delegator string) error { - return h.k.AfterDelegationModified(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID, delegator) -} - -func (h Hooks) BeforeOperatorDelegationCreated(ctx sdk.Context, operatorID uint32, delegator string) error { - return h.k.BeforeDelegationCreated(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID) -} - -func (h Hooks) BeforeOperatorDelegationSharesModified(ctx sdk.Context, operatorID uint32, delegator string) error { - return h.k.BeforeDelegationSharesModified(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID, delegator) -} - -func (h Hooks) AfterOperatorDelegationModified(ctx sdk.Context, operatorID uint32, delegator string) error { - return h.k.AfterDelegationModified(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID, delegator) -} - -func (h Hooks) BeforeServiceDelegationCreated(ctx sdk.Context, serviceID uint32, delegator string) error { - return h.k.BeforeDelegationCreated(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID) -} - -func (h Hooks) BeforeServiceDelegationSharesModified(ctx sdk.Context, serviceID uint32, delegator string) error { - return h.k.BeforeDelegationSharesModified(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID, delegator) -} +// AfterDelegationTargetCreated is called after a delegation target is created +func (k *Keeper) AfterDelegationTargetCreated(ctx sdk.Context, delType restakingtypes.DelegationType, targetID uint32) error { + target, err := k.GetDelegationTarget(ctx, delType, targetID) + if err != nil { + return err + } -func (h Hooks) AfterServiceDelegationModified(ctx sdk.Context, serviceID uint32, delegator string) error { - return h.k.AfterDelegationModified(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID, delegator) + return k.initializeDelegationTarget(ctx, target) } -func (k *Keeper) BeforeDelegationCreated(ctx sdk.Context, delType restakingtypes.DelegationType, targetID uint32) error { +// AfterDelegationTargetRemoved is called after a delegation target is removed +func (k *Keeper) AfterDelegationTargetRemoved(ctx sdk.Context, delType restakingtypes.DelegationType, targetID uint32) error { target, err := k.GetDelegationTarget(ctx, delType, targetID) if err != nil { return err } - // Initialize target if it doesn't exist yet. - exists, err := k.HasCurrentRewards(ctx, target) + return k.clearDelegationTarget(ctx, target) +} + +// BeforeDelegationCreated is called before a delegation to a target is created +func (k *Keeper) BeforeDelegationCreated(ctx sdk.Context, delType restakingtypes.DelegationType, targetID uint32) error { + target, err := k.GetDelegationTarget(ctx, delType, targetID) if err != nil { return err } - if !exists { - if err := k.initializeDelegationTarget(ctx, target); err != nil { - return err - } - } _, err = k.IncrementDelegationTargetPeriod(ctx, target) return err } +// BeforeDelegationSharesModified is called before a delegation to a target is modified func (k *Keeper) BeforeDelegationSharesModified(ctx sdk.Context, delType restakingtypes.DelegationType, targetID uint32, delegator string) error { target, err := k.GetDelegationTarget(ctx, delType, targetID) if err != nil { @@ -88,37 +53,21 @@ func (k *Keeper) BeforeDelegationSharesModified(ctx sdk.Context, delType restaki return sdkerrors.ErrNotFound.Wrapf("delegation not found: %d, %s", target.GetID(), delegator) } - if _, err := k.withdrawDelegationRewards(ctx, target, del); err != nil { - return err - } - - return nil + _, err = k.withdrawDelegationRewards(ctx, target, del) + return err } +// AfterDelegationModified is called after a delegation to a target is modified func (k *Keeper) AfterDelegationModified(ctx sdk.Context, delType restakingtypes.DelegationType, targetID uint32, delegator string) error { - delAddr, err := k.accountKeeper.AddressCodec().StringToBytes(delegator) + target, err := k.GetDelegationTarget(ctx, delType, targetID) if err != nil { return err } - target, err := k.GetDelegationTarget(ctx, delType, targetID) + + delAddr, err := k.accountKeeper.AddressCodec().StringToBytes(delegator) if err != nil { return err } - return k.initializeDelegation(ctx, target, delAddr) -} - -func (h Hooks) BeforePoolDelegationRemoved(_ sdk.Context, _ uint32, _ string) error { - return nil -} - -func (h Hooks) BeforeOperatorDelegationRemoved(_ sdk.Context, _ uint32, _ string) error { - return nil -} -func (h Hooks) BeforeServiceDelegationRemoved(_ sdk.Context, _ uint32, _ string) error { - return nil -} - -func (h Hooks) AfterUnbondingInitiated(_ sdk.Context, _ uint64) error { - return nil + return k.initializeDelegation(ctx, target, delAddr) } diff --git a/x/rewards/keeper/opertors_hooks.go b/x/rewards/keeper/opertors_hooks.go new file mode 100644 index 00000000..d091b107 --- /dev/null +++ b/x/rewards/keeper/opertors_hooks.go @@ -0,0 +1,43 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + operatorstypes "github.com/milkyway-labs/milkyway/x/operators/types" + restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" +) + +var _ operatorstypes.OperatorsHooks = OperatorsHooks{} + +type OperatorsHooks struct { + k *Keeper +} + +func (k *Keeper) OperatorsHooks() OperatorsHooks { + return OperatorsHooks{k} +} + +// AfterOperatorRegistered implements operatorstypes.OperatorsHooks +func (h OperatorsHooks) AfterOperatorRegistered(ctx sdk.Context, operatorID uint32) error { + return h.k.AfterDelegationTargetCreated(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID) +} + +// AfterOperatorInactivatingStarted implements operatorstypes.OperatorsHooks +func (h OperatorsHooks) AfterOperatorInactivatingStarted(sdk.Context, uint32) error { + return nil +} + +// AfterOperatorInactivatingCompleted implements operatorstypes.OperatorsHooks +func (h OperatorsHooks) AfterOperatorInactivatingCompleted(sdk.Context, uint32) error { + return nil +} + +// AfterOperatorReactivated implements operatorstypes.OperatorsHooks +func (h OperatorsHooks) AfterOperatorReactivated(sdk.Context, uint32) error { + return nil +} + +// AfterOperatorDeleted implements operatorstypes.OperatorsHooks +func (h OperatorsHooks) AfterOperatorDeleted(ctx sdk.Context, operatorID uint32) error { + return h.k.AfterDelegationTargetRemoved(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID) +} diff --git a/x/rewards/keeper/pools_hooks.go b/x/rewards/keeper/pools_hooks.go new file mode 100644 index 00000000..95d60da0 --- /dev/null +++ b/x/rewards/keeper/pools_hooks.go @@ -0,0 +1,25 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + poolstypes "github.com/milkyway-labs/milkyway/x/pools/types" + restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" +) + +var ( + _ poolstypes.PoolsHooks = PoolsHooks{} +) + +type PoolsHooks struct { + k *Keeper +} + +func (k *Keeper) PoolsHooks() PoolsHooks { + return PoolsHooks{k} +} + +// AfterPoolCreated implements poolstypes.PoolsHooks +func (h PoolsHooks) AfterPoolCreated(ctx sdk.Context, poolID uint32) error { + return h.k.AfterDelegationTargetCreated(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID) +} diff --git a/x/rewards/keeper/restaking_keeper.go b/x/rewards/keeper/restaking_keeper.go new file mode 100644 index 00000000..e3c00b6a --- /dev/null +++ b/x/rewards/keeper/restaking_keeper.go @@ -0,0 +1,82 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" +) + +var _ restakingtypes.RestakingHooks = RestakingHooks{} + +type RestakingHooks struct { + k *Keeper +} + +func (k *Keeper) RestakingHooks() RestakingHooks { + return RestakingHooks{k} +} + +// BeforePoolDelegationCreated implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforePoolDelegationCreated(ctx sdk.Context, poolID uint32, _ string) error { + return h.k.BeforeDelegationCreated(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID) +} + +// BeforePoolDelegationSharesModified implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforePoolDelegationSharesModified(ctx sdk.Context, poolID uint32, delegator string) error { + return h.k.BeforeDelegationSharesModified(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID, delegator) +} + +// AfterPoolDelegationModified implements restakingtypes.RestakingHooks +func (h RestakingHooks) AfterPoolDelegationModified(ctx sdk.Context, poolID uint32, delegator string) error { + return h.k.AfterDelegationModified(ctx, restakingtypes.DELEGATION_TYPE_POOL, poolID, delegator) +} + +// BeforeOperatorDelegationCreated implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforeOperatorDelegationCreated(ctx sdk.Context, operatorID uint32, _ string) error { + return h.k.BeforeDelegationCreated(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID) +} + +// BeforeOperatorDelegationSharesModified implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforeOperatorDelegationSharesModified(ctx sdk.Context, operatorID uint32, delegator string) error { + return h.k.BeforeDelegationSharesModified(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID, delegator) +} + +// AfterOperatorDelegationModified implements restakingtypes.RestakingHooks +func (h RestakingHooks) AfterOperatorDelegationModified(ctx sdk.Context, operatorID uint32, delegator string) error { + return h.k.AfterDelegationModified(ctx, restakingtypes.DELEGATION_TYPE_OPERATOR, operatorID, delegator) +} + +// BeforeServiceDelegationCreated implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforeServiceDelegationCreated(ctx sdk.Context, serviceID uint32, _ string) error { + return h.k.BeforeDelegationCreated(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID) +} + +// BeforeServiceDelegationSharesModified implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforeServiceDelegationSharesModified(ctx sdk.Context, serviceID uint32, delegator string) error { + return h.k.BeforeDelegationSharesModified(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID, delegator) +} + +// AfterServiceDelegationModified implements restakingtypes.RestakingHooks +func (h RestakingHooks) AfterServiceDelegationModified(ctx sdk.Context, serviceID uint32, delegator string) error { + return h.k.AfterDelegationModified(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID, delegator) +} + +// BeforePoolDelegationRemoved implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforePoolDelegationRemoved(_ sdk.Context, _ uint32, _ string) error { + return nil +} + +// BeforeOperatorDelegationRemoved implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforeOperatorDelegationRemoved(_ sdk.Context, _ uint32, _ string) error { + return nil +} + +// BeforeServiceDelegationRemoved implements restakingtypes.RestakingHooks +func (h RestakingHooks) BeforeServiceDelegationRemoved(_ sdk.Context, _ uint32, _ string) error { + return nil +} + +// AfterUnbondingInitiated implements restakingtypes.RestakingHooks +func (h RestakingHooks) AfterUnbondingInitiated(_ sdk.Context, _ uint64) error { + return nil +} diff --git a/x/rewards/keeper/services_hooks.go b/x/rewards/keeper/services_hooks.go new file mode 100644 index 00000000..9479d2e6 --- /dev/null +++ b/x/rewards/keeper/services_hooks.go @@ -0,0 +1,38 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + restakingtypes "github.com/milkyway-labs/milkyway/x/restaking/types" + servicestypes "github.com/milkyway-labs/milkyway/x/services/types" +) + +var _ servicestypes.ServicesHooks = ServicesHooks{} + +type ServicesHooks struct { + k *Keeper +} + +func (k *Keeper) ServicesHooks() ServicesHooks { + return ServicesHooks{k} +} + +// AfterServiceCreated implements servicestypes.ServicesHooks +func (h ServicesHooks) AfterServiceCreated(ctx sdk.Context, serviceID uint32) error { + return h.k.AfterDelegationTargetCreated(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID) +} + +// AfterServiceActivated implements servicestypes.ServicesHooks +func (h ServicesHooks) AfterServiceActivated(sdk.Context, uint32) error { + return nil +} + +// AfterServiceDeactivated implements servicestypes.ServicesHooks +func (h ServicesHooks) AfterServiceDeactivated(sdk.Context, uint32) error { + return nil +} + +// AfterServiceDeleted implements servicestypes.ServicesHooks +func (h ServicesHooks) AfterServiceDeleted(ctx sdk.Context, serviceID uint32) error { + return h.k.AfterDelegationTargetRemoved(ctx, restakingtypes.DELEGATION_TYPE_SERVICE, serviceID) +} diff --git a/x/rewards/keeper/target.go b/x/rewards/keeper/target.go index 3a671d44..2a870fb6 100644 --- a/x/rewards/keeper/target.go +++ b/x/rewards/keeper/target.go @@ -69,7 +69,7 @@ func (k *Keeper) initializeDelegationTarget(ctx context.Context, target restakin return err } -// increment period, returning the period just ended +// IncrementDelegationTargetPeriod increments the period, returning the period that just ended func (k *Keeper) IncrementDelegationTargetPeriod(ctx context.Context, target restakingtypes.DelegationTarget) (uint64, error) { // fetch current rewards rewards, err := k.GetCurrentRewards(ctx, target) @@ -176,3 +176,116 @@ func (k *Keeper) decrementReferenceCount(ctx context.Context, target restakingty return k.SetHistoricalRewards(ctx, target, period, historical) } + +// clearDelegateTarget clears all rewards for a delegation target +func (k *Keeper) clearDelegationTarget(ctx sdk.Context, target restakingtypes.DelegationTarget) error { + // fetch outstanding + outstandingCoins, err := k.GetOutstandingRewardsCoins(ctx, target) + if err != nil { + return err + } + + outstanding := outstandingCoins.CoinsAmount() + + // Clear data related to an operator + if operator, ok := target.(*operatorstypes.Operator); ok { + outstanding, err = k.clearOperator(ctx, outstanding, operator) + if err != nil { + return err + } + } + + // Add outstanding to community pool + // The target is removed only after it has no more delegations. + // This operation sends only the remaining dust to the community pool. + operatorAddr, err := sdk.AccAddressFromBech32(target.GetAddress()) + if err != nil { + return err + } + + // We truncate the outstanding to be able to send it to the community pool + // The remainder will be just be removed + outstandingTruncated, _ := outstanding.TruncateDecimal() + err = k.communityPoolKeeper.FundCommunityPool(ctx, outstandingTruncated, operatorAddr) + if err != nil { + return err + } + + // Delete outstanding rewards + err = k.DeleteOutstandingRewards(ctx, target) + if err != nil { + return err + } + + // Remove the commission record + if operator, ok := target.(*operatorstypes.Operator); ok { + err = k.DeleteOperatorAccumulatedCommission(ctx, operator.ID) + if err != nil { + return err + } + } + + // TODO: Clear slash events when we introduce slashing + + // Clear historical rewards + err = k.DeleteHistoricalRewards(ctx, target) + if err != nil { + return err + } + + // Clear current rewards + err = k.DeleteCurrentRewards(ctx, target) + if err != nil { + return err + } + + return nil +} + +func (k *Keeper) clearOperator(ctx context.Context, outstanding sdk.DecCoins, operator *operatorstypes.Operator) (outstandingLeftOver sdk.DecCoins, err error) { + // Force-withdraw commission + valCommission, err := k.GetOperatorAccumulatedCommission(ctx, operator.ID) + if err != nil { + return outstanding, err + } + + commission := valCommission.Commissions.CoinsAmount() + + if !commission.IsZero() { + // Subtract from outstanding + outstanding = outstanding.Sub(commission) + + // Split into integral & remainder + coins, remainder := commission.TruncateDecimal() + + // Send remainder to community pool + operatorAddress, err := sdk.AccAddressFromBech32(operator.Address) + if err != nil { + return outstanding, err + } + + // We truncate the remainder to be able to send it to the community pool + // The remainder will be just be removed + remainderTruncated, _ := remainder.TruncateDecimal() + + err = k.communityPoolKeeper.FundCommunityPool(ctx, remainderTruncated, operatorAddress) + if err != nil { + return outstanding, err + } + + // Add to operator account + if !coins.IsZero() { + withdrawAddr, err := k.GetOperatorWithdrawAddr(ctx, operator) + if err != nil { + return outstanding, err + } + + err = k.bankKeeper.SendCoinsFromModuleToAccount(ctx, types.ModuleName, withdrawAddr, coins) + if err != nil { + return outstanding, err + } + } + } + + return outstanding, nil +} diff --git a/x/rewards/testutils/keeper.go b/x/rewards/testutils/keeper.go index 20003a4b..a2ed5713 100644 --- a/x/rewards/testutils/keeper.go +++ b/x/rewards/testutils/keeper.go @@ -13,7 +13,6 @@ import ( oracletypes "github.com/skip-mev/connect/v2/x/oracle/types" "github.com/stretchr/testify/require" - "github.com/milkyway-labs/milkyway/app/keepers" "github.com/milkyway-labs/milkyway/testutils/storetesting" assetskeeper "github.com/milkyway-labs/milkyway/x/assets/keeper" assetstypes "github.com/milkyway-labs/milkyway/x/assets/types" @@ -61,8 +60,6 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { WithBlockHeight(1). WithBlockTime(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)) - communityPoolKeeper := keepers.NewCommunityPoolKeeper(data.BankKeeper, authtypes.FeeCollectorName) - data.MarketMapKeeper = marketmapkeeper.NewKeeper( runtime.NewKVStoreService(data.Keys[marketmaptypes.StoreKey]), data.Cdc, @@ -87,7 +84,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.Keys[operatorstypes.StoreKey], runtime.NewKVStoreService(data.Keys[operatorstypes.StoreKey]), data.AccountKeeper, - communityPoolKeeper, + data.DistributionKeeper, data.AuthorityAddress, ) data.ServicesKeeper = serviceskeeper.NewKeeper( @@ -95,7 +92,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.Keys[servicestypes.StoreKey], runtime.NewKVStoreService(data.Keys[servicestypes.StoreKey]), data.AccountKeeper, - communityPoolKeeper, + data.DistributionKeeper, data.AuthorityAddress, ) data.RestakingKeeper = restakingkeeper.NewKeeper( @@ -120,7 +117,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { runtime.NewKVStoreService(data.Keys[rewardstypes.StoreKey]), data.AccountKeeper, data.BankKeeper, - communityPoolKeeper, + data.DistributionKeeper, &data.OracleKeeper, data.PoolsKeeper, data.OperatorsKeeper, @@ -131,9 +128,16 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { ) // Set the hooks - data.OperatorsKeeper.SetHooks(data.RestakingKeeper.OperatorsHooks()) - data.ServicesKeeper.SetHooks(data.RestakingKeeper.ServicesHooks()) - data.RestakingKeeper.SetHooks(data.Keeper.Hooks()) + data.PoolsKeeper.SetHooks(data.Keeper.PoolsHooks()) + data.OperatorsKeeper.SetHooks(operatorstypes.NewMultiOperatorsHooks( + data.RestakingKeeper.OperatorsHooks(), + data.Keeper.OperatorsHooks(), + )) + data.ServicesKeeper.SetHooks(servicestypes.NewMultiServicesHooks( + data.RestakingKeeper.ServicesHooks(), + data.Keeper.ServicesHooks(), + )) + data.RestakingKeeper.SetHooks(data.Keeper.RestakingHooks()) // Set the base params data.PoolsKeeper.SetNextPoolID(data.Context, 1) diff --git a/x/rewards/types/dec_pool.go b/x/rewards/types/dec_pool.go index 0419a533..556f62de 100644 --- a/x/rewards/types/dec_pool.go +++ b/x/rewards/types/dec_pool.go @@ -215,6 +215,15 @@ func (pools DecPools) TruncateDecimal() (truncatedDecPools Pools, changeDecPools return truncatedDecPools, changeDecPools } +// CoinsAmount returns the total amount of coins in the DecPools +func (pools DecPools) CoinsAmount() sdk.DecCoins { + coins := sdk.NewDecCoins() + for _, pool := range pools { + coins = coins.Add(pool.DecCoins...) + } + return coins +} + // Intersect will return a new set of pools which contains the minimum pool Coins // for common denoms found in both `pools` and `poolsB`. For denoms not common // to both `pools` and `poolsB` the minimum is considered to be 0, thus they diff --git a/x/services/keeper/genesis.go b/x/services/keeper/genesis.go index b3cc4b16..361f5fd1 100644 --- a/x/services/keeper/genesis.go +++ b/x/services/keeper/genesis.go @@ -39,7 +39,7 @@ func (k *Keeper) InitGenesis(ctx sdk.Context, state *types.GenesisState) error { // Store the services for _, service := range state.Services { - if err := k.SaveService(ctx, service); err != nil { + if err := k.CreateService(ctx, service); err != nil { return err } } diff --git a/x/services/keeper/msg_server.go b/x/services/keeper/msg_server.go index fb732eb6..f8214261 100644 --- a/x/services/keeper/msg_server.go +++ b/x/services/keeper/msg_server.go @@ -50,6 +50,22 @@ func (k msgServer) CreateService(goCtx context.Context, msg *types.MsgCreateServ return nil, errors.Wrap(sdkerrors.ErrInvalidRequest, err.Error()) } + // Charge for the creation + // We do not place this inside the CreateService method to avoid charging fees during genesis + // init and other places that use that method + registrationFees := k.GetParams(ctx).ServiceRegistrationFee + if !registrationFees.IsZero() { + userAddress, err := sdk.AccAddressFromBech32(service.Admin) + if err != nil { + return nil, errors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid service admin address: %s", service.Admin) + } + + err = k.poolKeeper.FundCommunityPool(ctx, registrationFees, userAddress) + if err != nil { + return nil, err + } + } + // Create the service err = k.Keeper.CreateService(ctx, service) if err != nil { diff --git a/x/services/keeper/msg_server_test.go b/x/services/keeper/msg_server_test.go index 2c613c47..4ae49e76 100644 --- a/x/services/keeper/msg_server_test.go +++ b/x/services/keeper/msg_server_test.go @@ -4,6 +4,7 @@ import ( sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types" "github.com/milkyway-labs/milkyway/x/services/keeper" "github.com/milkyway-labs/milkyway/x/services/types" @@ -128,7 +129,7 @@ func (suite *KeeperTestSuite) TestMsgServer_CreateService() { suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000)), balance) // Make sure the fee was transferred to the module account - poolBalance := suite.bk.GetBalance(ctx, authtypes.NewModuleAddress(authtypes.FeeCollectorName), "uatom") + poolBalance := suite.bk.GetBalance(ctx, authtypes.NewModuleAddress(distrtypes.ModuleName), "uatom") suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000)), poolBalance) }, }, diff --git a/x/services/keeper/services.go b/x/services/keeper/services.go index 69f90503..4a52da54 100644 --- a/x/services/keeper/services.go +++ b/x/services/keeper/services.go @@ -38,20 +38,6 @@ func (k *Keeper) SaveService(ctx sdk.Context, service types.Service) error { // CreateService creates a new Service and stores it in the KVStore func (k *Keeper) CreateService(ctx sdk.Context, service types.Service) error { - // Charge for the creation - registrationFees := k.GetParams(ctx).ServiceRegistrationFee - if !registrationFees.IsZero() { - userAddress, err := sdk.AccAddressFromBech32(service.Admin) - if err != nil { - return err - } - - err = k.poolKeeper.FundCommunityPool(ctx, registrationFees, userAddress) - if err != nil { - return err - } - } - // Create the service account serviceAddress, err := sdk.AccAddressFromBech32(service.Address) if err != nil { diff --git a/x/services/keeper/services_test.go b/x/services/keeper/services_test.go index 0df679af..c787e88c 100644 --- a/x/services/keeper/services_test.go +++ b/x/services/keeper/services_test.go @@ -3,7 +3,6 @@ package keeper_test import ( sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" - authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/milkyway-labs/milkyway/x/services/types" ) @@ -104,30 +103,6 @@ func (suite *KeeperTestSuite) TestKeeper_CreateService() { shouldErr bool check func(ctx sdk.Context) }{ - { - name: "user without enough funds to pay for registration fees returns error", - store: func(ctx sdk.Context) { - // Set the params - suite.k.SetParams(ctx, types.NewParams( - sdk.NewCoins(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000))), - )) - - // Fund the user account - userBalance := sdk.NewCoins(sdk.NewCoin("uatom", sdkmath.NewInt(50_000_000))) - suite.fundAccount(ctx, "cosmos13t6y2nnugtshwuy0zkrq287a95lyy8vzleaxmd", userBalance) - }, - service: types.NewService( - 1, - types.SERVICE_STATUS_ACTIVE, - "MilkyWay", - "MilkyWay is an AVS of a restaking platform", - "https://milkyway.com", - "https://milkyway.com/logo.png", - "cosmos13t6y2nnugtshwuy0zkrq287a95lyy8vzleaxmd", - false, - ), - shouldErr: true, - }, { name: "service is created properly", store: func(ctx sdk.Context) { @@ -152,17 +127,6 @@ func (suite *KeeperTestSuite) TestKeeper_CreateService() { ), shouldErr: false, check: func(ctx sdk.Context) { - // Make sure the user balance has been reduced - userAddress, err := sdk.AccAddressFromBech32("cosmos13t6y2nnugtshwuy0zkrq287a95lyy8vzleaxmd") - suite.Require().NoError(err) - - userBalance := suite.bk.GetBalance(ctx, userAddress, "uatom") - suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(50_000_000)), userBalance) - - // Make sure the community pool has been funded - poolBalance := suite.bk.GetBalance(ctx, authtypes.NewModuleAddress(authtypes.FeeCollectorName), "uatom") - suite.Require().Equal(sdk.NewCoin("uatom", sdkmath.NewInt(100_000_000)), poolBalance) - // Make sure the service account has been created hasAccount := suite.ak.HasAccount(ctx, types.GetServiceAddress(1)) suite.Require().True(hasAccount) diff --git a/x/services/testutils/keeper.go b/x/services/testutils/keeper.go index d5bda10a..3f7465fd 100644 --- a/x/services/testutils/keeper.go +++ b/x/services/testutils/keeper.go @@ -15,7 +15,6 @@ import ( "go.uber.org/mock/gomock" milkyway "github.com/milkyway-labs/milkyway/app" - "github.com/milkyway-labs/milkyway/app/keepers" "github.com/milkyway-labs/milkyway/testutils/storetesting" bankkeeper "github.com/milkyway-labs/milkyway/x/bank/keeper" poolskeeper "github.com/milkyway-labs/milkyway/x/pools/keeper" @@ -90,7 +89,7 @@ func NewKeeperTestData(t *testing.T) KeeperTestData { data.StoreKey, runtime.NewKVStoreService(data.Keys[servicestypes.StoreKey]), data.AccountKeeper, - keepers.NewCommunityPoolKeeper(data.BankKeeper, authtypes.FeeCollectorName), + data.DistributionKeeper, authorityAddr, ) diff --git a/x/tokenfactory/keeper/common_test.go b/x/tokenfactory/keeper/common_test.go index 489fe68d..13349bd9 100644 --- a/x/tokenfactory/keeper/common_test.go +++ b/x/tokenfactory/keeper/common_test.go @@ -5,7 +5,8 @@ import ( "testing" "time" - appkeepers "github.com/milkyway-labs/milkyway/app/keepers" + distrkeeper "github.com/cosmos/cosmos-sdk/x/distribution/keeper" + distrtypes "github.com/cosmos/cosmos-sdk/x/distribution/types" "cosmossdk.io/core/address" "cosmossdk.io/log" @@ -197,7 +198,7 @@ type TestKeepers struct { ContractKeeper *wasmkeeper.PermissionedKeeper WasmKeeper *wasmkeeper.Keeper TokenFactoryKeeper *tokenfactorykeeper.Keeper - CommunityPoolKeeper *appkeepers.CommunityPoolKeeper + CommunityPoolKeeper *distrkeeper.Keeper EncodingConfig initiaappparams.EncodingConfig Faucet *TestFaucet MultiStore storetypes.CommitMultiStore @@ -237,7 +238,12 @@ func _createTestInput( db dbm.DB, ) (sdk.Context, TestKeepers) { keys := storetypes.NewKVStoreKeys( - authtypes.StoreKey, banktypes.StoreKey, tokenfactorytypes.StoreKey, govtypes.StoreKey, wasmtypes.StoreKey, + authtypes.StoreKey, + banktypes.StoreKey, + distrtypes.StoreKey, + tokenfactorytypes.StoreKey, + govtypes.StoreKey, + wasmtypes.StoreKey, ) ms := store.NewCommitMultiStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) for _, v := range keys { @@ -260,6 +266,7 @@ func _createTestInput( maccPerms := map[string][]string{ // module account permissions authtypes.FeeCollectorName: nil, + distrtypes.ModuleName: nil, govtypes.ModuleName: {authtypes.Burner}, authtypes.Minter: {authtypes.Minter, authtypes.Burner}, tokenfactorytypes.ModuleName: {authtypes.Minter, authtypes.Burner}, @@ -291,7 +298,17 @@ func _createTestInput( ) require.NoError(t, bankKeeper.SetParams(ctx, banktypes.DefaultParams())) - communityPoolKeeper := appkeepers.NewCommunityPoolKeeper(bankKeeper, authtypes.FeeCollectorName) + distrKeeper := distrkeeper.NewKeeper( + appCodec, + runtime.NewKVStoreService(keys[distrtypes.StoreKey]), + accountKeeper, + bankKeeper, + nil, + authtypes.FeeCollectorName, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + ) + distrKeeper.InitGenesis(ctx, *distrtypes.DefaultGenesisState()) + faucet := NewTestFaucet(t, ctx, bankKeeper, authtypes.Minter, initialTotalSupply()...) msgRouter := baseapp.NewMsgServiceRouter() msgRouter.SetInterfaceRegistry(encodingConfig.InterfaceRegistry) @@ -299,7 +316,7 @@ func _createTestInput( keepers := TestKeepers{ AccountKeeper: &accountKeeper, BankKeeper: &bankKeeper, - CommunityPoolKeeper: &communityPoolKeeper, + CommunityPoolKeeper: &distrKeeper, EncodingConfig: encodingConfig, Faucet: faucet, MultiStore: ms, From 8e1235b7d5fc10f04d081d8c35fcb99af8c8e323 Mon Sep 17 00:00:00 2001 From: Riccardo Montagnin Date: Thu, 14 Nov 2024 15:36:13 +0900 Subject: [PATCH 7/8] feat(tests): add tests to make sure the invariants are not broken --- x/rewards/keeper/invariants.go | 122 ++++++++++++++++++---------- x/rewards/keeper/invariants_test.go | 101 +++++++++++++++++++++++ 2 files changed, 178 insertions(+), 45 deletions(-) create mode 100644 x/rewards/keeper/invariants_test.go diff --git a/x/rewards/keeper/invariants.go b/x/rewards/keeper/invariants.go index ce45b2f6..f15ca981 100644 --- a/x/rewards/keeper/invariants.go +++ b/x/rewards/keeper/invariants.go @@ -216,63 +216,95 @@ func CanWithdrawInvariant(k *Keeper) sdk.Invariant { } } +// -------------------------------------------------------------------------------------------------------------------- + // ReferenceCountInvariant checks that the number of historical rewards records is correct func ReferenceCountInvariant(k *Keeper) sdk.Invariant { return func(ctx sdk.Context) (string, bool) { - targetCount := uint64(0) - k.poolsKeeper.IteratePools(ctx, func(_ poolstypes.Pool) (stop bool) { - targetCount++ - return false - }) - k.operatorsKeeper.IterateOperators(ctx, func(_ operatorstypes.Operator) (stop bool) { - targetCount++ - return false - }) - k.servicesKeeper.IterateServices(ctx, func(_ servicestypes.Service) (stop bool) { - targetCount++ - return false - }) + // Check the reference count for pools + msg, broken := checkReferencesCount( + ctx, + restakingtypes.DELEGATION_TYPE_POOL, + k.poolsKeeper.IteratePools, + k.restakingKeeper.IterateAllPoolDelegations, + k.PoolHistoricalRewards, + ) + if broken { + return sdk.FormatInvariant(types.ModuleName, "reference count", msg), broken + } - delCount := uint64(0) - k.restakingKeeper.IterateAllPoolDelegations(ctx, func(_ restakingtypes.Delegation) (stop bool) { - delCount++ - return false - }) - k.restakingKeeper.IterateAllOperatorDelegations(ctx, func(_ restakingtypes.Delegation) (stop bool) { - delCount++ - return false - }) - k.restakingKeeper.IterateAllServiceDelegations(ctx, func(_ restakingtypes.Delegation) (stop bool) { - delCount++ - return false - }) + // Check the reference count for operators + msg, broken = checkReferencesCount( + ctx, + restakingtypes.DELEGATION_TYPE_OPERATOR, + k.operatorsKeeper.IterateOperators, + k.restakingKeeper.IterateAllOperatorDelegations, + k.OperatorHistoricalRewards, + ) + if broken { + return sdk.FormatInvariant(types.ModuleName, "reference count", msg), broken + } - // one record per delegation target (last tracked period), one record per - // delegation (previous period) - // TODO: handle slash events - expected := targetCount + delCount - count := uint64(0) - elements := 0 - err := k.PoolHistoricalRewards.Walk( - ctx, nil, func(key collections.Pair[uint32, uint64], rewards types.HistoricalRewards) (stop bool, err error) { - count += uint64(rewards.ReferenceCount) - elements += 1 - return false, nil - }, + // Check the reference count for services + msg, broken = checkReferencesCount( + ctx, + restakingtypes.DELEGATION_TYPE_SERVICE, + k.servicesKeeper.IterateServices, + k.restakingKeeper.IterateAllServiceDelegations, + k.ServiceHistoricalRewards, ) - if err != nil { - panic(err) + if broken { + return sdk.FormatInvariant(types.ModuleName, "reference count", msg), broken } - broken := elements > 0 && count != expected + return "", false + } +} - return sdk.FormatInvariant(types.ModuleName, "reference count", - fmt.Sprintf("expected historical reference count: %d = %v delegation targets + %v delegations\n"+ - "total validator historical reference count: %d\n", - expected, targetCount, delCount, count)), broken +// checkReferencesCount checks the reference count for a given delegation target type +func checkReferencesCount[T any]( + ctx sdk.Context, + delegationTargetType restakingtypes.DelegationType, + targetsIterator func(ctx sdk.Context, fn func(T) bool), + delegationsIterator func(ctx sdk.Context, fn func(restakingtypes.Delegation) bool), + historicalRewardsCollection collections.Map[collections.Pair[uint32, uint64], types.HistoricalRewards], +) (msg string, broken bool) { + + targetCount := uint64(0) + targetsIterator(ctx, func(_ T) bool { + targetCount++ + return false + }) + + delegationsCount := uint64(0) + delegationsIterator(ctx, func(_ restakingtypes.Delegation) bool { + delegationsCount++ + return false + }) + + referencesCount := uint64(0) + err := historicalRewardsCollection.Walk(ctx, nil, func(key collections.Pair[uint32, uint64], value types.HistoricalRewards) (stop bool, err error) { + referencesCount += uint64(value.ReferenceCount) + return false, nil + }) + if err != nil { + panic(err) } + + // Make sure we have one record per delegation target (last tracked period) and + // one record per delegation (previous period) + expected := targetCount + delegationsCount + + broken = referencesCount != expected + + return fmt.Sprintf("expected historical reference count: %d = %v delegation targets + %v delegations\n"+ + "total %s historical reference count: %d\n", + expected, targetCount, delegationsCount, delegationTargetType, referencesCount, + ), broken } +// -------------------------------------------------------------------------------------------------------------------- + // ModuleAccountInvariant checks that the coins held by the global rewards pool // is consistent with the sum of outstanding rewards func ModuleAccountInvariant(k *Keeper) sdk.Invariant { diff --git a/x/rewards/keeper/invariants_test.go b/x/rewards/keeper/invariants_test.go new file mode 100644 index 00000000..f0490bdb --- /dev/null +++ b/x/rewards/keeper/invariants_test.go @@ -0,0 +1,101 @@ +package keeper_test + +import ( + "cosmossdk.io/collections" + sdk "github.com/cosmos/cosmos-sdk/types" + + operatorstypes "github.com/milkyway-labs/milkyway/x/operators/types" + poolstypes "github.com/milkyway-labs/milkyway/x/pools/types" + "github.com/milkyway-labs/milkyway/x/rewards/keeper" + "github.com/milkyway-labs/milkyway/x/rewards/types" + servicestypes "github.com/milkyway-labs/milkyway/x/services/types" +) + +func (suite *KeeperTestSuite) TestInvariants_ReferenceCountInvariant() { + testCases := []struct { + name string + store func(ctx sdk.Context) + expBroken bool + }{ + { + name: "default genesis does not return errors", + store: func(ctx sdk.Context) { + // Store the genesis data for all the modules involved + suite.poolsKeeper.InitGenesis(ctx, poolstypes.DefaultGenesis()) + err := suite.servicesKeeper.InitGenesis(ctx, servicestypes.DefaultGenesis()) + suite.NoError(err) + err = suite.operatorsKeeper.InitGenesis(ctx, *operatorstypes.DefaultGenesis()) + suite.NoError(err) + err = suite.keeper.InitGenesis(ctx, *types.DefaultGenesis()) + suite.NoError(err) + }, + expBroken: false, + }, + { + name: "initializing services and operators does not break the invariant", + store: func(ctx sdk.Context) { + suite.CreateOperator(ctx, "MilkyWay", "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4") + suite.CreateService(ctx, "MilkyWay AVS", "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4") + }, + expBroken: false, + }, + { + name: "invalid pools historical rewards reference count breaks the invariant", + store: func(ctx sdk.Context) { + pool, err := suite.poolsKeeper.CreateOrGetPoolByDenom(ctx, "stake") + suite.NoError(err) + + // Create an invalid number of historical rewards + historicalRewards := types.NewHistoricalRewards([]types.DecPool{types.NewDecPool("umilk", nil)}, 1) + err = suite.keeper.PoolHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](pool.ID, 1), historicalRewards) + err = suite.keeper.PoolHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](pool.ID, 2), historicalRewards) + err = suite.keeper.PoolHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](pool.ID, 3), historicalRewards) + suite.Require().NoError(err) + }, + expBroken: true, + }, + { + name: "invalid service historical rewards reference count breaks the invariant", + store: func(ctx sdk.Context) { + service := suite.CreateService(ctx, "MilkyWay AVS", "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4") + + // Create an invalid number of historical rewards + historicalRewards := types.NewHistoricalRewards([]types.DecPool{types.NewDecPool("umilk", nil)}, 1) + err := suite.keeper.ServiceHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](service.ID, 1), historicalRewards) + err = suite.keeper.ServiceHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](service.ID, 2), historicalRewards) + err = suite.keeper.ServiceHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](service.ID, 3), historicalRewards) + suite.Require().NoError(err) + }, + expBroken: true, + }, + { + name: "invalid operator historical rewards reference count breaks the invariant", + store: func(ctx sdk.Context) { + service := suite.CreateOperator(ctx, "MilkyWay", "cosmos167x6ehhple8gwz5ezy9x0464jltvdpzl6qfdt4") + + // Create an invalid number of historical rewards + historicalRewards := types.NewHistoricalRewards([]types.DecPool{types.NewDecPool("umilk", nil)}, 1) + err := suite.keeper.OperatorHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](service.ID, 1), historicalRewards) + err = suite.keeper.OperatorHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](service.ID, 2), historicalRewards) + err = suite.keeper.OperatorHistoricalRewards.Set(ctx, collections.Join[uint32, uint64](service.ID, 3), historicalRewards) + suite.Require().NoError(err) + }, + expBroken: true, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + ctx, _ := suite.ctx.CacheContext() + if tc.store != nil { + tc.store(ctx) + } + + res, broken := keeper.ReferenceCountInvariant(suite.keeper)(ctx) + suite.Equal(tc.expBroken, broken, res) + }) + } +} From 54bd043eb08618f01c1d0db6dd676e2d4c27645c Mon Sep 17 00:00:00 2001 From: Riccardo Montagnin Date: Thu, 14 Nov 2024 16:19:03 +0900 Subject: [PATCH 8/8] fix(tests): fixed services sim genesis setup --- x/services/simulation/genesis.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x/services/simulation/genesis.go b/x/services/simulation/genesis.go index c8f5c48e..d7b7debb 100644 --- a/x/services/simulation/genesis.go +++ b/x/services/simulation/genesis.go @@ -16,11 +16,11 @@ const ( keyServicesParams = "services_params" ) -func getServices(r *rand.Rand) []types.Service { +func getServices(r *rand.Rand, simState *module.SimulationState) []types.Service { count := r.Intn(10) + 1 var services []types.Service for i := 0; i < count; i++ { - adminAccount := simulation.RandomAccounts(r, 1)[0] + adminAccount, _ := simulation.RandomAcc(r, simState.Accounts) service := RandomService(r, uint32(i)+1, adminAccount.Address.String()) services = append(services, service) } @@ -50,7 +50,7 @@ func RandomizedGenState(simState *module.SimulationState) { ) simState.AppParams.GetOrGenerate(keyServices, &services, simState.Rand, func(r *rand.Rand) { - services = getServices(r) + services = getServices(r, simState) }) simState.AppParams.GetOrGenerate(keyServicesParams, &servicesParams, simState.Rand, func(r *rand.Rand) {