From 9e1ec886c78c6665e4c6963b821876628cb1fffd Mon Sep 17 00:00:00 2001 From: beer-1 <147697694+beer-1@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:20:40 +0900 Subject: [PATCH] fix to encode intermediate sender changed packet with counterparty port consideration put source port add testcase --- x/ibc-hooks/move-hooks/receive.go | 13 +----- x/ibc-hooks/move-hooks/receive_test.go | 59 +++++++++++++++++++++++++- x/ibc-hooks/move-hooks/util.go | 4 +- x/ibc/nft-transfer/types/packet.go | 8 ++-- 4 files changed, 66 insertions(+), 18 deletions(-) diff --git a/x/ibc-hooks/move-hooks/receive.go b/x/ibc-hooks/move-hooks/receive.go index c9362e22..87bf4e1a 100644 --- a/x/ibc-hooks/move-hooks/receive.go +++ b/x/ibc-hooks/move-hooks/receive.go @@ -1,7 +1,6 @@ package move_hooks import ( - "encoding/json" "fmt" sdk "github.com/cosmos/cosmos-sdk/types" @@ -51,11 +50,7 @@ func (h MoveHooks) onRecvIcs20Packet( // // If that succeeds, we make the contract call data.Receiver = intermediateSender - bz, err := json.Marshal(data) - if err != nil { - return newEmitErrorAcknowledgement(err) - } - packet.Data = bz + packet.Data = data.GetBytes() ack := im.App.OnRecvPacket(ctx, packet, relayer) if !ack.Success() { @@ -107,11 +102,7 @@ func (h MoveHooks) onRecvIcs721Packet( // // If that succeeds, we make the contract call data.Receiver = intermediateSender - bz, err := json.Marshal(data) - if err != nil { - return newEmitErrorAcknowledgement(err) - } - packet.Data = bz + packet.Data = data.GetBytes(packet.GetSourcePort()) ack := im.App.OnRecvPacket(ctx, packet, relayer) if !ack.Success() { diff --git a/x/ibc-hooks/move-hooks/receive_test.go b/x/ibc-hooks/move-hooks/receive_test.go index a0897a48..41909ad0 100644 --- a/x/ibc-hooks/move-hooks/receive_test.go +++ b/x/ibc-hooks/move-hooks/receive_test.go @@ -115,7 +115,7 @@ func Test_OnReceivePacket_ICS721(t *testing.T) { require.True(t, ack.Success()) } -func Test_onReceiveIcs20Packet_memo_ICS721(t *testing.T) { +func Test_onReceivePacket_memo_ICS721(t *testing.T) { ctx, input := createDefaultTestInput(t) _, _, addr := keyPubAddr() @@ -169,3 +169,60 @@ func Test_onReceiveIcs20Packet_memo_ICS721(t *testing.T) { require.NoError(t, err) require.Equal(t, "\"1\"", queryRes.Ret) } + +func Test_onReceivePacket_memo_ICS721_Wasm(t *testing.T) { + ctx, input := createDefaultTestInput(t) + _, _, addr := keyPubAddr() + + data := nfttransfertypes.NonFungibleTokenPacketDataWasm{ + ClassId: "classId", + ClassUri: "classUri", + ClassData: "classData", + TokenIds: []string{"tokenId"}, + TokenUris: []string{"tokenUri"}, + TokenData: []string{"tokenData"}, + Sender: addr.String(), + Receiver: "0x1::Counter::increase", + Memo: `{ + "move": { + "message": { + "module_address": "0x1", + "module_name": "Counter", + "function_name": "increase" + } + } + }`, + } + + dataBz, err := json.Marshal(&data) + require.NoError(t, err) + + // failed to due to acl + ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{ + SourcePort: "wasm.contract_address", + Data: dataBz, + }, addr) + require.False(t, ack.Success()) + + // set acl + require.NoError(t, input.IBCHooksKeeper.SetAllowed(ctx, movetypes.ConvertVMAddressToSDKAddress(vmtypes.StdAddress), true)) + + // success + ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{ + SourcePort: "wasm.contract_address", + Data: dataBz, + }, addr) + require.True(t, ack.Success()) + + // check the contract state + queryRes, err := input.MoveKeeper.ExecuteViewFunction( + ctx, + vmtypes.StdAddress, + "Counter", + "get", + []vmtypes.TypeTag{}, + [][]byte{}, + ) + require.NoError(t, err) + require.Equal(t, "\"1\"", queryRes.Ret) +} diff --git a/x/ibc-hooks/move-hooks/util.go b/x/ibc-hooks/move-hooks/util.go index 240cfd67..28185e20 100644 --- a/x/ibc-hooks/move-hooks/util.go +++ b/x/ibc-hooks/move-hooks/util.go @@ -38,8 +38,8 @@ func isIcs20Packet(packetData []byte) (isIcs20 bool, ics20data transfertypes.Fun return true, data } -func isIcs721Packet(packetData []byte, counterPartyPort string) (isIcs721 bool, ics721data nfttransfertypes.NonFungibleTokenPacketData) { - if data, err := nfttransfertypes.DecodePacketData(packetData, counterPartyPort); err != nil { +func isIcs721Packet(packetData []byte, counterpartyPort string) (isIcs721 bool, ics721data nfttransfertypes.NonFungibleTokenPacketData) { + if data, err := nfttransfertypes.DecodePacketData(packetData, counterpartyPort); err != nil { return false, data } else { return true, data diff --git a/x/ibc/nft-transfer/types/packet.go b/x/ibc/nft-transfer/types/packet.go index aca9c57e..236856e1 100644 --- a/x/ibc/nft-transfer/types/packet.go +++ b/x/ibc/nft-transfer/types/packet.go @@ -71,10 +71,10 @@ func (nftpd NonFungibleTokenPacketData) ValidateBasic() error { } // GetBytes is a helper for serializing -func (nftpd NonFungibleTokenPacketData) GetBytes(counterPartyPort string) []byte { +func (nftpd NonFungibleTokenPacketData) GetBytes(counterpartyPort string) []byte { var bz []byte var err error - if isWasmPacket(counterPartyPort) { + if isWasmPacket(counterpartyPort) { bz, err = json.Marshal(nftpd.ToWasmData()) } else { bz, err = json.Marshal(nftpd) @@ -87,11 +87,11 @@ func (nftpd NonFungibleTokenPacketData) GetBytes(counterPartyPort string) []byte } // decode packet data to NonFungibleTokenPacketData -func DecodePacketData(packetData []byte, counterPartyPort string) (NonFungibleTokenPacketData, error) { +func DecodePacketData(packetData []byte, counterpartyPort string) (NonFungibleTokenPacketData, error) { decoder := json.NewDecoder(strings.NewReader(string(packetData))) decoder.DisallowUnknownFields() - if isWasmPacket(counterPartyPort) { + if isWasmPacket(counterpartyPort) { var wasmData NonFungibleTokenPacketDataWasm if err := decoder.Decode(&wasmData); err != nil { return NonFungibleTokenPacketData{}, sdkerrors.ErrInvalidRequest.Wrap(err.Error())