From 48a45beb316cf9b6cba2807ec44fe825121dec65 Mon Sep 17 00:00:00 2001 From: Jose Date: Mon, 2 Oct 2023 10:30:34 -0500 Subject: [PATCH] Update message.Marshal message to accept structs with Go's native types --- field/composite.go | 24 ++++++++---- field/composite_test.go | 4 +- message.go | 24 ++++++++---- message_test.go | 82 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 16 deletions(-) diff --git a/field/composite.go b/field/composite.go index 6612c89..212975c 100644 --- a/field/composite.go +++ b/field/composite.go @@ -195,13 +195,23 @@ func (f *Composite) Unmarshal(v interface{}) error { } dataField := dataStruct.Field(i) - if dataField.IsNil() { - dataField.Set(reflect.New(dataField.Type().Elem())) - } + switch dataField.Kind() { //nolint:exhaustive + case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if dataField.IsNil() { + dataField.Set(reflect.New(dataField.Type().Elem())) + } - err = messageField.Unmarshal(dataField.Interface()) - if err != nil { - return fmt.Errorf("failed to get data from field %s: %w", indexOrTag, err) + err = messageField.Unmarshal(dataField.Interface()) + if err != nil { + return fmt.Errorf("failed to get data from field %s: %w", indexOrTag, err) + } + default: // Native types + vv := reflect.New(dataField.Type()).Elem() + err = messageField.Unmarshal(vv) + if err != nil { + return fmt.Errorf("failed to get data from field %s: %w", indexOrTag, err) + } + dataField.Set(vv) } } @@ -259,7 +269,7 @@ func (f *Composite) Marshal(v interface{}) error { } dataField := dataStruct.Field(i) - if dataField.IsNil() { + if dataField.IsZero() { continue } diff --git a/field/composite_test.go b/field/composite_test.go index 7993927..e32f79a 100644 --- a/field/composite_test.go +++ b/field/composite_test.go @@ -597,7 +597,7 @@ func TestCompositePacking(t *testing.T) { }) require.Error(t, err) - require.EqualError(t, err, "failed to set data from field 1: data does not match required *String type") + require.EqualError(t, err, "failed to set data from field 1: data does not match required *String or (string, *string, int, *int) type") }) t.Run("Pack returns error on failure of subfield packing", func(t *testing.T) { @@ -747,7 +747,7 @@ func TestCompositePacking(t *testing.T) { err = composite.Unmarshal(data) require.Error(t, err) - require.EqualError(t, err, "failed to get data from field 1: data does not match required *String type") + require.EqualError(t, err, "failed to get data from field 1: data does not match required *String or *string type") }) t.Run("Unpack returns an error on failure of subfield to unpack bytes", func(t *testing.T) { diff --git a/message.go b/message.go index 54f01aa..e07b472 100644 --- a/message.go +++ b/message.go @@ -445,7 +445,7 @@ func (m *Message) Marshal(v interface{}) error { } dataField := dataStruct.Field(i) - if dataField.IsNil() { + if dataField.IsZero() { continue } @@ -502,13 +502,23 @@ func (m *Message) Unmarshal(v interface{}) error { } dataField := dataStruct.Field(i) - if dataField.IsNil() { - dataField.Set(reflect.New(dataField.Type().Elem())) - } + switch dataField.Kind() { //nolint:exhaustive + case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if dataField.IsNil() { + dataField.Set(reflect.New(dataField.Type().Elem())) + } - err = messageField.Unmarshal(dataField.Interface()) - if err != nil { - return fmt.Errorf("failed to get value from field %d: %w", fieldIndex, err) + err = messageField.Unmarshal(dataField.Interface()) + if err != nil { + return fmt.Errorf("failed to get value from field %d: %w", fieldIndex, err) + } + default: // Native types + vv := reflect.New(dataField.Type()).Elem() + err = messageField.Unmarshal(vv) + if err != nil { + return fmt.Errorf("failed to get value from field %d: %w", fieldIndex, err) + } + dataField.Set(vv) } } diff --git a/message_test.go b/message_test.go index 95f0634..5056580 100644 --- a/message_test.go +++ b/message_test.go @@ -250,6 +250,53 @@ func TestMessage(t *testing.T) { require.Equal(t, "100", data.F4.Value()) }) + t.Run("Test unpacking with untyped fields", func(t *testing.T) { + type TestISOF3Data struct { + F1 *string + F2 string + F3 string + } + + type ISO87Data struct { + F0 *string + F2 string + F3 *TestISOF3Data + F4 string + } + + message := NewMessage(spec) + + rawMsg := []byte("01007000000000000000164242424242424242123456000000000100") + err := message.Unpack([]byte(rawMsg)) + + require.NoError(t, err) + + s, err := message.GetString(2) + require.NoError(t, err) + require.Equal(t, "4242424242424242", s) + + s, err = message.GetString(3) + require.NoError(t, err) + require.Equal(t, "123456", s) + + s, err = message.GetString(4) + require.NoError(t, err) + require.Equal(t, "100", s) + + data := &ISO87Data{} + + require.NoError(t, message.Unmarshal(data)) + + require.NotNil(t, data.F0) + require.Equal(t, "0100", *data.F0) + require.Equal(t, "4242424242424242", data.F2) + require.NotNil(t, data.F3.F1) + require.Equal(t, "12", *data.F3.F1) + require.Equal(t, "34", data.F3.F2) + require.Equal(t, "56", data.F3.F3) + require.Equal(t, "100", data.F4) + }) + t.Run("Test packing with typed fields", func(t *testing.T) { type TestISOF3Data struct { F1 *field.String @@ -283,6 +330,41 @@ func TestMessage(t *testing.T) { wantMsg := []byte("01007000000000000000164242424242424242123456000000000100") require.Equal(t, wantMsg, rawMsg) }) + + t.Run("Test packing with untyped fields", func(t *testing.T) { + type TestISOF3Data struct { + F1 string + F2 string + F3 string + } + + type ISO87Data struct { + F0 *string + F2 string + F3 *TestISOF3Data + F4 string + } + + messageCode := "0100" + message := NewMessage(spec) + err := message.Marshal(&ISO87Data{ + F0: &messageCode, + F2: "4242424242424242", + F3: &TestISOF3Data{ + F1: "12", + F2: "34", + F3: "56", + }, + F4: "100", + }) + require.NoError(t, err) + + rawMsg, err := message.Pack() + require.NoError(t, err) + + wantMsg := []byte("01007000000000000000164242424242424242123456000000000100") + require.Equal(t, wantMsg, rawMsg) + }) } func TestPackUnpack(t *testing.T) {