diff --git a/x/ibc-hooks/move-hooks/message.go b/x/ibc-hooks/move-hooks/message.go index 5c16b27c..10c935f3 100644 --- a/x/ibc-hooks/move-hooks/message.go +++ b/x/ibc-hooks/move-hooks/message.go @@ -2,6 +2,8 @@ package move_hooks import ( "encoding/json" + "fmt" + "strings" movetypes "github.com/initia-labs/initia/x/move/types" ) @@ -53,44 +55,50 @@ type HookData struct { AsyncCallback *AsyncCallback `json:"async_callback,omitempty"` } -// asyncCallback is same as AsyncCallback. -type asyncCallback struct { - // callback id should be issued form the executor contract - Id uint64 `json:"id"` - ModuleAddress string `json:"module_address"` - ModuleName string `json:"module_name"` -} - -// asyncCallbackStringID is same as AsyncCallback but -// it has Id as string. -type asyncCallbackStringID struct { - // callback id should be issued form the executor contract - Id uint64 `json:"id,string"` - ModuleAddress string `json:"module_address"` - ModuleName string `json:"module_name"` +// intermediateCallback is used internally for JSON unmarshaling +type intermediateCallback struct { + Id interface{} `json:"id"` + ModuleAddress string `json:"module_address"` + ModuleName string `json:"module_name"` } // UnmarshalJSON implements the json unmarshaler interface. -// custom unmarshaler is required because we have to handle -// id as string and uint64. +// It handles both string and numeric id formats and validates the module address. func (a *AsyncCallback) UnmarshalJSON(bz []byte) error { - var ac asyncCallback - err := json.Unmarshal(bz, &ac) - if err != nil { - var acStr asyncCallbackStringID - err := json.Unmarshal(bz, &acStr) - if err != nil { - return err - } + var ic intermediateCallback + if err := json.Unmarshal(bz, &ic); err != nil { + return fmt.Errorf("failed to unmarshal AsyncCallback: %w", err) + } + + // Validate required fields + if ic.ModuleAddress == "" { + return fmt.Errorf("module_address cannot be empty") + } + if ic.ModuleName == "" { + return fmt.Errorf("module_name cannot be empty") + } + + // Validate module address format (assuming it should start with "0x") + if !strings.HasPrefix(ic.ModuleAddress, "0x") { + return fmt.Errorf("invalid module_address format: must start with '0x'") + } - a.Id = acStr.Id - a.ModuleAddress = acStr.ModuleAddress - a.ModuleName = acStr.ModuleName - return nil + // Handle ID based on type + switch v := ic.Id.(type) { + case float64: + a.Id = uint64(v) + case string: + var err error + var parsed float64 + if err = json.Unmarshal([]byte(v), &parsed); err != nil { + return fmt.Errorf("invalid id format: %w", err) + } + a.Id = uint64(parsed) + default: + return fmt.Errorf("invalid id type: expected string or number") } - a.Id = ac.Id - a.ModuleAddress = ac.ModuleAddress - a.ModuleName = ac.ModuleName + a.ModuleAddress = ic.ModuleAddress + a.ModuleName = ic.ModuleName return nil } diff --git a/x/ibc-hooks/move-hooks/message_test.go b/x/ibc-hooks/move-hooks/message_test.go index 199918d3..ca2446b5 100644 --- a/x/ibc-hooks/move-hooks/message_test.go +++ b/x/ibc-hooks/move-hooks/message_test.go @@ -9,25 +9,95 @@ import ( ) func Test_Unmarshal_AsyncCallback(t *testing.T) { - var callback movehooks.AsyncCallback - err := json.Unmarshal([]byte(`{ - "id": 99, - "module_address": "0x1", - "module_name": "Counter" - }`), &callback) - require.NoError(t, err) - require.Equal(t, movehooks.AsyncCallback{ - Id: 99, - ModuleAddress: "0x1", - ModuleName: "Counter", - }, callback) - - var callbackStringID movehooks.AsyncCallback - err = json.Unmarshal([]byte(`{ - "id": "99", - "module_address": "0x1", - "module_name": "Counter" - }`), &callbackStringID) - require.NoError(t, err) - require.Equal(t, callback, callbackStringID) + t.Run("valid numeric id", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": 99, + "module_address": "0x1", + "module_name": "Counter" + }`), &callback) + require.NoError(t, err) + require.Equal(t, movehooks.AsyncCallback{ + Id: 99, + ModuleAddress: "0x1", + ModuleName: "Counter", + }, callback) + }) + + t.Run("valid string id", func(t *testing.T) { + var callbackStringID movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": "99", + "module_address": "0x1", + "module_name": "Counter" + }`), &callbackStringID) + require.NoError(t, err) + require.Equal(t, movehooks.AsyncCallback{ + Id: 99, + ModuleAddress: "0x1", + ModuleName: "Counter", + }, callbackStringID) + }) + + t.Run("empty module address", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": 99, + "module_address": "", + "module_name": "Counter" + }`), &callback) + require.Error(t, err) + require.Contains(t, err.Error(), "module_address cannot be empty") + }) + + t.Run("empty module name", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": 99, + "module_address": "0x1", + "module_name": "" + }`), &callback) + require.Error(t, err) + require.Contains(t, err.Error(), "module_name cannot be empty") + }) + + t.Run("invalid module address format", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": 99, + "module_address": "invalid", + "module_name": "Counter" + }`), &callback) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid module_address format") + }) + + t.Run("invalid id type", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": true, + "module_address": "0x1", + "module_name": "Counter" + }`), &callback) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid id type") + }) + + t.Run("invalid id string format", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{ + "id": "not_a_number", + "module_address": "0x1", + "module_name": "Counter" + }`), &callback) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid id format") + }) + + t.Run("malformed json", func(t *testing.T) { + var callback movehooks.AsyncCallback + err := json.Unmarshal([]byte(`{malformed`), &callback) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid character") + }) }