Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Limit recursion depth for unknown field detection and unpack any #22901

Merged
merged 6 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i
* [#22826](https://github.com/cosmos/cosmos-sdk/pull/22826) Simplify testing frameworks by removing `testutil/cmdtest`.

### Bug Fixes

* (sims) [#21906](https://github.com/cosmos/cosmos-sdk/pull/21906) Skip sims test when running dry on validators
* (cli) [#21919](https://github.com/cosmos/cosmos-sdk/pull/21919) Query address-by-acc-num by account_id instead of id.
* (cli) [#22656](https://github.com/cosmos/cosmos-sdk/pull/22656) Prune cmd should disable async pruning.
Expand Down
56 changes: 54 additions & 2 deletions codec/types/interface_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ import (
"cosmossdk.io/x/tx/signing"
)

var (

// MaxUnpackAnySubCalls extension point that defines the maximum number of sub-calls allowed during the unpacking
// process of protobuf Any messages.
MaxUnpackAnySubCalls = 100

// MaxUnpackAnyRecursionDepth extension point that defines the maximum allowed recursion depth during protobuf Any
// message unpacking.
MaxUnpackAnyRecursionDepth = 10
)

// UnpackInterfaces is a convenience function that calls UnpackInterfaces
// on x if x implements UnpackInterfacesMessage
func UnpackInterfaces(x interface{}, unpacker gogoprotoany.AnyUnpacker) error {
Expand Down Expand Up @@ -230,6 +241,45 @@ func (registry *interfaceRegistry) ListImplementations(ifaceName string) []strin
}

func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error {
unpacker := &statefulUnpacker{
registry: registry,
maxDepth: MaxUnpackAnyRecursionDepth,
maxCalls: &sharedCounter{count: MaxUnpackAnySubCalls},
}
return unpacker.UnpackAny(any, iface)
}

// sharedCounter is a type that encapsulates a counter value
type sharedCounter struct {
count int
}

// statefulUnpacker is a struct that helps in deserializing and unpacking
// protobuf Any messages while maintaining certain stateful constraints.
type statefulUnpacker struct {
registry *interfaceRegistry
maxDepth int
maxCalls *sharedCounter
}

// cloneForRecursion returns a new statefulUnpacker instance with maxDepth reduced by one, preserving the registry and maxCalls.
func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker {
return &statefulUnpacker{
registry: r.registry,
maxDepth: r.maxDepth - 1,
maxCalls: r.maxCalls,
}
}

// UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer.
// It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required.
func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error {
if r.maxDepth == 0 {
return errors.New("max depth exceeded")
}
if r.maxCalls.count == 0 {
return errors.New("call limit exceeded")
}
Comment on lines +277 to +282
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update limit checks to handle non-positive values

The current checks for maxDepth and maxCalls.count only handle zero values. To enhance robustness and prevent potential underflows, consider updating the conditions to check for less than or equal to zero.

Apply this diff:

 func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error {
-    if r.maxDepth == 0 {
+    if r.maxDepth <= 0 {
         return errors.New("max depth exceeded")
     }
-    if r.maxCalls.count == 0 {
+    if r.maxCalls.count <= 0 {
         return errors.New("call limit exceeded")
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if r.maxDepth == 0 {
return errors.New("max depth exceeded")
}
if r.maxCalls.count == 0 {
return errors.New("call limit exceeded")
}
if r.maxDepth <= 0 {
return errors.New("max depth exceeded")
}
if r.maxCalls.count <= 0 {
return errors.New("call limit exceeded")
}

// here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding
if any == nil {
return nil
Expand All @@ -240,6 +290,8 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error
return nil
}

r.maxCalls.count--
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Prevent maxCalls.count from becoming negative

After decrementing r.maxCalls.count, ensure it does not become negative, which could lead to incorrect behavior. Consider adding a check after the decrement.

Apply this diff:

 r.maxCalls.count--
+if r.maxCalls.count < 0 {
+    return errors.New("call limit exceeded")
+}

Alternatively, since the limit check has been updated to handle non-positive values, this might already be addressed.

Committable suggestion skipped: line range outside the PR's diff.


rv := reflect.ValueOf(iface)
if rv.Kind() != reflect.Ptr {
return errors.New("UnpackAny expects a pointer")
Expand All @@ -255,7 +307,7 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error
}
}

imap, found := registry.interfaceImpls[rt]
imap, found := r.registry.interfaceImpls[rt]
if !found {
return fmt.Errorf("no registered implementations of type %+v", rt)
}
Expand All @@ -277,7 +329,7 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error
return err
}

err = UnpackInterfaces(msg, registry)
err = UnpackInterfaces(msg, r.cloneForRecursion())
if err != nil {
return err
}
Expand Down
18 changes: 16 additions & 2 deletions codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,23 @@ func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.Any
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) {
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, msg, allowUnknownNonCriticals, resolver, 10_000)
}

func doRejectUnknownFields(
bz []byte,
msg proto.Message,
allowUnknownNonCriticals bool,
resolver jsonpb.AnyResolver,
recursionLimit int,
) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit == 0 {
return false, errors.New("recursion limit reached")
}
Comment on lines +57 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update recursion limit check to handle non-positive values

To prevent potential underflows and enhance robustness, consider updating the recursion limit check from recursionLimit == 0 to recursionLimit <= 0.

Apply this diff:

 if len(bz) == 0 {
     return hasUnknownNonCriticals, nil
 }
-if recursionLimit == 0 {
+if recursionLimit <= 0 {
     return false, errors.New("recursion limit reached")
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if recursionLimit == 0 {
return false, errors.New("recursion limit reached")
}
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit <= 0 {
return false, errors.New("recursion limit reached")
}


fieldDescProtoFromTagNum, _, err := getDescriptorInfo(msg)
if err != nil {
Expand Down Expand Up @@ -125,7 +139,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals

if protoMessageName == ".google.protobuf.Any" {
// Firstly typecheck types.Any to ensure nothing snuck in.
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver, recursionLimit-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ensure proper handling of recursion limit during recursive calls

When recursively calling doRejectUnknownFields, ensure that the decrementing of recursionLimit does not result in negative values, which could bypass termination conditions.

Consider verifying that recursionLimit is greater than zero before the recursive call or rely on the updated recursion limit check as previously suggested.

hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand All @@ -148,7 +162,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
}
}

hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand Down
2 changes: 1 addition & 1 deletion x/tx/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Since v0.13.0, x/tx follows Cosmos SDK semver: https://github.com/cosmos/cosmos-

## [Unreleased]

## [v1.0.0-alpha.3](https://github.com/cosmos/cosmos-sdk/releases/tag/x/tx/v1.0.0-alpha.3) - 2024-12-12
## [v1.0.0-alpha.3](https://github.com/cosmos/cosmos-sdk/releases/tag/x/tx/v1.0.0-alpha.3) - 2024-12-16

### Bug Fixes

Expand Down
18 changes: 16 additions & 2 deletions x/tx/decode/unknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,23 @@ func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, re
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) {
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000)
}
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Define recursion limit as a constant for maintainability

Currently, the recursion limit of 10,000 is hardcoded within the RejectUnknownFields function. Defining this value as a constant improves readability and makes it easier to manage or adjust in the future.

Apply this diff to define the recursion limit as a constant:

+const defaultRecursionLimit = 10000

 func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) {
-	return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000)
+	return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, defaultRecursionLimit)
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000)
}
const defaultRecursionLimit = 10000
func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) {
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, defaultRecursionLimit)
}


