diff --git a/auth/auth_test.go b/auth/auth_test.go index e14ef3c62..d806ea928 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -1,10 +1,16 @@ package auth import ( + "bytes" + "encoding/hex" + "errors" "fmt" "strings" "testing" + btcec "github.com/btcsuite/btcd/btcec/v2" + btcecdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/stakwork/sphinx-tribes/config" "github.com/stretchr/testify/assert" ) @@ -346,3 +352,182 @@ func TestEncodeJwt(t *testing.T) { }) } } + +func TestVerifyAndExtract(t *testing.T) { + + privKey, err := btcec.NewPrivateKey() + assert.NoError(t, err) + + createValidSignature := func(msg []byte) []byte { + signedMsg := append(signedMsgPrefix, msg...) + digest := chainhash.DoubleHashB(signedMsg) + sig, err := btcecdsa.SignCompact(privKey, digest, true) + assert.NoError(t, err) + return sig + } + + expectedPubKeyHex := hex.EncodeToString(privKey.PubKey().SerializeCompressed()) + + tests := []struct { + name string + msg []byte + sig []byte + expectedKey string + expectedValid bool + expectedErr error + }{ + { + name: "Valid signature and message", + msg: []byte("test message"), + sig: createValidSignature([]byte("test message")), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Empty message", + msg: []byte{}, + sig: createValidSignature([]byte{}), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Nil signature", + msg: []byte("test message"), + sig: nil, + expectedKey: "", + expectedValid: false, + expectedErr: errors.New("bad"), + }, + { + name: "Nil message", + msg: nil, + sig: createValidSignature([]byte("test message")), + expectedKey: "", + expectedValid: false, + expectedErr: errors.New("bad"), + }, + { + name: "Both nil inputs", + msg: nil, + sig: nil, + expectedKey: "", + expectedValid: false, + expectedErr: errors.New("bad"), + }, + { + name: "Empty signature", + msg: []byte("test message"), + sig: []byte{}, + expectedKey: "", + expectedValid: false, + expectedErr: errors.New("invalid compact signature size"), + }, + { + name: "Invalid signature format", + msg: []byte("test message"), + sig: []byte{0xFF, 0xFF}, + expectedKey: "", + expectedValid: false, + expectedErr: errors.New("invalid compact signature size"), + }, + { + name: "Corrupted signature", + msg: []byte("test message"), + sig: append(createValidSignature([]byte("test message")), byte(0x00)), + expectedKey: "", + expectedValid: false, + expectedErr: errors.New("invalid compact signature size"), + }, + { + name: "Large message", + msg: bytes.Repeat([]byte("a"), 1000), + sig: createValidSignature(bytes.Repeat([]byte("a"), 1000)), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Special characters in message", + msg: []byte("!@#$%^&*()_+{}:|<>?"), + sig: createValidSignature([]byte("!@#$%^&*()_+{}:|<>?")), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "UTF-8 characters in message", + msg: []byte("Hello, 世界"), + sig: createValidSignature([]byte("Hello, 世界")), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Message with null bytes", + msg: []byte("test\x00message"), + sig: createValidSignature([]byte("test\x00message")), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Message with only whitespace", + msg: []byte(" "), + sig: createValidSignature([]byte(" ")), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Maximum length message", + msg: bytes.Repeat([]byte("x"), 1<<20), + sig: createValidSignature(bytes.Repeat([]byte("x"), 1<<20)), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + { + name: "Binary data in message", + msg: []byte{0x00, 0x01, 0x02, 0x03, 0xFF}, + sig: createValidSignature([]byte{0x00, 0x01, 0x02, 0x03, 0xFF}), + expectedKey: expectedPubKeyHex, + expectedValid: true, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pubKeyHex, valid, err := VerifyAndExtract(tt.msg, tt.sig) + + if tt.expectedErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tt.expectedValid, valid) + + if tt.expectedKey != "" { + assert.Equal(t, tt.expectedKey, pubKeyHex) + } + + if tt.msg != nil && tt.sig != nil && err == nil { + assert.True(t, bytes.HasPrefix(append(signedMsgPrefix, tt.msg...), signedMsgPrefix)) + } + + if valid && err == nil { + _, err := hex.DecodeString(pubKeyHex) + assert.NoError(t, err, "Public key should be valid hex") + + if tt.sig != nil { + assert.Equal(t, 65, len(tt.sig), + "Valid signature should be 65 bytes (64 bytes signature + 1 byte recovery ID)") + } + } + }) + } +}