diff --git a/chunk_payload_data.go b/chunk_payload_data.go index a6e50db9..a5a00064 100644 --- a/chunk_payload_data.go +++ b/chunk_payload_data.go @@ -132,7 +132,7 @@ func (p *chunkPayloadData) unmarshal(raw []byte) error { p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0 p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0 - if len(raw) < payloadDataHeaderSize { + if len(p.raw) < payloadDataHeaderSize { return ErrChunkPayloadSmall } p.tsn = binary.BigEndian.Uint32(p.raw[0:]) diff --git a/error_cause_header.go b/error_cause_header.go index 84ff209d..98c530fb 100644 --- a/error_cause_header.go +++ b/error_cause_header.go @@ -5,6 +5,7 @@ package sctp import ( "encoding/binary" + "errors" ) // errorCauseHeader represents the shared header that is shared by all error causes @@ -18,6 +19,9 @@ const ( errorCauseHeaderLength = 4 ) +// ErrInvalidSCTPChunk is returned when an SCTP chunk is invalid +var ErrInvalidSCTPChunk = errors.New("invalid SCTP chunk") + func (e *errorCauseHeader) marshal() ([]byte, error) { e.len = uint16(len(e.raw)) + uint16(errorCauseHeaderLength) raw := make([]byte, e.len) @@ -31,6 +35,9 @@ func (e *errorCauseHeader) marshal() ([]byte, error) { func (e *errorCauseHeader) unmarshal(raw []byte) error { e.code = errorCauseCode(binary.BigEndian.Uint16(raw[0:])) e.len = binary.BigEndian.Uint16(raw[2:]) + if e.len < errorCauseHeaderLength || int(e.len) > len(raw) { + return ErrInvalidSCTPChunk + } valueLength := e.len - errorCauseHeaderLength e.raw = raw[errorCauseHeaderLength : errorCauseHeaderLength+valueLength] return nil diff --git a/param_requested_hmac_algorithm.go b/param_requested_hmac_algorithm.go index ca9b0147..3e98ea71 100644 --- a/param_requested_hmac_algorithm.go +++ b/param_requested_hmac_algorithm.go @@ -13,7 +13,7 @@ type hmacAlgorithm uint16 const ( hmacResv1 hmacAlgorithm = 0 - hmacSHA128 = 1 + hmacSHA128 hmacAlgorithm = 1 hmacResv2 hmacAlgorithm = 2 hmacSHA256 hmacAlgorithm = 3 ) @@ -21,6 +21,9 @@ const ( // ErrInvalidAlgorithmType is returned if unknown auth algorithm is specified. var ErrInvalidAlgorithmType = errors.New("invalid algorithm type") +// ErrInvalidChunkLength is returned if the chunk length is invalid. +var ErrInvalidChunkLength = errors.New("invalid chunk length") + func (c hmacAlgorithm) String() string { switch c { case hmacResv1: @@ -58,6 +61,9 @@ func (r *paramRequestedHMACAlgorithm) unmarshal(raw []byte) (param, error) { if err != nil { return nil, err } + if len(r.raw)%2 == 1 { + return nil, ErrInvalidChunkLength + } i := 0 for i < len(r.raw) {