Skip to content

Commit

Permalink
chore: handle error
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Aug 8, 2024
1 parent ca3293f commit fa2958f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 25 deletions.
2 changes: 2 additions & 0 deletions pkg/kerrors/kerrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ var (
ErrRPCFinish = &basicError{"rpc call finished"}
// ErrRoute happens when router fail to route this call
ErrRoute = &basicError{"rpc route failed"}
// ErrPayloadValidation happens when payload validation failed
ErrPayloadValidation = &basicError{"payload validation error"}
)

// More detailed error types
Expand Down
4 changes: 2 additions & 2 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
}
if c.payloadValidator != nil {
fillRPCInfo(message)
fillRPCInfoBeforeValidate(message)
if vErr := validate(ctx, message, in, c.payloadValidator); vErr != nil {
return vErr
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func (c *defaultCodec) encodeMetaAndPayloadWithPayloadValidator(ctx context.Cont
if c.payloadValidator != nil {
need, value, pErr := c.payloadValidator.Generate(ctx, flatten2DSlice(payload, payloadLen))
if pErr != nil {
return pErr
return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("generate failed, err=%v", pErr))
}
if need {
if len(value) > maxPayloadChecksumLength {
Expand Down
63 changes: 40 additions & 23 deletions pkg/remote/codec/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
Expand All @@ -13,8 +15,8 @@ import (
)

const (
maxPayloadChecksumLength = 1024
PayloadValidatorPrefix = "PV_"
maxPayloadChecksumLength = 1024
)

func getValidatorKey(ctx context.Context, p PayloadValidator) string {
Expand All @@ -39,35 +41,50 @@ func validate(ctx context.Context, message remote.Message, in remote.ByteBuffer,
}
pass, err := p.Validate(ctx, expectedValue, payload)
if err != nil {
return err
return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("validation failed, err=%v", err))
}
if !pass {
return perrors.NewProtocolErrorWithType(perrors.InvalidData, "not pass")
return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("validation failed"))
}
return nil
}

func fillRPCInfo(message remote.Message) {
if ri := message.RPCInfo(); ri != nil {
transInfo := message.TransInfo()
intInfo := transInfo.TransIntInfo()
from := rpcinfo.AsMutableEndpointInfo(ri.From())
if from != nil {
if v := intInfo[transmeta.FromService]; v != "" {
from.SetServiceName(v)
}
if v := intInfo[transmeta.FromMethod]; v != "" {
from.SetMethod(v)
}
// fillRPCInfoBeforeValidate reads header and set into the RPCInfo, which allows Validate() to use RPCInfo.
func fillRPCInfoBeforeValidate(message remote.Message) {
if message.RPCRole() != remote.Server {
// only fill when server-side reading the request header
// TODO: client-side can read from the response header
return
}
ri := message.RPCInfo()
if ri == nil {
return
}
transInfo := message.TransInfo()
if transInfo == nil {
return
}
intInfo := transInfo.TransIntInfo()
if intInfo == nil {
return
}
from := rpcinfo.AsMutableEndpointInfo(ri.From())
if from != nil {
if v := intInfo[transmeta.FromService]; v != "" {
from.SetServiceName(v)
}
if v := intInfo[transmeta.FromMethod]; v != "" {
from.SetMethod(v)
}
}
to := rpcinfo.AsMutableEndpointInfo(ri.To())
if to != nil {
// server-side reads "to_method" from ttheader since "method" is set in thrift payload, which has not been unmarshalled
if v := intInfo[transmeta.ToMethod]; v != "" {
to.SetMethod(v)
}
to := rpcinfo.AsMutableEndpointInfo(ri.To())
if to != nil {
if v := intInfo[transmeta.ToMethod]; v != "" {
to.SetMethod(v)
}
if v := intInfo[transmeta.ToService]; v != "" {
to.SetServiceName(v)
}
if v := intInfo[transmeta.ToService]; v != "" {
to.SetServiceName(v)
}
}
}
Expand Down

0 comments on commit fa2958f

Please sign in to comment.