func doRejectUnknownFields(
bz []byte,
desc protoreflect.MessageDescriptor,
allowUnknownNonCriticals bool,
resolver protodesc.Resolver,
recursionLimit int,
) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit == 0 {
return false, errors.New("recursion limit reached")
}
Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update recursion limit check to handle non-positive values

The current recursion limit check only handles the case when recursionLimit equals zero. To prevent potential underflows and enhance robustness, consider updating the condition to check for less than or equal to zero.

Apply this diff:

 if len(bz) == 0 {
     return hasUnknownNonCriticals, nil
 }
-if recursionLimit == 0 {
+if recursionLimit <= 0 {
     return false, errors.New("recursion limit reached")
 }

Committable suggestion skipped: line range outside the PR's diff.


fields := desc.Fields()

Expand Down Expand Up @@ -111,7 +125,7 @@ func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUn

if fieldMessage.FullName() == anyFullName {
// Firstly typecheck types.Any to ensure nothing snuck in.
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ensure decrementing recursion limit does not cause underflow

When recursively calling doRejectUnknownFields, ensure that decrementing recursionLimit does not result in negative values, which could bypass the termination condition.

Consider verifying that recursionLimit is greater than zero before the recursive call:

 if fieldMessage.FullName() == anyFullName {
     // Firstly typecheck types.Any to ensure nothing snuck in.
-    hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1)
+    if recursionLimit > 0 {
+        hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1)
+    } else {
+        return false, errors.New("recursion limit reached")
+    }
     hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
     if err != nil {
         return hasUnknownNonCriticals, err
     }

Alternatively, ensure the recursion limit check at the beginning of the function handles non-positive values as previously suggested.

Committable suggestion skipped: line range outside the PR's diff.

hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand All @@ -131,7 +145,7 @@ func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUn
fieldBytes = a.Value
}

hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consistent recursion limit handling in recursive calls

Similar to the previous comment, ensure that all recursive calls to doRejectUnknownFields properly handle the decrementing of recursionLimit and prevent potential underflow.

Apply this diff:

 hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1)
+if err != nil {
+    return hasUnknownNonCriticals, err
+}

Ensure that the recursive function handles recursionLimit <= 0 appropriately, as previously suggested.

Committable suggestion skipped: line range outside the PR's diff.

hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand Down
Loading