diff --git a/certificate-authority/config.yaml b/certificate-authority/config.yaml index d96d07fc8..754a1ec13 100644 --- a/certificate-authority/config.yaml +++ b/certificate-authority/config.yaml @@ -48,6 +48,7 @@ apis: tokenTrustVerification: cacheExpiration: 30s http: + externalAddress: "https://0.0.0.0:9101" address: "0.0.0.0:9101" readTimeout: 8s readHeaderTimeout: 4s @@ -68,10 +69,6 @@ clients: keyFile: "/secrets/private/cert.key" certFile: "/secrets/public/cert.crt" useSystemCAPool: false - bulkWrite: - timeout: 1m0s - throttleTime: 500ms - documentLimit: 1000 cqlDB: table: "signedCertificateRecords" hosts: [] @@ -115,3 +112,5 @@ signer: certFile: "/secrets/public/intermediateca.crt" validFrom: "now-1h" expiresIn: "87600h" + crl: + expiresIn: "10m" diff --git a/certificate-authority/pb/README.md b/certificate-authority/pb/README.md index 0ca7dd372..34c6d0f06 100644 --- a/certificate-authority/pb/README.md +++ b/certificate-authority/pb/README.md @@ -89,8 +89,8 @@ | ----------- | ------------ | ------------- | ------------| | SignIdentityCertificate | [SignCertificateRequest](#certificateauthority-pb-SignCertificateRequest) | [SignCertificateResponse](#certificateauthority-pb-SignCertificateResponse) | SignIdentityCertificate sends a Identity Certificate Signing Request to the certificate authority and obtains a signed certificate. Both in the PEM format. It adds EKU: '1.3.6.1.4.1.44924.1.6' . | | SignCertificate | [SignCertificateRequest](#certificateauthority-pb-SignCertificateRequest) | [SignCertificateResponse](#certificateauthority-pb-SignCertificateResponse) | SignCertificate sends a Certificate Signing Request to the certificate authority and obtains a signed certificate. Both in the PEM format. | -| GetSigningRecords | [GetSigningRecordsRequest](#certificateauthority-pb-GetSigningRecordsRequest) | [SigningRecord](#certificateauthority-pb-SigningRecord) stream | Get signed certficate records. | -| DeleteSigningRecords | [DeleteSigningRecordsRequest](#certificateauthority-pb-DeleteSigningRecordsRequest) | [DeletedSigningRecords](#certificateauthority-pb-DeletedSigningRecords) | Delete signed certficate records. | +| GetSigningRecords | [GetSigningRecordsRequest](#certificateauthority-pb-GetSigningRecordsRequest) | [SigningRecord](#certificateauthority-pb-SigningRecord) stream | Get signed certificate records. | +| DeleteSigningRecords | [DeleteSigningRecordsRequest](#certificateauthority-pb-DeleteSigningRecordsRequest) | [DeletedSigningRecords](#certificateauthority-pb-DeletedSigningRecords) | Revoke signed certficate or delete expired signed certificate records. | @@ -120,6 +120,12 @@ | valid_until_date | [int64](#int64) | | Record valid until date, in unix nanoseconds timestamp format @gotags: bson:"validUntilDate" | +| serial | [string](#string) | | Serial number of the last certificat issued + +@gotags: bson:"serial" | +| issuer_id | [string](#string) | | Issuer id is calculated from the issuer's public certificate, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw) + +@gotags: bson:"issuerId" | @@ -145,7 +151,7 @@ ### DeletedSigningRecords - +Revoke or delete certificates | Field | Type | Label | Description | diff --git a/certificate-authority/pb/doc.html b/certificate-authority/pb/doc.html index 68d51f2f0..e0b12bf80 100644 --- a/certificate-authority/pb/doc.html +++ b/certificate-authority/pb/doc.html @@ -346,14 +346,14 @@

CertificateAuthority

GetSigningRecords GetSigningRecordsRequest SigningRecord stream -

Get signed certficate records.

+

Get signed certificate records.

DeleteSigningRecords DeleteSigningRecordsRequest DeletedSigningRecords -

Delete signed certficate records.

+

Revoke signed certficate or delete expired signed certificate records.

@@ -463,6 +463,24 @@

CredentialStatus

@gotags: bson:"validUntilDate"

+ + serial + string + +

Serial number of the last certificat issued + +@gotags: bson:"serial"

+ + + + issuer_id + string + +

Issuer id is calculated from the issuer's public certificate, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw) + +@gotags: bson:"issuerId"

+ + @@ -502,7 +520,7 @@

DeleteSigningRecord

DeletedSigningRecords

-

+

Revoke or delete certificates

diff --git a/certificate-authority/pb/service.proto b/certificate-authority/pb/service.proto index 7efca0723..439315af4 100644 --- a/certificate-authority/pb/service.proto +++ b/certificate-authority/pb/service.proto @@ -56,7 +56,7 @@ service CertificateAuthority { }; } - // Get signed certficate records. + // Get signed certificate records. rpc GetSigningRecords (GetSigningRecordsRequest) returns (stream SigningRecord) { option (google.api.http) = { get: "/api/v1/signing/records" @@ -66,7 +66,7 @@ service CertificateAuthority { }; }; - // Delete signed certficate records. + // Revoke signed certficate or delete expired signed certificate records. rpc DeleteSigningRecords (DeleteSigningRecordsRequest) returns (DeletedSigningRecords) { option (google.api.http) = { delete: "/api/v1/signing/records" diff --git a/certificate-authority/pb/service.swagger.json b/certificate-authority/pb/service.swagger.json index 2db104ed0..93bb25f90 100644 --- a/certificate-authority/pb/service.swagger.json +++ b/certificate-authority/pb/service.swagger.json @@ -98,7 +98,7 @@ }, "/api/v1/signing/records": { "get": { - "summary": "Get signed certficate records.", + "summary": "Get signed certificate records.", "operationId": "CertificateAuthority_GetSigningRecords", "responses": { "200": { @@ -163,7 +163,7 @@ ] }, "delete": { - "summary": "Delete signed certficate records.", + "summary": "Revoke signed certficate or delete expired signed certificate records.", "operationId": "CertificateAuthority_DeleteSigningRecords", "responses": { "200": { @@ -227,6 +227,16 @@ "format": "int64", "description": "@gotags: bson:\"validUntilDate\"", "title": "Record valid until date, in unix nanoseconds timestamp format" + }, + "serial": { + "type": "string", + "description": "@gotags: bson:\"serial\"", + "title": "Serial number of the last certificat issued" + }, + "issuerId": { + "type": "string", + "description": "@gotags: bson:\"issuerId\"", + "title": "Issuer id is calculated from the issuer's public certificate, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw)" } } }, @@ -238,7 +248,8 @@ "format": "int64", "description": "Number of deleted records." } - } + }, + "title": "Revoke or delete certificates" }, "pbSignCertificateRequest": { "type": "object", diff --git a/certificate-authority/pb/service_grpc.pb.go b/certificate-authority/pb/service_grpc.pb.go index 71dd16956..67bda0a08 100644 --- a/certificate-authority/pb/service_grpc.pb.go +++ b/certificate-authority/pb/service_grpc.pb.go @@ -35,9 +35,9 @@ type CertificateAuthorityClient interface { // SignCertificate sends a Certificate Signing Request to the certificate authority // and obtains a signed certificate. Both in the PEM format. SignCertificate(ctx context.Context, in *SignCertificateRequest, opts ...grpc.CallOption) (*SignCertificateResponse, error) - // Get signed certficate records. + // Get signed certificate records. GetSigningRecords(ctx context.Context, in *GetSigningRecordsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SigningRecord], error) - // Delete signed certficate records. + // Revoke signed certficate or delete expired signed certificate records. DeleteSigningRecords(ctx context.Context, in *DeleteSigningRecordsRequest, opts ...grpc.CallOption) (*DeletedSigningRecords, error) } @@ -108,9 +108,9 @@ type CertificateAuthorityServer interface { // SignCertificate sends a Certificate Signing Request to the certificate authority // and obtains a signed certificate. Both in the PEM format. SignCertificate(context.Context, *SignCertificateRequest) (*SignCertificateResponse, error) - // Get signed certficate records. + // Get signed certificate records. GetSigningRecords(*GetSigningRecordsRequest, grpc.ServerStreamingServer[SigningRecord]) error - // Delete signed certficate records. + // Revoke signed certficate or delete expired signed certificate records. DeleteSigningRecords(context.Context, *DeleteSigningRecordsRequest) (*DeletedSigningRecords, error) mustEmbedUnimplementedCertificateAuthorityServer() } diff --git a/certificate-authority/pb/signingRecords.go b/certificate-authority/pb/signingRecords.go index b6eccf5cf..c9a2fdd54 100644 --- a/certificate-authority/pb/signingRecords.go +++ b/certificate-authority/pb/signingRecords.go @@ -3,6 +3,7 @@ package pb import ( "errors" "fmt" + "math/big" "sort" "github.com/google/uuid" @@ -17,6 +18,26 @@ func (p SigningRecords) Sort() { }) } +func (credential *CredentialStatus) Validate() error { + if credential.GetDate() == 0 { + return errors.New("empty signing credential date") + } + if credential.GetValidUntilDate() == 0 { + return errors.New("empty signing record credential expiration date") + } + if credential.GetCertificatePem() == "" { + return errors.New("empty signing record credential certificate") + } + serial := big.Int{} + if _, ok := serial.SetString(credential.GetSerial(), 10); !ok { + return errors.New("invalid signing record credential certificate serial number") + } + if _, err := uuid.Parse(credential.GetIssuerId()); err != nil { + return fmt.Errorf("invalid signing record issuer's ID(%v): %w", credential.GetIssuerId(), err) + } + return nil +} + func (signingRecord *SigningRecord) Marshal() ([]byte, error) { return proto.Marshal(signingRecord) } @@ -43,14 +64,9 @@ func (signingRecord *SigningRecord) Validate() error { if signingRecord.GetOwner() == "" { return errors.New("empty signing record owner") } - if signingRecord.GetCredential() != nil && signingRecord.GetCredential().GetDate() == 0 { - return errors.New("empty signing credential date") - } - if signingRecord.GetCredential() != nil && signingRecord.GetCredential().GetValidUntilDate() == 0 { - return errors.New("empty signing record credential expiration date") - } - if signingRecord.GetCredential() != nil && signingRecord.GetCredential().GetCertificatePem() == "" { - return errors.New("empty signing record credential certificate") + credential := signingRecord.GetCredential() + if credential != nil { + return credential.Validate() } return nil } diff --git a/certificate-authority/pb/signingRecords.pb.go b/certificate-authority/pb/signingRecords.pb.go index 8ce4fbb3d..5b8a059a9 100644 --- a/certificate-authority/pb/signingRecords.pb.go +++ b/certificate-authority/pb/signingRecords.pb.go @@ -97,6 +97,10 @@ type CredentialStatus struct { CertificatePem string `protobuf:"bytes,2,opt,name=certificate_pem,json=certificatePem,proto3" json:"certificate_pem,omitempty" bson:"identityCertificate"` // @gotags: bson:"identityCertificate" // Record valid until date, in unix nanoseconds timestamp format ValidUntilDate int64 `protobuf:"varint,3,opt,name=valid_until_date,json=validUntilDate,proto3" json:"valid_until_date,omitempty" bson:"validUntilDate"` // @gotags: bson:"validUntilDate" + // Serial number of the last certificat issued + Serial string `protobuf:"bytes,4,opt,name=serial,proto3" json:"serial,omitempty" bson:"serial"` // @gotags: bson:"serial" + // Issuer id is calculated from the issuer's public certificate, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw) + IssuerId string `protobuf:"bytes,5,opt,name=issuer_id,json=issuerId,proto3" json:"issuer_id,omitempty" bson:"issuerId"` // @gotags: bson:"issuerId" } func (x *CredentialStatus) Reset() { @@ -152,6 +156,20 @@ func (x *CredentialStatus) GetValidUntilDate() int64 { return 0 } +func (x *CredentialStatus) GetSerial() string { + if x != nil { + return x.Serial + } + return "" +} + +func (x *CredentialStatus) GetIssuerId() string { + if x != nil { + return x.IssuerId + } + return "" +} + type SigningRecord struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -311,6 +329,7 @@ func (x *DeleteSigningRecordsRequest) GetDeviceIdFilter() []string { return nil } +// Revoke or delete certificates type DeletedSigningRecords struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -375,45 +394,48 @@ var file_certificate_authority_pb_signingRecords_proto_rawDesc = []byte{ 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x4e, 0x61, 0x6d, 0x65, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x64, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x49, 0x64, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x22, 0x79, 0x0a, 0x10, 0x43, 0x72, - 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, - 0x0a, 0x04, 0x64, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x64, 0x61, - 0x74, 0x65, 0x12, 0x27, 0x0a, 0x0f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, - 0x65, 0x5f, 0x70, 0x65, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x65, 0x72, - 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x50, 0x65, 0x6d, 0x12, 0x28, 0x0a, 0x10, 0x76, - 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x75, 0x6e, 0x74, 0x69, 0x6c, 0x5f, 0x64, 0x61, 0x74, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x55, 0x6e, 0x74, 0x69, - 0x6c, 0x44, 0x61, 0x74, 0x65, 0x22, 0x82, 0x02, 0x0a, 0x0d, 0x53, 0x69, 0x67, 0x6e, 0x69, 0x6e, - 0x67, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6f, 0x77, 0x6e, 0x65, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x12, 0x1f, 0x0a, - 0x0b, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1b, - 0x0a, 0x09, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x08, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x70, - 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x72, - 0x65, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x64, 0x61, 0x74, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x0c, 0x63, 0x72, 0x65, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x44, 0x61, 0x74, 0x65, 0x12, - 0x49, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, - 0x65, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x70, 0x62, 0x2e, 0x43, 0x72, - 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, - 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x22, 0x64, 0x0a, 0x1b, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x53, 0x69, 0x67, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x63, 0x6f, 0x72, - 0x64, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x64, 0x5f, - 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x69, 0x64, - 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x5f, 0x69, 0x64, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x0e, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, - 0x22, 0x2d, 0x0a, 0x15, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x69, 0x67, 0x6e, 0x69, - 0x6e, 0x67, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x42, - 0x38, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6c, - 0x67, 0x64, 0x2d, 0x64, 0x65, 0x76, 0x2f, 0x68, 0x75, 0x62, 0x2f, 0x76, 0x32, 0x2f, 0x63, 0x65, - 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x2d, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x74, 0x79, 0x2f, 0x70, 0x62, 0x3b, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x63, 0x65, 0x49, 0x64, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x22, 0xae, 0x01, 0x0a, 0x10, 0x43, + 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x64, + 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x0f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, + 0x74, 0x65, 0x5f, 0x70, 0x65, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x65, + 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x50, 0x65, 0x6d, 0x12, 0x28, 0x0a, 0x10, + 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x75, 0x6e, 0x74, 0x69, 0x6c, 0x5f, 0x64, 0x61, 0x74, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x55, 0x6e, 0x74, + 0x69, 0x6c, 0x44, 0x61, 0x74, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x1b, + 0x0a, 0x09, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x49, 0x64, 0x22, 0x82, 0x02, 0x0a, 0x0d, + 0x53, 0x69, 0x67, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x0e, 0x0a, + 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x14, 0x0a, + 0x05, 0x6f, 0x77, 0x6e, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6f, 0x77, + 0x6e, 0x65, 0x72, 0x12, 0x1f, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, + 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, + 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, + 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, + 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x72, 0x65, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x64, 0x61, 0x74, + 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0c, 0x63, 0x72, 0x65, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x44, 0x61, 0x74, 0x65, 0x12, 0x49, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, + 0x69, 0x61, 0x6c, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x63, 0x65, 0x72, 0x74, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x79, + 0x2e, 0x70, 0x62, 0x2e, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, + 0x22, 0x64, 0x0a, 0x1b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x69, 0x67, 0x6e, 0x69, 0x6e, + 0x67, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x1b, 0x0a, 0x09, 0x69, 0x64, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x08, 0x69, 0x64, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x12, 0x28, 0x0a, 0x10, + 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x5f, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, + 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x22, 0x2d, 0x0a, 0x15, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x53, 0x69, 0x67, 0x6e, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x38, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6c, 0x67, 0x64, 0x2d, 0x64, 0x65, 0x76, 0x2f, 0x68, 0x75, 0x62, + 0x2f, 0x76, 0x32, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x2d, + 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x79, 0x2f, 0x70, 0x62, 0x3b, 0x70, 0x62, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/certificate-authority/pb/signingRecords.proto b/certificate-authority/pb/signingRecords.proto index 79166bfaf..8981cac10 100644 --- a/certificate-authority/pb/signingRecords.proto +++ b/certificate-authority/pb/signingRecords.proto @@ -20,6 +20,10 @@ message CredentialStatus { string certificate_pem = 2; // @gotags: bson:"identityCertificate" // Record valid until date, in unix nanoseconds timestamp format int64 valid_until_date = 3; // @gotags: bson:"validUntilDate" + // Serial number of the last certificat issued + string serial = 4; // @gotags: bson:"serial" + // Issuer id is calculated from the issuer's public certificate, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw) + string issuer_id = 5; // @gotags: bson:"issuerId" } message SigningRecord { @@ -46,7 +50,8 @@ message DeleteSigningRecordsRequest { repeated string device_id_filter = 2; } +// Revoke or delete certificates message DeletedSigningRecords { // Number of deleted records. int64 count = 1; -} \ No newline at end of file +} diff --git a/certificate-authority/pb/signingRecords_test.go b/certificate-authority/pb/signingRecords_test.go new file mode 100644 index 000000000..42e9e5d16 --- /dev/null +++ b/certificate-authority/pb/signingRecords_test.go @@ -0,0 +1,189 @@ +package pb_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/plgd-dev/hub/v2/certificate-authority/pb" + "github.com/stretchr/testify/require" +) + +func TestCredentialStatusValidate(t *testing.T) { + tests := []struct { + name string + input *pb.CredentialStatus + wantErr bool + }{ + { + name: "Valid credential", + input: &pb.CredentialStatus{ + Date: 1659462400000000000, + ValidUntilDate: 1669462400000000000, + CertificatePem: "valid-cert", + Serial: "1234567890", + IssuerId: uuid.New().String(), + }, + wantErr: false, + }, + { + name: "Missing signing credential date", + input: &pb.CredentialStatus{ + Date: 0, + ValidUntilDate: 1669462400000000000, + CertificatePem: "valid-cert", + Serial: "1234567890", + IssuerId: uuid.New().String(), + }, + wantErr: true, + }, + { + name: "Missing signing credential expiration date", + input: &pb.CredentialStatus{ + Date: 1659462400000000000, + ValidUntilDate: 0, + CertificatePem: "valid-cert", + Serial: "1234567890", + IssuerId: uuid.New().String(), + }, + wantErr: true, + }, + { + name: "Missing signing record credential certificate", + input: &pb.CredentialStatus{ + Date: 1659462400000000000, + ValidUntilDate: 1669462400000000000, + CertificatePem: "", + Serial: "1234567890", + IssuerId: uuid.New().String(), + }, + wantErr: true, + }, + { + name: "Invalid certificate serial number", + input: &pb.CredentialStatus{ + Date: 1659462400000000000, + ValidUntilDate: 1669462400000000000, + CertificatePem: "valid-cert", + Serial: "invalid-serial", + IssuerId: uuid.New().String(), + }, + wantErr: true, + }, + { + name: "Invalid issuer ID", + input: &pb.CredentialStatus{ + Date: 1659462400000000000, + ValidUntilDate: 1669462400000000000, + CertificatePem: "valid-cert", + Serial: "1234567890", + IssuerId: "invalid-uuid", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestSigningRecordValidate(t *testing.T) { + validCredential := &pb.CredentialStatus{ + Date: 1659462400000000000, + ValidUntilDate: 1669462400000000000, + CertificatePem: "valid-cert", + Serial: "1234567890", + IssuerId: uuid.New().String(), + } + + tests := []struct { + name string + input *pb.SigningRecord + wantErr bool + }{ + { + name: "Valid signing record", + input: &pb.SigningRecord{ + Id: uuid.New().String(), + Owner: "owner", + CommonName: "common_name", + DeviceId: uuid.New().String(), + Credential: validCredential, + }, + wantErr: false, + }, + { + name: "Missing signing record ID", + input: &pb.SigningRecord{ + Id: "", + Owner: "owner", + CommonName: "common_name", + DeviceId: uuid.New().String(), + Credential: validCredential, + }, + wantErr: true, + }, + { + name: "Invalid signing record ID", + input: &pb.SigningRecord{ + Id: "invalid-uuid", + Owner: "owner", + CommonName: "common_name", + DeviceId: uuid.New().String(), + Credential: validCredential, + }, + wantErr: true, + }, + { + name: "Invalid device ID", + input: &pb.SigningRecord{ + Id: uuid.New().String(), + Owner: "owner", + CommonName: "common_name", + DeviceId: "invalid-uuid", + Credential: validCredential, + }, + wantErr: true, + }, + { + name: "Missing common name", + input: &pb.SigningRecord{ + Id: uuid.New().String(), + Owner: "owner", + CommonName: "", + DeviceId: uuid.New().String(), + Credential: validCredential, + }, + wantErr: true, + }, + { + name: "Missing owner", + input: &pb.SigningRecord{ + Id: uuid.New().String(), + Owner: "", + CommonName: "common_name", + DeviceId: uuid.New().String(), + Credential: validCredential, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/certificate-authority/service/cleanDatabase_test.go b/certificate-authority/service/cleanDatabase_test.go index afe8d2c87..fbba1cbb7 100644 --- a/certificate-authority/service/cleanDatabase_test.go +++ b/certificate-authority/service/cleanDatabase_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math/big" "testing" "time" @@ -17,7 +18,7 @@ import ( "github.com/plgd-dev/hub/v2/identity-store/events" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" "github.com/plgd-dev/hub/v2/test/config" testService "github.com/plgd-dev/hub/v2/test/service" "github.com/stretchr/testify/require" @@ -48,6 +49,8 @@ func TestCertificateAuthorityServerCleanUpSigningRecords(t *testing.T) { CertificatePem: "certificate1", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, } @@ -64,7 +67,7 @@ func TestCertificateAuthorityServerCleanUpSigningRecords(t *testing.T) { }() ch := new(inprocgrpc.Channel) - ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), test.MakeConfig(t).Signer, storeDB, fileWatcher, logger) + ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), "https://"+config.CERTIFICATE_AUTHORITY_HTTP_HOST, test.MakeConfig(t).Signer, storeDB, fileWatcher, logger) require.NoError(t, err) defer ca.Close() @@ -73,7 +76,7 @@ func TestCertificateAuthorityServerCleanUpSigningRecords(t *testing.T) { token := config.CreateJwtToken(t, jwt.MapClaims{ ownerClaim: owner, }) - ctx := kitNetGrpc.CtxWithToken(context.Background(), token) + ctx := pkgGrpc.CtxWithToken(context.Background(), token) client, err := grpcClient.GetSigningRecords(ctx, &pb.GetSigningRecordsRequest{}) require.NoError(t, err) var got pb.SigningRecords diff --git a/certificate-authority/service/config.go b/certificate-authority/service/config.go index 7a0852e2c..e08518c3c 100644 --- a/certificate-authority/service/config.go +++ b/certificate-authority/service/config.go @@ -3,6 +3,7 @@ package service import ( "fmt" "net" + "net/url" "time" "github.com/go-co-op/gocron/v2" @@ -40,7 +41,7 @@ func (c *Config) Validate() error { return fmt.Errorf("hubID('%v') - %w", c.HubID, err) } - _, err := grpcService.NewSigner(c.APIs.GRPC.Authorization.OwnerClaim, c.HubID, c.Signer) + _, err := grpcService.NewSigner(c.APIs.GRPC.Authorization.OwnerClaim, c.HubID, c.APIs.HTTP.ExternalAddress, c.Signer) if err != nil { return fmt.Errorf("signer('%v') - %w", c.Signer, err) } @@ -65,11 +66,15 @@ func (c *APIsConfig) Validate() error { } type HTTPConfig struct { - Addr string `yaml:"address" json:"address"` - Server httpServer.Config `yaml:",inline" json:",inline"` + ExternalAddress string `yaml:"externalAddress" json:"externalAddress"` + Addr string `yaml:"address" json:"address"` + Server httpServer.Config `yaml:",inline" json:",inline"` } func (c *HTTPConfig) Validate() error { + if _, err := url.ParseRequestURI(c.ExternalAddress); err != nil { + return fmt.Errorf("externalAddress('%v') invalid", c.ExternalAddress) + } if _, err := net.ResolveTCPAddr("tcp", c.Addr); err != nil { return fmt.Errorf("address('%v') - %w", c.Addr, err) } diff --git a/certificate-authority/service/config_test.go b/certificate-authority/service/config_test.go index 58427612d..6b96edb37 100644 --- a/certificate-authority/service/config_test.go +++ b/certificate-authority/service/config_test.go @@ -147,6 +147,57 @@ func TestConfigValidate(t *testing.T) { } } +func TestHTTPConfigValidate(t *testing.T) { + type args struct { + cfg service.HTTPConfig + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid", + args: args{ + cfg: test.MakeHTTPConfig(), + }, + }, + { + name: "invalid external address", + args: args{ + cfg: func() service.HTTPConfig { + cfg := test.MakeHTTPConfig() + cfg.ExternalAddress = "invalid" + return cfg + }(), + }, + wantErr: true, + }, + { + name: "invalid address", + args: args{ + cfg: func() service.HTTPConfig { + cfg := test.MakeHTTPConfig() + cfg.Addr = "invalid" + return cfg + }(), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.args.cfg.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + func TestStorageConfigValidate(t *testing.T) { type args struct { cfg service.StorageConfig diff --git a/certificate-authority/service/grpc/config.go b/certificate-authority/service/grpc/config.go index 325b98da2..4871b7afa 100644 --- a/certificate-authority/service/grpc/config.go +++ b/certificate-authority/service/grpc/config.go @@ -13,12 +13,31 @@ import ( type Config = server.Config +type CRLConfig struct { + ExpiresIn time.Duration `yaml:"expiresIn" json:"expiresIn"` + + // needed by tests with cqldb - remove once support for CRL + // is implemented in cqldb or cqldb is removed + Enabled bool `yaml:"-" json:"-"` +} + +func (c *CRLConfig) Validate() error { + if !c.Enabled { + return nil + } + if c.ExpiresIn <= time.Minute { + return fmt.Errorf("expiresIn('%v') - less than %v", c.ExpiresIn, time.Minute) + } + return nil +} + type SignerConfig struct { CAPool interface{} `yaml:"caPool" json:"caPool" description:"file path to the root certificates in PEM format"` KeyFile urischeme.URIScheme `yaml:"keyFile" json:"keyFile" description:"file name of CA private key in PEM format"` CertFile urischeme.URIScheme `yaml:"certFile" json:"certFile" description:"file name of CA certificate in PEM format"` ValidFrom string `yaml:"validFrom" json:"validFrom" description:"format https://github.com/karrick/tparse"` ExpiresIn time.Duration `yaml:"expiresIn" json:"expiresIn"` + CRL CRLConfig `yaml:"crl" json:"crl"` caPoolArray []urischeme.URIScheme `yaml:"-" json:"-"` } @@ -36,13 +55,15 @@ func (c *SignerConfig) Validate() error { return fmt.Errorf("keyFile('%v')", c.KeyFile) } if c.ExpiresIn <= 0 { - return fmt.Errorf("expiresIn('%v')", c.KeyFile) + return fmt.Errorf("expiresIn('%v')", c.ExpiresIn) } _, err := tparse.ParseNow(time.RFC3339, c.ValidFrom) if err != nil { - return fmt.Errorf("validFrom('%v')", c.ValidFrom) + return fmt.Errorf("validFrom('%v').%w", c.ValidFrom, err) + } + if err := c.CRL.Validate(); err != nil { + return fmt.Errorf("crl.%w", err) } - return nil } diff --git a/certificate-authority/service/grpc/config_test.go b/certificate-authority/service/grpc/config_test.go new file mode 100644 index 000000000..560a62434 --- /dev/null +++ b/certificate-authority/service/grpc/config_test.go @@ -0,0 +1,158 @@ +package grpc_test + +import ( + "testing" + "time" + + "github.com/plgd-dev/hub/v2/certificate-authority/service/grpc" + "github.com/plgd-dev/hub/v2/pkg/config/property/urischeme" + "github.com/stretchr/testify/require" +) + +func TestCRLConfigValidate(t *testing.T) { + tests := []struct { + name string + input grpc.CRLConfig + wantErr bool + }{ + { + name: "Disabled CRLConfig", + input: grpc.CRLConfig{ + Enabled: false, + }, + }, + { + name: "Enabled CRLConfig with valid ExternalAddress and ExpiresIn", + input: grpc.CRLConfig{ + Enabled: true, + ExpiresIn: time.Hour, + }, + }, + { + name: "Enabled CRLConfig with ExpiresIn less than 1 minute", + input: grpc.CRLConfig{ + Enabled: true, + ExpiresIn: 30 * time.Second, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestSignerConfigValidate(t *testing.T) { + crl := grpc.CRLConfig{ + Enabled: true, + ExpiresIn: time.Hour, + } + tests := []struct { + name string + input grpc.SignerConfig + wantErr bool + }{ + { + name: "Valid SignerConfig", + input: grpc.SignerConfig{ + CAPool: []string{"ca1.pem", "ca2.pem"}, + KeyFile: urischeme.URIScheme("key.pem"), + CertFile: urischeme.URIScheme("cert.pem"), + ValidFrom: time.Now().Format(time.RFC3339), + ExpiresIn: time.Hour * 24, + CRL: crl, + }, + }, + { + name: "Invalid CA Pool", + input: grpc.SignerConfig{ + CAPool: 42, + KeyFile: urischeme.URIScheme("key.pem"), + CertFile: urischeme.URIScheme("cert.pem"), + ValidFrom: time.Now().Format(time.RFC3339), + ExpiresIn: time.Hour * 24, + CRL: crl, + }, + wantErr: true, + }, + { + name: "Empty CertFile", + input: grpc.SignerConfig{ + CAPool: []string{"ca1.pem"}, + KeyFile: urischeme.URIScheme("key.pem"), + CertFile: "", + ValidFrom: time.Now().Format(time.RFC3339), + ExpiresIn: time.Hour * 24, + CRL: crl, + }, + wantErr: true, + }, + { + name: "Empty KeyFile", + input: grpc.SignerConfig{ + CAPool: []string{"ca1.pem"}, + KeyFile: "", + CertFile: urischeme.URIScheme("cert.pem"), + ValidFrom: time.Now().Format(time.RFC3339), + ExpiresIn: time.Hour * 24, + CRL: crl, + }, + wantErr: true, + }, + { + name: "Invalid ExpiresIn", + input: grpc.SignerConfig{ + CAPool: []string{"ca1.pem", "ca2.pem"}, + KeyFile: urischeme.URIScheme("key.pem"), + CertFile: urischeme.URIScheme("cert.pem"), + ValidFrom: time.Now().Format(time.RFC3339), + ExpiresIn: -1, + CRL: crl, + }, + wantErr: true, + }, + { + name: "Invalid ValidFrom format", + input: grpc.SignerConfig{ + CAPool: []string{"ca1.pem"}, + KeyFile: urischeme.URIScheme("key.pem"), + CertFile: urischeme.URIScheme("cert.pem"), + ValidFrom: "invalid-date", + ExpiresIn: time.Hour * 24, + CRL: crl, + }, + wantErr: true, + }, + { + name: "Invalid CRL", + input: grpc.SignerConfig{ + CAPool: []string{"ca1.pem", "ca2.pem"}, + KeyFile: urischeme.URIScheme("key.pem"), + CertFile: urischeme.URIScheme("cert.pem"), + ValidFrom: time.Now().Format(time.RFC3339), + ExpiresIn: time.Hour * 24, + CRL: grpc.CRLConfig{ + Enabled: true, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/certificate-authority/service/grpc/deleteSigningRecords.go b/certificate-authority/service/grpc/deleteSigningRecords.go index 53456c26a..42637619d 100644 --- a/certificate-authority/service/grpc/deleteSigningRecords.go +++ b/certificate-authority/service/grpc/deleteSigningRecords.go @@ -8,16 +8,20 @@ import ( "google.golang.org/grpc/status" ) +func errDeleteSigningRecords(err error) error { + return status.Errorf(codes.InvalidArgument, "cannot delete signing records: %v", err) +} + func (s *CertificateAuthorityServer) DeleteSigningRecords(ctx context.Context, req *pb.DeleteSigningRecordsRequest) (*pb.DeletedSigningRecords, error) { owner, err := ownerToUUID(ctx, s.ownerClaim) if err != nil { - return nil, s.logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot delete signing records: %v", err)) + return nil, s.logger.LogAndReturnError(errDeleteSigningRecords(err)) } - n, err := s.store.DeleteSigningRecords(ctx, owner, req) + count, err := s.store.RevokeSigningRecords(ctx, owner, req) if err != nil { - return nil, s.logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot delete signing records: %v", err)) + return nil, s.logger.LogAndReturnError(errDeleteSigningRecords(err)) } return &pb.DeletedSigningRecords{ - Count: n, + Count: count, }, nil } diff --git a/certificate-authority/service/grpc/deleteSigningRecords_test.go b/certificate-authority/service/grpc/deleteSigningRecords_test.go index 573fdff84..f59b06c4b 100644 --- a/certificate-authority/service/grpc/deleteSigningRecords_test.go +++ b/certificate-authority/service/grpc/deleteSigningRecords_test.go @@ -2,6 +2,7 @@ package grpc_test import ( "context" + "math/big" "testing" "github.com/fullstorydev/grpchan/inprocgrpc" @@ -13,7 +14,7 @@ import ( "github.com/plgd-dev/hub/v2/identity-store/events" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/require" ) @@ -21,6 +22,10 @@ import ( func TestCertificateAuthorityServerDeleteSigningRecords(t *testing.T) { owner := events.OwnerToUUID("owner") const ownerClaim = "sub" + token := config.CreateJwtToken(t, jwt.MapClaims{ + ownerClaim: owner, + }) + ctx := pkgGrpc.CtxWithToken(context.Background(), token) r := &store.SigningRecord{ Id: "9d017fad-2961-4fcc-94a9-1e1291a88ffc", Owner: owner, @@ -31,10 +36,13 @@ func TestCertificateAuthorityServerDeleteSigningRecords(t *testing.T) { CertificatePem: "certificate1", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, } type args struct { req *pb.DeleteSigningRecordsRequest + ctx context.Context } tests := []struct { name string @@ -42,14 +50,24 @@ func TestCertificateAuthorityServerDeleteSigningRecords(t *testing.T) { want int64 wantErr bool }{ + { + name: "missing token with ownerClaim in ctx", + args: args{ + req: &pb.DeleteSigningRecordsRequest{ + IdFilter: []string{r.GetId()}, + }, + ctx: context.Background(), + }, + wantErr: true, + }, { name: "invalidID", args: args{ req: &pb.DeleteSigningRecordsRequest{ IdFilter: []string{"invalidID"}, }, + ctx: ctx, }, - wantErr: true, }, { name: "valid", @@ -57,6 +75,7 @@ func TestCertificateAuthorityServerDeleteSigningRecords(t *testing.T) { req: &pb.DeleteSigningRecordsRequest{ IdFilter: []string{r.GetId()}, }, + ctx: ctx, }, want: 1, }, @@ -78,20 +97,20 @@ func TestCertificateAuthorityServerDeleteSigningRecords(t *testing.T) { }() ch := new(inprocgrpc.Channel) - ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), test.MakeConfig(t).Signer, store, fileWatcher, logger) + ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), "https://"+config.CERTIFICATE_AUTHORITY_HTTP_HOST, test.MakeConfig(t).Signer, store, fileWatcher, logger) require.NoError(t, err) defer ca.Close() pb.RegisterCertificateAuthorityServer(ch, ca) grpcClient := pb.NewCertificateAuthorityClient(ch) - token := config.CreateJwtToken(t, jwt.MapClaims{ - ownerClaim: owner, - }) - ctx := kitNetGrpc.CtxWithToken(context.Background(), token) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := grpcClient.DeleteSigningRecords(ctx, tt.args.req) + got, err := grpcClient.DeleteSigningRecords(tt.args.ctx, tt.args.req) + if tt.wantErr { + require.Error(t, err) + return + } require.NoError(t, err) require.Equal(t, tt.want, got.GetCount()) }) diff --git a/certificate-authority/service/grpc/getSigningRecords.go b/certificate-authority/service/grpc/getSigningRecords.go index 9236524a9..875954746 100644 --- a/certificate-authority/service/grpc/getSigningRecords.go +++ b/certificate-authority/service/grpc/getSigningRecords.go @@ -1,10 +1,7 @@ package grpc import ( - "context" - "github.com/plgd-dev/hub/v2/certificate-authority/pb" - "github.com/plgd-dev/hub/v2/certificate-authority/store" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -14,16 +11,11 @@ func (s *CertificateAuthorityServer) GetSigningRecords(req *pb.GetSigningRecords if err != nil { return s.logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot get signing records: %v", err)) } - err = s.store.LoadSigningRecords(srv.Context(), owner, req, func(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var sub pb.SigningRecord - if ok := iter.Next(ctx, &sub); !ok { - return iter.Err() - } - if err = srv.Send(&sub); err != nil { - return err - } + err = s.store.LoadSigningRecords(srv.Context(), owner, req, func(sr *pb.SigningRecord) (err error) { + if err = srv.Send(sr); err != nil { + return err } + return nil }) if err != nil { return s.logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot get signing records: %v", err)) diff --git a/certificate-authority/service/grpc/getSigningRecords_test.go b/certificate-authority/service/grpc/getSigningRecords_test.go index c7f1d859a..1b5f0abae 100644 --- a/certificate-authority/service/grpc/getSigningRecords_test.go +++ b/certificate-authority/service/grpc/getSigningRecords_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "math/big" "testing" "time" @@ -16,7 +17,7 @@ import ( "github.com/plgd-dev/hub/v2/identity-store/events" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" hubTest "github.com/plgd-dev/hub/v2/test" "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/require" @@ -39,6 +40,8 @@ func TestCertificateAuthorityServerGetSigningRecords(t *testing.T) { CertificatePem: "certificate1", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, } type args struct { @@ -86,7 +89,7 @@ func TestCertificateAuthorityServerGetSigningRecords(t *testing.T) { }() ch := new(inprocgrpc.Channel) - ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), test.MakeConfig(t).Signer, store, fileWatcher, logger) + ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), "https://"+config.CERTIFICATE_AUTHORITY_HTTP_HOST, test.MakeConfig(t).Signer, store, fileWatcher, logger) require.NoError(t, err) defer ca.Close() @@ -95,7 +98,7 @@ func TestCertificateAuthorityServerGetSigningRecords(t *testing.T) { token := config.CreateJwtToken(t, jwt.MapClaims{ ownerClaim: owner, }) - ctx := kitNetGrpc.CtxWithToken(context.Background(), token) + ctx := pkgGrpc.CtxWithToken(context.Background(), token) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/certificate-authority/service/grpc/server.go b/certificate-authority/service/grpc/server.go index bddabf426..20c27930a 100644 --- a/certificate-authority/service/grpc/server.go +++ b/certificate-authority/service/grpc/server.go @@ -23,18 +23,20 @@ type CertificateAuthorityServer struct { hubID string fileWatcher *fsnotify.Watcher onFileChangeFunc func(event fsnotify.Event) + crlServerAddress string signer atomic.Pointer[Signer] } -func NewCertificateAuthorityServer(ownerClaim string, hubID string, signerConfig SignerConfig, store store.Store, fileWatcher *fsnotify.Watcher, logger log.Logger) (*CertificateAuthorityServer, error) { +func NewCertificateAuthorityServer(ownerClaim, hubID, crlServerAddress string, signerConfig SignerConfig, store store.Store, fileWatcher *fsnotify.Watcher, logger log.Logger) (*CertificateAuthorityServer, error) { s := &CertificateAuthorityServer{ - signerConfig: signerConfig, - logger: logger, - ownerClaim: ownerClaim, - store: store, - hubID: hubID, - fileWatcher: fileWatcher, + signerConfig: signerConfig, + logger: logger, + ownerClaim: ownerClaim, + store: store, + hubID: hubID, + fileWatcher: fileWatcher, + crlServerAddress: crlServerAddress, } _, err := s.load() @@ -100,7 +102,7 @@ func (s *CertificateAuthorityServer) Close() { } func (s *CertificateAuthorityServer) load() (bool, error) { - signer, err := NewSigner(s.ownerClaim, s.hubID, s.signerConfig) + signer, err := NewSigner(s.ownerClaim, s.hubID, s.crlServerAddress, s.signerConfig) if err != nil { return false, fmt.Errorf("cannot create signer: %w", err) } diff --git a/certificate-authority/service/grpc/server_test.go b/certificate-authority/service/grpc/server_test.go index 34a21297b..2baf46be4 100644 --- a/certificate-authority/service/grpc/server_test.go +++ b/certificate-authority/service/grpc/server_test.go @@ -43,8 +43,6 @@ func TestReloadCerts(t *testing.T) { store, closeStore := test.NewMongoStore(t) defer closeStore() - logCfg := log.MakeDefaultConfig() - logCfg.Level = log.DebugLevel logger := log.NewLogger(log.MakeDefaultConfig()) fileWatcher, err := fsnotify.NewWatcher(logger) @@ -85,7 +83,7 @@ func TestReloadCerts(t *testing.T) { err = s.Validate() require.NoError(t, err) - ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), s, store, fileWatcher, logger) + ca, err := grpc.NewCertificateAuthorityServer(ownerClaim, config.HubID(), "https://"+config.CERTIFICATE_AUTHORITY_HTTP_HOST, s, store, fileWatcher, logger) require.NoError(t, err) defer ca.Close() diff --git a/certificate-authority/service/grpc/signCertificate.go b/certificate-authority/service/grpc/signCertificate.go index 45252f07c..8ac5dafe5 100644 --- a/certificate-authority/service/grpc/signCertificate.go +++ b/certificate-authority/service/grpc/signCertificate.go @@ -26,35 +26,7 @@ func (s *CertificateAuthorityServer) validateRequest(csr []byte) error { return nil } -func (s *CertificateAuthorityServer) updateSigningIdentityCertificateRecord(ctx context.Context, updateSigningRecord *pb.SigningRecord) error { - var found bool - now := time.Now().UnixNano() - err := s.store.LoadSigningRecords(ctx, updateSigningRecord.GetOwner(), &store.SigningRecordsQuery{ - CommonNameFilter: []string{updateSigningRecord.GetCommonName()}, - }, func(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var signingRecord pb.SigningRecord - ok := iter.Next(ctx, &signingRecord) - if !ok { - break - } - if updateSigningRecord.GetPublicKey() != signingRecord.GetPublicKey() && signingRecord.GetCredential().GetValidUntilDate() > now { - return fmt.Errorf("common name %v with different public key fingerprint exist", signingRecord.GetCommonName()) - } - found = true - } - return nil - }) - if err != nil { - return err - } - if found { - return s.store.UpdateSigningRecord(ctx, updateSigningRecord) - } - return s.store.CreateSigningRecord(ctx, updateSigningRecord) -} - -func toSigningRecord(owner string, template *x509.Certificate) (*pb.SigningRecord, error) { +func toSigningRecord(owner, issuerID string, template *x509.Certificate) (*pb.SigningRecord, error) { publicKeyRaw, err := x509.MarshalPKIXPublicKey(template.PublicKey) if err != nil { return nil, err @@ -82,19 +54,95 @@ func toSigningRecord(owner string, template *x509.Certificate) (*pb.SigningRecor CertificatePem: "", Date: now, ValidUntilDate: template.NotAfter.UnixNano(), + Serial: template.SerialNumber.String(), + IssuerId: issuerID, }, }, nil } +func (s *CertificateAuthorityServer) getSigningRecord(ctx context.Context, signingRecord *pb.SigningRecord) (*pb.SigningRecord, error) { + checkForIdentity := signingRecord.GetDeviceId() != "" && signingRecord.GetDeviceId() != signingRecord.GetOwner() + var err error + var originalSr *store.SigningRecord + if checkForIdentity { + now := time.Now().UnixNano() + err = s.store.LoadSigningRecords(ctx, signingRecord.GetOwner(), &store.SigningRecordsQuery{ + CommonNameFilter: []string{signingRecord.GetCommonName()}, + }, func(sr *store.SigningRecord) (err error) { + // _id is calculated as uuid.NewSHA1(uuid.NameSpaceX500, CommonName + PublicKey) -> thus same CommonName and PublicKey == same _id + if signingRecord.GetPublicKey() != sr.GetPublicKey() && + sr.GetCredential().GetValidUntilDate() > now { + return fmt.Errorf("common name %v with different public key fingerprint exist", sr.GetCommonName()) + } + if signingRecord.GetId() == sr.GetId() { + originalSr = sr + } + return nil + }) + } else { + err = s.store.LoadSigningRecords(ctx, signingRecord.GetOwner(), &store.SigningRecordsQuery{ + IdFilter: []string{signingRecord.GetId()}, + }, func(sr *store.SigningRecord) (err error) { + originalSr = sr + return nil + }) + } + if err != nil { + return nil, err + } + return originalSr, nil +} + +func (s *CertificateAuthorityServer) updateRevocationListForSigningRecord(ctx context.Context, sr, prevSr *pb.SigningRecord) error { + if prevSr != nil { + // revoke previous signing record + prevCred := prevSr.GetCredential() + if prevCred != nil { + query := store.UpdateRevocationListQuery{ + IssuerID: prevCred.GetIssuerId(), + RevokedCertificates: []*store.RevocationListCertificate{ + { + Serial: prevCred.GetSerial(), + ValidUntil: prevCred.GetValidUntilDate(), + Revocation: time.Now().UnixNano(), + }, + }, + } + _, err := s.store.UpdateRevocationList(ctx, &query) + return err + } + return nil + } + cred := sr.GetCredential() + if cred != nil { + // create new RevocationList if it doesn't exist + err := s.store.InsertRevocationLists(ctx, &store.RevocationList{ + Id: cred.GetIssuerId(), + Number: "1", + }) + if errors.Is(err, store.ErrDuplicateID) { + return nil + } + return err + } + return nil +} + func (s *CertificateAuthorityServer) updateSigningRecord(ctx context.Context, signingRecord *pb.SigningRecord) error { - var checkForIdentity bool - if signingRecord.GetDeviceId() != "" && signingRecord.GetDeviceId() != signingRecord.GetOwner() { - checkForIdentity = true + // try to get previous signing record + prevSr, err := s.getSigningRecord(ctx, signingRecord) + if err != nil { + return err } - if checkForIdentity { - return s.updateSigningIdentityCertificateRecord(ctx, signingRecord) + if s.store.SupportsRevocationList() { + err = s.updateRevocationListForSigningRecord(ctx, signingRecord, prevSr) + if err != nil { + return err + } } - return s.store.UpdateSigningRecord(ctx, signingRecord) + // upsert new one + err = s.store.UpdateSigningRecord(ctx, signingRecord) + return err } func (s *CertificateAuthorityServer) SignCertificate(ctx context.Context, req *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) { @@ -111,10 +159,11 @@ func (s *CertificateAuthorityServer) SignCertificate(ctx context.Context, req *p if err != nil { return nil, logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, fmtError, err)) } - if signingRecord.GetCredential() == nil { - return nil, logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, "cannot sign certificate: cannot create signing record")) + credential := signingRecord.GetCredential() + if credential == nil { + return nil, logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, fmtError, errors.New("cannot create signing record"))) } - signingRecord.Credential.CertificatePem = string(cert) + credential.CertificatePem = string(cert) if err := s.updateSigningRecord(ctx, signingRecord); err != nil { return nil, logger.LogAndReturnError(status.Errorf(codes.InvalidArgument, fmtError, err)) } diff --git a/certificate-authority/service/grpc/signCertificate_test.go b/certificate-authority/service/grpc/signCertificate_test.go index 133de2e8b..1da435196 100644 --- a/certificate-authority/service/grpc/signCertificate_test.go +++ b/certificate-authority/service/grpc/signCertificate_test.go @@ -10,12 +10,13 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/hub/v2/certificate-authority/pb" caTest "github.com/plgd-dev/hub/v2/certificate-authority/test" m2mOauthTest "github.com/plgd-dev/hub/v2/m2m-oauth-server/test" m2mOauthUri "github.com/plgd-dev/hub/v2/m2m-oauth-server/uri" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" "github.com/plgd-dev/hub/v2/pkg/security/jwt/validator" "github.com/plgd-dev/hub/v2/test" "github.com/plgd-dev/hub/v2/test/config" @@ -69,7 +70,7 @@ func testSigningByFunction(t *testing.T, signFn ClientSignFunc, csr ...[]byte) { tearDown := service.SetUp(ctx, t) defer tearDown() - ctx = kitNetGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) + ctx = pkgGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ RootCAs: test.GetRootCertificatePool(t), @@ -90,9 +91,7 @@ func testSigningByFunction(t *testing.T, signFn ClientSignFunc, csr ...[]byte) { } } -func createCSR(t *testing.T, commonName string) []byte { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) +func createCSRWithKey(t *testing.T, commonName string, priv *ecdsa.PrivateKey) []byte { var cfg generateCertificate.Configuration cfg.Subject.CommonName = commonName csr, err := generateCertificate.GenerateCSR(cfg, priv) @@ -100,6 +99,12 @@ func createCSR(t *testing.T, commonName string) []byte { return csr } +func createCSR(t *testing.T, commonName string) []byte { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return createCSRWithKey(t, commonName, priv) +} + func TestCertificateAuthorityServerSignCSR(t *testing.T) { csr := createCSR(t, "aa") testSigningByFunction(t, func(ctx context.Context, c pb.CertificateAuthorityClient, req *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) { @@ -130,7 +135,42 @@ func TestCertificateAuthorityServerSignCSRWithDifferentPublicKeys(t *testing.T) tearDown := service.SetUp(ctx, t, service.WithCAConfig(cfg), service.WithM2MOAuthConfig(m2mCfg)) defer tearDown() - ctx = kitNetGrpc.CtxWithToken(ctx, m2mOauthTest.GetDefaultAccessToken(t)) + ctx = pkgGrpc.CtxWithToken(ctx, m2mOauthTest.GetDefaultAccessToken(t)) + + conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: test.GetRootCertificatePool(t), + }))) + require.NoError(t, err) + c := pb.NewCertificateAuthorityClient(conn) + + _, err = c.SignIdentityCertificate(ctx, &pb.SignCertificateRequest{CertificateSigningRequest: csr}) + require.NoError(t, err) + + _, err = c.SignIdentityCertificate(ctx, &pb.SignCertificateRequest{CertificateSigningRequest: csr1}) + require.NoError(t, err) +} + +func TestCertificateAuthorityServerSignCSRWithSameDevice(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + cfg := caTest.MakeConfig(t) + cfg.APIs.GRPC.Authorization.Endpoints = append(cfg.APIs.GRPC.Authorization.Endpoints, validator.AuthorityConfig{ + Authority: "https://" + config.M2M_OAUTH_SERVER_HTTP_HOST + m2mOauthUri.Base, + HTTP: config.MakeHttpClientConfig(), + }) + + m2mCfg := m2mOauthTest.MakeConfig(t) + serviceOAuthClient := m2mOauthTest.ServiceOAuthClient + serviceOAuthClient.InsertTokenClaims = map[string]interface{}{ + config.OWNER_CLAIM: oauthService.DeviceUserID, + } + m2mCfg.OAuthSigner.Clients[0] = &serviceOAuthClient + + tearDown := service.SetUp(ctx, t, service.WithCAConfig(cfg), service.WithM2MOAuthConfig(m2mCfg)) + defer tearDown() + + ctx = pkgGrpc.CtxWithToken(ctx, m2mOauthTest.GetDefaultAccessToken(t)) conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ RootCAs: test.GetRootCertificatePool(t), @@ -138,11 +178,23 @@ func TestCertificateAuthorityServerSignCSRWithDifferentPublicKeys(t *testing.T) require.NoError(t, err) c := pb.NewCertificateAuthorityClient(conn) + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + deviceID := uuid.NewString() + + csr := createCSRWithKey(t, "uuid:"+deviceID, priv) _, err = c.SignIdentityCertificate(ctx, &pb.SignCertificateRequest{CertificateSigningRequest: csr}) require.NoError(t, err) + csr1 := createCSRWithKey(t, "uuid:"+deviceID, priv) _, err = c.SignIdentityCertificate(ctx, &pb.SignCertificateRequest{CertificateSigningRequest: csr1}) require.NoError(t, err) + + priv2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + csr2 := createCSRWithKey(t, "uuid:"+deviceID, priv2) + _, err = c.SignIdentityCertificate(ctx, &pb.SignCertificateRequest{CertificateSigningRequest: csr2}) + require.Error(t, err) } func TestCertificateAuthorityServerSignCSRWithEmptyCommonName(t *testing.T) { diff --git a/certificate-authority/service/grpc/signIdentityCertificate_test.go b/certificate-authority/service/grpc/signIdentityCertificate_test.go index a4a1b4260..7ee6528d9 100644 --- a/certificate-authority/service/grpc/signIdentityCertificate_test.go +++ b/certificate-authority/service/grpc/signIdentityCertificate_test.go @@ -12,7 +12,7 @@ import ( "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/hub/v2/certificate-authority/pb" "github.com/plgd-dev/hub/v2/identity-store/events" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" "github.com/plgd-dev/hub/v2/test" "github.com/plgd-dev/hub/v2/test/config" oauthService "github.com/plgd-dev/hub/v2/test/oauth-server/service" @@ -49,7 +49,7 @@ func TestCertificateAuthorityServerSignDeviceIdentityCSRWithDifferentPublicKeys( tearDown := service.SetUp(ctx, t) defer tearDown() - ctx = kitNetGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) + ctx = pkgGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ RootCAs: test.GetRootCertificatePool(t), @@ -77,7 +77,7 @@ func TestCertificateAuthorityServerSignOwnerIdentityCSRWithDifferentPublicKeys(t tearDown := service.SetUp(ctx, t) defer tearDown() - ctx = kitNetGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) + ctx = pkgGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ RootCAs: test.GetRootCertificatePool(t), diff --git a/certificate-authority/service/grpc/signer.go b/certificate-authority/service/grpc/signer.go index 0b4866e0a..480c25993 100644 --- a/certificate-authority/service/grpc/signer.go +++ b/certificate-authority/service/grpc/signer.go @@ -2,14 +2,15 @@ package grpc import ( "context" - "crypto" "crypto/ecdsa" "crypto/x509" "errors" "time" + "github.com/google/uuid" "github.com/karrick/tparse/v2" "github.com/plgd-dev/hub/v2/certificate-authority/pb" + "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" "github.com/plgd-dev/hub/v2/pkg/security/certificateSigner" pkgX509 "github.com/plgd-dev/hub/v2/pkg/security/x509" ) @@ -18,9 +19,14 @@ type Signer struct { validFrom func() time.Time validFor time.Duration certificate []*x509.Certificate - privateKey crypto.PrivateKey + privateKey *ecdsa.PrivateKey + issuerID string ownerClaim string hubID string + crl struct { + serverAddress string + validFor time.Duration + } } func checkCertificatePrivateKey(cert []*x509.Certificate, priv *ecdsa.PrivateKey) error { @@ -39,7 +45,39 @@ func checkCertificatePrivateKey(cert []*x509.Certificate, priv *ecdsa.PrivateKey return nil } -func NewSigner(ownerClaim string, hubID string, signerConfig SignerConfig) (*Signer, error) { +func getIssuerID(rootCertificate *x509.Certificate) (string, error) { + publicKeyRaw, err := x509.MarshalPKIXPublicKey(rootCertificate.PublicKey) + if err != nil { + return "", err + } + return uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw).String(), nil +} + +func newSigner(ownerClaim, hubID, crlServerAddress string, signerConfig SignerConfig, privateKey *ecdsa.PrivateKey, certificate []*x509.Certificate) (*Signer, error) { + issuerID, err := getIssuerID(certificate[0]) + if err != nil { + return nil, err + } + signer := &Signer{ + validFrom: func() time.Time { + t, _ := tparse.ParseNow(time.RFC3339, signerConfig.ValidFrom) + return t + }, + validFor: signerConfig.ExpiresIn, + certificate: certificate, + privateKey: privateKey, + issuerID: issuerID, + ownerClaim: ownerClaim, + hubID: hubID, + } + if signerConfig.CRL.Enabled { + signer.crl.serverAddress = crlServerAddress + signer.crl.validFor = signerConfig.CRL.ExpiresIn + } + return signer, nil +} + +func NewSigner(ownerClaim, hubID, crlServerAddress string, signerConfig SignerConfig) (*Signer, error) { data, err := signerConfig.CertFile.Read() if err != nil { return nil, err @@ -60,19 +98,8 @@ func NewSigner(ownerClaim string, hubID string, signerConfig SignerConfig) (*Sig return nil, err } if len(certificate) == 1 && pkgX509.IsRootCA(certificate[0]) { - return &Signer{ - validFrom: func() time.Time { - t, _ := tparse.ParseNow(time.RFC3339, signerConfig.ValidFrom) - return t - }, - validFor: signerConfig.ExpiresIn, - certificate: certificate, - privateKey: privateKey, - ownerClaim: ownerClaim, - hubID: hubID, - }, nil + return newSigner(ownerClaim, hubID, crlServerAddress, signerConfig, privateKey, certificate) } - certificateAuthorities := make([]*x509.Certificate, 0, len(signerConfig.caPoolArray)*4) for _, caFile := range signerConfig.caPoolArray { data, errR := caFile.Read() @@ -93,18 +120,7 @@ func NewSigner(ownerClaim string, hubID string, signerConfig SignerConfig) (*Sig if err != nil { return nil, err } - - return &Signer{ - validFrom: func() time.Time { - t, _ := tparse.ParseNow(time.RFC3339, signerConfig.ValidFrom) - return t - }, - validFor: signerConfig.ExpiresIn, - certificate: chains[0], - privateKey: privateKey, - ownerClaim: ownerClaim, - hubID: hubID, - }, nil + return newSigner(ownerClaim, hubID, crlServerAddress, signerConfig, privateKey, chains[0]) } func (s *Signer) prepareSigningRecord(ctx context.Context, template *x509.Certificate) (*pb.SigningRecord, error) { @@ -117,31 +133,62 @@ func (s *Signer) prepareSigningRecord(ctx context.Context, template *x509.Certif if err != nil { return nil, err } - return toSigningRecord(owner, template) + return toSigningRecord(owner, s.issuerID, template) } -func (s *Signer) Sign(ctx context.Context, csr []byte) ([]byte, *pb.SigningRecord, error) { - notBefore := s.validFrom() - notAfter := notBefore.Add(s.validFor) - var signingRecord *pb.SigningRecord - signer := certificateSigner.New(s.certificate, s.privateKey, certificateSigner.WithNotBefore(notBefore), certificateSigner.WithNotAfter(notAfter), certificateSigner.WithOverrideCertTemplate(func(template *x509.Certificate) error { - var err error - signingRecord, err = s.prepareSigningRecord(ctx, template) - return err - })) - crt, err := signer.Sign(ctx, csr) - return crt, signingRecord, err +func (s *Signer) GetCertificate() *x509.Certificate { + return s.certificate[0] } -func (s *Signer) SignIdentityCSR(ctx context.Context, csr []byte) ([]byte, *pb.SigningRecord, error) { +func (s *Signer) GetPrivateKey() *ecdsa.PrivateKey { + return s.privateKey +} + +func (s *Signer) GetCRLConfiguation() (string, time.Duration) { + return s.crl.serverAddress, s.crl.validFor +} + +func (s *Signer) IsCRLEnabled() bool { + return s.crl.serverAddress != "" +} + +func (s *Signer) newCertificateSigner(identitySigner bool, opts ...func(cfg *certificateSigner.SignerConfig)) *certificateSigner.CertificateSigner { + if identitySigner { + return certificateSigner.NewIdentityCertificateSigner(s.certificate, s.privateKey, opts...) + } + return certificateSigner.New(s.certificate, s.privateKey, opts...) +} + +func (s *Signer) sign(ctx context.Context, isIdentityCertificate bool, csr []byte) ([]byte, *pb.SigningRecord, error) { notBefore := s.validFrom() notAfter := notBefore.Add(s.validFor) var signingRecord *pb.SigningRecord - signer := certificateSigner.NewIdentityCertificateSigner(s.certificate, s.privateKey, certificateSigner.WithNotBefore(notBefore), certificateSigner.WithNotAfter(notAfter), certificateSigner.WithOverrideCertTemplate(func(template *x509.Certificate) error { - var err error - signingRecord, err = s.prepareSigningRecord(ctx, template) - return err - })) + opts := []certificateSigner.Opt{ + certificateSigner.WithNotBefore(notBefore), + certificateSigner.WithNotAfter(notAfter), + certificateSigner.WithOverrideCertTemplate(func(template *x509.Certificate) error { + var err error + signingRecord, err = s.prepareSigningRecord(ctx, template) + return err + }), + } + if s.IsCRLEnabled() { + opts = append(opts, certificateSigner.WithCRLDistributionPoints( + []string{s.crl.serverAddress + uri.SigningRevocationListBase + "/" + s.issuerID}, + )) + } + signer := s.newCertificateSigner(isIdentityCertificate, opts...) cert, err := signer.Sign(ctx, csr) - return cert, signingRecord, err + if err != nil { + return nil, nil, err + } + return cert, signingRecord, nil +} + +func (s *Signer) Sign(ctx context.Context, csr []byte) ([]byte, *pb.SigningRecord, error) { + return s.sign(ctx, false, csr) +} + +func (s *Signer) SignIdentityCSR(ctx context.Context, csr []byte) ([]byte, *pb.SigningRecord, error) { + return s.sign(ctx, true, csr) } diff --git a/certificate-authority/service/grpc/signer_internal_test.go b/certificate-authority/service/grpc/signer_internal_test.go index 7c0464214..ab785ea33 100644 --- a/certificate-authority/service/grpc/signer_internal_test.go +++ b/certificate-authority/service/grpc/signer_internal_test.go @@ -116,7 +116,7 @@ func TestNewSigner(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewSigner("tt.args.ownerClaim", "tt.args.hubID", tt.args.signerConfig) + got, err := NewSigner("ownerClaim", "hubID", "", tt.args.signerConfig) if tt.wantErr { require.Error(t, err) return diff --git a/certificate-authority/service/http/config.go b/certificate-authority/service/http/config.go index f4392ddc6..d30b24556 100644 --- a/certificate-authority/service/http/config.go +++ b/certificate-authority/service/http/config.go @@ -12,6 +12,8 @@ type Config struct { Connection listener.Config `yaml:",inline" json:",inline"` Authorization validator.Config `yaml:"authorization" json:"authorization"` Server server.Config `yaml:",inline" json:",inline"` + + CRLEnabled bool `yaml:"-" json:"-"` } func (c *Config) Validate() error { diff --git a/certificate-authority/service/http/requestHandler.go b/certificate-authority/service/http/requestHandler.go index 4f8fd1953..33c0526b1 100644 --- a/certificate-authority/service/http/requestHandler.go +++ b/certificate-authority/service/http/requestHandler.go @@ -3,12 +3,15 @@ package http import ( "context" "fmt" + "net/http" "github.com/fullstorydev/grpchan/inprocgrpc" "github.com/gorilla/mux" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/plgd-dev/hub/v2/certificate-authority/pb" grpcService "github.com/plgd-dev/hub/v2/certificate-authority/service/grpc" + "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" + "github.com/plgd-dev/hub/v2/certificate-authority/store" "github.com/plgd-dev/hub/v2/http-gateway/serverMux" ) @@ -16,17 +19,26 @@ import ( type RequestHandler struct { config *Config mux *runtime.ServeMux + + cas *grpcService.CertificateAuthorityServer + store store.Store } // NewHTTP returns HTTP handler -func NewRequestHandler(config *Config, r *mux.Router, certificateAuthorityServer *grpcService.CertificateAuthorityServer) (*RequestHandler, error) { +func NewRequestHandler(config *Config, r *mux.Router, cas *grpcService.CertificateAuthorityServer, s store.Store) (*RequestHandler, error) { requestHandler := &RequestHandler{ config: config, mux: serverMux.New(), + cas: cas, + store: s, + } + + if config.CRLEnabled { + r.HandleFunc(uri.SigningRevocationList, requestHandler.revocationList).Methods(http.MethodGet) } ch := new(inprocgrpc.Channel) - pb.RegisterCertificateAuthorityServer(ch, certificateAuthorityServer) + pb.RegisterCertificateAuthorityServer(ch, cas) grpcClient := pb.NewCertificateAuthorityClient(ch) // register grpc-proxy handler if err := pb.RegisterCertificateAuthorityHandlerClient(context.Background(), requestHandler.mux, grpcClient); err != nil { diff --git a/certificate-authority/service/http/revocationList.go b/certificate-authority/service/http/revocationList.go new file mode 100644 index 000000000..2afb4a9f6 --- /dev/null +++ b/certificate-authority/service/http/revocationList.go @@ -0,0 +1,89 @@ +package http + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/x509" + "errors" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" + "github.com/plgd-dev/hub/v2/certificate-authority/store" + "github.com/plgd-dev/hub/v2/http-gateway/serverMux" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgHttp "github.com/plgd-dev/hub/v2/pkg/net/http" + pkgTime "github.com/plgd-dev/hub/v2/pkg/time" + "google.golang.org/grpc/codes" +) + +func errCannotGetRevocationList(err error) error { + return pkgGrpc.ForwardErrorf(codes.Internal, "cannot get revocation list: %v", err) +} + +func createCRL(rl *store.RevocationList, issuer *x509.Certificate, priv crypto.Signer) ([]byte, error) { + number, err := store.ParseBigInt(rl.Number) + if err != nil { + return nil, err + } + template := &x509.RevocationList{ + Number: number, + ThisUpdate: pkgTime.Unix(0, rl.IssuedAt), + NextUpdate: pkgTime.Unix(0, rl.ValidUntil), + } + for _, c := range rl.Certificates { + sn, errP := store.ParseBigInt(c.Serial) + if errP != nil { + return nil, errP + } + template.RevokedCertificateEntries = append(template.RevokedCertificateEntries, x509.RevocationListEntry{ + SerialNumber: sn, + RevocationTime: pkgTime.Unix(0, c.Revocation), + }) + } + return x509.CreateRevocationList(rand.Reader, template, issuer, priv) +} + +func (requestHandler *RequestHandler) tryGetRevocationList(ctx context.Context, issuerID string, validFor time.Duration, tries int) (*store.RevocationList, error) { + for range tries { + rl, err := requestHandler.store.GetLatestIssuedOrIssueRevocationList(ctx, issuerID, validFor) + if err == nil { + return rl, nil + } + if errors.Is(err, store.ErrNotFound) { + continue + } + return nil, err + } + return nil, store.ErrNotFound +} + +func (requestHandler *RequestHandler) writeRevocationList(w http.ResponseWriter, r *http.Request) error { + vars := mux.Vars(r) + issuerID := vars[uri.IssuerIDKey] + if _, err := uuid.Parse(issuerID); err != nil { + return err + } + signer := requestHandler.cas.GetSigner() + _, validFor := signer.GetCRLConfiguation() + rl, err := requestHandler.tryGetRevocationList(r.Context(), issuerID, validFor, 3) + if err != nil { + return err + } + crl, err := createCRL(rl, signer.GetCertificate(), signer.GetPrivateKey()) + if err != nil { + return err + } + w.Header().Set(pkgHttp.ContentTypeHeaderKey, "application/pkix-crl") + _, err = w.Write(crl) + return err +} + +func (requestHandler *RequestHandler) revocationList(w http.ResponseWriter, r *http.Request) { + if err := requestHandler.writeRevocationList(w, r); err != nil { + serverMux.WriteError(w, errCannotGetRevocationList(err)) + } +} diff --git a/certificate-authority/service/http/revocationList_test.go b/certificate-authority/service/http/revocationList_test.go new file mode 100644 index 000000000..4f9051016 --- /dev/null +++ b/certificate-authority/service/http/revocationList_test.go @@ -0,0 +1,165 @@ +package http_test + +import ( + "context" + "crypto/x509" + "io" + "net/http" + "testing" + "time" + + certAuthURI "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" + "github.com/plgd-dev/hub/v2/certificate-authority/store" + "github.com/plgd-dev/hub/v2/certificate-authority/test" + httpgwTest "github.com/plgd-dev/hub/v2/http-gateway/test" + "github.com/plgd-dev/hub/v2/pkg/config/database" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgTime "github.com/plgd-dev/hub/v2/pkg/time" + "github.com/plgd-dev/hub/v2/test/config" + oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" + testService "github.com/plgd-dev/hub/v2/test/service" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" +) + +func checkRevocationList(t *testing.T, crl *x509.RevocationList, certificates []*store.RevocationListCertificate) { + require.NotEmpty(t, crl.ThisUpdate) + require.NotEmpty(t, crl.NextUpdate) + expected := make([]x509.RevocationListEntry, 0, len(certificates)) + for _, cert := range certificates { + serial, err := store.ParseBigInt(cert.Serial) + require.NoError(t, err) + expected = append(expected, x509.RevocationListEntry{ + SerialNumber: serial, + RevocationTime: pkgTime.Unix(pkgTime.Unix(0, cert.Revocation).Unix(), 0).UTC(), + }) + } + actual := make([]x509.RevocationListEntry, 0, len(crl.RevokedCertificateEntries)) + for _, cert := range crl.RevokedCertificateEntries { + newCert := cert + newCert.Raw = nil + actual = append(actual, newCert) + } + require.Equal(t, expected, actual) +} + +func addRevocationLists(ctx context.Context, t *testing.T, s store.Store) map[string]*store.RevocationList { + rlm := make(map[string]*store.RevocationList) + // valid + now := time.Now() + rl1 := &store.RevocationList{ + Id: test.GetIssuerID(0), + IssuedAt: now.Add(-time.Minute).UnixNano(), + ValidUntil: now.Add(time.Minute * 10).UnixNano(), + Number: "1", + } + for i := range 10 { + rlc := test.GetCertificate(i, now, now.Add(time.Hour)) + rl1.Certificates = append(rl1.Certificates, rlc) + } + rlm[rl1.Id] = rl1 + + // not issued + rl2 := &store.RevocationList{ + Id: test.GetIssuerID(1), + Number: "2", + } + for i := range 10 { + rlc := test.GetCertificate(i, now, now.Add(time.Hour)) + rl2.Certificates = append(rl2.Certificates, rlc) + } + rlm[rl2.Id] = rl2 + + // expired + + err := s.InsertRevocationLists(ctx, maps.Values(rlm)...) + require.NoError(t, err) + return rlm +} + +func TestRevocationList(t *testing.T) { + if config.ACTIVE_DATABASE() == database.CqlDB { + t.Skip("revocation list not supported for CqlDB") + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + shutDown := testService.SetUpServices(context.Background(), t, testService.SetUpServicesOAuth|testService.SetUpServicesMachine2MachineOAuth) + defer shutDown() + caShutdown := test.New(t, test.MakeConfig(t)) + defer caShutdown() + s, cleanUpStore := test.NewStore(t) + defer cleanUpStore() + + token := oauthTest.GetDefaultAccessToken(t) + ctx = pkgGrpc.CtxWithToken(ctx, token) + + stored := addRevocationLists(ctx, t, s) + + type args struct { + issuer string + } + tests := []struct { + name string + args args + verifyCRL func(crl *x509.RevocationList) + wantErr bool + }{ + { + name: "invalid issuerID", + args: args{ + issuer: "invalid", + }, + wantErr: true, + }, + { + name: "valid", + args: args{ + issuer: test.GetIssuerID(0), + }, + verifyCRL: func(crl *x509.RevocationList) { + var certificates []*store.RevocationListCertificate + for _, issuerCerts := range stored { + if issuerCerts.Id != test.GetIssuerID(0) { + continue + } + certificates = append(certificates, issuerCerts.Certificates...) + } + checkRevocationList(t, crl, certificates) + }, + }, + { + name: "valid - not issued", + args: args{ + issuer: test.GetIssuerID(1), + }, + verifyCRL: func(crl *x509.RevocationList) { + var certificates []*store.RevocationListCertificate + for _, issuerCerts := range stored { + if issuerCerts.Id != test.GetIssuerID(1) { + continue + } + certificates = append(certificates, issuerCerts.Certificates...) + } + checkRevocationList(t, crl, certificates) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := httpgwTest.NewRequest(http.MethodGet, certAuthURI.SigningRevocationList, nil).Host(config.CERTIFICATE_AUTHORITY_HTTP_HOST).AuthToken(token).AddIssuerID(tt.args.issuer).Build() + httpResp := httpgwTest.HTTPDo(t, request) + respBody, err := io.ReadAll(httpResp.Body) + require.NoError(t, err) + err = httpResp.Body.Close() + require.NoError(t, err) + crl, err := x509.ParseRevocationList(respBody) + if tt.wantErr { + require.Error(t, err) + return + } + tt.verifyCRL(crl) + }) + } +} diff --git a/certificate-authority/service/http/service.go b/certificate-authority/service/http/service.go index ab124b375..edd51755e 100644 --- a/certificate-authority/service/http/service.go +++ b/certificate-authority/service/http/service.go @@ -2,12 +2,16 @@ package http import ( "fmt" + "net/http" + "regexp" grpcService "github.com/plgd-dev/hub/v2/certificate-authority/service/grpc" - "github.com/plgd-dev/hub/v2/http-gateway/uri" + "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" + "github.com/plgd-dev/hub/v2/certificate-authority/store" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" kitNetHttp "github.com/plgd-dev/hub/v2/pkg/net/http" + pkgHttpJwt "github.com/plgd-dev/hub/v2/pkg/net/http/jwt" httpService "github.com/plgd-dev/hub/v2/pkg/net/http/service" "github.com/plgd-dev/hub/v2/pkg/security/jwt/validator" "go.opentelemetry.io/otel/trace" @@ -20,24 +24,32 @@ type Service struct { } // New parses configuration and creates new Server with provided store and bus -func New(serviceName string, config Config, ca *grpcService.CertificateAuthorityServer, validator *validator.Validator, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Service, error) { +func New(serviceName string, config Config, s store.Store, ca *grpcService.CertificateAuthorityServer, validator *validator.Validator, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Service, error) { + var whiteList []pkgHttpJwt.RequestMatcher + if config.CRLEnabled { + whiteList = append(whiteList, pkgHttpJwt.RequestMatcher{ + Method: http.MethodGet, + URI: regexp.MustCompile(regexp.QuoteMeta(uri.SigningRevocationListBase) + `\/.*`), + }) + } + service, err := httpService.New(httpService.Config{ - HTTPConnection: config.Connection, - HTTPServer: config.Server, - ServiceName: serviceName, - AuthRules: kitNetHttp.NewDefaultAuthorizationRules(uri.API), - // WhiteEndpointList: whiteList, - FileWatcher: fileWatcher, - Logger: logger, - TraceProvider: tracerProvider, - Validator: validator, - // QueryCaseInsensitive: map[string]string{}, + HTTPConnection: config.Connection, + HTTPServer: config.Server, + ServiceName: serviceName, + AuthRules: kitNetHttp.NewDefaultAuthorizationRules(uri.API), + WhiteEndpointList: whiteList, + FileWatcher: fileWatcher, + Logger: logger, + TraceProvider: tracerProvider, + Validator: validator, + QueryCaseInsensitive: uri.QueryCaseInsensitive, }) if err != nil { return nil, fmt.Errorf("cannot create http service: %w", err) } - requestHandler, err := NewRequestHandler(&config, service.GetRouter(), ca) + requestHandler, err := NewRequestHandler(&config, service.GetRouter(), ca, s) if err != nil { _ = service.Close() return nil, err diff --git a/certificate-authority/service/http/signCertificate_test.go b/certificate-authority/service/http/signCertificate_test.go index 10cfc1b36..ce64799d5 100644 --- a/certificate-authority/service/http/signCertificate_test.go +++ b/certificate-authority/service/http/signCertificate_test.go @@ -13,8 +13,9 @@ import ( "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/hub/v2/certificate-authority/pb" + certAuthURI "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" httpgwTest "github.com/plgd-dev/hub/v2/http-gateway/test" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" "github.com/plgd-dev/hub/v2/test/config" oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" "github.com/plgd-dev/hub/v2/test/service" @@ -26,11 +27,6 @@ import ( type ClientSignFunc = func(context.Context, *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) -const ( - URISignIdentityCertificate = "/api/v1/sign/identity-csr" - URISignCertificate = "/api/v1/sign/csr" -) - func testSigningByFunction(t *testing.T, signFn ClientSignFunc, csr []byte) { type args struct { req *pb.SignCertificateRequest @@ -55,7 +51,14 @@ func testSigningByFunction(t *testing.T, signFn ClientSignFunc, csr []byte) { CertificateSigningRequest: csr, }, }, - wantErr: false, + }, + { + name: "valid - new with the same csr", + args: args{ + req: &pb.SignCertificateRequest{ + CertificateSigningRequest: csr, + }, + }, }, } @@ -64,7 +67,7 @@ func testSigningByFunction(t *testing.T, signFn ClientSignFunc, csr []byte) { tearDown := service.SetUp(ctx, t) defer tearDown() - ctx = kitNetGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) + ctx = pkgGrpc.CtxWithToken(ctx, oauthTest.GetDefaultAccessToken(t)) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -80,7 +83,7 @@ func testSigningByFunction(t *testing.T, signFn ClientSignFunc, csr []byte) { } func httpDoSign(ctx context.Context, t *testing.T, uri string, req *pb.SignCertificateRequest, resp *pb.SignCertificateResponse) error { - token, err := kitNetGrpc.TokenFromOutgoingMD(ctx) + token, err := pkgGrpc.TokenFromOutgoingMD(ctx) require.NoError(t, err) reqBody, err := protojson.Marshal(req) require.NoError(t, err) @@ -110,7 +113,7 @@ func TestCertificateAuthorityServerSignCSR(t *testing.T) { require.NoError(t, err) testSigningByFunction(t, func(ctx context.Context, req *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) { var resp pb.SignCertificateResponse - return &resp, httpDoSign(ctx, t, URISignCertificate, req, &resp) + return &resp, httpDoSign(ctx, t, certAuthURI.SignCertificate, req, &resp) }, csr) } @@ -122,6 +125,6 @@ func TestCertificateAuthorityServerSignCSRWithEmptyCommonName(t *testing.T) { require.NoError(t, err) testSigningByFunction(t, func(ctx context.Context, req *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) { var resp pb.SignCertificateResponse - return &resp, httpDoSign(ctx, t, URISignCertificate, req, &resp) + return &resp, httpDoSign(ctx, t, certAuthURI.SignCertificate, req, &resp) }, csr) } diff --git a/certificate-authority/service/http/signIdentityCertificate_test.go b/certificate-authority/service/http/signIdentityCertificate_test.go index d42505067..2270674fa 100644 --- a/certificate-authority/service/http/signIdentityCertificate_test.go +++ b/certificate-authority/service/http/signIdentityCertificate_test.go @@ -9,6 +9,7 @@ import ( "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/hub/v2/certificate-authority/pb" + certAuthURI "github.com/plgd-dev/hub/v2/certificate-authority/service/uri" "github.com/stretchr/testify/require" ) @@ -21,7 +22,7 @@ func TestCertificateAuthorityServerSignIdentityCSR(t *testing.T) { require.NoError(t, err) testSigningByFunction(t, func(ctx context.Context, req *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) { var resp pb.SignCertificateResponse - return &resp, httpDoSign(ctx, t, URISignIdentityCertificate, req, &resp) + return &resp, httpDoSign(ctx, t, certAuthURI.SignIdentityCertificate, req, &resp) }, csr) } @@ -32,6 +33,6 @@ func TestCertificateAuthorityServerSignIdentityCSRWithEmptyCN(t *testing.T) { require.NoError(t, err) testSigningByFunction(t, func(ctx context.Context, req *pb.SignCertificateRequest) (*pb.SignCertificateResponse, error) { var resp pb.SignCertificateResponse - return &resp, httpDoSign(ctx, t, URISignIdentityCertificate, req, &resp) + return &resp, httpDoSign(ctx, t, certAuthURI.SignIdentityCertificate, req, &resp) }, csr) } diff --git a/certificate-authority/service/service.go b/certificate-authority/service/service.go index c4c240c5c..8e54df5ee 100644 --- a/certificate-authority/service/service.go +++ b/certificate-authority/service/service.go @@ -100,7 +100,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg } closerFn.AddFunc(closeStore) - ca, err := grpcService.NewCertificateAuthorityServer(config.APIs.GRPC.Authorization.OwnerClaim, config.HubID, config.Signer, dbStorage, fileWatcher, logger) + ca, err := grpcService.NewCertificateAuthorityServer(config.APIs.GRPC.Authorization.OwnerClaim, config.HubID, config.APIs.HTTP.ExternalAddress, config.Signer, dbStorage, fileWatcher, logger) if err != nil { closerFn.Execute() return nil, fmt.Errorf("cannot create grpc certificate authority server: %w", err) @@ -119,7 +119,8 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg }, Authorization: config.APIs.GRPC.Authorization.Config, Server: config.APIs.HTTP.Server, - }, ca, httpValidator, fileWatcher, logger, tracerProvider) + CRLEnabled: config.Signer.CRL.Enabled, + }, dbStorage, ca, httpValidator, fileWatcher, logger, tracerProvider) if err != nil { closerFn.Execute() return nil, fmt.Errorf("cannot create http service: %w", err) diff --git a/certificate-authority/service/uri/uri.go b/certificate-authority/service/uri/uri.go new file mode 100644 index 000000000..b6e7e2410 --- /dev/null +++ b/certificate-authority/service/uri/uri.go @@ -0,0 +1,20 @@ +package uri + +import "strings" + +const ( + API string = "/api/v1" + Sign string = API + "/sign" + + SignIdentityCertificate string = Sign + "/identity-csr" + SignCertificate string = Sign + "/csr" + + IssuerIDKey string = "issuerId" + + SigningRevocationListBase string = API + "/signing/crl" + SigningRevocationList string = SigningRevocationListBase + "/{" + IssuerIDKey + "}" +) + +var QueryCaseInsensitive = map[string]string{ + strings.ToLower(IssuerIDKey): IssuerIDKey, +} diff --git a/certificate-authority/store/cqldb/revocationList.go b/certificate-authority/store/cqldb/revocationList.go new file mode 100644 index 000000000..02d174f88 --- /dev/null +++ b/certificate-authority/store/cqldb/revocationList.go @@ -0,0 +1,24 @@ +package cqldb + +import ( + "context" + "time" + + "github.com/plgd-dev/hub/v2/certificate-authority/store" +) + +func (s *Store) SupportsRevocationList() bool { + return false +} + +func (s *Store) InsertRevocationLists(context.Context, ...*store.RevocationList) error { + return store.ErrNotSupported +} + +func (s *Store) UpdateRevocationList(context.Context, *store.UpdateRevocationListQuery) (*store.RevocationList, error) { + return nil, store.ErrNotSupported +} + +func (s *Store) GetLatestIssuedOrIssueRevocationList(context.Context, string, time.Duration) (*store.RevocationList, error) { + return nil, store.ErrNotSupported +} diff --git a/certificate-authority/store/cqldb/signingRecords.go b/certificate-authority/store/cqldb/signingRecords.go index feaa18b9b..3ec5d4687 100644 --- a/certificate-authority/store/cqldb/signingRecords.go +++ b/certificate-authority/store/cqldb/signingRecords.go @@ -339,15 +339,25 @@ func (s *Store) DeleteNonDeviceExpiredRecords(_ context.Context, _ time.Time) (i return 0, store.ErrNotSupported } -func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, h store.LoadSigningRecordsFunc) error { +func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, p store.Process[store.SigningRecord]) error { i := SigningRecordsIterator{ ctx: ctx, s: s, queries: toSigningRecordsQueryFilter(owner, query, true), provided: make(map[string]struct{}, 32), } - err := h(ctx, &i) - + var err error + for { + var stored store.SigningRecord + if !i.Next(ctx, &stored) { + err = i.Err() + break + } + err = p(&stored) + if err != nil { + break + } + } errClose := i.close() if err == nil { return errClose @@ -355,6 +365,14 @@ func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *sto return err } +func (s *Store) RevokeSigningRecords(ctx context.Context, ownerID string, query *store.RevokeSigningRecordsQuery) (int64, error) { + // TODO: revocation list not yet supported by cqldb, so for now we just delete the records + return s.DeleteSigningRecords(ctx, ownerID, &store.DeleteSigningRecordsQuery{ + IdFilter: query.GetIdFilter(), + DeviceIdFilter: query.GetDeviceIdFilter(), + }) +} + type SigningRecordsIterator struct { ctx context.Context queries []string diff --git a/certificate-authority/store/cqldb/signingRecords_test.go b/certificate-authority/store/cqldb/signingRecords_test.go index f2421c41d..a953edbf9 100644 --- a/certificate-authority/store/cqldb/signingRecords_test.go +++ b/certificate-authority/store/cqldb/signingRecords_test.go @@ -2,6 +2,7 @@ package cqldb_test import ( "context" + "math/big" "strconv" "sync" "testing" @@ -31,6 +32,8 @@ func TestStoreInsertSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano() - 1, ValidUntilDate: date.UnixNano() - 1, + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, } tests := []struct { @@ -49,6 +52,8 @@ func TestStoreInsertSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -102,6 +107,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano() - 1, ValidUntilDate: date.UnixNano() - 1, + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, } tests := []struct { @@ -120,6 +127,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -138,6 +147,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -155,6 +166,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate1", Date: date1.UnixNano(), ValidUntilDate: date1.UnixNano(), + Serial: big.NewInt(43).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -177,7 +190,7 @@ func TestStoreUpdateSigningRecord(t *testing.T) { var h testSigningRecordHandler err = s.LoadSigningRecords(ctx, tt.args.sub.GetOwner(), &pb.GetSigningRecordsRequest{ IdFilter: []string{tt.args.sub.GetId()}, - }, h.Handle) + }, h.process) require.NoError(t, err) require.Len(t, h.lcs, 1) hubTest.CheckProtobufs(t, tt.args.sub, h.lcs[0], hubTest.RequireToCheckFunc(require.Equal)) @@ -278,6 +291,8 @@ func TestStoreDeleteSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) require.NoError(t, err) @@ -292,6 +307,8 @@ func TestStoreDeleteSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(43).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) require.NoError(t, err) @@ -306,6 +323,8 @@ func TestStoreDeleteSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(44).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) require.NoError(t, err) @@ -341,11 +360,13 @@ func TestStoreDeleteExpiredRecords(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) require.NoError(t, err) var h testSigningRecordHandler - err = s.LoadSigningRecords(ctx, "owner", nil, h.Handle) + err = s.LoadSigningRecords(ctx, "owner", nil, h.process) require.NoError(t, err) require.Len(t, h.lcs, 1) time.Sleep(time.Second * 3) @@ -353,7 +374,7 @@ func TestStoreDeleteExpiredRecords(t *testing.T) { require.Error(t, err) require.Equal(t, store.ErrNotSupported, err) var h1 testSigningRecordHandler - err = s.LoadSigningRecords(ctx, "owner", nil, h1.Handle) + err = s.LoadSigningRecords(ctx, "owner", nil, h1.process) require.NoError(t, err) require.Empty(t, h1.lcs) } @@ -362,15 +383,9 @@ type testSigningRecordHandler struct { lcs pb.SigningRecords } -func (h *testSigningRecordHandler) Handle(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var sub store.SigningRecord - if !iter.Next(ctx, &sub) { - break - } - h.lcs = append(h.lcs, &sub) - } - return iter.Err() +func (h *testSigningRecordHandler) process(sr *store.SigningRecord) (err error) { + h.lcs = append(h.lcs, sr) + return nil } func TestStoreLoadSigningRecords(t *testing.T) { @@ -390,6 +405,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, { @@ -403,6 +420,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(43).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, { @@ -416,6 +435,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(44).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, } @@ -518,7 +539,7 @@ func TestStoreLoadSigningRecords(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var h testSigningRecordHandler - err := s.LoadSigningRecords(ctx, "owner", tt.args.query, h.Handle) + err := s.LoadSigningRecords(ctx, "owner", tt.args.query, h.process) if tt.wantErr { require.Error(t, err) return @@ -550,6 +571,8 @@ func BenchmarkSigningRecords(b *testing.B) { CertificatePem: "certificate", Date: date.UnixNano(), ValidUntilDate: date.UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) } diff --git a/certificate-authority/store/cqldb/store.go b/certificate-authority/store/cqldb/store.go index ea8fcc084..3d3e4c0ab 100644 --- a/certificate-authority/store/cqldb/store.go +++ b/certificate-authority/store/cqldb/store.go @@ -38,7 +38,7 @@ type Store struct { } func New(ctx context.Context, config *Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, error) { - certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger) + certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/certificate-authority/store/mongodb/bulkWriter.go b/certificate-authority/store/mongodb/bulkWriter.go deleted file mode 100644 index 795ed327e..000000000 --- a/certificate-authority/store/mongodb/bulkWriter.go +++ /dev/null @@ -1,229 +0,0 @@ -package mongodb - -import ( - "context" - "sync" - "time" - - "github.com/hashicorp/go-multierror" - "github.com/plgd-dev/hub/v2/certificate-authority/store" - "github.com/plgd-dev/hub/v2/pkg/log" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type bulkWriter struct { - col *mongo.Collection - documentLimit uint16 // https://www.mongodb.com/docs/manual/reference/limits/#mongodb-limit-Write-Command-Batch-Limit-Size - must be <= 100000 - throttleTime time.Duration - flushTimeout time.Duration - logger log.Logger - - done chan struct{} - trigger chan bool - - mutex sync.Mutex - models map[string]*store.SigningRecord - wg sync.WaitGroup -} - -func newBulkWriter(col *mongo.Collection, documentLimit uint16, throttleTime time.Duration, flushTimeout time.Duration, logger log.Logger) *bulkWriter { - r := &bulkWriter{ - col: col, - documentLimit: documentLimit, - throttleTime: throttleTime, - flushTimeout: flushTimeout, - done: make(chan struct{}), - trigger: make(chan bool, 1), - logger: logger, - } - - r.wg.Add(1) - go func() { - defer r.wg.Done() - r.run() - }() - return r -} - -func toSigningRecordFilter(signingRecord *store.SigningRecord) bson.M { - res := bson.M{"_id": signingRecord.GetId()} - return res -} - -func getSigningRecordCreationDate(defaultTime time.Time, signingRecord *store.SigningRecord) int64 { - ret := defaultTime.UTC().UnixNano() - if signingRecord.GetCredential().GetDate() > 0 && signingRecord.GetCredential().GetDate() < ret { - ret = signingRecord.GetCredential().GetDate() - } - return ret -} - -func setValueByDate(key, datePath string, dateOperator string, date int64, value interface{}) bson.M { - return bson.M{ - "$set": bson.M{ - key: bson.M{ - "$ifNull": bson.A{ - bson.M{ - "$cond": bson.M{ - "if": bson.M{ - dateOperator: bson.A{"$" + datePath, date}, - }, - "then": value, - "else": "$" + key, - }, - }, value, - }, - }, - }, - } -} - -func updateSigningRecord(signingRecord *store.SigningRecord) []bson.M { - creationDate := signingRecord.GetCreationDate() - if creationDate == 0 { - creationDate = getSigningRecordCreationDate(time.Now(), signingRecord) - } - ret := []bson.M{ - {"$set": bson.M{ - "_id": signingRecord.GetId(), - store.CommonNameKey: signingRecord.GetCommonName(), - store.OwnerKey: signingRecord.GetOwner(), - store.PublicKeyKey: signingRecord.GetPublicKey(), - }}, - } - ret = append(ret, setValueByDate(store.CreationDateKey, store.CreationDateKey, "$gt", creationDate, creationDate)) - if signingRecord.GetCredential() != nil { - ret = append(ret, setValueByDate(store.CredentialKey, store.CredentialKey+"."+store.DateKey, "$lt", signingRecord.GetCredential().GetDate(), signingRecord.GetCredential())) - } - return ret -} - -func convertSigningRecordToWriteModel(signingRecord *store.SigningRecord) mongo.WriteModel { - return mongo.NewUpdateOneModel().SetFilter(toSigningRecordFilter(signingRecord)).SetUpdate(updateSigningRecord(signingRecord)).SetUpsert(true) -} - -func mergeLatestUpdateSigningRecord(toUpdate *store.SigningRecord, latest *store.SigningRecord) *store.SigningRecord { - if toUpdate == nil { - return latest - } - if latest.GetCredential().GetDate() > toUpdate.GetCredential().GetDate() { - toUpdate.Credential = latest.GetCredential() - } - if latest.GetCreationDate() < toUpdate.GetCreationDate() { - toUpdate.CreationDate = latest.GetCreationDate() - if toUpdate.GetCommonName() == "" { - toUpdate.CommonName = latest.GetCommonName() - } - if toUpdate.GetOwner() == "" { - toUpdate.Owner = latest.GetOwner() - } - if toUpdate.GetPublicKey() == "" { - toUpdate.PublicKey = latest.GetPublicKey() - } - } - return toUpdate -} - -func (b *bulkWriter) popSigningRecords() map[string]*store.SigningRecord { - b.mutex.Lock() - defer b.mutex.Unlock() - models := b.models - b.models = nil - return models -} - -func (b *bulkWriter) Push(signingRecords ...*store.SigningRecord) { - b.mutex.Lock() - defer b.mutex.Unlock() - if b.models == nil { - b.models = make(map[string]*store.SigningRecord) - } - for _, signingRecord := range signingRecords { - b.models[signingRecord.GetId()] = mergeLatestUpdateSigningRecord(b.models[signingRecord.GetId()], signingRecord) - } - select { - case b.trigger <- true: - default: - } -} - -func (b *bulkWriter) numSigningRecords() int { - b.mutex.Lock() - defer b.mutex.Unlock() - return len(b.models) -} - -func (b *bulkWriter) run() { - ticker := time.NewTicker(b.throttleTime) - tickerRunning := true - defer ticker.Stop() - for { - select { - case <-ticker.C: - if b.tryBulkWrite() == 0 && tickerRunning { - ticker.Stop() - tickerRunning = false - } - case <-b.trigger: - if !tickerRunning { - ticker.Reset(b.throttleTime) - tickerRunning = true - } - if b.numSigningRecords() > int(b.documentLimit) { - b.tryBulkWrite() - } - case <-b.done: - return - } - } -} - -func (b *bulkWriter) bulkWrite() (int, error) { - SigningRecords := b.popSigningRecords() - if len(SigningRecords) == 0 { - return 0, nil - } - ctx := context.Background() - if b.flushTimeout != 0 { - ctx1, cancel := context.WithTimeout(context.Background(), b.flushTimeout) - defer cancel() - ctx = ctx1 - } - m := make([]mongo.WriteModel, 0, int(b.documentLimit)+1) - - var errors *multierror.Error - for _, SigningRecord := range SigningRecords { - m = append(m, convertSigningRecordToWriteModel(SigningRecord)) - if b.documentLimit == 0 || len(m)%int(b.documentLimit) == 0 { - _, err := b.col.BulkWrite(ctx, m, options.BulkWrite().SetOrdered(false)) - if err != nil { - errors = multierror.Append(errors, err) - } - m = m[:0] - } - } - - if len(m) > 0 { - _, err := b.col.BulkWrite(ctx, m, options.BulkWrite().SetOrdered(false)) - if err != nil { - errors = multierror.Append(errors, err) - } - } - return len(SigningRecords), errors.ErrorOrNil() -} - -func (b *bulkWriter) tryBulkWrite() int { - n, err := b.bulkWrite() - if err != nil { - b.logger.Errorf("failed to bulk update Signing records: %w", err) - } - return n -} - -func (b *bulkWriter) Close() { - close(b.done) - b.wg.Wait() - b.tryBulkWrite() -} diff --git a/certificate-authority/store/mongodb/config.go b/certificate-authority/store/mongodb/config.go index b2d839864..8034a68c7 100644 --- a/certificate-authority/store/mongodb/config.go +++ b/certificate-authority/store/mongodb/config.go @@ -1,38 +1,13 @@ package mongodb import ( - "fmt" - "time" - pkgMongo "github.com/plgd-dev/hub/v2/pkg/mongodb" ) -const minDuration = time.Millisecond * 100 - -type BulkWriteConfig struct { - Timeout time.Duration `yaml:"timeout"` - ThrottleTime time.Duration `yaml:"throttleTime"` - DocumentLimit uint16 `yaml:"documentLimit"` -} - -func (c *BulkWriteConfig) Validate() error { - if c.Timeout <= minDuration { - return fmt.Errorf("timeout('%v')", c.Timeout) - } - if c.ThrottleTime <= minDuration { - return fmt.Errorf("throttleTime('%v')", c.ThrottleTime) - } - return nil -} - type Config struct { - Mongo pkgMongo.Config `yaml:",inline"` - BulkWrite BulkWriteConfig `yaml:"bulkWrite"` + Mongo pkgMongo.Config `yaml:",inline"` } func (c *Config) Validate() error { - if err := c.BulkWrite.Validate(); err != nil { - return fmt.Errorf("bulkWrite.%w", err) - } return c.Mongo.Validate() } diff --git a/certificate-authority/store/mongodb/revocationList.go b/certificate-authority/store/mongodb/revocationList.go new file mode 100644 index 000000000..28db5c7c1 --- /dev/null +++ b/certificate-authority/store/mongodb/revocationList.go @@ -0,0 +1,231 @@ +package mongodb + +import ( + "context" + "errors" + "fmt" + "math/big" + "time" + + "github.com/plgd-dev/hub/v2/certificate-authority/store" + "github.com/plgd-dev/hub/v2/pkg/mongodb" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "golang.org/x/exp/maps" +) + +const revocationListCol = "revocationList" + +func (s *Store) SupportsRevocationList() bool { + return true +} + +func (s *Store) InsertRevocationLists(ctx context.Context, rls ...*store.RevocationList) error { + documents := make([]interface{}, 0, len(rls)) + for _, rl := range rls { + if err := rl.Validate(); err != nil { + return err + } + documents = append(documents, rl) + } + _, err := s.Collection(revocationListCol).InsertMany(ctx, documents) + if err != nil && mongo.IsDuplicateKeyError(err) { + return fmt.Errorf("%w: %w", store.ErrDuplicateID, err) + } + return err +} + +type revocationListUpdate struct { + originalRevocationList *store.RevocationList + certificatesToInsert map[string]*store.RevocationListCertificate +} + +// check the database and remove serials that are already in the array +func (s *Store) getRevocationListUpdate(ctx context.Context, query *store.UpdateRevocationListQuery) (revocationListUpdate, bool, error) { + cmap := make(map[string]*store.RevocationListCertificate) + for _, cert := range query.RevokedCertificates { + if _, ok := cmap[cert.Serial]; ok { + s.logger.Debugf("skipping duplicate serial number(%v) in query", cert.Serial) + continue + } + if err := cert.Validate(); err != nil { + return revocationListUpdate{}, false, err + } + cmap[cert.Serial] = cert + } + pl := mongo.Pipeline{ + bson.D{{Key: mongodb.Match, Value: bson.D{{Key: "_id", Value: query.IssuerID}}}}, + } + if len(cmap) > 0 { + pl = append(pl, bson.D{{Key: "$addFields", Value: bson.M{ + "duplicates": bson.M{ + "$filter": bson.M{ + "input": "$" + store.CertificatesKey, + "as": "cert", + "cond": bson.M{mongodb.In: bson.A{"$$cert." + store.SerialKey, maps.Keys(cmap)}}, + }, + }, + }}}) + } + cur, err := s.Collection(revocationListCol).Aggregate(ctx, pl) + if err != nil { + return revocationListUpdate{}, false, err + } + type revocationListWithNewCertificates struct { + *store.RevocationList `bson:",inline"` + Duplicates []*store.RevocationListCertificate `bson:"duplicates,omitempty"` + } + var rl *revocationListWithNewCertificates + count, err := processCursor(ctx, cur, func(item *revocationListWithNewCertificates) error { + rl = item + return nil + }) + if err != nil { + return revocationListUpdate{}, false, err + } + if count == 0 { + return revocationListUpdate{ + certificatesToInsert: cmap, + }, true, nil + } + for _, c := range rl.Duplicates { + s.logger.Debugf("skipping duplicate serial number(%v)", c.Serial) + delete(cmap, c.Serial) + } + if len(cmap) == 0 && (!query.UpdateIfExpired || !rl.IsExpired()) { + return revocationListUpdate{ + originalRevocationList: rl.RevocationList, + }, false, nil + } + return revocationListUpdate{ + originalRevocationList: rl.RevocationList, + certificatesToInsert: cmap, + }, true, nil +} + +func (s *Store) UpdateRevocationList(ctx context.Context, query *store.UpdateRevocationListQuery) (*store.RevocationList, error) { + if err := query.Validate(); err != nil { + return nil, err + } + upd, needsUpdate, err := s.getRevocationListUpdate(ctx, query) + if err != nil { + return nil, err + } + if !needsUpdate { + return upd.originalRevocationList, nil + } + + if upd.originalRevocationList == nil { + newRL := &store.RevocationList{ + Id: query.IssuerID, + Number: "1", // the sequence for the CRL number field starts from 1 + IssuedAt: query.IssuedAt, + ValidUntil: query.ValidUntil, + Certificates: maps.Values(upd.certificatesToInsert), + } + if err = s.InsertRevocationLists(ctx, newRL); err != nil { + if mongo.IsDuplicateKeyError(err) { + return nil, fmt.Errorf("%w: %w", store.ErrDuplicateID, err) + } + return nil, err + } + return newRL, nil + } + + number, err := store.ParseBigInt(upd.originalRevocationList.Number) + if err != nil { + return nil, err + } + filter := bson.M{ + "_id": query.IssuerID, + store.NumberKey: number.String(), + } + + nextNumber := number + // for not issued (IssuedAt == 0) we don't need to increment the Number, it was already incremented when + // the list was updated and the IssuedAt was set to 0 + if upd.originalRevocationList.IssuedAt != 0 { + nextNumber = nextNumber.Add(nextNumber, big.NewInt(1)) + } + update := bson.M{ + "$set": bson.M{ + store.NumberKey: nextNumber.String(), + store.IssuedAtKey: query.IssuedAt, + store.ValidUntilKey: query.ValidUntil, + }, + } + if len(upd.certificatesToInsert) > 0 { + update["$push"] = bson.M{ + store.CertificatesKey: bson.M{"$each": maps.Values(upd.certificatesToInsert)}, + } + } + opts := options.FindOneAndUpdate().SetReturnDocument(options.After) + var updatedRL store.RevocationList + if err = s.Collection(revocationListCol).FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedRL); err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, fmt.Errorf("%w: %w", store.ErrNotFound, err) + } + return nil, err + } + return &updatedRL, nil +} + +func (s *Store) GetRevocationList(ctx context.Context, issuerID string, includeExpired bool) (*store.RevocationList, error) { + now := time.Now().UnixNano() + filter := bson.M{ + "_id": issuerID, + } + var opts []*options.FindOneOptions + if !includeExpired { + filter[store.CertificatesKey] = bson.M{ + "$elemMatch": bson.M{ + store.ValidUntilKey: bson.M{"$gte": now}, // non-expired certificates + }, + } + projection := bson.M{ + "_id": 1, + store.NumberKey: 1, + store.IssuedAtKey: 1, + store.ValidUntilKey: 1, + store.CertificatesKey: bson.M{ + "$filter": bson.M{ + "input": "$" + store.CertificatesKey, + "as": "cert", + "cond": bson.M{ + "$gte": []interface{}{"$$cert." + store.ValidUntilKey, now}, // non-expired certificates + }, + }, + }, + } + opts = append(opts, options.FindOne().SetProjection(projection)) + } + + var rl store.RevocationList + err := s.Collection(revocationListCol).FindOne(ctx, filter, opts...).Decode(&rl) + if err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, store.ErrNotFound + } + return nil, err + } + return &rl, nil +} + +func (s *Store) GetLatestIssuedOrIssueRevocationList(ctx context.Context, issuerID string, validFor time.Duration) (*store.RevocationList, error) { + rl, err := s.GetRevocationList(ctx, issuerID, true) + if err != nil { + return nil, err + } + if rl.IssuedAt > 0 && !rl.IsExpired() { + return rl, nil + } + issuedAt := time.Now() + validUntil := issuedAt.Add(validFor) + return s.UpdateRevocationList(ctx, &store.UpdateRevocationListQuery{ + IssuerID: issuerID, + IssuedAt: issuedAt.UnixNano(), + ValidUntil: validUntil.UnixNano(), + UpdateIfExpired: true, + }) +} diff --git a/certificate-authority/store/mongodb/revocationList_test.go b/certificate-authority/store/mongodb/revocationList_test.go new file mode 100644 index 000000000..65b4aaa8d --- /dev/null +++ b/certificate-authority/store/mongodb/revocationList_test.go @@ -0,0 +1,341 @@ +package mongodb_test + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/hub/v2/certificate-authority/store" + "github.com/plgd-dev/hub/v2/certificate-authority/store/mongodb" + "github.com/plgd-dev/hub/v2/certificate-authority/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUpdateRevocationList(t *testing.T) { + s, cleanUpStore := test.NewMongoStore(t) + defer cleanUpStore() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + id := uuid.NewString() + id2 := uuid.NewString() + id3 := uuid.NewString() + cert1 := &store.RevocationListCertificate{ + Serial: "1", + ValidUntil: time.Now().Add(time.Hour).Unix(), + Revocation: time.Now().Unix(), + } + rl1 := store.RevocationList{ + Id: id, + Number: "1", + IssuedAt: time.Now().UnixNano(), + ValidUntil: time.Now().Add(time.Minute).UnixNano(), + Certificates: []*store.RevocationListCertificate{cert1}, + } + cert2 := &store.RevocationListCertificate{ + Serial: "2", + ValidUntil: time.Now().Add(time.Hour).Unix(), + Revocation: time.Now().Unix(), + } + cert3 := &store.RevocationListCertificate{ + Serial: "2", + ValidUntil: time.Now().Add(time.Hour).Unix(), + Revocation: time.Now().Unix(), + } + rl3 := store.RevocationList{ + Id: id3, + Number: "1", + IssuedAt: time.Now().Add(-time.Minute).UnixNano(), + ValidUntil: time.Now().UnixNano(), + } + type args struct { + query store.UpdateRevocationListQuery + } + tests := []struct { + name string + args args + want *store.RevocationList + wantErr bool + }{ + { + name: "missing ID", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: "", + RevokedCertificates: []*store.RevocationListCertificate{cert1}, + }, + }, + wantErr: true, + }, + { + name: "missing serial number", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: id, + RevokedCertificates: []*store.RevocationListCertificate{{ + Revocation: time.Now().UnixNano(), + }}, + }, + }, + wantErr: true, + }, + { + name: "missing revocation time", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: id, + RevokedCertificates: []*store.RevocationListCertificate{{ + Serial: "1", + }}, + }, + }, + wantErr: true, + }, + { + name: "valid - new document", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: rl1.Id, + RevokedCertificates: rl1.Certificates, + IssuedAt: rl1.IssuedAt, + ValidUntil: rl1.ValidUntil, + }, + }, + want: &rl1, + }, + { + name: "valid - add to existing document", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: id, + RevokedCertificates: []*store.RevocationListCertificate{cert2}, + }, + }, + want: &store.RevocationList{ + Id: id, + Number: "2", + Certificates: []*store.RevocationListCertificate{ + cert1, + cert2, + }, + }, + }, + { + name: "valid - duplicate serial, noop", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: id, + RevokedCertificates: []*store.RevocationListCertificate{{ + Serial: cert2.Serial, + ValidUntil: time.Now().Add(time.Hour).Unix(), + Revocation: time.Now().Unix(), + }}, + }, + }, + want: &store.RevocationList{ + Id: id, + Number: "2", + Certificates: []*store.RevocationListCertificate{ + cert1, + cert2, + }, + }, + }, + { + name: "valid - different issuer, existing serial", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: id2, + RevokedCertificates: []*store.RevocationListCertificate{cert3}, + }, + }, + want: &store.RevocationList{ + Id: id2, + Number: "1", + Certificates: []*store.RevocationListCertificate{cert3}, + }, + }, + { + name: "valid - no certificates, set to expired", + args: args{ + query: store.UpdateRevocationListQuery{ + IssuerID: rl3.Id, + IssuedAt: rl3.IssuedAt, + ValidUntil: rl3.ValidUntil, + }, + }, + want: &rl3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + updatedRL, err := s.UpdateRevocationList(ctx, &tt.args.query) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + test.CheckRevocationList(t, tt.want, updatedRL, false) + }) + } +} + +func TestParallelUpdateRevocationList(t *testing.T) { + s, cleanUpStore := test.NewMongoStore(t) + defer cleanUpStore() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + issuerID := uuid.NewString() + firstCount := 10 + secondCount := 10 + certificates := make([]*store.RevocationListCertificate, firstCount+secondCount) + for i := range firstCount + secondCount { + certificates[i] = test.GetCertificate(i, time.Now(), time.Now().Add(time.Hour)) + } + + // create or update + createOrUpdateRevocationList(ctx, t, 0, firstCount, certificates, issuerID, s) + + rl, err := s.GetLatestIssuedOrIssueRevocationList(ctx, issuerID, time.Hour) + require.NoError(t, err) + require.NotEmpty(t, rl.IssuedAt) + require.NotEmpty(t, rl.ValidUntil) + expected := &store.RevocationList{ + Id: issuerID, + Number: "1", + IssuedAt: rl.IssuedAt, + ValidUntil: rl.ValidUntil, + Certificates: certificates[:10], + } + test.CheckRevocationList(t, expected, rl, true) + + createOrUpdateRevocationList(ctx, t, firstCount, secondCount, certificates, issuerID, s) + + rl, err = s.GetLatestIssuedOrIssueRevocationList(ctx, issuerID, time.Hour) + require.NoError(t, err) + require.NotEmpty(t, rl.IssuedAt) + require.NotEmpty(t, rl.ValidUntil) + expected = &store.RevocationList{ + Id: issuerID, + Number: "2", + IssuedAt: rl.IssuedAt, + ValidUntil: rl.ValidUntil, + Certificates: certificates, + } + test.CheckRevocationList(t, expected, rl, true) +} + +func createOrUpdateRevocationList(ctx context.Context, t *testing.T, start, count int, certificates []*store.RevocationListCertificate, issuerID string, s *mongodb.Store) { + var failed atomic.Bool + failed.Store(false) + var wg sync.WaitGroup + wg.Add(10) + for i := start; i < start+count; i++ { + go func(index int) { + defer wg.Done() + cert := certificates[index] + var err error + // parallel execution should eventually succeed in cases when we get duplicate _id + // or not found errors + for range 100 { + q := &store.UpdateRevocationListQuery{ + IssuerID: issuerID, + RevokedCertificates: []*store.RevocationListCertificate{cert}, + } + _, err = s.UpdateRevocationList(ctx, q) + if errors.Is(err, store.ErrDuplicateID) || errors.Is(err, store.ErrNotFound) { + continue + } + if err == nil { + break + } + failed.Store(true) + assert.NoError(t, err) + } + assert.NoError(t, err) + }(i) + } + wg.Wait() + require.False(t, failed.Load()) +} + +func TestGetRevocationList(t *testing.T) { + s, cleanUpStore := test.NewMongoStore(t) + defer cleanUpStore() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + stored := test.AddRevocationListToStore(ctx, t, s, time.Now().Add(-2*time.Hour-time.Minute)) + + type args struct { + issuerID string + includeExpired bool + } + tests := []struct { + name string + args args + want *store.RevocationList + wantErr bool + }{ + { + name: "no matching ID", + args: args{ + issuerID: "00000000-0000-0000-0000-123456789012", + }, + wantErr: true, + }, + { + name: "all from issuer0", + args: args{ + issuerID: test.GetIssuerID(0), + includeExpired: true, + }, + want: func() *store.RevocationList { + expected, ok := stored[test.GetIssuerID(0)] + require.True(t, ok) + return expected + }(), + }, + { + name: "no valid from issuer0", + args: args{ + issuerID: test.GetIssuerID(0), + }, + wantErr: true, + }, + { + name: "non-expired from issuer4", + args: args{ + issuerID: test.GetIssuerID(4), + }, + want: func() *store.RevocationList { + expected, ok := stored[test.GetIssuerID(4)] + require.True(t, ok) + return expected + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + retrieved, err := s.GetRevocationList(ctx, tt.args.issuerID, tt.args.includeExpired) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NoError(t, err) + test.CheckRevocationList(t, tt.want, retrieved, false) + }) + } +} diff --git a/certificate-authority/store/mongodb/signingRecords.go b/certificate-authority/store/mongodb/signingRecords.go index 1217f9ca7..c3188f21c 100644 --- a/certificate-authority/store/mongodb/signingRecords.go +++ b/certificate-authority/store/mongodb/signingRecords.go @@ -6,9 +6,12 @@ import ( "time" "github.com/hashicorp/go-multierror" + "github.com/plgd-dev/hub/v2/certificate-authority/pb" "github.com/plgd-dev/hub/v2/certificate-authority/store" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "golang.org/x/exp/maps" ) const signingRecordsCol = "signedCertificateRecords" @@ -20,20 +23,20 @@ func (s *Store) CreateSigningRecord(ctx context.Context, signingRecord *store.Si return err } _, err := s.Collection(signingRecordsCol).InsertOne(ctx, signingRecord) - if err != nil { - return err - } - - return nil + return err } -func (s *Store) UpdateSigningRecord(_ context.Context, signingRecord *store.SigningRecord) error { +func (s *Store) UpdateSigningRecord(ctx context.Context, signingRecord *store.SigningRecord) error { if err := signingRecord.Validate(); err != nil { return err } - - s.bulkWriter.Push(signingRecord) - return nil + filter := bson.M{"_id": signingRecord.GetId()} + upsert := true + opts := &options.UpdateOptions{ + Upsert: &upsert, + } + _, err := s.Collection(signingRecordsCol).UpdateOne(ctx, filter, bson.M{"$set": signingRecord}, opts) + return err } func toCommonNameQueryFilter(owner string, commonName string) bson.D { @@ -112,44 +115,75 @@ func (s *Store) DeleteNonDeviceExpiredRecords(ctx context.Context, now time.Time if err != nil { return -1, multierror.Append(ErrCannotRemoveSigningRecord, err) } - return res.DeletedCount, nil } -func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, h store.LoadSigningRecordsFunc) error { - col := s.Collection(signingRecordsCol) - iter, err := col.Find(ctx, toSigningRecordsQueryFilter(owner, query)) - if errors.Is(err, mongo.ErrNilDocument) { - return nil +func (s *Store) RevokeSigningRecords(ctx context.Context, ownerID string, query *store.RevokeSigningRecordsQuery) (int64, error) { + now := time.Now().UnixNano() + // get signing records to be deleted + type issuersRecord struct { + ids []string + certificates []*store.RevocationListCertificate } + idFilter := make(map[string]struct{}) + irs := make(map[string]issuersRecord) + err := s.LoadSigningRecords(ctx, ownerID, &pb.GetSigningRecordsRequest{ + IdFilter: query.GetIdFilter(), + DeviceIdFilter: query.GetDeviceIdFilter(), + }, func(v *pb.SigningRecord) error { + credential := v.GetCredential() + if credential == nil { + return nil + } + idFilter[v.GetId()] = struct{}{} + if credential.GetValidUntilDate() <= now { + return nil + } + record := irs[credential.GetIssuerId()] + record.ids = append(record.ids, v.GetId()) + record.certificates = append(record.certificates, &store.RevocationListCertificate{ + Serial: credential.GetSerial(), + ValidUntil: credential.GetValidUntilDate(), + Revocation: now, + }) + irs[credential.GetIssuerId()] = record + return nil + }) if err != nil { - return err + return 0, err } - i := SigningRecordsIterator{ - iter: iter, + // add certificates for the signing records to revocation lists + for issuerID, record := range irs { + query := store.UpdateRevocationListQuery{ + IssuerID: issuerID, + RevokedCertificates: record.certificates, + } + _, err := s.UpdateRevocationList(ctx, &query) + if err != nil { + return 0, err + } } - err = h(ctx, &i) - errClose := iter.Close(ctx) - if err == nil { - return errClose + if len(idFilter) == 0 { + return 0, nil } - return err -} -type SigningRecordsIterator struct { - iter *mongo.Cursor + // delete the signing records + return s.DeleteSigningRecords(ctx, ownerID, &pb.DeleteSigningRecordsRequest{ + IdFilter: maps.Keys(idFilter), + }) } -func (i *SigningRecordsIterator) Next(ctx context.Context, s *store.SigningRecord) bool { - if !i.iter.Next(ctx) { - return false +func (s *Store) LoadSigningRecords(ctx context.Context, owner string, query *store.SigningRecordsQuery, p store.Process[store.SigningRecord]) error { + col := s.Collection(signingRecordsCol) + cur, err := col.Find(ctx, toSigningRecordsQueryFilter(owner, query)) + if err != nil { + if errors.Is(err, mongo.ErrNilDocument) { + return nil + } + return err } - err := i.iter.Decode(s) - return err == nil -} - -func (i *SigningRecordsIterator) Err() error { - return i.iter.Err() + _, err = processCursor(ctx, cur, p) + return err } diff --git a/certificate-authority/store/mongodb/signingRecords_test.go b/certificate-authority/store/mongodb/signingRecords_test.go index 312f1a29b..b83852161 100644 --- a/certificate-authority/store/mongodb/signingRecords_test.go +++ b/certificate-authority/store/mongodb/signingRecords_test.go @@ -2,8 +2,7 @@ package mongodb_test import ( "context" - "strconv" - "sync" + "math/big" "testing" "time" @@ -11,7 +10,6 @@ import ( "github.com/plgd-dev/hub/v2/certificate-authority/store" "github.com/plgd-dev/hub/v2/certificate-authority/test" hubTest "github.com/plgd-dev/hub/v2/test" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,6 +41,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -61,6 +61,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -78,6 +80,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { CertificatePem: "certificate1", Date: constDate1().UnixNano(), ValidUntilDate: constDate1().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, }, @@ -94,10 +98,8 @@ func TestStoreUpdateSigningRecord(t *testing.T) { err := s.UpdateSigningRecord(ctx, tt.args.sub) if tt.wantErr { require.Error(t, err) - } else { - require.NoError(t, err) + return } - err = s.FlushBulkWriter() require.NoError(t, err) }) } @@ -212,6 +214,8 @@ func TestStoreDeleteSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) require.NoError(t, err) @@ -276,6 +280,8 @@ func TestStoreDeleteExpiredRecords(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }) require.NoError(t, err) @@ -293,15 +299,9 @@ type testSigningRecordHandler struct { lcs pb.SigningRecords } -func (h *testSigningRecordHandler) Handle(ctx context.Context, iter store.SigningRecordIter) (err error) { - for { - var sub store.SigningRecord - if !iter.Next(ctx, &sub) { - break - } - h.lcs = append(h.lcs, &sub) - } - return iter.Err() +func (h *testSigningRecordHandler) process(sr *store.SigningRecord) (err error) { + h.lcs = append(h.lcs, sr) + return nil } func TestStoreLoadSigningRecords(t *testing.T) { @@ -323,6 +323,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, { @@ -336,6 +338,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, { @@ -349,6 +353,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, { @@ -362,6 +368,8 @@ func TestStoreLoadSigningRecords(t *testing.T) { CertificatePem: "certificate", Date: constDate().UnixNano(), ValidUntilDate: constDate().UnixNano(), + Serial: big.NewInt(42).String(), + IssuerId: "42424242-4242-4242-4242-424242424242", }, }, } @@ -466,7 +474,7 @@ func TestStoreLoadSigningRecords(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var h testSigningRecordHandler - err := s.LoadSigningRecords(ctx, tt.args.owner, tt.args.query, h.Handle) + err := s.LoadSigningRecords(ctx, tt.args.owner, tt.args.query, h.process) require.NoError(t, err) require.Len(t, h.lcs, len(tt.want)) h.lcs.Sort() @@ -478,47 +486,3 @@ func TestStoreLoadSigningRecords(t *testing.T) { }) } } - -func BenchmarkSigningRecords(b *testing.B) { - data := make([]*store.SigningRecord, 0, 5001) - dataCap := cap(data) - for i := 0; i < dataCap; i++ { - data = append(data, &store.SigningRecord{ - Id: hubTest.GenerateDeviceIDbyIdx(i), - Owner: "owner", - CommonName: "commonName" + strconv.Itoa(i), - CreationDate: constDate().UnixNano(), - PublicKey: "publicKey", - Credential: &pb.CredentialStatus{ - CertificatePem: "certificate", - Date: constDate().UnixNano(), - ValidUntilDate: constDate().UnixNano(), - }, - }) - } - - ctx := context.Background() - b.ResetTimer() - s, cleanUpStore := test.NewMongoStore(b) - defer cleanUpStore() - for i := uint32(0); i < uint32(b.N); i++ { - b.StopTimer() - err := s.Clear(ctx) - require.NoError(b, err) - b.StartTimer() - func() { - var wg sync.WaitGroup - wg.Add(len(data)) - for _, l := range data { - go func(l *pb.SigningRecord) { - defer wg.Done() - err := s.UpdateSigningRecord(ctx, l) - assert.NoError(b, err) - }(l) - } - wg.Wait() - err := s.FlushBulkWriter() - assert.NoError(b, err) - }() - } -} diff --git a/certificate-authority/store/mongodb/store.go b/certificate-authority/store/mongodb/store.go index 72f535d39..7cadf88e5 100644 --- a/certificate-authority/store/mongodb/store.go +++ b/certificate-authority/store/mongodb/store.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/hashicorp/go-multierror" "github.com/plgd-dev/hub/v2/certificate-authority/store" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" @@ -16,7 +17,7 @@ import ( type Store struct { *pkgMongo.Store - bulkWriter *bulkWriter + logger log.Logger } var deviceIDKeyQueryIndex = mongo.IndexModel{ @@ -33,22 +34,62 @@ var commonNameKeyQueryIndex = mongo.IndexModel{ }, } +type MongoIterator[T any] struct { + Cursor *mongo.Cursor +} + +func (i *MongoIterator[T]) Next(ctx context.Context, s *T) bool { + if !i.Cursor.Next(ctx) { + return false + } + err := i.Cursor.Decode(s) + return err == nil +} + +func (i *MongoIterator[T]) Err() error { + return i.Cursor.Err() +} + +func processCursor[T any](ctx context.Context, cr *mongo.Cursor, p store.Process[T]) (int, error) { + var errors *multierror.Error + iter := MongoIterator[T]{ + Cursor: cr, + } + count := 0 + for { + var stored T + if !iter.Next(ctx, &stored) { + break + } + err := p(&stored) + if err != nil { + errors = multierror.Append(errors, err) + break + } + count++ + } + errors = multierror.Append(errors, iter.Err()) + errClose := cr.Close(ctx) + errors = multierror.Append(errors, errClose) + return count, errors.ErrorOrNil() +} + func New(ctx context.Context, cfg *Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, error) { - certManager, err := client.New(cfg.Mongo.TLS, fileWatcher, logger) + certManager, err := client.New(cfg.Mongo.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } - m, err := pkgMongo.NewStore(ctx, &cfg.Mongo, certManager.GetTLSConfig(), tracerProvider) + m, err := pkgMongo.NewStoreWithCollections(ctx, &cfg.Mongo, certManager.GetTLSConfig(), tracerProvider, map[string][]mongo.IndexModel{ + signingRecordsCol: {commonNameKeyQueryIndex, deviceIDKeyQueryIndex}, + revocationListCol: nil, + }) if err != nil { certManager.Close() return nil, err } - bulkWriter := newBulkWriter(m.Collection(signingRecordsCol), cfg.BulkWrite.DocumentLimit, cfg.BulkWrite.ThrottleTime, cfg.BulkWrite.Timeout, logger) - s := Store{Store: m, bulkWriter: bulkWriter} - err = s.EnsureIndex(ctx, signingRecordsCol, commonNameKeyQueryIndex, deviceIDKeyQueryIndex) - if err != nil { - certManager.Close() - return nil, err + s := Store{ + Store: m, + logger: logger, } s.SetOnClear(s.clearDatabases) s.AddCloseFunc(certManager.Close) @@ -56,15 +97,12 @@ func New(ctx context.Context, cfg *Config, fileWatcher *fsnotify.Watcher, logger } func (s *Store) clearDatabases(ctx context.Context) error { - return s.Collection(signingRecordsCol).Drop(ctx) + var errs *multierror.Error + errs = multierror.Append(errs, s.Collection(signingRecordsCol).Drop(ctx)) + errs = multierror.Append(errs, s.Collection(revocationListCol).Drop(ctx)) + return errs.ErrorOrNil() } func (s *Store) Close(ctx context.Context) error { - s.bulkWriter.Close() return s.Store.Close(ctx) } - -func (s *Store) FlushBulkWriter() error { - _, err := s.bulkWriter.bulkWrite() - return err -} diff --git a/certificate-authority/store/revocationList.go b/certificate-authority/store/revocationList.go new file mode 100644 index 000000000..02ce79c7c --- /dev/null +++ b/certificate-authority/store/revocationList.go @@ -0,0 +1,83 @@ +package store + +import ( + "errors" + "fmt" + "math/big" + "time" + + "github.com/google/uuid" +) + +const ( + CertificatesKey = "certificates" // must match with RevocationList.Certificates bson tag + IssuedAtKey = "issuedAt" // must match with RevocationListCertificate.IssuedAt bson tag + NumberKey = "number" // must match with RevocationListCertificate.NumberKey bson tag + SerialKey = "serial" // must match with RevocationListCertificate.Serial bson tag + ValidUntilKey = "validUntil" // must match with RevocationListCertificate.ValidUntil bson tag + RevocationKey = "revocation" // must match with RevocationListCertificate.Revocation bson tag +) + +type RevocationListCertificate struct { + // Serial number + Serial string `bson:"serial"` + // Time until the record is valid in Unix nanoseconds timestamp format + ValidUntil int64 `bson:"validUntil,omitempty"` + // Revocation time of the certificate in Unix nanoseconds timestamp format. 0 means that the certificate hasn't been revoked. + Revocation int64 `bson:"revocation"` +} + +func (rlc *RevocationListCertificate) Validate() error { + if rlc.Serial == "" { + return errors.New("serial number not set") + } + if rlc.Revocation == 0 { + return errors.New("revocation time not set") + } + return nil +} + +type RevocationList struct { + // The record ID is determined by applying a formula that utilizes the public key of the issuer, and it is computed as uuid.NewSHA1(uuid.NameSpaceX500, publicKeyRaw). + Id string `bson:"_id"` + // Number is used to populate the X.509 v2 cRLNumber extension in the CRL, which should be a monotonically increasing sequence number for a given + // CRL scope and CRL issuer. + Number string `bson:"number"` + // Time when the CRL was issued in Unix timestamp format + IssuedAt int64 `bson:"issuedAt"` + // Time until the issued CRL is valid in Unix nanoseconds timestamp format + ValidUntil int64 `bson:"validUntil"` + // List of certificates issued by the issuer + Certificates []*RevocationListCertificate `bson:"certificates,omitempty"` +} + +func ParseBigInt(s string) (*big.Int, error) { + var number big.Int + if _, ok := number.SetString(s, 10); !ok { + return nil, fmt.Errorf("invalid numeric string(%v)", s) + } + return &number, nil +} + +// TODO: use some delta to check expiration +func (rl *RevocationList) IsExpired() bool { + return rl.ValidUntil <= time.Now().UnixNano() +} + +func (rl *RevocationList) Validate() error { + if _, err := uuid.Parse(rl.Id); err != nil { + return fmt.Errorf("invalid ID(%v): %w", rl.Id, err) + } + if (rl.IssuedAt == 0 && rl.ValidUntil != 0) || (rl.ValidUntil < rl.IssuedAt) { + return fmt.Errorf("invalid validity period timestamps(from %v to %v)", rl.IssuedAt, rl.ValidUntil) + } + if _, err := ParseBigInt(rl.Number); err != nil { + return err + } + for _, c := range rl.Certificates { + if err := c.Validate(); err != nil { + return err + } + } + return nil +} diff --git a/certificate-authority/store/revocationList_test.go b/certificate-authority/store/revocationList_test.go new file mode 100644 index 000000000..1327df431 --- /dev/null +++ b/certificate-authority/store/revocationList_test.go @@ -0,0 +1,145 @@ +package store_test + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/hub/v2/certificate-authority/store" + "github.com/stretchr/testify/require" +) + +func TestRevocationListCertificateValidate(t *testing.T) { + tests := []struct { + name string + input store.RevocationListCertificate + wantErr bool + }{ + { + name: "Valid certificate", + input: store.RevocationListCertificate{ + Serial: "12345", + Revocation: time.Now().UnixNano(), + }, + wantErr: false, + }, + { + name: "Missing serial number", + input: store.RevocationListCertificate{ + Serial: "", + Revocation: time.Now().UnixNano(), + }, + wantErr: true, + }, + { + name: "Missing revocation time", + input: store.RevocationListCertificate{ + Serial: "12345", + Revocation: 0, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestRevocationListValidate(t *testing.T) { + validCertificate := &store.RevocationListCertificate{ + Serial: "12345", + Revocation: time.Now().UnixNano(), + } + invalidCertificate := &store.RevocationListCertificate{ + Serial: "", + Revocation: time.Now().UnixNano(), + } + + tests := []struct { + name string + input store.RevocationList + wantErr bool + }{ + { + name: "Valid revocation list", + input: store.RevocationList{ + Id: uuid.New().String(), + IssuedAt: time.Now().UnixNano(), + ValidUntil: time.Now().Add(time.Minute).UnixNano(), + Number: "1", + Certificates: []*store.RevocationListCertificate{validCertificate}, + }, + wantErr: false, + }, + { + name: "Valid not-issued revocation list", + input: store.RevocationList{ + Id: uuid.New().String(), + Number: "1", + Certificates: []*store.RevocationListCertificate{validCertificate}, + }, + wantErr: false, + }, + { + name: "Invalid UUID", + input: store.RevocationList{ + Id: "invalid-uuid", + IssuedAt: time.Now().UnixNano(), + ValidUntil: time.Now().Add(time.Minute).UnixNano(), + Number: "1", + Certificates: []*store.RevocationListCertificate{validCertificate}, + }, + wantErr: true, + }, + { + name: "Missing issuedAt time", + input: store.RevocationList{ + Id: uuid.New().String(), + ValidUntil: time.Now().Add(time.Minute).UnixNano(), + Number: "1", + Certificates: []*store.RevocationListCertificate{validCertificate}, + }, + wantErr: true, + }, + { + name: "Missing validUntil time", + input: store.RevocationList{ + Id: uuid.New().String(), + IssuedAt: time.Now().UnixNano(), + Number: "1", + Certificates: []*store.RevocationListCertificate{validCertificate}, + }, + wantErr: true, + }, + { + name: "Invalid certificate in the list", + input: store.RevocationList{ + Id: uuid.New().String(), + IssuedAt: time.Now().UnixNano(), + ValidUntil: time.Now().Add(time.Minute).UnixNano(), + Number: "1", + Certificates: []*store.RevocationListCertificate{invalidCertificate}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.input.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/certificate-authority/store/store.go b/certificate-authority/store/store.go index 87bd470f6..bf356cfb9 100644 --- a/certificate-authority/store/store.go +++ b/certificate-authority/store/store.go @@ -3,25 +3,33 @@ package store import ( "context" "errors" + "fmt" "time" + "github.com/google/uuid" "github.com/plgd-dev/hub/v2/certificate-authority/pb" ) -var ErrNotSupported = errors.New("not supported") +var ( + ErrNotSupported = errors.New("not supported") + ErrNotFound = errors.New("no document found") + ErrDuplicateID = errors.New("duplicate ID") +) type ( + Process[T any] func(v *T) error + SigningRecordsQuery = pb.GetSigningRecordsRequest DeleteSigningRecordsQuery = pb.DeleteSigningRecordsRequest -) - -type SigningRecordIter interface { - Next(ctx context.Context, SigningRecord *SigningRecord) bool - Err() error -} + RevokeSigningRecordsQuery = pb.DeleteSigningRecordsRequest -type ( - LoadSigningRecordsFunc = func(ctx context.Context, iter SigningRecordIter) (err error) + UpdateRevocationListQuery struct { + IssuerID string + IssuedAt int64 // 0 is allowed, the timestamp will be generated when the CRL is first issued + ValidUntil int64 // 0 is allowed, the timestamp will be generated when the CRL is first issued + UpdateIfExpired bool + RevokedCertificates []*RevocationListCertificate + } ) type Store interface { @@ -30,11 +38,30 @@ type Store interface { // UpdateSigningRecord updates an existing signing record. If the record does not exist, it will create a new one. UpdateSigningRecord(ctx context.Context, record *SigningRecord) error DeleteSigningRecords(ctx context.Context, ownerID string, query *DeleteSigningRecordsQuery) (int64, error) - LoadSigningRecords(ctx context.Context, ownerID string, query *SigningRecordsQuery, h LoadSigningRecordsFunc) error + LoadSigningRecords(ctx context.Context, ownerID string, query *SigningRecordsQuery, p Process[SigningRecord]) error // DeleteNonDeviceExpiredRecords deletes all expired records that are not associated with a device. // For CqlDB, this is a no-op because expired records are deleted by Cassandra automatically. DeleteNonDeviceExpiredRecords(ctx context.Context, now time.Time) (int64, error) + // Check if the implementation supports the RevocationList feature + SupportsRevocationList() bool + // InsertRevocationLists adds revocations lists to the database + InsertRevocationLists(ctx context.Context, rls ...*RevocationList) error + // UpdateRevocationList updates revocation list number and validity and adds certificates to revocation list. Returns the updated revocation list. + UpdateRevocationList(ctx context.Context, query *UpdateRevocationListQuery) (*RevocationList, error) + // Get valid latest issued or issue a new one revocation list + GetLatestIssuedOrIssueRevocationList(ctx context.Context, issuerID string, validFor time.Duration) (*RevocationList, error) + + // Removed matched signing records and move them to a revocation list. + RevokeSigningRecords(ctx context.Context, ownerID string, query *RevokeSigningRecordsQuery) (int64, error) + Close(ctx context.Context) error } + +func (q *UpdateRevocationListQuery) Validate() error { + if _, err := uuid.Parse(q.IssuerID); err != nil { + return fmt.Errorf("invalid revocation list issuerID(%v): %w", q.IssuerID, err) + } + return nil +} diff --git a/certificate-authority/test/revocationList.go b/certificate-authority/test/revocationList.go new file mode 100644 index 000000000..2fcdb0988 --- /dev/null +++ b/certificate-authority/test/revocationList.go @@ -0,0 +1,88 @@ +package test + +import ( + "context" + "fmt" + "math/rand" + "sort" + "strconv" + "testing" + "time" + + "github.com/plgd-dev/hub/v2/certificate-authority/store" + pkgTime "github.com/plgd-dev/hub/v2/pkg/time" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" +) + +var ( + serial0 = rand.Int31() + serials = make(map[int]string) +) + +func GetIssuerID(i int) string { + return fmt.Sprintf("49000000-0000-0000-0000-%012d", i) +} + +func GetCertificateSerial(i int) string { + id, ok := serials[i] + if !ok { + id = strconv.FormatInt(int64(serial0)+int64(i), 10) + serials[i] = id + } + return id +} + +func GetCertificate(c int, rev, exp time.Time) *store.RevocationListCertificate { + return &store.RevocationListCertificate{ + Serial: GetCertificateSerial(c), + ValidUntil: pkgTime.UnixNano(exp), + Revocation: pkgTime.UnixNano(rev), + } +} + +func AddRevocationListToStore(ctx context.Context, t *testing.T, s store.Store, expirationStart time.Time) map[string]*store.RevocationList { + rlm := make(map[string]*store.RevocationList) + c := 0 + for i := range 10 { + now := time.Now() + rlID := GetIssuerID(i) + actual := &store.RevocationList{ + Id: rlID, + IssuedAt: now.UnixNano(), + ValidUntil: now.Add(time.Minute * 10).UnixNano(), + Number: strconv.Itoa(i), + } + exp := expirationStart.Add(time.Duration(i) * time.Hour) + for range 10 { + rlc := GetCertificate(c, now, exp) + actual.Certificates = append(actual.Certificates, rlc) + c++ + } + rlm[rlID] = actual + } + + err := s.InsertRevocationLists(ctx, maps.Values(rlm)...) + require.NoError(t, err) + return rlm +} + +func CheckRevocationList(t *testing.T, expected, actual *store.RevocationList, ignoreRevocationTime bool) { + require.Equal(t, expected.Number, actual.Number) + require.Equal(t, expected.IssuedAt, actual.IssuedAt) + require.Equal(t, expected.ValidUntil, actual.ValidUntil) + require.Len(t, actual.Certificates, len(expected.Certificates)) + sort.Slice(actual.Certificates, func(i, j int) bool { + return actual.Certificates[i].Serial < actual.Certificates[j].Serial + }) + sort.Slice(expected.Certificates, func(i, j int) bool { + return expected.Certificates[i].Serial < expected.Certificates[j].Serial + }) + for i := range actual.Certificates { + require.Equal(t, expected.Certificates[i].Serial, actual.Certificates[i].Serial) + require.Equal(t, expected.Certificates[i].ValidUntil, actual.Certificates[i].ValidUntil) + if !ignoreRevocationTime { + require.Equal(t, expected.Certificates[i].Revocation, actual.Certificates[i].Revocation) + } + } +} diff --git a/certificate-authority/test/service.go b/certificate-authority/test/service.go index dd5669513..560d255d1 100644 --- a/certificate-authority/test/service.go +++ b/certificate-authority/test/service.go @@ -7,6 +7,7 @@ import ( "time" "github.com/plgd-dev/hub/v2/certificate-authority/service" + "github.com/plgd-dev/hub/v2/certificate-authority/service/grpc" "github.com/plgd-dev/hub/v2/certificate-authority/store" storeConfig "github.com/plgd-dev/hub/v2/certificate-authority/store/config" storeCqlDB "github.com/plgd-dev/hub/v2/certificate-authority/store/cqldb" @@ -21,6 +22,26 @@ import ( "go.opentelemetry.io/otel/trace/noop" ) +func MakeHTTPConfig() service.HTTPConfig { + return service.HTTPConfig{ + ExternalAddress: "https://" + config.CERTIFICATE_AUTHORITY_HTTP_HOST, + Addr: config.CERTIFICATE_AUTHORITY_HTTP_HOST, + Server: config.MakeHttpServerConfig(), + } +} + +func MakeCRLConfig() grpc.CRLConfig { + if config.ACTIVE_DATABASE() == database.MongoDB { + return grpc.CRLConfig{ + Enabled: true, + ExpiresIn: time.Hour, + } + } + return grpc.CRLConfig{ + Enabled: false, + } +} + func MakeConfig(t require.TestingT) service.Config { var cfg service.Config @@ -28,14 +49,14 @@ func MakeConfig(t require.TestingT) service.Config { cfg.Log = log.MakeDefaultConfig() cfg.APIs.GRPC = config.MakeGrpcServerConfig(config.CERTIFICATE_AUTHORITY_HOST) - cfg.APIs.HTTP.Addr = config.CERTIFICATE_AUTHORITY_HTTP_HOST - cfg.APIs.HTTP.Server = config.MakeHttpServerConfig() cfg.APIs.GRPC.TLS.ClientCertificateRequired = false + cfg.APIs.HTTP = MakeHTTPConfig() cfg.Signer.CAPool = []urischeme.URIScheme{urischeme.URIScheme(os.Getenv("TEST_ROOT_CA_CERT"))} cfg.Signer.KeyFile = urischeme.URIScheme(os.Getenv("TEST_ROOT_CA_KEY")) cfg.Signer.CertFile = urischeme.URIScheme(os.Getenv("TEST_ROOT_CA_CERT")) cfg.Signer.ValidFrom = "now-1h" cfg.Signer.ExpiresIn = time.Hour * 2 + cfg.Signer.CRL = MakeCRLConfig() cfg.Clients.OpenTelemetryCollector = config.MakeOpenTelemetryCollectorClient() cfg.Clients.Storage = MakeStorageConfig() @@ -88,11 +109,6 @@ func MakeStorageConfig() service.StorageConfig { Database: "certificateAuthority", TLS: config.MakeTLSClientConfig(), }, - BulkWrite: storeMongo.BulkWriteConfig{ - Timeout: time.Minute, - ThrottleTime: time.Millisecond * 500, - DocumentLimit: 1000, - }, }, CqlDB: &storeCqlDB.Config{ Embedded: config.MakeCqlDBConfig(), diff --git a/charts/plgd-hub/README.md b/charts/plgd-hub/README.md index 1ca616b16..930bc7146 100644 --- a/charts/plgd-hub/README.md +++ b/charts/plgd-hub/README.md @@ -79,9 +79,6 @@ global: | certificateauthority.clients.storage.cqlDB.tls.keyFile | string | `nil` | | | certificateauthority.clients.storage.cqlDB.tls.useSystemCAPool | bool | `false` | | | certificateauthority.clients.storage.cqlDB.useHostnameResolution | bool | `true` | Resolve IP address to hostname before validate certificate. If false, the TLS validator will use ip/hostname advertised by the Cassandra node. | -| certificateauthority.clients.storage.mongoDB.bulkWrite.documentLimit | int | `1000` | The maximum number of documents to cache before an immediate write. | -| certificateauthority.clients.storage.mongoDB.bulkWrite.throttleTime | string | `"500ms"` | The amount of time to wait until a record is written to mongodb. Any records collected during the throttle time will also be written. A throttle time of zero writes immediately. If recordLimit is reached, all records are written immediately | -| certificateauthority.clients.storage.mongoDB.bulkWrite.timeout | string | `"1m0s"` | A time limit for write bulk to mongodb. A Timeout of zero means no timeout. | | certificateauthority.clients.storage.mongoDB.database | string | `"certificateAuthorityService"` | | | certificateauthority.clients.storage.mongoDB.maxConnIdleTime | string | `"4m0s"` | | | certificateauthority.clients.storage.mongoDB.maxPoolSize | int | `16` | | diff --git a/charts/plgd-hub/templates/certificate-authority/config.yaml b/charts/plgd-hub/templates/certificate-authority/config.yaml index a72d72326..af49dec39 100644 --- a/charts/plgd-hub/templates/certificate-authority/config.yaml +++ b/charts/plgd-hub/templates/certificate-authority/config.yaml @@ -61,10 +61,6 @@ data: {{- $mongoDbTls := .clients.storage.mongoDB.tls }} {{- include "plgd-hub.internalCertificateConfig" (list $ $mongoDbTls $cert ) | indent 10 }} useSystemCAPool: {{ .clients.storage.mongoDB.tls.useSystemCAPool }} - bulkWrite: - timeout: {{ .clients.storage.mongoDB.bulkWrite.timeout | quote }} - throttleTime: {{ .clients.storage.mongoDB.bulkWrite.throttleTime | quote }} - documentLimit: {{ .clients.storage.mongoDB.bulkWrite.documentLimit }} cqlDB: hosts: {{- include "plgd-hub.cqlDBHosts" (list $ .clients.storage.cqlDB.hosts ) | indent 8 }} diff --git a/charts/plgd-hub/values.yaml b/charts/plgd-hub/values.yaml index 4ec9e6f4b..3fbaaeccf 100644 --- a/charts/plgd-hub/values.yaml +++ b/charts/plgd-hub/values.yaml @@ -2412,13 +2412,6 @@ certificateauthority: keyFile: certFile: useSystemCAPool: false - bulkWrite: - # -- A time limit for write bulk to mongodb. A Timeout of zero means no timeout. - timeout: 1m0s - # -- The amount of time to wait until a record is written to mongodb. Any records collected during the throttle time will also be written. A throttle time of zero writes immediately. If recordLimit is reached, all records are written immediately - throttleTime: 500ms - # -- The maximum number of documents to cache before an immediate write. - documentLimit: 1000 cqlDB: table: signedCertificateRecords hosts: [] diff --git a/cloud2cloud-connector/service/service.go b/cloud2cloud-connector/service/service.go index b7c1d5078..976e93e9b 100644 --- a/cloud2cloud-connector/service/service.go +++ b/cloud2cloud-connector/service/service.go @@ -89,9 +89,9 @@ func newIdentityStoreClient(config IdentityStoreConfig, fileWatcher *fsnotify.Wa return pbIS.NewIdentityStoreClient(isConn.GRPC()), closeIsConn, nil } -func newSubscriber(config natsClient.ConfigSubscriber, fileWatcher *fsnotify.Watcher, logger log.Logger) (*subscriber.Subscriber, func(), error) { +func newSubscriber(config natsClient.ConfigSubscriber, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*subscriber.Subscriber, func(), error) { var fl fn.FuncList - nats, err := natsClient.New(config.Config, fileWatcher, logger) + nats, err := natsClient.New(config.Config, fileWatcher, logger, tp) if err != nil { return nil, nil, fmt.Errorf("cannot create nats client: %w", err) } @@ -110,7 +110,7 @@ func newSubscriber(config natsClient.ConfigSubscriber, fileWatcher *fsnotify.Wat func newStore(ctx context.Context, config pkgMongo.Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, func(), error) { var fl fn.FuncList - certManager, err := cmClient.New(config.TLS, fileWatcher, logger) + certManager, err := cmClient.New(config.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, nil, fmt.Errorf("cannot create cert manager: %w", err) } @@ -175,7 +175,7 @@ func newDevicesSubscription(ctx context.Context, config Config, raClient raServi } fl.AddFunc(closeGrpcClient) - sub, closeSub, err := newSubscriber(config.Clients.Eventbus.NATS, fileWatcher, logger) + sub, closeSub, err := newSubscriber(config.Clients.Eventbus.NATS, fileWatcher, logger, tracerProvider) if err != nil { fl.Execute() return nil, nil, fmt.Errorf("cannot create subscriber: %w", err) diff --git a/cloud2cloud-connector/test/test.go b/cloud2cloud-connector/test/test.go index e07c3ce6d..a3772bbc1 100644 --- a/cloud2cloud-connector/test/test.go +++ b/cloud2cloud-connector/test/test.go @@ -159,7 +159,7 @@ func NewMongoStore(t *testing.T) (*mongodb.Store, func()) { fileWatcher, err := fsnotify.NewWatcher(logger) require.NoError(t, err) - certManager, err := cmClient.New(cfg.Clients.Storage.MongoDB.TLS, fileWatcher, logger) + certManager, err := cmClient.New(cfg.Clients.Storage.MongoDB.TLS, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) ctx := context.Background() diff --git a/cloud2cloud-gateway/service/emitEvent.go b/cloud2cloud-gateway/service/emitEvent.go index c1e453438..79084b69b 100644 --- a/cloud2cloud-gateway/service/emitEvent.go +++ b/cloud2cloud-gateway/service/emitEvent.go @@ -13,6 +13,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" cmClient "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" + "go.opentelemetry.io/otel/trace" ) type ( @@ -88,8 +89,8 @@ func makeEmitEventRequest(ctx context.Context, eventType events.EventType, s sto return req, nil } -func createEmitEventFunc(cfg cmClient.Config, timeout time.Duration, fileWatcher *fsnotify.Watcher, logger log.Logger) (emitEventFunc, func(), error) { - certManager, err := cmClient.New(cfg, fileWatcher, logger) +func createEmitEventFunc(cfg cmClient.Config, timeout time.Duration, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (emitEventFunc, func(), error) { + certManager, err := cmClient.New(cfg, fileWatcher, logger, tp) if err != nil { return nil, nil, fmt.Errorf("cannot create cert manager: %w", err) } diff --git a/cloud2cloud-gateway/service/service.go b/cloud2cloud-gateway/service/service.go index 44f38897b..d55894234 100644 --- a/cloud2cloud-gateway/service/service.go +++ b/cloud2cloud-gateway/service/service.go @@ -153,9 +153,9 @@ func newGrpcGatewayClient(config GrpcGatewayConfig, fileWatcher *fsnotify.Watche return client, fl.ToFunction(), nil } -func newResourceSubscriber(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*subscriber.Subscriber, func(), error) { +func newResourceSubscriber(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*subscriber.Subscriber, func(), error) { var fl fn.FuncList - nats, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger) + nats, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger, tp) if err != nil { return nil, nil, fmt.Errorf("cannot create nats client: %w", err) } @@ -199,7 +199,7 @@ func newResourceAggregateClient(config ResourceAggregateConfig, subscriber *subs func newSubscriptionManager(ctx context.Context, cfg Config, gwClient pbGRPC.GrpcGatewayClient, emitEvent emitEventFunc, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*SubscriptionManager, func(), error) { var fl fn.FuncList - certManager, err := cmClient.New(cfg.Clients.Storage.MongoDB.TLS, fileWatcher, logger) + certManager, err := cmClient.New(cfg.Clients.Storage.MongoDB.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, nil, fmt.Errorf("cannot create cert manager: %w", err) } @@ -259,7 +259,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg } listener.AddCloseFunc(closeGwClient) - subscriber, closeSubscriberFn, err := newResourceSubscriber(config, fileWatcher, logger) + subscriber, closeSubscriberFn, err := newResourceSubscriber(config, fileWatcher, logger, tracerProvider) if err != nil { closeListener() return nil, fmt.Errorf("cannot create resource subscriber: %w", err) @@ -273,7 +273,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg } listener.AddCloseFunc(closeRaClient) - emitEvent, closeEmitEventFn, err := createEmitEventFunc(config.Clients.Subscription.HTTP.TLS, config.Clients.Subscription.HTTP.EmitEventTimeout, fileWatcher, logger) + emitEvent, closeEmitEventFn, err := createEmitEventFunc(config.Clients.Subscription.HTTP.TLS, config.Clients.Subscription.HTTP.EmitEventTimeout, fileWatcher, logger, tracerProvider) if err != nil { closeListener() return nil, fmt.Errorf("cannot create emit event function: %w", err) diff --git a/cloud2cloud-gateway/store/mongodb/subscription_test.go b/cloud2cloud-gateway/store/mongodb/subscription_test.go index d4be5ce64..d3ff28f88 100644 --- a/cloud2cloud-gateway/store/mongodb/subscription_test.go +++ b/cloud2cloud-gateway/store/mongodb/subscription_test.go @@ -24,7 +24,7 @@ func newTestStore(t *testing.T) (*mongodb.Store, func()) { fileWatcher, err := fsnotify.NewWatcher(logger) require.NoError(t, err) - certManager, err := client.New(cfg.Clients.Storage.MongoDB.TLS, fileWatcher, logger) + certManager, err := client.New(cfg.Clients.Storage.MongoDB.TLS, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) ctx := context.Background() diff --git a/coap-gateway/config.yaml b/coap-gateway/config.yaml index 6d5d4dbba..d3b8aaa18 100644 --- a/coap-gateway/config.yaml +++ b/coap-gateway/config.yaml @@ -32,6 +32,19 @@ apis: certFile: "/secrets/public/cert.crt" clientCertificateRequired: true identityPropertiesRequired: true + crl: + enabled: true + http: + maxIdleConns: 16 + maxConnsPerHost: 32 + maxIdleConnsPerHost: 16 + idleConnTimeout: "30s" + timeout: "10s" + tls: + caPool: "/secrets/public/rootca.crt" + keyFile: "/secrets/private/cert.key" + certFile: "/secrets/public/cert.crt" + useSystemCAPool: false authorization: ownerClaim: "sub" deviceIDClaim: "" diff --git a/coap-gateway/service/auth.go b/coap-gateway/service/auth.go index 72003e5b5..3b8d9fee7 100644 --- a/coap-gateway/service/auth.go +++ b/coap-gateway/service/auth.go @@ -18,7 +18,10 @@ import ( pkgX509 "github.com/plgd-dev/hub/v2/pkg/security/x509" ) -type Interceptor = func(ctx context.Context, code codes.Code, path string) (context.Context, error) +type ( + Interceptor = func(ctx context.Context, code codes.Code, path string) (context.Context, error) + VerifyByCRL = func(context.Context, *x509.Certificate, []string) error +) func newAuthInterceptor() Interceptor { return func(ctx context.Context, _ codes.Code, path string) (context.Context, error) { @@ -80,7 +83,7 @@ func (s *Service) VerifyAndResolveDeviceID(tlsDeviceID, paramDeviceID string, cl return deviceID, nil } -func verifyChain(chain []*x509.Certificate, capool *x509.CertPool, identityPropertiesRequired bool) error { +func verifyChain(ctx context.Context, chain []*x509.Certificate, capool *x509.CertPool, identityPropertiesRequired bool, verifyByCRL VerifyByCRL) error { if len(chain) == 0 { return errors.New("certificate chain is empty") } @@ -115,17 +118,26 @@ func verifyChain(chain []*x509.Certificate, capool *x509.CertPool, identityPrope if !ekuHasServer { return errors.New("the extended key usage field in the device certificate does not contain server authentication") } - if !identityPropertiesRequired { - return nil + + if identityPropertiesRequired { + _, err = coap.GetDeviceIDFromIdentityCertificate(certificate) + if err != nil { + return fmt.Errorf("the device ID is not part of the certificate's common name: %w", err) + } } - _, err = coap.GetDeviceIDFromIdentityCertificate(certificate) - if err != nil { - return fmt.Errorf("the device ID is not part of the certificate's common name: %w", err) + + if len(certificate.CRLDistributionPoints) > 0 { + if verifyByCRL == nil { + return errors.New("failed to check certificate validity by CRL") + } + if err = verifyByCRL(ctx, certificate, certificate.CRLDistributionPoints); err != nil { + return err + } } return nil } -func MakeGetConfigForClient(tlsCfg *tls.Config, identityPropertiesRequired bool) tls.Config { +func MakeGetConfigForClient(ctx context.Context, tlsCfg *tls.Config, identityPropertiesRequired bool, verifyByCRL VerifyByCRL) tls.Config { return tls.Config{ GetCertificate: tlsCfg.GetCertificate, MinVersion: tlsCfg.MinVersion, @@ -134,7 +146,7 @@ func MakeGetConfigForClient(tlsCfg *tls.Config, identityPropertiesRequired bool) VerifyPeerCertificate: func(_ [][]byte, chains [][]*x509.Certificate) error { var errs *multierror.Error for _, chain := range chains { - err := verifyChain(chain, tlsCfg.ClientCAs, identityPropertiesRequired) + err := verifyChain(ctx, chain, tlsCfg.ClientCAs, identityPropertiesRequired, verifyByCRL) if err == nil { return nil } diff --git a/coap-gateway/service/auth_test.go b/coap-gateway/service/auth_test.go new file mode 100644 index 000000000..35456d0ae --- /dev/null +++ b/coap-gateway/service/auth_test.go @@ -0,0 +1,75 @@ +//go:build test +// +build test + +package service_test + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "testing" + "time" + + pbCA "github.com/plgd-dev/hub/v2/certificate-authority/pb" + coapgwTest "github.com/plgd-dev/hub/v2/coap-gateway/test" + "github.com/plgd-dev/hub/v2/pkg/config/database" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + "github.com/plgd-dev/hub/v2/test" + "github.com/plgd-dev/hub/v2/test/config" + oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +func TestCertificateWithCRL(t *testing.T) { + if config.ACTIVE_DATABASE() == database.CqlDB { + t.Skip("revocation list not supported for CqlDB") + } + coapgwCfg := coapgwTest.MakeConfig(t) + coapgwCfg.APIs.COAP.TLS.Enabled = new(bool) + *coapgwCfg.APIs.COAP.TLS.Enabled = true + coapgwCfg.APIs.COAP.TLS.Embedded.ClientCertificateRequired = true + shutdown := setUp(t, coapgwCfg) + defer shutdown() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30*20) + defer cancel() + tokenWithoutDeviceID := oauthTest.GetDefaultAccessToken(t) + ctx = pkgGrpc.CtxWithToken(ctx, tokenWithoutDeviceID) + conn, err := grpc.NewClient(config.CERTIFICATE_AUTHORITY_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + RootCAs: test.GetRootCertificatePool(t), + }))) + require.NoError(t, err) + caClient := pbCA.NewCertificateAuthorityClient(conn) + + signerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + cg := coapgwTest.NewCACertificateGenerator(caClient, signerKey) + + crt, err := cg.GetIdentityCertificate(ctx, CertIdentity) + require.NoError(t, err) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{crt}, + InsecureSkipVerify: true, + } + co := testCoapDialWithHandler(t, makeTestCoapHandler(t), WithTLSConfig(tlsConfig)) + require.NotEmpty(t, co) + testSignUp(t, CertIdentity, co) + _ = co.Close() + + // revoke all certs for device + resp, err := caClient.DeleteSigningRecords(ctx, &pbCA.DeleteSigningRecordsRequest{ + DeviceIdFilter: []string{CertIdentity}, + }) + require.NoError(t, err) + require.Equal(t, int64(1), resp.Count) + // sign-up with revoked cert should fail + co = testCoapDialWithHandler(t, makeTestCoapHandler(t), WithTLSConfig(tlsConfig)) + require.NotEmpty(t, co) + _, err = doSignUp(t, CertIdentity, co) + _ = co.Close() + require.Error(t, err) +} diff --git a/coap-gateway/service/config.go b/coap-gateway/service/config.go index 9c3695a74..6673ee6e5 100644 --- a/coap-gateway/service/config.go +++ b/coap-gateway/service/config.go @@ -133,10 +133,6 @@ type InjectedCOAPConfig struct { TLSConfig InjectedTLSConfig `yaml:"tls" json:"tls"` } -func (c *InjectedCOAPConfig) Validate() error { - return nil -} - type DeviceTwinConfig struct { MaxETagsCountInRequest uint32 `yaml:"maxETagsCountInRequest" json:"maxETagsCountInRequest"` UseETags bool `yaml:"useETags" json:"useETags"` @@ -183,9 +179,6 @@ func (c *COAPConfigMarshalerUnmarshaler) Validate() error { if err := c.COAPConfig.Validate(); err != nil { return err } - if err := c.InjectedCOAPConfig.Validate(); err != nil { - return err - } if !c.InjectedCOAPConfig.TLSConfig.IdentityPropertiesRequired && c.Authorization.DeviceIDClaim != "" { return fmt.Errorf("tls.identityPropertiesRequired('%v') - %w", c.InjectedCOAPConfig.TLSConfig.IdentityPropertiesRequired, errors.New("combination with authorization.deviceIDClaim is not supported")) } diff --git a/coap-gateway/service/service.go b/coap-gateway/service/service.go index beefc5bd2..4caa06b86 100644 --- a/coap-gateway/service/service.go +++ b/coap-gateway/service/service.go @@ -236,7 +236,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg return nil, fmt.Errorf("cannot create job queue %w", err) } - nats, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger) + nats, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger, tracerProvider) if err != nil { otelClient.Close() queue.Release() @@ -621,8 +621,8 @@ func (s *Service) createServices(fileWatcher *fsnotify.Watcher, logger log.Logge coapService.WithOnNewConnection(s.coapConnOnNew), coapService.WithOnInactivityConnection(s.onInactivityConnection), coapService.WithMessagePool(s.messagePool), - coapService.WithOverrideTLS(func(cfg *tls.Config) *tls.Config { - tlsCfg := MakeGetConfigForClient(cfg, s.config.APIs.COAP.InjectedCOAPConfig.TLSConfig.IdentityPropertiesRequired) + coapService.WithOverrideTLS(func(cfg *tls.Config, verifyByCRL VerifyByCRL) *tls.Config { + tlsCfg := MakeGetConfigForClient(s.ctx, cfg, s.config.APIs.COAP.InjectedCOAPConfig.TLSConfig.IdentityPropertiesRequired, verifyByCRL) return &tlsCfg }), ) diff --git a/coap-gateway/service/signIn_test.go b/coap-gateway/service/signIn_test.go index c93de8701..49eb9b8c5 100644 --- a/coap-gateway/service/signIn_test.go +++ b/coap-gateway/service/signIn_test.go @@ -20,7 +20,7 @@ import ( coapgwTest "github.com/plgd-dev/hub/v2/coap-gateway/test" "github.com/plgd-dev/hub/v2/coap-gateway/uri" "github.com/plgd-dev/hub/v2/grpc-gateway/pb" - kitNetGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" + pkgGrpc "github.com/plgd-dev/hub/v2/pkg/net/grpc" test "github.com/plgd-dev/hub/v2/test" "github.com/plgd-dev/hub/v2/test/config" oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" @@ -73,7 +73,7 @@ func TestSignInDeviceSubscriptionHandler(t *testing.T) { shutdown := setUp(t) defer shutdown() - ctx := kitNetGrpc.CtxWithToken(context.Background(), oauthTest.GetDefaultAccessToken(t)) + ctx := pkgGrpc.CtxWithToken(context.Background(), oauthTest.GetDefaultAccessToken(t)) conn, err := grpc.NewClient(config.GRPC_GW_HOST, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ RootCAs: test.GetRootCertificatePool(t), }))) @@ -142,7 +142,7 @@ func TestDontCreateObservationAfterRefreshTokenAndSignIn(t *testing.T) { h := makeTestCoapHandler(t) observedPath := make(map[string]struct{}) - co := testCoapDialWithHandler(t, "", true, true, time.Now().Add(time.Minute), func(w *responsewriter.ResponseWriter[*coapTcpClient.Conn], r *pool.Message) { + co := testCoapDialWithHandler(t, func(w *responsewriter.ResponseWriter[*coapTcpClient.Conn], r *pool.Message) { if r.Code() != coapCodes.GET { h(w, r) return @@ -160,7 +160,7 @@ func TestDontCreateObservationAfterRefreshTokenAndSignIn(t *testing.T) { } else { require.NoError(t, errors.New("cannot observe the same resource twice")) } - }) + }, WithGenerateTLS("", true, time.Now().Add(time.Minute))) if co == nil { return } diff --git a/coap-gateway/service/utils_test.go b/coap-gateway/service/utils_test.go index 2a11aa712..a49cce1dc 100644 --- a/coap-gateway/service/utils_test.go +++ b/coap-gateway/service/utils_test.go @@ -6,12 +6,8 @@ package service_test import ( "bytes" "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" "crypto/x509" - "encoding/pem" "errors" "io" "os" @@ -19,7 +15,6 @@ import ( "testing" "time" - "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/schema/resources" @@ -31,6 +26,7 @@ import ( "github.com/plgd-dev/go-coap/v3/tcp" coapTcpClient "github.com/plgd-dev/go-coap/v3/tcp/client" "github.com/plgd-dev/hub/v2/coap-gateway/service" + "github.com/plgd-dev/hub/v2/coap-gateway/test" coapgwTest "github.com/plgd-dev/hub/v2/coap-gateway/test" "github.com/plgd-dev/hub/v2/coap-gateway/uri" pkgX509 "github.com/plgd-dev/hub/v2/pkg/security/x509" @@ -100,7 +96,7 @@ func testValidateResp(t *testing.T, test testEl, resp *pool.Message) { } } -func testSignUp(t *testing.T, deviceID string, co *coapTcpClient.Conn) service.CoapSignUpResponse { +func doSignUp(t *testing.T, deviceID string, co *coapTcpClient.Conn) (*pool.Message, error) { code := oauthTest.GetDefaultDeviceAuthorizationCode(t, deviceID) signUpReq := service.CoapSignUpRequest{ DeviceID: deviceID, @@ -123,7 +119,11 @@ func testSignUp(t *testing.T, deviceID string, co *coapTcpClient.Conn) service.C req.SetContentFormat(message.AppOcfCbor) req.SetBody(bytes.NewReader(inputCbor)) - resp, err := co.Do(req) + return co.Do(req) +} + +func testSignUp(t *testing.T, deviceID string, co *coapTcpClient.Conn) service.CoapSignUpResponse { + resp, err := doSignUp(t, deviceID, co) require.NoError(t, err) defer co.ReleaseMessage(resp) @@ -386,43 +386,65 @@ func makeTestCoapHandler(t *testing.T) func(w *responsewriter.ResponseWriter[*co } } -func testCoapDial(t *testing.T, deviceID string, withTLS, identityCert bool, validTo time.Time) *coapTcpClient.Conn { - return testCoapDialWithHandler(t, deviceID, withTLS, identityCert, validTo, makeTestCoapHandler(t)) +type testCoapDialConfig struct { + generateTLS *struct { + deviceID string + identityCert bool + validTo time.Time + } + tlsConfig *tls.Config } -func testCoapDialWithHandler(t *testing.T, deviceID string, withTLS, identityCert bool, validTo time.Time, h func(w *responsewriter.ResponseWriter[*coapTcpClient.Conn], r *pool.Message)) *coapTcpClient.Conn { - var tlsConfig *tls.Config +type option interface { + apply(*testCoapDialConfig) +} - if withTLS { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) +type optionFunc func(*testCoapDialConfig) + +func (o optionFunc) apply(c *testCoapDialConfig) { + o(c) +} + +func WithGenerateTLS(deviceID string, identityCert bool, validTo time.Time) option { + return optionFunc(func(cfg *testCoapDialConfig) { + var generateTLS struct { + deviceID string + identityCert bool + validTo time.Time + } + generateTLS.deviceID = deviceID + generateTLS.identityCert = identityCert + generateTLS.validTo = validTo + cfg.generateTLS = &generateTLS + }) +} + +func WithTLSConfig(tlsConfig *tls.Config) option { + return optionFunc(func(cfg *testCoapDialConfig) { + cfg.tlsConfig = tlsConfig + }) +} + +func testCoapDialWithHandler(t *testing.T, h func(w *responsewriter.ResponseWriter[*coapTcpClient.Conn], r *pool.Message), opts ...option) *coapTcpClient.Conn { + c := &testCoapDialConfig{} + for _, opt := range opts { + opt.apply(c) + } + tlsConfig := c.tlsConfig + if c.generateTLS != nil { signerCert, err := pkgX509.ReadX509(os.Getenv("TEST_ROOT_CA_CERT")) require.NoError(t, err) signerKey, err := pkgX509.ReadPrivateKey(os.Getenv("TEST_ROOT_CA_KEY")) require.NoError(t, err) - - var certData []byte - - if identityCert { - certData, err = generateCertificate.GenerateIdentityCert(generateCertificate.Configuration{ - ValidFrom: time.Now().Add(-time.Hour).Format(time.RFC3339), - ValidFor: time.Until(validTo) + time.Hour, - }, deviceID, priv, signerCert, signerKey) + cg := test.NewLocalCertificateGenerator(signerCert, signerKey) + var crt tls.Certificate + if c.generateTLS.identityCert { + crt, err = cg.GetIdentityCertificate(c.generateTLS.deviceID, c.generateTLS.validTo) } else { - c := generateCertificate.Configuration{ - ValidFrom: time.Now().Add(-time.Hour).Format(time.RFC3339), - ValidFor: time.Until(validTo) + time.Hour, - } - c.Subject.CommonName = "non-identity-cert" - c.ExtensionKeyUsages = []string{"client", "server"} - certData, err = generateCertificate.GenerateCert(c, priv, signerCert, signerKey) + crt, err = cg.GetCertificate(c.generateTLS.validTo) } require.NoError(t, err) - b, err := x509.MarshalECPrivateKey(priv) - require.NoError(t, err) - key := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) - crt, err := tls.X509KeyPair(certData, key) - require.NoError(t, err) + caPool := x509.NewCertPool() for _, c := range signerCert { caPool.AddCert(c) @@ -448,21 +470,13 @@ func testCoapDialWithHandler(t *testing.T, deviceID string, withTLS, identityCer for _, cert := range certs[1:] { intermediateCAPool.AddCert(cert) } - caPool := x509.NewCertPool() - for _, c := range signerCert { - caPool.AddCert(c) - } _, err := certs[0].Verify(x509.VerifyOptions{ Roots: caPool, Intermediates: intermediateCAPool, CurrentTime: time.Now(), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, }) - if err != nil { - return err - } - - return nil + return err }, } } @@ -471,6 +485,14 @@ func testCoapDialWithHandler(t *testing.T, deviceID string, withTLS, identityCer return conn } +func testCoapDial(t *testing.T, deviceID string, withTLS, identityCert bool, validTo time.Time) *coapTcpClient.Conn { + var opts []option + if withTLS { + opts = append(opts, WithGenerateTLS(deviceID, identityCert, validTo)) + } + return testCoapDialWithHandler(t, makeTestCoapHandler(t), opts...) +} + func setUp(t *testing.T, coapgwCfgs ...service.Config) func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() diff --git a/coap-gateway/test/certificates.go b/coap-gateway/test/certificates.go new file mode 100644 index 000000000..813717982 --- /dev/null +++ b/coap-gateway/test/certificates.go @@ -0,0 +1,102 @@ +package test + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "time" + + "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" + "github.com/plgd-dev/hub/v2/certificate-authority/pb" +) + +type LocalCertificateGenerator struct { + signerCACertificate []*x509.Certificate + signerCAKey *ecdsa.PrivateKey +} + +func NewLocalCertificateGenerator(sc []*x509.Certificate, sk *ecdsa.PrivateKey) *LocalCertificateGenerator { + return &LocalCertificateGenerator{ + signerCACertificate: sc, + signerCAKey: sk, + } +} + +func getTLSCertificate(certPEMBlock []byte, pk *ecdsa.PrivateKey) (tls.Certificate, error) { + b, err := x509.MarshalECPrivateKey(pk) + key := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + if err != nil { + return tls.Certificate{}, err + } + crt, err := tls.X509KeyPair(certPEMBlock, key) + if err != nil { + return tls.Certificate{}, err + } + return crt, nil +} + +func (g *LocalCertificateGenerator) getCertificate(identityCert bool, deviceID string, validTo time.Time) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + var certData []byte + if identityCert { + certData, err = generateCertificate.GenerateIdentityCert(generateCertificate.Configuration{ + ValidFrom: time.Now().Add(-time.Hour).Format(time.RFC3339), + ValidFor: time.Until(validTo) + time.Hour, + }, deviceID, priv, g.signerCACertificate, g.signerCAKey) + } else { + c := generateCertificate.Configuration{ + ValidFrom: time.Now().Add(-time.Hour).Format(time.RFC3339), + ValidFor: time.Until(validTo) + time.Hour, + } + c.Subject.CommonName = "non-identity-cert" + c.ExtensionKeyUsages = []string{"client", "server"} + certData, err = generateCertificate.GenerateCert(c, priv, g.signerCACertificate, g.signerCAKey) + } + if err != nil { + return tls.Certificate{}, err + } + return getTLSCertificate(certData, priv) +} + +func (g *LocalCertificateGenerator) GetIdentityCertificate(deviceID string, validTo time.Time) (tls.Certificate, error) { + return g.getCertificate(true, deviceID, validTo) +} + +func (g *LocalCertificateGenerator) GetCertificate(validTo time.Time) (tls.Certificate, error) { + return g.getCertificate(false, "", validTo) +} + +type CACertificateGenerator struct { + caClient pb.CertificateAuthorityClient + signerKey *ecdsa.PrivateKey +} + +func NewCACertificateGenerator(caClient pb.CertificateAuthorityClient, signerKey *ecdsa.PrivateKey) *CACertificateGenerator { + return &CACertificateGenerator{ + caClient: caClient, + signerKey: signerKey, + } +} + +func (c *CACertificateGenerator) GetIdentityCertificate(ctx context.Context, deviceID string) (tls.Certificate, error) { + csr, err := generateCertificate.GenerateIdentityCSR(generateCertificate.Configuration{}, deviceID, c.signerKey) + if err != nil { + return tls.Certificate{}, fmt.Errorf("cannot generate identity csr: %w", err) + } + + resp, err := c.caClient.SignIdentityCertificate(ctx, &pb.SignCertificateRequest{ + CertificateSigningRequest: csr, + }) + if err != nil { + return tls.Certificate{}, fmt.Errorf("certificate authority failed to sign certificate: %w", err) + } + return getTLSCertificate(resp.GetCertificate(), c.signerKey) +} diff --git a/coap-gateway/test/test.go b/coap-gateway/test/test.go index 206d70111..97ce9f235 100644 --- a/coap-gateway/test/test.go +++ b/coap-gateway/test/test.go @@ -1,4 +1,4 @@ -package service +package test import ( "context" @@ -44,6 +44,9 @@ func MakeConfig(t require.TestingT) service.Config { cfg.APIs.COAP.TLS.Embedded.ClientCertificateRequired = false cfg.APIs.COAP.TLS.Embedded.CertFile = urischeme.URIScheme(os.Getenv("TEST_COAP_GW_CERT_FILE")) cfg.APIs.COAP.TLS.Embedded.KeyFile = urischeme.URIScheme(os.Getenv("TEST_COAP_GW_KEY_FILE")) + cfg.APIs.COAP.TLS.Embedded.CRL.Enabled = true + httpClientConfig := config.MakeHttpClientConfig() + cfg.APIs.COAP.TLS.Embedded.CRL.HTTP = &httpClientConfig cfg.APIs.COAP.Authorization = service.AuthorizationConfig{ OwnerClaim: config.OWNER_CLAIM, Providers: []service.ProvidersConfig{ diff --git a/device-provisioning-service/pb/hub.go b/device-provisioning-service/pb/hub.go index 3343f786e..3ca32bcaa 100644 --- a/device-provisioning-service/pb/hub.go +++ b/device-provisioning-service/pb/hub.go @@ -13,8 +13,8 @@ import ( "github.com/plgd-dev/hub/v2/device-provisioning-service/uri" "github.com/plgd-dev/hub/v2/pkg/config/property/urischeme" "github.com/plgd-dev/hub/v2/pkg/net/grpc/client" - pkgHttpClient "github.com/plgd-dev/hub/v2/pkg/net/http/client" pkgCertManagerClient "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" "github.com/plgd-dev/hub/v2/pkg/strings" "github.com/plgd-dev/kit/v2/security" ) @@ -55,7 +55,7 @@ func (c *AuthorizationProviderConfig) ToConfig() clientcredentials.Config { ClientSecretFile: urischeme.URIScheme(c.GetClientSecret()), Scopes: c.GetScopes(), Audience: c.GetAudience(), - HTTP: pkgHttpClient.Config{ + HTTP: pkgTls.HTTPConfig{ MaxIdleConns: int(c.GetHttp().GetMaxIdleConns()), MaxConnsPerHost: int(c.GetHttp().GetMaxConnsPerHost()), MaxIdleConnsPerHost: int(c.GetHttp().GetMaxIdleConnsPerHost()), diff --git a/device-provisioning-service/security/oauth/clientcredentials/cache.go b/device-provisioning-service/security/oauth/clientcredentials/cache.go index 76b4ebce4..6ed426e25 100644 --- a/device-provisioning-service/security/oauth/clientcredentials/cache.go +++ b/device-provisioning-service/security/oauth/clientcredentials/cache.go @@ -10,6 +10,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" "github.com/plgd-dev/hub/v2/pkg/net/http/client" + cmClient "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" "github.com/plgd-dev/hub/v2/pkg/security/jwt" "github.com/plgd-dev/hub/v2/pkg/security/openid" "github.com/plgd-dev/hub/v2/pkg/sync/task/future" @@ -32,7 +33,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg if err != nil { return nil, fmt.Errorf("invalid OAuth client credential config: %w", err) } - httpClient, err := client.New(config.HTTP, fileWatcher, logger, tracerProvider) + httpClient, err := cmClient.NewHTTPClient(&config.HTTP, fileWatcher, logger, tracerProvider) if err != nil { return nil, err } diff --git a/device-provisioning-service/security/oauth/clientcredentials/config.go b/device-provisioning-service/security/oauth/clientcredentials/config.go index bb414868a..18b96286f 100644 --- a/device-provisioning-service/security/oauth/clientcredentials/config.go +++ b/device-provisioning-service/security/oauth/clientcredentials/config.go @@ -5,7 +5,7 @@ import ( "net/url" "github.com/plgd-dev/hub/v2/pkg/config/property/urischeme" - "github.com/plgd-dev/hub/v2/pkg/net/http/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" "golang.org/x/oauth2/clientcredentials" ) @@ -17,7 +17,7 @@ type Config struct { TokenURL string `yaml:"-" json:"tokenUrl"` Audience string `yaml:"audience" json:"audience"` ClientSecret string `yaml:"-" json:"clientSecret"` - HTTP client.Config `yaml:"http" json:"http"` + HTTP pkgTls.HTTPConfig `yaml:"http" json:"http"` } func (c *Config) Validate() error { diff --git a/device-provisioning-service/service/clients.go b/device-provisioning-service/service/clients.go index 1868aa221..6874c48ac 100644 --- a/device-provisioning-service/service/clients.go +++ b/device-provisioning-service/service/clients.go @@ -38,7 +38,7 @@ func newCertificateAuthorityClient(config client.Config, fileWatcher *fsnotify.W func NewStore(ctx context.Context, config mongodb.Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*mongodb.Store, func(), error) { var fl fn.FuncList - certManager, err := cmClient.New(config.Mongo.TLS, fileWatcher, logger) + certManager, err := cmClient.New(config.Mongo.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, nil, fmt.Errorf("cannot create cert manager: %w", err) } diff --git a/device-provisioning-service/service/config.go b/device-provisioning-service/service/config.go index fa840de27..15dac0b0c 100644 --- a/device-provisioning-service/service/config.go +++ b/device-provisioning-service/service/config.go @@ -16,8 +16,8 @@ import ( pkgCoapService "github.com/plgd-dev/hub/v2/pkg/net/coap/service" "github.com/plgd-dev/hub/v2/pkg/net/grpc/client" pkgHttp "github.com/plgd-dev/hub/v2/pkg/net/http" - pkgHttpClient "github.com/plgd-dev/hub/v2/pkg/net/http/client" pkgCertManagerClient "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" "github.com/plgd-dev/hub/v2/pkg/strings" ) @@ -253,24 +253,24 @@ type AuthorizationProviderConfig struct { clientcredentials.Config `yaml:",inline"` } -func HTTPConfigToProto(cfg pkgHttpClient.Config) (*pb.HttpConfig, error) { - tls, err := TLSConfigToProto(cfg.TLS) +func HTTPConfigToProto(cfg pkgTls.HTTPConfigurer) (*pb.HttpConfig, error) { + tls, err := TLSConfigToProto(cfg.GetTLS()) if err != nil { return nil, err } return &pb.HttpConfig{ - MaxIdleConns: math.CastTo[uint32](cfg.MaxIdleConns), - MaxConnsPerHost: math.CastTo[uint32](cfg.MaxConnsPerHost), - MaxIdleConnsPerHost: math.CastTo[uint32](cfg.MaxIdleConnsPerHost), - IdleConnTimeout: cfg.IdleConnTimeout.Nanoseconds(), - Timeout: cfg.Timeout.Nanoseconds(), + MaxIdleConns: math.CastTo[uint32](cfg.GetMaxIdleConns()), + MaxConnsPerHost: math.CastTo[uint32](cfg.GetMaxConnsPerHost()), + MaxIdleConnsPerHost: math.CastTo[uint32](cfg.GetMaxIdleConnsPerHost()), + IdleConnTimeout: cfg.GetIdleConnTimeout().Nanoseconds(), + Timeout: cfg.GetTimeout().Nanoseconds(), Tls: tls, }, nil } func (c *AuthorizationProviderConfig) ToProto() (*pb.AuthorizationProviderConfig, error) { - http, err := HTTPConfigToProto(c.HTTP) + http, err := HTTPConfigToProto(&c.HTTP) if err != nil { return nil, err } diff --git a/device-provisioning-service/service/service.go b/device-provisioning-service/service/service.go index 7d9ec3a19..cf93196b0 100644 --- a/device-provisioning-service/service/service.go +++ b/device-provisioning-service/service/service.go @@ -355,7 +355,7 @@ func (server *Service) createServices(fileWatcher *fsnotify.Watcher, logger log. coapService.WithOnNewConnection(server.coapConnOnNew), coapService.WithOnInactivityConnection(server.onInactivityConnection), coapService.WithMessagePool(server.messagePool), - coapService.WithOverrideTLS(func(cfg *tls.Config) *tls.Config { + coapService.WithOverrideTLS(func(cfg *tls.Config, _ coapService.VerifyByCRL) *tls.Config { cfg.InsecureSkipVerify = true cfg.ClientAuth = tls.RequireAnyClientCert cfg.VerifyPeerCertificate = server.authHandler.VerifyPeerCertificate diff --git a/device-provisioning-service/test/test.go b/device-provisioning-service/test/test.go index 457658c37..530b895f3 100644 --- a/device-provisioning-service/test/test.go +++ b/device-provisioning-service/test/test.go @@ -331,7 +331,7 @@ func NewMongoStore(t require.TestingT) (*storeMongo.Store, func()) { fileWatcher, err := fsnotify.NewWatcher(logger) require.NoError(t, err) - certManager, err := cmClient.New(cfg.Clients.Storage.MongoDB.Mongo.TLS, fileWatcher, logger) + certManager, err := cmClient.New(cfg.Clients.Storage.MongoDB.Mongo.TLS, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) ctx := context.Background() diff --git a/go.mod b/go.mod index ec0489316..bab8f7c5a 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/panjf2000/ants/v2 v2.10.0 github.com/pion/dtls/v3 v3.0.2 github.com/pion/logging v0.2.2 - github.com/plgd-dev/device/v2 v2.5.3-0.20240916150018-cc07b737d112 + github.com/plgd-dev/device/v2 v2.5.4-0.20241023145624-fd64dcccb418 github.com/plgd-dev/go-coap/v3 v3.3.5 github.com/plgd-dev/kit/v2 v2.0.0-20211006190727-057b33161b90 github.com/pseudomuto/protoc-gen-doc v1.5.1 diff --git a/go.sum b/go.sum index 2fb0a8ee7..177c169ce 100644 --- a/go.sum +++ b/go.sum @@ -266,8 +266,8 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= -github.com/plgd-dev/device/v2 v2.5.3-0.20240916150018-cc07b737d112 h1:nSgyZUOfQr1l6E3cXOcNzogmE13uOkfZ4mh/aK+HhyQ= -github.com/plgd-dev/device/v2 v2.5.3-0.20240916150018-cc07b737d112/go.mod h1:TXeTvVt0hi22FwhxaGxM1NiRwDDi1RmSgIrX9foWlic= +github.com/plgd-dev/device/v2 v2.5.4-0.20241023145624-fd64dcccb418 h1:j+6vyTwS2RY4rmMKMYxp6OWXhaEx64Fic10y586Y10U= +github.com/plgd-dev/device/v2 v2.5.4-0.20241023145624-fd64dcccb418/go.mod h1:TXeTvVt0hi22FwhxaGxM1NiRwDDi1RmSgIrX9foWlic= github.com/plgd-dev/go-coap/v2 v2.0.4-0.20200819112225-8eb712b901bc/go.mod h1:+tCi9Q78H/orWRtpVWyBgrr4vKFo2zYtbbxUllerBp4= github.com/plgd-dev/go-coap/v2 v2.4.1-0.20210517130748-95c37ac8e1fa/go.mod h1:rA7fc7ar+B/qa+Q0hRqv7yj/EMtIlmo1l7vkQGSrHPU= github.com/plgd-dev/go-coap/v3 v3.3.5 h1:GBdBwM/9JtJhbHxBhbzXAc40yaWvdYX16+vN0ShoX7w= diff --git a/grpc-gateway/service/grpcApi.go b/grpc-gateway/service/grpcApi.go index 9365906bf..88516e8ce 100644 --- a/grpc-gateway/service/grpcApi.go +++ b/grpc-gateway/service/grpcApi.go @@ -117,7 +117,7 @@ func newRequestHandlerFromConfig(config Config, fileWatcher *fsnotify.Watcher, l } closeFunc.AddFunc(closeIdClient) - natsClient, err := naClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger) + natsClient, err := naClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("cannot create nats client: %w", err) } diff --git a/grpc-gateway/service/updateDeviceMetadata_test.go b/grpc-gateway/service/updateDeviceMetadata_test.go index 74a5cea9d..23be04971 100644 --- a/grpc-gateway/service/updateDeviceMetadata_test.go +++ b/grpc-gateway/service/updateDeviceMetadata_test.go @@ -26,6 +26,7 @@ import ( oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" "github.com/plgd-dev/hub/v2/test/service" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -142,7 +143,7 @@ func TestRequestHandlerUpdateDeviceMetadataTwinEnabled(t *testing.T) { require.NoError(t, errC) }() - naClient, s, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, logger, subscriber.WithUnmarshaler(utils.Unmarshal)) + naClient, s, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, logger, noop.NewTracerProvider(), subscriber.WithUnmarshaler(utils.Unmarshal)) require.NoError(t, err) defer func() { s.Close() @@ -304,7 +305,7 @@ func TestRequestHandlerUpdateDeviceMetadataTwinForceSynchronization(t *testing.T require.NoError(t, errC) }() - naClient, s, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, logger, subscriber.WithUnmarshaler(utils.Unmarshal)) + naClient, s, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, logger, noop.NewTracerProvider(), subscriber.WithUnmarshaler(utils.Unmarshal)) require.NoError(t, err) defer func() { s.Close() diff --git a/grpc-gateway/subscription/subscription_test.go b/grpc-gateway/subscription/subscription_test.go index 26ba21932..68a25cf2c 100644 --- a/grpc-gateway/subscription/subscription_test.go +++ b/grpc-gateway/subscription/subscription_test.go @@ -200,7 +200,7 @@ func prepareServicesAndSubscription(t *testing.T, owner, correlationID string, l pool, err := ants.NewPool(1) require.NoError(t, err) - natsConn, resourceSubscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, log.Get(), subscriber.WithGoPool(pool.Submit), subscriber.WithUnmarshaler(utils.Unmarshal)) + natsConn, resourceSubscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, log.Get(), noop.NewTracerProvider(), subscriber.WithGoPool(pool.Submit), subscriber.WithUnmarshaler(utils.Unmarshal)) require.NoError(t, err) cleanUp.AddFunc(func() { resourceSubscriber.Close() diff --git a/http-gateway/service/updateDeviceMetadata_test.go b/http-gateway/service/updateDeviceMetadata_test.go index 473cc1636..4bbbcce84 100644 --- a/http-gateway/service/updateDeviceMetadata_test.go +++ b/http-gateway/service/updateDeviceMetadata_test.go @@ -31,6 +31,7 @@ import ( oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" "github.com/plgd-dev/hub/v2/test/service" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/protobuf/encoding/protojson" @@ -151,7 +152,7 @@ func TestRequestHandlerUpdateDeviceMetadata(t *testing.T) { require.NoError(t, errC) }() - naClient, s, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, logger, subscriber.WithUnmarshaler(utils.Unmarshal)) + naClient, s, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, logger, noop.NewTracerProvider(), subscriber.WithUnmarshaler(utils.Unmarshal)) require.NoError(t, err) defer func() { s.Close() diff --git a/http-gateway/test/http.go b/http-gateway/test/http.go index 14d738176..cf920ea5b 100644 --- a/http-gateway/test/http.go +++ b/http-gateway/test/http.go @@ -193,6 +193,11 @@ func (c *RequestBuilder) AddTimeToLive(ttl time.Duration) *RequestBuilder { return c } +func (c *RequestBuilder) AddIssuerID(issuerID string) *RequestBuilder { + c.uriParams[uri.IssuerIDKey] = issuerID + return c +} + func (c *RequestBuilder) SetQuery(value string) *RequestBuilder { c.query = value return c diff --git a/http-gateway/uri/uri.go b/http-gateway/uri/uri.go index 05aba37dd..653b1eeef 100644 --- a/http-gateway/uri/uri.go +++ b/http-gateway/uri/uri.go @@ -24,6 +24,7 @@ const ( OnlyContentQueryKey = "onlyContent" IncludeHiddenResourcesQueryKey = "includeHiddenResources" ForceQueryKey = "force" + IssuerIDKey = "issuerId" AliasInterfaceQueryKey = "interface" AliasCommandFilterQueryKey = "command" diff --git a/identity-store/client/ownerCache_test.go b/identity-store/client/ownerCache_test.go index 77f8313a1..683ecd010 100644 --- a/identity-store/client/ownerCache_test.go +++ b/identity-store/client/ownerCache_test.go @@ -75,7 +75,7 @@ func TestOwnerCacheSubscribe(t *testing.T) { owner, err := kitNetGrpc.ParseOwnerFromJwtToken("sub", token) require.NoError(t, err) - naClient, subscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, log.Get()) + naClient, subscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, log.Get(), noop.NewTracerProvider()) require.NoError(t, err) defer func() { subscriber.Close() diff --git a/identity-store/persistence/cqldb/store.go b/identity-store/persistence/cqldb/store.go index f593b2821..eb5fe1c4f 100644 --- a/identity-store/persistence/cqldb/store.go +++ b/identity-store/persistence/cqldb/store.go @@ -36,7 +36,7 @@ type Store struct { } func New(ctx context.Context, config *Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, error) { - certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger) + certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/identity-store/persistence/mongodb/store.go b/identity-store/persistence/mongodb/store.go index 5948c71b1..355994448 100644 --- a/identity-store/persistence/mongodb/store.go +++ b/identity-store/persistence/mongodb/store.go @@ -33,7 +33,7 @@ type Store struct { } func New(ctx context.Context, config *Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, error) { - certManager, err := client.New(config.TLS, fileWatcher, logger) + certManager, err := client.New(config.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/identity-store/service/service.go b/identity-store/service/service.go index c69e5c01d..debca7d13 100644 --- a/identity-store/service/service.go +++ b/identity-store/service/service.go @@ -105,7 +105,7 @@ func New(ctx context.Context, cfg Config, fileWatcher *fsnotify.Watcher, logger } tracerProvider := otelClient.GetTracerProvider() - naClient, err := client.New(cfg.Clients.Eventbus.NATS.Config, fileWatcher, logger) + naClient, err := client.New(cfg.Clients.Eventbus.NATS.Config, fileWatcher, logger, tracerProvider) if err != nil { otelClient.Close() return nil, fmt.Errorf("cannot create nats client %w", err) diff --git a/identity-store/service/service_test.go b/identity-store/service/service_test.go index 3b887c715..30804a33f 100644 --- a/identity-store/service/service_test.go +++ b/identity-store/service/service_test.go @@ -64,7 +64,7 @@ func newTestService(t *testing.T) (*Server, func()) { fileWatcher, err := fsnotify.NewWatcher(logger) require.NoError(t, err) - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) s, err := NewServer(context.Background(), cfg, fileWatcher, logger, noop.NewTracerProvider(), publisher) diff --git a/m2m-oauth-server/store/mongodb/store.go b/m2m-oauth-server/store/mongodb/store.go index bca87a192..b51bbf285 100644 --- a/m2m-oauth-server/store/mongodb/store.go +++ b/m2m-oauth-server/store/mongodb/store.go @@ -31,7 +31,7 @@ var idOwnerIndex = mongo.IndexModel{ } func New(ctx context.Context, cfg *Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, error) { - certManager, err := client.New(cfg.Mongo.TLS, fileWatcher, logger) + certManager, err := client.New(cfg.Mongo.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/pkg/cqldb/cqldb_test.go b/pkg/cqldb/cqldb_test.go index 647cb8e52..b2f234242 100644 --- a/pkg/cqldb/cqldb_test.go +++ b/pkg/cqldb/cqldb_test.go @@ -41,7 +41,7 @@ func TestCqlDB(t *testing.T) { err = fileWatcher.Close() require.NoError(t, err) }() - certManagerClient, err := client.New(config.TLS, fileWatcher, logger) + certManagerClient, err := client.New(config.TLS, fileWatcher, logger, nil) require.NoError(t, err) defer certManagerClient.Close() diff --git a/pkg/net/coap/service/options.go b/pkg/net/coap/service/options.go index 9b54320b1..74bf72b35 100644 --- a/pkg/net/coap/service/options.go +++ b/pkg/net/coap/service/options.go @@ -1,20 +1,24 @@ package service import ( + "context" "crypto/tls" + "crypto/x509" "github.com/plgd-dev/go-coap/v3/message/pool" "github.com/plgd-dev/go-coap/v3/mux" ) +type VerifyByCRL = func(context.Context, *x509.Certificate, []string) error + type Options struct { - OverrideTLSConfig func(cfg *tls.Config) *tls.Config + OverrideTLSConfig func(cfg *tls.Config, verifyByCRL VerifyByCRL) *tls.Config OnNewConnection func(conn mux.Conn) OnInactivityConnection func(conn mux.Conn) MessagePool *pool.Pool } -func WithOverrideTLS(f func(cfg *tls.Config) *tls.Config) func(*Options) { +func WithOverrideTLS(f func(cfg *tls.Config, verifyByCRL VerifyByCRL) *tls.Config) func(*Options) { return func(o *Options) { o.OverrideTLSConfig = f } diff --git a/pkg/net/coap/service/service_test.go b/pkg/net/coap/service/service_test.go index 32d000615..1f6c6318e 100644 --- a/pkg/net/coap/service/service_test.go +++ b/pkg/net/coap/service/service_test.go @@ -63,7 +63,7 @@ func TestNew(t *testing.T) { WithMessagePool(pool.New(uint32(1024), 1024)), WithOnNewConnection(func(mux.Conn) {}), WithOnInactivityConnection(func(mux.Conn) {}), - WithOverrideTLS(func(cfg *tls.Config) *tls.Config { return cfg }), + WithOverrideTLS(func(cfg *tls.Config, _ VerifyByCRL) *tls.Config { return cfg }), }, }, }, @@ -90,7 +90,7 @@ func TestNew(t *testing.T) { WithMessagePool(pool.New(uint32(1024), 1024)), WithOnNewConnection(func(mux.Conn) {}), WithOnInactivityConnection(func(mux.Conn) {}), - WithOverrideTLS(func(cfg *tls.Config) *tls.Config { return cfg }), + WithOverrideTLS(func(cfg *tls.Config, _ VerifyByCRL) *tls.Config { return cfg }), }, }, }, diff --git a/pkg/net/coap/service/tcpServer.go b/pkg/net/coap/service/tcpServer.go index 96a9439a8..b985c920d 100644 --- a/pkg/net/coap/service/tcpServer.go +++ b/pkg/net/coap/service/tcpServer.go @@ -15,10 +15,15 @@ import ( certManagerServer "github.com/plgd-dev/hub/v2/pkg/security/certManager/server" ) +type tcpListener struct { + coapTcpServer.Listener + tlsManager *certManagerServer.CertManager + close func() +} + type tcpServer struct { - coapServer *coapTcpServer.Server - listener coapTcpServer.Listener - closeListener func() + coapServer *coapTcpServer.Server + listener *tcpListener } func (s *tcpServer) Serve() error { @@ -27,44 +32,52 @@ func (s *tcpServer) Serve() error { func (s *tcpServer) Close() error { s.coapServer.Stop() + s.listener.close() return nil } -func newTCPListener(config Config, serviceOpts Options, fileWatcher *fsnotify.Watcher, logger log.Logger) (coapTcpServer.Listener, func(), error) { +func newTCPListener(config Config, serviceOpts Options, fileWatcher *fsnotify.Watcher, logger log.Logger) (*tcpListener, error) { if !config.TLS.IsEnabled() { listener, err := net.NewTCPListener("tcp", config.Addr) if err != nil { - return nil, nil, fmt.Errorf("cannot create tcp listener: %w", err) + return nil, fmt.Errorf("cannot create tcp listener: %w", err) } closeListener := func() { if err := listener.Close(); err != nil { logger.Errorf("failed to close tcp listener: %w", err) } } - return listener, closeListener, nil + return &tcpListener{ + Listener: listener, + close: closeListener, + }, nil } var closeListener fn.FuncList coapsTLS, err := certManagerServer.New(config.TLS.Embedded, fileWatcher, logger) if err != nil { - return nil, nil, fmt.Errorf("cannot create tls cert manager: %w", err) + return nil, fmt.Errorf("cannot create tls cert manager: %w", err) } closeListener.AddFunc(coapsTLS.Close) tlsCfg := coapsTLS.GetTLSConfig() if serviceOpts.OverrideTLSConfig != nil { - tlsCfg = serviceOpts.OverrideTLSConfig(tlsCfg) + tlsCfg = serviceOpts.OverrideTLSConfig(tlsCfg, coapsTLS.VerifyByCRL) } listener, err := net.NewTLSListener("tcp", config.Addr, tlsCfg) if err != nil { closeListener.Execute() - return nil, nil, fmt.Errorf("cannot create tcp-tls listener: %w", err) + return nil, fmt.Errorf("cannot create tcp-tls listener: %w", err) } closeListener.AddFunc(func() { if err := listener.Close(); err != nil { logger.Errorf("failed to close tcp-tls listener: %w", err) } }) - return listener, closeListener.ToFunction(), nil + return &tcpListener{ + Listener: listener, + close: closeListener.ToFunction(), + tlsManager: coapsTLS, + }, nil } func newTCPServer(config Config, serviceOpts Options, fileWatcher *fsnotify.Watcher, logger log.Logger, opts ...interface { @@ -73,7 +86,7 @@ func newTCPServer(config Config, serviceOpts Options, fileWatcher *fsnotify.Watc coapUdpServer.Option }, ) (*tcpServer, error) { - listener, closeListener, err := newTCPListener(config, serviceOpts, fileWatcher, logger) + listener, err := newTCPListener(config, serviceOpts, fileWatcher, logger) if err != nil { return nil, fmt.Errorf("cannot create listener: %w", err) } @@ -97,8 +110,7 @@ func newTCPServer(config Config, serviceOpts Options, fileWatcher *fsnotify.Watc tcpOpts = append(tcpOpts, o) } return &tcpServer{ - coapServer: coapTcpServer.New(tcpOpts...), - listener: listener, - closeListener: closeListener, + coapServer: coapTcpServer.New(tcpOpts...), + listener: listener, }, nil } diff --git a/pkg/net/coap/service/udpServer.go b/pkg/net/coap/service/udpServer.go index 5ae9a9c85..8ffbfe678 100644 --- a/pkg/net/coap/service/udpServer.go +++ b/pkg/net/coap/service/udpServer.go @@ -17,10 +17,15 @@ import ( certManagerServer "github.com/plgd-dev/hub/v2/pkg/security/certManager/server" ) +type dtlsListerner struct { + coapDtlsServer.Listener + tlsManager *certManagerServer.CertManager + close func() +} + type dtlsServer struct { - coapServer *coapDtlsServer.Server - listener coapDtlsServer.Listener - closeListener func() + coapServer *coapDtlsServer.Server + listener *dtlsListerner } func (s *dtlsServer) Serve() error { @@ -29,37 +34,44 @@ func (s *dtlsServer) Serve() error { func (s *dtlsServer) Close() error { s.coapServer.Stop() - s.closeListener() + s.listener.close() return nil } +type udpListerner struct { + *net.UDPConn + close func() +} + type udpServer struct { - coapServer *coapUdpServer.Server - listener *net.UDPConn - closeListener func() + coapServer *coapUdpServer.Server + listener *udpListerner } func (s *udpServer) Serve() error { - return s.coapServer.Serve(s.listener) + return s.coapServer.Serve(s.listener.UDPConn) } func (s *udpServer) Close() error { s.coapServer.Stop() - s.closeListener() + s.listener.close() return nil } -func newUDPListener(config Config, logger log.Logger) (*net.UDPConn, func(), error) { +func newUDPListener(config Config, logger log.Logger) (*udpListerner, error) { listener, err := net.NewListenUDP("udp", config.Addr) if err != nil { - return nil, nil, fmt.Errorf("cannot create tcp listener: %w", err) + return nil, fmt.Errorf("cannot create tcp listener: %w", err) } closeListener := func() { if err := listener.Close(); err != nil { logger.Errorf("failed to close tcp listener: %w", err) } } - return listener, closeListener, nil + return &udpListerner{ + UDPConn: listener, + close: closeListener, + }, nil } var mapDTLSClientAuth = map[tls.ClientAuthType]dtls.ClientAuthType{ @@ -101,30 +113,35 @@ func TLSConfigToDTLSConfig(tlsConfig *tls.Config) *dtls.Config { } } -func newDTLSListener(config Config, serviceOpts Options, fileWatcher *fsnotify.Watcher, logger log.Logger) (coapDtlsServer.Listener, func(), error) { +func newDTLSListener(config Config, serviceOpts Options, fileWatcher *fsnotify.Watcher, logger log.Logger) (*dtlsListerner, error) { var closeListener fn.FuncList coapsTLS, err := certManagerServer.New(config.TLS.Embedded, fileWatcher, logger) if err != nil { - return nil, nil, fmt.Errorf("cannot create tls cert manager: %w", err) + return nil, fmt.Errorf("cannot create tls cert manager: %w", err) } closeListener.AddFunc(coapsTLS.Close) tlsCfg := coapsTLS.GetTLSConfig() if serviceOpts.OverrideTLSConfig != nil { - tlsCfg = serviceOpts.OverrideTLSConfig(tlsCfg) + tlsCfg = serviceOpts.OverrideTLSConfig(tlsCfg, coapsTLS.VerifyByCRL) } dtlsCfg := TLSConfigToDTLSConfig(tlsCfg) dtlsCfg.LoggerFactory = logger.DTLSLoggerFactory() listener, err := net.NewDTLSListener("udp", config.Addr, dtlsCfg) if err != nil { closeListener.Execute() - return nil, nil, fmt.Errorf("cannot create dtls listener: %w", err) + return nil, fmt.Errorf("cannot create dtls listener: %w", err) } closeListener.AddFunc(func() { if err := listener.Close(); err != nil { logger.Errorf("failed to close dtls listener: %w", err) } }) - return listener, closeListener.ToFunction(), nil + + return &dtlsListerner{ + Listener: listener, + close: closeListener.ToFunction(), + tlsManager: coapsTLS, + }, nil } func newDTLSServer(config Config, serviceOpts Options, fileWatcher *fsnotify.Watcher, logger log.Logger, opts ...interface { @@ -133,7 +150,7 @@ func newDTLSServer(config Config, serviceOpts Options, fileWatcher *fsnotify.Wat coapUdpServer.Option }, ) (*dtlsServer, error) { - listener, closeListener, err := newDTLSListener(config, serviceOpts, fileWatcher, logger) + listener, err := newDTLSListener(config, serviceOpts, fileWatcher, logger) if err != nil { return nil, fmt.Errorf("cannot create listener: %w", err) } @@ -157,9 +174,8 @@ func newDTLSServer(config Config, serviceOpts Options, fileWatcher *fsnotify.Wat dtlsOpts = append(dtlsOpts, o) } return &dtlsServer{ - coapServer: coapDtlsServer.New(dtlsOpts...), - listener: listener, - closeListener: closeListener, + coapServer: coapDtlsServer.New(dtlsOpts...), + listener: listener, }, nil } @@ -169,7 +185,7 @@ func newUDPServer(config Config, serviceOpts Options, logger log.Logger, opts .. coapUdpServer.Option }, ) (*udpServer, error) { - listener, closeListener, err := newUDPListener(config, logger) + listener, err := newUDPListener(config, logger) if err != nil { return nil, fmt.Errorf("cannot create listener: %w", err) } @@ -193,8 +209,7 @@ func newUDPServer(config Config, serviceOpts Options, logger log.Logger, opts .. udpOpts = append(udpOpts, o) } return &udpServer{ - coapServer: coapUdpServer.New(udpOpts...), - listener: listener, - closeListener: closeListener, + coapServer: coapUdpServer.New(udpOpts...), + listener: listener, }, nil } diff --git a/pkg/net/grpc/client/client.go b/pkg/net/grpc/client/client.go index cd4e45cc7..eb71a9c81 100644 --- a/pkg/net/grpc/client/client.go +++ b/pkg/net/grpc/client/client.go @@ -41,7 +41,7 @@ func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracer if err != nil { return nil, fmt.Errorf("invalid config: %w", err) } - certManager, err := client.New(config.TLS, fileWatcher, logger) + certManager, err := client.New(config.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("cannot create cert manager: %w", err) } diff --git a/pkg/net/http/client/client.go b/pkg/net/http/client/client.go index 5b232a48f..b527e8ee6 100644 --- a/pkg/net/http/client/client.go +++ b/pkg/net/http/client/client.go @@ -1,17 +1,20 @@ package client import ( - "fmt" + "crypto/tls" "net/http" "github.com/plgd-dev/hub/v2/pkg/fn" - "github.com/plgd-dev/hub/v2/pkg/fsnotify" - "github.com/plgd-dev/hub/v2/pkg/log" - "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/trace" ) +type CertificateManager interface { + GetTLSConfig() *tls.Config + Close() +} + // Server handles gRPC requests to the service. type Client struct { client *http.Client @@ -31,23 +34,19 @@ func (c *Client) Close() { c.closeFunc.Execute() } -func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Client, error) { - certManager, err := client.New(config.TLS, fileWatcher, logger) - if err != nil { - return nil, fmt.Errorf("cannot create cert manager %w", err) - } +func New(config pkgTls.HTTPConfigurer, cm CertificateManager, tracerProvider trace.TracerProvider) (*Client, error) { t := http.DefaultTransport.(*http.Transport).Clone() - t.MaxIdleConns = config.MaxIdleConns - t.MaxConnsPerHost = config.MaxConnsPerHost - t.MaxIdleConnsPerHost = config.MaxIdleConnsPerHost - t.IdleConnTimeout = config.IdleConnTimeout - t.TLSClientConfig = certManager.GetTLSConfig() + t.MaxIdleConns = config.GetMaxIdleConns() + t.MaxConnsPerHost = config.GetMaxConnsPerHost() + t.MaxIdleConnsPerHost = config.GetMaxIdleConnsPerHost() + t.IdleConnTimeout = config.GetIdleConnTimeout() + t.TLSClientConfig = cm.GetTLSConfig() c := &Client{ client: &http.Client{ Transport: otelhttp.NewTransport(t, otelhttp.WithTracerProvider(tracerProvider)), - Timeout: config.Timeout, + Timeout: config.GetTimeout(), }, } - c.AddCloseFunc(certManager.Close) + c.AddCloseFunc(cm.Close) return c, nil } diff --git a/pkg/security/certManager/client/certManager.go b/pkg/security/certManager/client/certManager.go index 2ca175623..09f5e8f51 100644 --- a/pkg/security/certManager/client/certManager.go +++ b/pkg/security/certManager/client/certManager.go @@ -1,94 +1,23 @@ package client import ( - "crypto/tls" - "errors" - "fmt" - - "github.com/plgd-dev/hub/v2/pkg/config/property/urischeme" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" + "github.com/plgd-dev/hub/v2/pkg/net/http/client" "github.com/plgd-dev/hub/v2/pkg/security/certManager/general" - "github.com/plgd-dev/hub/v2/pkg/strings" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" + "go.opentelemetry.io/otel/trace" ) -// Config provides configuration of a file based Server Certificate manager. CAPool can be a string or an array of strings. -type Config struct { - CAPool interface{} `yaml:"caPool" json:"caPool" description:"file path to the root certificates in PEM format"` - KeyFile urischeme.URIScheme `yaml:"keyFile" json:"keyFile" description:"file name of private key in PEM format"` - CertFile urischeme.URIScheme `yaml:"certFile" json:"certFile" description:"file name of certificate in PEM format"` - UseSystemCAPool bool `yaml:"useSystemCAPool" json:"useSystemCaPool" description:"use system certification pool"` - caPoolArray []urischeme.URIScheme `yaml:"-" json:"-"` - validated bool -} - -func (c *Config) Validate() error { - caPoolArray, ok := strings.ToStringArray(c.CAPool) - if !ok { - return fmt.Errorf("caPool('%v') - unsupported", c.CAPool) - } - c.caPoolArray = urischeme.ToURISchemeArray(caPoolArray) - if !c.UseSystemCAPool && len(c.caPoolArray) == 0 { - return fmt.Errorf("caPool('%v') - is empty", c.CAPool) - } - if c.CertFile == "" { - return fmt.Errorf("certFile('%v')", c.CertFile) - } - if c.KeyFile == "" { - return fmt.Errorf("keyFile('%v')", c.KeyFile) - } - c.validated = true - return nil -} - -func (c *Config) CAPoolArray() ([]urischeme.URIScheme, error) { - if !c.validated { - return nil, errors.New("call Validate() first") - } - return c.caPoolArray, nil -} - -func (c *Config) CAPoolFilePathArray() ([]string, error) { - a, err := c.CAPoolArray() - if err != nil { - return nil, err - } - return urischeme.ToFilePathArray(a), nil -} +type Config = pkgTls.ClientConfig // CertManager holds certificates from filesystem watched for changes -type CertManager struct { - c *general.CertManager -} - -// GetTLSConfig returns tls configuration for clients -func (c *CertManager) GetTLSConfig() *tls.Config { - return c.c.GetClientTLSConfig() -} +type CertManager = general.ClientCertManager -// Close ends watching certificates -func (c *CertManager) Close() { - c.c.Close() +func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*CertManager, error) { + return general.NewClientCertManager(config, fileWatcher, logger, tp) } -// New creates a new certificate manager which watches for certs in a filesystem -func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*CertManager, error) { - if !config.validated { - if err := config.Validate(); err != nil { - return nil, err - } - } - c, err := general.New(general.Config{ - CAPool: config.caPoolArray, - KeyFile: config.KeyFile, - CertFile: config.CertFile, - ClientCertificateRequired: false, - UseSystemCAPool: config.UseSystemCAPool, - }, fileWatcher, logger.With(log.CertManagerKey, "client")) - if err != nil { - return nil, err - } - return &CertManager{ - c: c, - }, nil +func NewHTTPClient(config pkgTls.HTTPConfigurer, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*client.Client, error) { + return general.NewHTTPClient(config, fileWatcher, logger, tracerProvider) } diff --git a/pkg/security/certManager/client/certManager_test.go b/pkg/security/certManager/client/certManager_test.go index 077c8d9b5..eea10a587 100644 --- a/pkg/security/certManager/client/certManager_test.go +++ b/pkg/security/certManager/client/certManager_test.go @@ -11,6 +11,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" testX509 "github.com/plgd-dev/hub/v2/test/security/x509" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "gopkg.in/yaml.v3" ) @@ -143,7 +144,7 @@ func TestNew(t *testing.T) { err = fileWatcher.Close() require.NoError(t, err) }() - mng, err := client.New(config, fileWatcher, logger) + mng, err := client.New(config, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) tlsConfig := mng.GetTLSConfig() diff --git a/pkg/security/certManager/general/certManager.go b/pkg/security/certManager/general/certManager.go index 637e63358..bd8c0f057 100644 --- a/pkg/security/certManager/general/certManager.go +++ b/pkg/security/certManager/general/certManager.go @@ -2,11 +2,14 @@ package general import ( "bytes" + "context" "crypto/sha256" "crypto/tls" "crypto/x509" "errors" "fmt" + "io" + "net/http" "strings" "sync" "time" @@ -16,9 +19,12 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fn" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" + pkgHttpClient "github.com/plgd-dev/hub/v2/pkg/net/http/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" pkgX509 "github.com/plgd-dev/hub/v2/pkg/security/x509" pkgTime "github.com/plgd-dev/hub/v2/pkg/time" "github.com/plgd-dev/kit/v2/security" + "go.opentelemetry.io/otel/trace" "go.uber.org/atomic" ) @@ -29,6 +35,7 @@ type Config struct { CertFile urischeme.URIScheme `yaml:"certFile" json:"certFile" description:"file name of certificate in PEM format"` ClientCertificateRequired bool `yaml:"clientCertificateRequired" json:"clientCertificateRequired" description:"require client certificate"` UseSystemCAPool bool `yaml:"useSystemCAPool" json:"useSystemCaPool" description:"use system certification pool"` + CRL pkgTls.CRLConfig `yaml:"crl" json:"crl"` } func (c Config) Validate() error { @@ -53,6 +60,7 @@ type CertManager struct { logger log.Logger onFileChangeFunc func(event fsnotify.Event) done atomic.Bool + httpClient *pkgHttpClient.Client private struct { mutex sync.Mutex @@ -78,17 +86,27 @@ func tryToWatchFile(file urischeme.URIScheme, fileWatcher *fsnotify.Watcher, rem } // New creates a new certificate manager which watches for certs in a filesystem -func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*CertManager, error) { +func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*CertManager, error) { verifyClientCertificate := tls.RequireAndVerifyClientCert if !config.ClientCertificateRequired { verifyClientCertificate = tls.NoClientCert } + var httpClient *pkgHttpClient.Client + if config.CRL.Enabled { + var err error + httpClient, err = NewHTTPClient(config.CRL.HTTP, fileWatcher, logger, tp) + if err != nil { + return nil, err + } + } + c := &CertManager{ fileWatcher: fileWatcher, config: config, verifyClientCertificate: verifyClientCertificate, logger: logger, + httpClient: httpClient, } _, err := c.loadCAs() if err != nil { @@ -364,3 +382,48 @@ func (a *CertManager) onFileChange(event fsnotify.Event) { } } } + +func (a *CertManager) downloadCRL(ctx context.Context, cdp string) (*x509.RevocationList, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, cdp, nil) + if err != nil { + return nil, err + } + req.Close = true + resp, err := a.httpClient.HTTP().Do(req) + if err != nil { + return nil, err + } + defer func() { + if errC := resp.Body.Close(); errC != nil { + a.logger.Errorf("failed to close response body stream: %w", errC) + } + }() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected statusCode %v: '%v'", resp.StatusCode, string(respBody)) + } + crl, err := x509.ParseRevocationList(respBody) + if err != nil { + return nil, err + } + return crl, nil +} + +func (a *CertManager) VerifyByCRL(ctx context.Context, certificate *x509.Certificate, cdps []string) error { + if !a.config.CRL.Enabled { + return nil + } + for _, dp := range cdps { + crl, err := a.downloadCRL(ctx, dp) + if err == nil { + if pkgX509.IsRevoked(certificate, crl) { + return fmt.Errorf("certificate(%s) was revoked by CRL(%s)", "", crl.Issuer.String()) + } + return nil + } + } + return errors.New("failed to verify certificate by CRL") +} diff --git a/pkg/security/certManager/general/certManager_test.go b/pkg/security/certManager/general/certManager_test.go index 3101e880b..27984e0b7 100644 --- a/pkg/security/certManager/general/certManager_test.go +++ b/pkg/security/certManager/general/certManager_test.go @@ -18,6 +18,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/security/certManager/general" pkgX509 "github.com/plgd-dev/hub/v2/pkg/security/x509" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" ) func getCA(t *testing.T, validFrom time.Time, validFor time.Duration) ([]byte, *ecdsa.PrivateKey) { @@ -79,7 +80,7 @@ func TestNew(t *testing.T) { defer func() { _ = fileWatcher.Close() }() - mng, err := general.New(config, fileWatcher, logger) + mng, err := general.New(config, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) defer mng.Close() @@ -173,7 +174,7 @@ func TestCertManagerWithExpiredCA(t *testing.T) { defer func() { _ = fileWatcher.Close() }() - mng, err := general.New(config, fileWatcher, logger) + mng, err := general.New(config, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) defer mng.Close() pool := mng.GetCertificateAuthorities() @@ -227,7 +228,7 @@ func TestCertManagerWithExpiredCertificate(t *testing.T) { defer func() { _ = fileWatcher.Close() }() - mng, err := general.New(config, fileWatcher, logger) + mng, err := general.New(config, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) defer mng.Close() diff --git a/pkg/security/certManager/general/clientCertManager.go b/pkg/security/certManager/general/clientCertManager.go new file mode 100644 index 000000000..79ac8e411 --- /dev/null +++ b/pkg/security/certManager/general/clientCertManager.go @@ -0,0 +1,67 @@ +package general + +import ( + "crypto/tls" + "fmt" + + "github.com/plgd-dev/hub/v2/pkg/config/property/urischeme" + "github.com/plgd-dev/hub/v2/pkg/fsnotify" + "github.com/plgd-dev/hub/v2/pkg/log" + "github.com/plgd-dev/hub/v2/pkg/net/http/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" + "go.opentelemetry.io/otel/trace" +) + +func ClientConfig(caPoolArray []urischeme.URIScheme, keyFile, certFile urischeme.URIScheme, useSystemCAPool bool, crl pkgTls.CRLConfig) Config { + return Config{ + CAPool: caPoolArray, + KeyFile: keyFile, + CertFile: certFile, + ClientCertificateRequired: false, + UseSystemCAPool: useSystemCAPool, + CRL: crl, + } +} + +func ClientLogger(logger log.Logger) log.Logger { + return logger.With(log.CertManagerKey, "client") +} + +// CertManager holds certificates from filesystem watched for changes +type ClientCertManager struct { + c *CertManager +} + +// GetTLSConfig returns tls configuration for clients +func (c *ClientCertManager) GetTLSConfig() *tls.Config { + return c.c.GetClientTLSConfig() +} + +// Close ends watching certificates +func (c *ClientCertManager) Close() { + c.c.Close() +} + +// New creates a new certificate manager which watches for certs in a filesystem +func NewClientCertManager(config pkgTls.ClientConfig, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*ClientCertManager, error) { + if err := config.Validate(); err != nil { + return nil, err + } + caPoolArray, _ := config.CAPoolArray() + + c, err := New(ClientConfig(caPoolArray, config.KeyFile, config.CertFile, config.UseSystemCAPool, config.CRL), fileWatcher, ClientLogger(logger), tp) + if err != nil { + return nil, err + } + return &ClientCertManager{ + c: c, + }, nil +} + +func NewHTTPClient(config pkgTls.HTTPConfigurer, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*client.Client, error) { + cm, err := NewClientCertManager(config.GetTLS(), fileWatcher, logger, tp) + if err != nil { + return nil, fmt.Errorf("cannot create cert manager %w", err) + } + return client.New(config, cm, tp) +} diff --git a/pkg/security/certManager/server/certManager.go b/pkg/security/certManager/server/certManager.go index 953d35cbb..5b65fa651 100644 --- a/pkg/security/certManager/server/certManager.go +++ b/pkg/security/certManager/server/certManager.go @@ -1,7 +1,9 @@ package server import ( + "context" "crypto/tls" + "crypto/x509" "errors" "fmt" @@ -9,21 +11,28 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" "github.com/plgd-dev/hub/v2/pkg/security/certManager/general" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" "github.com/plgd-dev/hub/v2/pkg/strings" + "go.opentelemetry.io/otel/trace/noop" ) // Config provides configuration of a file based Server Certificate manager. CAPool can be a string or an array of strings. type Config struct { - CAPool interface{} `yaml:"caPool" json:"caPool" description:"file path to the root certificates in PEM format"` - KeyFile urischeme.URIScheme `yaml:"keyFile" json:"keyFile" description:"file name of private key in PEM format"` - CertFile urischeme.URIScheme `yaml:"certFile" json:"certFile" description:"file name of certificate in PEM format"` - ClientCertificateRequired bool `yaml:"clientCertificateRequired" json:"clientCertificateRequired" description:"require client certificate"` - CAPoolIsOptional bool `yaml:"-" json:"-"` - caPoolArray []urischeme.URIScheme `yaml:"-" json:"-"` - validated bool + CAPool interface{} `yaml:"caPool" json:"caPool" description:"file path to the root certificates in PEM format"` + KeyFile urischeme.URIScheme `yaml:"keyFile" json:"keyFile" description:"file name of private key in PEM format"` + CertFile urischeme.URIScheme `yaml:"certFile" json:"certFile" description:"file name of certificate in PEM format"` + ClientCertificateRequired bool `yaml:"clientCertificateRequired" json:"clientCertificateRequired" description:"require client certificate"` + CRL pkgTls.CRLConfig `yaml:"crl" json:"crl"` + + CAPoolIsOptional bool `yaml:"-" json:"-"` + caPoolArray []urischeme.URIScheme `yaml:"-" json:"-"` + validated bool } func (c *Config) Validate() error { + if c.validated { + return nil + } caPoolArray, ok := strings.ToStringArray(c.CAPool) if !ok { return fmt.Errorf("caPool('%v') - unsupported", c.CAPool) @@ -59,6 +68,10 @@ func (c *CertManager) GetTLSConfig() *tls.Config { return c.c.GetServerTLSConfig() } +func (c *CertManager) VerifyByCRL(ctx context.Context, certificate *x509.Certificate, cdp []string) error { + return c.c.VerifyByCRL(ctx, certificate, cdp) +} + // Close ends watching certificates func (c *CertManager) Close() { c.c.Close() @@ -77,7 +90,9 @@ func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*Cert CertFile: config.CertFile, ClientCertificateRequired: config.ClientCertificateRequired, UseSystemCAPool: false, - }, fileWatcher, logger.With(log.CertManagerKey, "server")) + CRL: config.CRL, + // TODO: use real trace provider + }, fileWatcher, logger.With(log.CertManagerKey, "server"), noop.NewTracerProvider()) if err != nil { return nil, err } diff --git a/pkg/security/certificateSigner/certificateSigner.go b/pkg/security/certificateSigner/certificateSigner.go index 707588448..acbbc2f76 100644 --- a/pkg/security/certificateSigner/certificateSigner.go +++ b/pkg/security/certificateSigner/certificateSigner.go @@ -15,9 +15,10 @@ import ( ) type SignerConfig struct { - ValidNotBefore time.Time - ValidNotAfter time.Time - OverrideCertTemplate func(template *x509.Certificate) error + ValidNotBefore time.Time + ValidNotAfter time.Time + CRLDistributionPoints []string + OverrideCertTemplate func(template *x509.Certificate) error } type Opt = func(cfg *SignerConfig) @@ -34,6 +35,12 @@ func WithNotAfter(validNotAfter time.Time) Opt { } } +func WithCRLDistributionPoints(crlDistributionPoints []string) Opt { + return func(cfg *SignerConfig) { + cfg.CRLDistributionPoints = crlDistributionPoints + } +} + func WithOverrideCertTemplate(overrideCertTemplate func(template *x509.Certificate) error) Opt { return func(cfg *SignerConfig) { cfg.OverrideCertTemplate = overrideCertTemplate @@ -56,10 +63,7 @@ func New(caCert []*x509.Certificate, caKey crypto.PrivateKey, opts ...Opt) *Cert return &CertificateSigner{caCert: caCert, caKey: caKey, cfg: cfg} } -func (s *CertificateSigner) Sign(_ context.Context, csr []byte) ([]byte, error) { - if len(s.caCert) == 0 { - return nil, errors.New("cannot sign with empty signer CA certificates") - } +func parseCertificateRequest(csr []byte) (*x509.CertificateRequest, error) { csrBlock, _ := pem.Decode(csr) if csrBlock == nil { return nil, errors.New("pem not found") @@ -74,7 +78,17 @@ func (s *CertificateSigner) Sign(_ context.Context, csr []byte) ([]byte, error) if err != nil { return nil, err } + return certificateRequest, nil +} +func (s *CertificateSigner) Sign(_ context.Context, csr []byte) ([]byte, error) { + if len(s.caCert) == 0 { + return nil, errors.New("cannot sign with empty signer CA certificates") + } + parsedCSR, err := parseCertificateRequest(csr) + if err != nil { + return nil, err + } notBefore := s.cfg.ValidNotBefore notAfter := s.cfg.ValidNotAfter for _, c := range s.caCert { @@ -92,25 +106,26 @@ func (s *CertificateSigner) Sign(_ context.Context, csr []byte) ([]byte, error) } template := x509.Certificate{ - SerialNumber: serialNumber, - NotBefore: notBefore, - NotAfter: notAfter, - Subject: certificateRequest.Subject, - PublicKeyAlgorithm: certificateRequest.PublicKeyAlgorithm, - PublicKey: certificateRequest.PublicKey, - SignatureAlgorithm: s.caCert[0].SignatureAlgorithm, - DNSNames: certificateRequest.DNSNames, - IPAddresses: certificateRequest.IPAddresses, - URIs: certificateRequest.URIs, - EmailAddresses: certificateRequest.EmailAddresses, - ExtraExtensions: certificateRequest.Extensions, + SerialNumber: serialNumber, + NotBefore: notBefore, + NotAfter: notAfter, + Subject: parsedCSR.Subject, + PublicKeyAlgorithm: parsedCSR.PublicKeyAlgorithm, + PublicKey: parsedCSR.PublicKey, + SignatureAlgorithm: s.caCert[0].SignatureAlgorithm, + DNSNames: parsedCSR.DNSNames, + IPAddresses: parsedCSR.IPAddresses, + URIs: parsedCSR.URIs, + EmailAddresses: parsedCSR.EmailAddresses, + ExtraExtensions: parsedCSR.Extensions, + CRLDistributionPoints: s.cfg.CRLDistributionPoints, } if s.cfg.OverrideCertTemplate != nil { if err = s.cfg.OverrideCertTemplate(&template); err != nil { return nil, err } } - signedCsr, err := x509.CreateCertificate(rand.Reader, &template, s.caCert[0], certificateRequest.PublicKey, s.caKey) + signedCsr, err := x509.CreateCertificate(rand.Reader, &template, s.caCert[0], parsedCSR.PublicKey, s.caKey) if err != nil { return nil, err } diff --git a/pkg/security/certificateSigner/certificateSigner_test.go b/pkg/security/certificateSigner/certificateSigner_test.go index 180f1f86c..031655178 100644 --- a/pkg/security/certificateSigner/certificateSigner_test.go +++ b/pkg/security/certificateSigner/certificateSigner_test.go @@ -72,7 +72,6 @@ func TestCertificateSignerSign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := s.Sign(context.Background(), tt.args.csr) - if tt.wantErr { require.Error(t, err) return diff --git a/pkg/security/certificateSigner/identityCertificateSigner_test.go b/pkg/security/certificateSigner/identityCertificateSigner_test.go index 27b1baa82..7f95d1e4d 100644 --- a/pkg/security/certificateSigner/identityCertificateSigner_test.go +++ b/pkg/security/certificateSigner/identityCertificateSigner_test.go @@ -72,7 +72,6 @@ func TestIdentityCertificateSignerSign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := s.Sign(context.Background(), tt.args.csr) - if tt.wantErr { require.Error(t, err) return diff --git a/pkg/security/jwt/validator/config.go b/pkg/security/jwt/validator/config.go index fdd3f882d..667ffd48c 100644 --- a/pkg/security/jwt/validator/config.go +++ b/pkg/security/jwt/validator/config.go @@ -4,12 +4,12 @@ import ( "fmt" "time" - "github.com/plgd-dev/hub/v2/pkg/net/http/client" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" ) type AuthorityConfig struct { - Authority string `yaml:"authority" json:"authority"` - HTTP client.Config `yaml:"http" json:"http"` + Authority string `yaml:"authority" json:"authority"` + HTTP pkgTls.HTTPConfig `yaml:"http" json:"http"` } func (c *AuthorityConfig) Validate() error { @@ -38,7 +38,7 @@ type Config struct { Endpoints []AuthorityConfig `yaml:"endpoints" json:"endpoints"` TokenVerification TokenTrustVerificationConfig `yaml:"tokenTrustVerification,omitempty" json:"tokenTrustVerification,omitempty"` Authority *string `yaml:"authority,omitempty" json:"authority,omitempty"` // deprecated - HTTP *client.Config `yaml:"http,omitempty" json:"http,omitempty"` // deprecated + HTTP *pkgTls.HTTPConfig `yaml:"http,omitempty" json:"http,omitempty"` // deprecated } func (c *Config) Validate() error { diff --git a/pkg/security/jwt/validator/config_test.go b/pkg/security/jwt/validator/config_test.go index 638dcfe3c..66cc00bd5 100644 --- a/pkg/security/jwt/validator/config_test.go +++ b/pkg/security/jwt/validator/config_test.go @@ -3,8 +3,8 @@ package validator_test import ( "testing" - "github.com/plgd-dev/hub/v2/pkg/net/http/client" "github.com/plgd-dev/hub/v2/pkg/security/jwt/validator" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/require" ) @@ -54,7 +54,7 @@ func TestConfig_Validate(t *testing.T) { s := "example-authority" return &s }(), - HTTP: func() *client.Config { + HTTP: func() *pkgTls.HTTPConfig { c := config.MakeHttpClientConfig() return &c }(), diff --git a/pkg/security/jwt/validator/validator.go b/pkg/security/jwt/validator/validator.go index f54e3389f..2b7cb833c 100644 --- a/pkg/security/jwt/validator/validator.go +++ b/pkg/security/jwt/validator/validator.go @@ -10,8 +10,8 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fn" "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" - "github.com/plgd-dev/hub/v2/pkg/net/http/client" pkgHttpUri "github.com/plgd-dev/hub/v2/pkg/net/http/uri" + cmClient "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" jwtValidator "github.com/plgd-dev/hub/v2/pkg/security/jwt" "github.com/plgd-dev/hub/v2/pkg/security/openid" "go.opentelemetry.io/otel/trace" @@ -83,7 +83,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg openIDConfigurations := make([]openid.Config, 0, len(config.Endpoints)) clients := make(map[string]jwtValidator.TokenIssuerClient, len(config.Endpoints)) for _, authority := range config.Endpoints { - httpClient, err := client.New(authority.HTTP, fileWatcher, logger, tracerProvider) + httpClient, err := cmClient.NewHTTPClient(&authority.HTTP, fileWatcher, logger, tracerProvider) if err != nil { onClose.Execute() return nil, fmt.Errorf("cannot create client cert manager: %w", err) diff --git a/pkg/security/oauth2/config.go b/pkg/security/oauth2/config.go index 4d7c6f4b4..c1aadcddf 100644 --- a/pkg/security/oauth2/config.go +++ b/pkg/security/oauth2/config.go @@ -3,15 +3,15 @@ package oauth2 import ( "fmt" - "github.com/plgd-dev/hub/v2/pkg/net/http/client" "github.com/plgd-dev/hub/v2/pkg/security/oauth2/oauth" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" ) // Config general configuration type Config struct { Authority string `yaml:"authority" json:"authority"` oauth.Config `yaml:",inline"` - HTTP client.Config `yaml:"http" json:"http"` + HTTP pkgTls.HTTPConfig `yaml:"http" json:"http"` } func (c *Config) Validate() error { diff --git a/pkg/security/oauth2/plgd.go b/pkg/security/oauth2/plgd.go index f7e4b83f5..e86e5188d 100644 --- a/pkg/security/oauth2/plgd.go +++ b/pkg/security/oauth2/plgd.go @@ -6,7 +6,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" - "github.com/plgd-dev/hub/v2/pkg/net/http/client" + cmClient "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" "github.com/plgd-dev/hub/v2/pkg/security/jwt" "github.com/plgd-dev/hub/v2/pkg/security/oauth2/oauth" "github.com/plgd-dev/hub/v2/pkg/security/openid" @@ -32,7 +32,7 @@ func NewPlgdProvider(ctx context.Context, config Config, fileWatcher *fsnotify.W } config.ClientSecret = string(clientSecret) - httpClient, err := client.New(config.HTTP, fileWatcher, logger, tracerProvider) + httpClient, err := cmClient.NewHTTPClient(&config.HTTP, fileWatcher, logger, tracerProvider) if err != nil { return nil, err } diff --git a/pkg/security/tls/client.go b/pkg/security/tls/client.go new file mode 100644 index 000000000..18234279a --- /dev/null +++ b/pkg/security/tls/client.go @@ -0,0 +1,72 @@ +package tls + +import ( + "errors" + "fmt" + "slices" + + "github.com/plgd-dev/hub/v2/pkg/config/property/urischeme" + "github.com/plgd-dev/hub/v2/pkg/strings" +) + +// ClientConfig provides configuration of a file based Server Certificate manager. CAPool can be a string or an array of strings. +type ClientConfig struct { + CAPool interface{} `yaml:"caPool" json:"caPool" description:"file path to the root certificates in PEM format"` + KeyFile urischeme.URIScheme `yaml:"keyFile" json:"keyFile" description:"file name of private key in PEM format"` + CertFile urischeme.URIScheme `yaml:"certFile" json:"certFile" description:"file name of certificate in PEM format"` + UseSystemCAPool bool `yaml:"useSystemCAPool" json:"useSystemCaPool" description:"use system certification pool"` + CRL CRLConfig `yaml:"crl" json:"json"` + + caPoolArray []urischeme.URIScheme `yaml:"-" json:"-"` + validated bool +} + +func (c *ClientConfig) Validate() error { + if c.validated { + return nil + } + caPoolArray, ok := strings.ToStringArray(c.CAPool) + if !ok { + return fmt.Errorf("caPool('%v') - unsupported", c.CAPool) + } + c.caPoolArray = urischeme.ToURISchemeArray(caPoolArray) + if !c.UseSystemCAPool && len(c.caPoolArray) == 0 { + return fmt.Errorf("caPool('%v') - is empty", c.CAPool) + } + if c.CertFile == "" { + return fmt.Errorf("certFile('%v')", c.CertFile) + } + if c.KeyFile == "" { + return fmt.Errorf("keyFile('%v')", c.KeyFile) + } + c.validated = true + return nil +} + +func (c *ClientConfig) CAPoolArray() ([]urischeme.URIScheme, error) { + if !c.validated { + return nil, errors.New("call Validate() first") + } + return c.caPoolArray, nil +} + +func (c *ClientConfig) CAPoolFilePathArray() ([]string, error) { + a, err := c.CAPoolArray() + if err != nil { + return nil, err + } + return urischeme.ToFilePathArray(a), nil +} + +func (c *ClientConfig) Equals(c2 ClientConfig) bool { + caPool1, ok1 := strings.ToStringArray(c.CAPool) + caPool2, ok2 := strings.ToStringArray(c2.CAPool) + if !ok1 || !ok2 { + return false + } + return slices.Equal(caPool1, caPool2) && + c.KeyFile == c2.KeyFile && + c.CertFile == c2.CertFile && + c.UseSystemCAPool == c2.UseSystemCAPool && + c.CRL.Equals(c2.CRL) +} diff --git a/pkg/net/http/client/config.go b/pkg/security/tls/crl.go similarity index 52% rename from pkg/net/http/client/config.go rename to pkg/security/tls/crl.go index a8a09edc7..21372684f 100644 --- a/pkg/net/http/client/config.go +++ b/pkg/security/tls/crl.go @@ -1,13 +1,25 @@ -package client +package tls import ( + "errors" "fmt" "time" - "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" + "gopkg.in/yaml.v3" ) -type Config struct { +type HTTPConfigurer interface { + GetMaxIdleConns() int + GetMaxConnsPerHost() int + GetMaxIdleConnsPerHost() int + GetIdleConnTimeout() time.Duration + GetTimeout() time.Duration + GetTLS() ClientConfig + + Validate() error +} + +type HTTPConfig struct { // MaxIdleConns controls the maximum number of idle (keep-alive) // connections across all hosts. Zero means no limit. MaxIdleConns int `yaml:"maxIdleConns" json:"maxIdleConns"` @@ -47,10 +59,10 @@ type Config struct { // for cancellation instead of implementing CancelRequest. Timeout time.Duration `yaml:"timeout" json:"timeout"` - TLS client.Config `yaml:"tls" json:"tls"` + TLS ClientConfig `yaml:"tls" json:"tls"` } -func (c *Config) Validate() error { +func (c *HTTPConfig) Validate() error { if c.MaxIdleConns < 0 { return fmt.Errorf("maxIdleConns('%v')", c.MaxIdleConns) } @@ -71,3 +83,82 @@ func (c *Config) Validate() error { } return nil } + +func (c *HTTPConfig) GetMaxIdleConns() int { + return c.MaxIdleConns +} + +func (c *HTTPConfig) GetMaxConnsPerHost() int { + return c.MaxConnsPerHost +} + +func (c *HTTPConfig) GetMaxIdleConnsPerHost() int { + return c.MaxIdleConnsPerHost +} + +func (c *HTTPConfig) GetIdleConnTimeout() time.Duration { + return c.IdleConnTimeout +} + +func (c *HTTPConfig) GetTimeout() time.Duration { + return c.Timeout +} + +func (c *HTTPConfig) GetTLS() ClientConfig { + return c.TLS +} + +type CRLConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + HTTP HTTPConfigurer `yaml:"http" json:"http"` +} + +func (c *CRLConfig) Equals(c2 CRLConfig) bool { + if c.Enabled != c2.Enabled { + return false + } + if !c.Enabled { + r1 := c.HTTP == nil + r2 := c2.HTTP == nil + return r1 && r2 + } + if c.HTTP == nil { + return c2.HTTP == nil + } + tls := c.HTTP.GetTLS() + return c.HTTP.GetMaxIdleConns() == c2.HTTP.GetMaxIdleConns() && + c.HTTP.GetMaxConnsPerHost() == c2.HTTP.GetMaxConnsPerHost() && + c.HTTP.GetMaxIdleConnsPerHost() == c2.HTTP.GetMaxIdleConnsPerHost() && + c.HTTP.GetIdleConnTimeout() == c2.HTTP.GetIdleConnTimeout() && + c.HTTP.GetTimeout() == c2.HTTP.GetTimeout() && + tls.Equals(c2.HTTP.GetTLS()) +} + +func (c *CRLConfig) Validate() error { + if !c.Enabled { + return nil + } + if c.HTTP == nil { + return errors.New("http configuration missing") + } + return c.HTTP.Validate() +} + +func (c *CRLConfig) UnmarshalYAML(value *yaml.Node) error { + type crlConfig struct { + Enabled bool `yaml:"enabled"` + HTTP HTTPConfig `yaml:"http"` + } + cc := crlConfig{} + err := value.Decode(&cc) + if err != nil { + return err + } + c.Enabled = cc.Enabled + if !cc.Enabled { + c.HTTP = nil + return nil + } + c.HTTP = &cc.HTTP + return nil +} diff --git a/pkg/security/tls/crl_test.go b/pkg/security/tls/crl_test.go new file mode 100644 index 000000000..469d79e2b --- /dev/null +++ b/pkg/security/tls/crl_test.go @@ -0,0 +1,161 @@ +package tls_test + +import ( + "testing" + "time" + + "github.com/plgd-dev/hub/v2/pkg/security/tls" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func marshal(t *testing.T, in interface{}) string { + out, err := yaml.Marshal(in) + require.NoError(t, err) + return string(out) +} + +func TestCRLConfigUnmarshalYAML(t *testing.T) { + type args struct { + yaml string + } + tests := []struct { + name string + args args + want tls.CRLConfig + wantErr bool + }{ + { + name: "valid - CRL disabled", + args: args{ + yaml: `enabled: false`, + }, + want: tls.CRLConfig{ + Enabled: false, + HTTP: nil, + }, + }, + { + name: "valid - CRL enabled", + args: args{ + yaml: `enabled: true +http: + maxIdleConns: 10 + maxConnsPerHost: 20 + maxIdleConnsPerHost: 5 + idleConnTimeout: 30s + timeout: 60s +`, + }, + want: tls.CRLConfig{ + Enabled: true, + HTTP: &tls.HTTPConfig{ + MaxIdleConns: 10, + MaxConnsPerHost: 20, + MaxIdleConnsPerHost: 5, + IdleConnTimeout: 30 * time.Second, + Timeout: 60 * time.Second, + }, + }, + }, + { + name: "valid - CRL enabled, HTTP with TLS", + args: args{ + yaml: `enabled: true +http: + maxIdleConns: 10 + maxConnsPerHost: 20 + maxIdleConnsPerHost: 5 + idleConnTimeout: 30s + timeout: 60s + tls: + caPool: /capool + keyFile: /keyfile + certFile: /certfile + useSystemCAPool: true +`, + }, + want: tls.CRLConfig{ + Enabled: true, + HTTP: &tls.HTTPConfig{ + MaxIdleConns: 10, + MaxConnsPerHost: 20, + MaxIdleConnsPerHost: 5, + IdleConnTimeout: 30 * time.Second, + Timeout: 60 * time.Second, + TLS: tls.ClientConfig{ + CAPool: "/capool", + KeyFile: "/keyfile", + CertFile: "/certfile", + UseSystemCAPool: true, + }, + }, + }, + }, + { + name: "valid - CRL enabled, HTTP with TLS, recursive", + args: args{ + yaml: `enabled: true +http: + maxIdleConns: 10 + maxConnsPerHost: 20 + maxIdleConnsPerHost: 5 + idleConnTimeout: 30s + timeout: 60s + tls: + caPool: /capool + keyFile: /keyfile + certFile: /certfile + useSystemCAPool: true + crl: + enabled: true + http: + maxIdleConns: 20 + maxConnsPerHost: 40 + maxIdleConnsPerHost: 10 + idleConnTimeout: 60s + timeout: 120s +`, + }, + want: tls.CRLConfig{ + Enabled: true, + HTTP: &tls.HTTPConfig{ + MaxIdleConns: 10, + MaxConnsPerHost: 20, + MaxIdleConnsPerHost: 5, + IdleConnTimeout: 30 * time.Second, + Timeout: 60 * time.Second, + TLS: tls.ClientConfig{ + CAPool: "/capool", + KeyFile: "/keyfile", + CertFile: "/certfile", + UseSystemCAPool: true, + CRL: tls.CRLConfig{ + Enabled: true, + HTTP: &tls.HTTPConfig{ + MaxIdleConns: 20, + MaxConnsPerHost: 40, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + Timeout: 120 * time.Second, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var crlConfig tls.CRLConfig + err := yaml.Unmarshal([]byte(tt.args.yaml), &crlConfig) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.True(t, tt.want.Equals(crlConfig), "want:\n\n%v\nactual:\n\n%v\n", marshal(t, tt.want), marshal(t, crlConfig)) + }) + } +} diff --git a/pkg/security/x509/revocationList.go b/pkg/security/x509/revocationList.go new file mode 100644 index 000000000..9423179d8 --- /dev/null +++ b/pkg/security/x509/revocationList.go @@ -0,0 +1,12 @@ +package x509 + +import "crypto/x509" + +func IsRevoked(certificate *x509.Certificate, crl *x509.RevocationList) bool { + for _, entry := range crl.RevokedCertificateEntries { + if certificate.SerialNumber.Cmp(entry.SerialNumber) == 0 { + return true + } + } + return false +} diff --git a/resource-aggregate/cqrs/eventbus/nats/client/client.go b/resource-aggregate/cqrs/eventbus/nats/client/client.go index 21a2c4862..58dc9f24b 100644 --- a/resource-aggregate/cqrs/eventbus/nats/client/client.go +++ b/resource-aggregate/cqrs/eventbus/nats/client/client.go @@ -8,6 +8,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/fsnotify" "github.com/plgd-dev/hub/v2/pkg/log" "github.com/plgd-dev/hub/v2/pkg/security/certManager/client" + "go.opentelemetry.io/otel/trace" ) type Client struct { @@ -15,8 +16,8 @@ type Client struct { closeFunc fn.FuncList } -func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger) (*Client, error) { - certManager, err := client.New(config.TLS, fileWatcher, logger) +func New(config Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider) (*Client, error) { + certManager, err := client.New(config.TLS, fileWatcher, logger, tp) if err != nil { return nil, fmt.Errorf("cannot create cert manager: %w", err) } diff --git a/resource-aggregate/cqrs/eventbus/nats/publisher/publisher_test.go b/resource-aggregate/cqrs/eventbus/nats/publisher/publisher_test.go index 8f95fca5c..0c3ac6803 100644 --- a/resource-aggregate/cqrs/eventbus/nats/publisher/publisher_test.go +++ b/resource-aggregate/cqrs/eventbus/nats/publisher/publisher_test.go @@ -20,6 +20,7 @@ import ( "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" ) func TestPublisher(t *testing.T) { @@ -42,7 +43,7 @@ func TestPublisher(t *testing.T) { TLS: config.MakeTLSClientConfig(), FlusherTimeout: time.Second * 30, }, - }, fileWatcher, logger, publisher.WithMarshaler(json.Marshal)) + }, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(json.Marshal)) require.NoError(t, err) assert.NotNil(t, publisher) defer func() { @@ -51,7 +52,7 @@ func TestPublisher(t *testing.T) { }() naSubClient, subscriber, err := test.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(func(f func()) error { go f(); return nil }), subscriber.WithUnmarshaler(json.Unmarshal)) require.NoError(t, err) @@ -101,7 +102,7 @@ func TestPublisherJetStream(t *testing.T) { FlusherTimeout: time.Second * 30, }, JetStream: true, - }, fileWatcher, logger, publisher.WithMarshaler(json.Marshal)) + }, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(json.Marshal)) require.NoError(t, err) assert.NotNil(t, publisher) defer func() { @@ -110,7 +111,7 @@ func TestPublisherJetStream(t *testing.T) { }() naSubClient, subscriber, err := test.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(func(f func()) error { go f(); return nil }), subscriber.WithUnmarshaler(json.Unmarshal)) require.NoError(t, err) diff --git a/resource-aggregate/cqrs/eventbus/nats/subscriber/reconnect_test.go b/resource-aggregate/cqrs/eventbus/nats/subscriber/reconnect_test.go index e50506e32..9adfa0c7a 100644 --- a/resource-aggregate/cqrs/eventbus/nats/subscriber/reconnect_test.go +++ b/resource-aggregate/cqrs/eventbus/nats/subscriber/reconnect_test.go @@ -16,6 +16,7 @@ import ( "github.com/plgd-dev/hub/v2/test" "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" ) func TestSubscriberReconnect(t *testing.T) { @@ -32,7 +33,7 @@ func TestSubscriberReconnect(t *testing.T) { require.NoError(t, errC) }() - naPubClient, pub, err := natsTest.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, publisher.WithMarshaler(json.Marshal)) + naPubClient, pub, err := natsTest.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(json.Marshal)) require.NoError(t, err) require.NotNil(t, pub) defer func() { @@ -41,7 +42,7 @@ func TestSubscriberReconnect(t *testing.T) { }() naSubClient, subscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(func(f func()) error { go f(); return nil }), subscriber.WithUnmarshaler(json.Unmarshal)) require.NoError(t, err) @@ -97,7 +98,7 @@ func TestSubscriberReconnect(t *testing.T) { case <-ctx.Done(): require.NoError(t, errors.New("Timeout")) } - naClient1, pub1, err := natsTest.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, publisher.WithMarshaler(json.Marshal)) + naClient1, pub1, err := natsTest.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(json.Marshal)) require.NoError(t, err) require.NotNil(t, pub1) defer func() { diff --git a/resource-aggregate/cqrs/eventbus/nats/subscriber/subscriber_test.go b/resource-aggregate/cqrs/eventbus/nats/subscriber/subscriber_test.go index 53fb1f57b..0c2e11b42 100644 --- a/resource-aggregate/cqrs/eventbus/nats/subscriber/subscriber_test.go +++ b/resource-aggregate/cqrs/eventbus/nats/subscriber/subscriber_test.go @@ -17,6 +17,7 @@ import ( "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventstore" "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" ) func TestSubscriber(t *testing.T) { @@ -34,7 +35,7 @@ func TestSubscriber(t *testing.T) { require.NoError(t, errC) }() - naPubClient, publisher, err := test.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, publisher.WithMarshaler(json.Marshal)) + naPubClient, publisher, err := test.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(json.Marshal)) require.NoError(t, err) require.NotNil(t, publisher) defer func() { @@ -43,7 +44,7 @@ func TestSubscriber(t *testing.T) { }() naSubClient, subscriber, err := test.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(func(f func()) error { go f(); return nil }), subscriber.WithUnmarshaler(json.Unmarshal), ) diff --git a/resource-aggregate/cqrs/eventbus/nats/test/publisher.go b/resource-aggregate/cqrs/eventbus/nats/test/publisher.go index 05bcac7fe..a2e1b7c24 100644 --- a/resource-aggregate/cqrs/eventbus/nats/test/publisher.go +++ b/resource-aggregate/cqrs/eventbus/nats/test/publisher.go @@ -5,10 +5,11 @@ import ( "github.com/plgd-dev/hub/v2/pkg/log" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventbus/nats/client" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventbus/nats/publisher" + "go.opentelemetry.io/otel/trace" ) -func NewClientAndPublisher(config client.ConfigPublisher, fileWatcher *fsnotify.Watcher, logger log.Logger, opts ...publisher.Option) (*client.Client, *publisher.Publisher, error) { - c, err := client.New(config.Config, fileWatcher, logger) +func NewClientAndPublisher(config client.ConfigPublisher, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider, opts ...publisher.Option) (*client.Client, *publisher.Publisher, error) { + c, err := client.New(config.Config, fileWatcher, logger, tp) if err != nil { return nil, nil, err } diff --git a/resource-aggregate/cqrs/eventbus/nats/test/subscriber.go b/resource-aggregate/cqrs/eventbus/nats/test/subscriber.go index 42e5c4f3c..b21bcae59 100644 --- a/resource-aggregate/cqrs/eventbus/nats/test/subscriber.go +++ b/resource-aggregate/cqrs/eventbus/nats/test/subscriber.go @@ -5,10 +5,11 @@ import ( "github.com/plgd-dev/hub/v2/pkg/log" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventbus/nats/client" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventbus/nats/subscriber" + "go.opentelemetry.io/otel/trace" ) -func NewClientAndSubscriber(config client.ConfigSubscriber, fileWatcher *fsnotify.Watcher, logger log.Logger, opts ...subscriber.Option) (*client.Client, *subscriber.Subscriber, error) { - c, err := client.New(config.Config, fileWatcher, logger) +func NewClientAndSubscriber(config client.ConfigSubscriber, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider, opts ...subscriber.Option) (*client.Client, *subscriber.Subscriber, error) { + c, err := client.New(config.Config, fileWatcher, logger, tp) if err != nil { return nil, nil, err } diff --git a/resource-aggregate/cqrs/eventstore/cqldb/eventstore.go b/resource-aggregate/cqrs/eventstore/cqldb/eventstore.go index ab63d6ef7..91d4a3751 100644 --- a/resource-aggregate/cqrs/eventstore/cqldb/eventstore.go +++ b/resource-aggregate/cqrs/eventstore/cqldb/eventstore.go @@ -53,7 +53,7 @@ func New(ctx context.Context, config *Config, fileWatcher *fsnotify.Watcher, log for _, o := range opts { o.apply(config) } - certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger) + certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/resource-aggregate/cqrs/eventstore/mongodb/eventstore.go b/resource-aggregate/cqrs/eventstore/mongodb/eventstore.go index 7d84b0c52..b4f7bb112 100644 --- a/resource-aggregate/cqrs/eventstore/mongodb/eventstore.go +++ b/resource-aggregate/cqrs/eventstore/mongodb/eventstore.go @@ -137,7 +137,7 @@ func New(ctx context.Context, config *Config, fileWatcher *fsnotify.Watcher, log for _, o := range opts { o.apply(config) } - certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger) + certManager, err := client.New(config.Embedded.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/resource-aggregate/cqrs/projection/projectionInternal_test.go b/resource-aggregate/cqrs/projection/projectionInternal_test.go index d48a322c9..a2b5db9c0 100644 --- a/resource-aggregate/cqrs/projection/projectionInternal_test.go +++ b/resource-aggregate/cqrs/projection/projectionInternal_test.go @@ -56,7 +56,7 @@ func TestProjection(t *testing.T) { require.NoError(t, errC) }() - naPubClient, publisher, err := natsTest.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naPubClient, publisher, err := natsTest.NewClientAndPublisher(config.MakePublisherConfig(t), fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) require.NotNil(t, publisher) defer func() { @@ -69,7 +69,7 @@ func TestProjection(t *testing.T) { defer pool.Release() naSubClient, subscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(pool.Submit), subscriber.WithUnmarshaler(utils.Unmarshal), ) diff --git a/resource-aggregate/service/aggregate_test.go b/resource-aggregate/service/aggregate_test.go index 91b19faf4..6ac332559 100644 --- a/resource-aggregate/service/aggregate_test.go +++ b/resource-aggregate/service/aggregate_test.go @@ -105,7 +105,7 @@ func TestAggregateHandlePublishResourceLinks(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -230,7 +230,7 @@ func TestAggregateHandleUnpublishResource(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -284,7 +284,7 @@ func TestAggregateHandleUnpublishAllResources(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -345,7 +345,7 @@ func TestAggregateHandleUnpublishResourceSubset(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -658,7 +658,7 @@ func TestAggregateHandleNotifyContentChanged(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/cancelDeviceMetadataUpdates_test.go b/resource-aggregate/service/cancelDeviceMetadataUpdates_test.go index d4cdf4088..4b1e8fa92 100644 --- a/resource-aggregate/service/cancelDeviceMetadataUpdates_test.go +++ b/resource-aggregate/service/cancelDeviceMetadataUpdates_test.go @@ -96,7 +96,7 @@ func TestAggregateHandleCancelPendingMetadataUpdates(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -219,7 +219,7 @@ func TestRequestHandlerCancelPendingMetadataUpdates(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/cancelResourceCommands_test.go b/resource-aggregate/service/cancelResourceCommands_test.go index 76bdd1ed0..781a80f67 100644 --- a/resource-aggregate/service/cancelResourceCommands_test.go +++ b/resource-aggregate/service/cancelResourceCommands_test.go @@ -169,7 +169,7 @@ func TestRequestHandlerCancelPendingCommands(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/confirmDeviceMetadataUpdate_test.go b/resource-aggregate/service/confirmDeviceMetadataUpdate_test.go index 52fa9987b..260e0aaa2 100644 --- a/resource-aggregate/service/confirmDeviceMetadataUpdate_test.go +++ b/resource-aggregate/service/confirmDeviceMetadataUpdate_test.go @@ -91,7 +91,7 @@ func TestAggregateHandleConfirmDeviceMetadataUpdate(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -194,7 +194,7 @@ func TestRequestHandlerConfirmDeviceMetadataUpdate(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/deleteDevices_test.go b/resource-aggregate/service/deleteDevices_test.go index 0f24b59dd..1a40bdbab 100644 --- a/resource-aggregate/service/deleteDevices_test.go +++ b/resource-aggregate/service/deleteDevices_test.go @@ -43,7 +43,7 @@ func TestRequestHandler_DeleteDevices(t *testing.T) { require.NoError(t, errC) _ = eventstore.Close(ctx) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/grpcApi_test.go b/resource-aggregate/service/grpcApi_test.go index 50ae3c3c7..2cb3f8a3f 100644 --- a/resource-aggregate/service/grpcApi_test.go +++ b/resource-aggregate/service/grpcApi_test.go @@ -120,7 +120,7 @@ func TestRequestHandlerPublishResource(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -239,7 +239,7 @@ func TestRequestHandlerUnpublishResource(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -328,7 +328,7 @@ func TestRequestHandlerNotifyResourceChanged(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -429,7 +429,7 @@ func TestRequestHandlerUpdateResourceContent(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -526,7 +526,7 @@ func TestRequestHandlerConfirmResourceUpdate(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -621,7 +621,7 @@ func TestRequestHandlerRetrieveResource(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -718,7 +718,7 @@ func TestRequestHandlerConfirmResourceRetrieve(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -816,7 +816,7 @@ func TestRequestHandlerDeleteResource(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -914,7 +914,7 @@ func TestRequestHandlerConfirmResourceDelete(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -1012,7 +1012,7 @@ func TestRequestHandlerCreateResource(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -1110,7 +1110,7 @@ func TestRequestHandlerConfirmResourceCreate(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/service.go b/resource-aggregate/service/service.go index f719e33e1..e07452c02 100644 --- a/resource-aggregate/service/service.go +++ b/resource-aggregate/service/service.go @@ -65,7 +65,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg logger.Errorf("error occurs during closing of connection to eventstore: %w", errC) } } - naClient, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger) + naClient, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger, tracerProvider) if err != nil { closeEventStore() otelClient.Close() @@ -155,7 +155,7 @@ func NewService(ctx context.Context, config Config, fileWatcher *fsnotify.Watche } grpcServer.AddCloseFunc(closeIsClient) - nats, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger) + nats, err := natsClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger, tracerProvider) if err != nil { return nil, closeGrpcServerOnError(fmt.Errorf("cannot create nats client: %w", err)) } diff --git a/resource-aggregate/service/updateDeviceMetadata_test.go b/resource-aggregate/service/updateDeviceMetadata_test.go index 3f8077b0d..490544a74 100644 --- a/resource-aggregate/service/updateDeviceMetadata_test.go +++ b/resource-aggregate/service/updateDeviceMetadata_test.go @@ -111,7 +111,7 @@ func TestAggregateHandleUpdateDeviceMetadata(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(cfg.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() @@ -241,7 +241,7 @@ func TestRequestHandlerUpdateDeviceMetadata(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-aggregate/service/updateServiceHeartbeat_test.go b/resource-aggregate/service/updateServiceHeartbeat_test.go index c472d8b6b..3803ddbbe 100644 --- a/resource-aggregate/service/updateServiceHeartbeat_test.go +++ b/resource-aggregate/service/updateServiceHeartbeat_test.go @@ -45,7 +45,7 @@ func TestNewServiceHeartbeat(t *testing.T) { errC := eventstore.Close(ctx) require.NoError(t, errC) }() - naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, publisher.WithMarshaler(utils.Marshal)) + naClient, publisher, err := natsTest.NewClientAndPublisher(config.Clients.Eventbus.NATS, fileWatcher, logger, noop.NewTracerProvider(), publisher.WithMarshaler(utils.Marshal)) require.NoError(t, err) defer func() { publisher.Close() diff --git a/resource-directory/service/deviceDirectory_test.go b/resource-directory/service/deviceDirectory_test.go index b43de0d09..13bbb3fdc 100644 --- a/resource-directory/service/deviceDirectory_test.go +++ b/resource-directory/service/deviceDirectory_test.go @@ -22,6 +22,7 @@ import ( "github.com/plgd-dev/hub/v2/test/config" cbor "github.com/plgd-dev/kit/v2/codec/cbor" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -123,7 +124,7 @@ func TestDeviceDirectoryGetDevices(t *testing.T) { pool, err := ants.NewPool(1) require.NoError(t, err) naClient, resourceSubscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(pool.Submit), subscriber.WithUnmarshaler(utils.Unmarshal), ) diff --git a/resource-directory/service/grpcApi.go b/resource-directory/service/grpcApi.go index 17b15f074..2fec87d7b 100644 --- a/resource-directory/service/grpcApi.go +++ b/resource-directory/service/grpcApi.go @@ -123,7 +123,7 @@ func newRequestHandlerFromConfig(ctx context.Context, config Config, publicConfi } }) - natsClient, err := naClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger) + natsClient, err := naClient.New(config.Clients.Eventbus.NATS.Config, fileWatcher, logger, tracerProvider) if err != nil { closeFunc.Execute() return nil, fmt.Errorf("cannot create nats client: %w", err) diff --git a/resource-directory/service/resourceDirectory_test.go b/resource-directory/service/resourceDirectory_test.go index 64e0a116c..d01b39cb0 100644 --- a/resource-directory/service/resourceDirectory_test.go +++ b/resource-directory/service/resourceDirectory_test.go @@ -21,6 +21,7 @@ import ( "github.com/plgd-dev/hub/v2/test/config" pbTest "github.com/plgd-dev/hub/v2/test/pb" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" ) @@ -63,7 +64,7 @@ func TestResourceDirectoryGetResourceLinks(t *testing.T) { pool, err := ants.NewPool(1) require.NoError(t, err) naClient, resourceSubscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(pool.Submit), subscriber.WithUnmarshaler(utils.Unmarshal), ) diff --git a/resource-directory/service/resourceShadow_test.go b/resource-directory/service/resourceShadow_test.go index 087278913..8269c28e0 100644 --- a/resource-directory/service/resourceShadow_test.go +++ b/resource-directory/service/resourceShadow_test.go @@ -20,6 +20,7 @@ import ( "github.com/plgd-dev/hub/v2/test" "github.com/plgd-dev/hub/v2/test/config" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" ) @@ -223,7 +224,7 @@ func TestResourceTwinGetResources(t *testing.T) { pool, err := ants.NewPool(1) require.NoError(t, err) naClient, resourceSubscriber, err := natsTest.NewClientAndSubscriber(config.MakeSubscriberConfig(), fileWatcher, - logger, + logger, noop.NewTracerProvider(), subscriber.WithGoPool(pool.Submit), subscriber.WithUnmarshaler(utils.Unmarshal), ) diff --git a/snippet-service/service/resourceSubscriber.go b/snippet-service/service/resourceSubscriber.go index a24b224a5..9064ccd1d 100644 --- a/snippet-service/service/resourceSubscriber.go +++ b/snippet-service/service/resourceSubscriber.go @@ -12,6 +12,7 @@ import ( "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventbus/nats/subscriber" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/utils" "github.com/plgd-dev/hub/v2/resource-aggregate/events" + "go.opentelemetry.io/otel/trace" ) type ResourceSubscriber struct { @@ -21,8 +22,8 @@ type ResourceSubscriber struct { observer eventbus.Observer } -func NewResourceSubscriber(ctx context.Context, config natsClient.ConfigSubscriber, subscriptionID string, fileWatcher *fsnotify.Watcher, logger log.Logger, handler eventbus.Handler) (*ResourceSubscriber, error) { - nats, err := natsClient.New(config.Config, fileWatcher, logger) +func NewResourceSubscriber(ctx context.Context, config natsClient.ConfigSubscriber, subscriptionID string, fileWatcher *fsnotify.Watcher, logger log.Logger, tp trace.TracerProvider, handler eventbus.Handler) (*ResourceSubscriber, error) { + nats, err := natsClient.New(config.Config, fileWatcher, logger, tp) if err != nil { return nil, fmt.Errorf("cannot create nats client: %w", err) } diff --git a/snippet-service/service/resourceSubscriber_test.go b/snippet-service/service/resourceSubscriber_test.go index f862209d8..6a5133260 100644 --- a/snippet-service/service/resourceSubscriber_test.go +++ b/snippet-service/service/resourceSubscriber_test.go @@ -19,6 +19,7 @@ import ( oauthTest "github.com/plgd-dev/hub/v2/test/oauth-server/test" hubTestService "github.com/plgd-dev/hub/v2/test/service" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -63,7 +64,7 @@ func TestResourceSubscriber(t *testing.T) { ch: make(chan *events.ResourceChanged, 8), } cfg := test.MakeConfig(t) - rs, err := service.NewResourceSubscriber(ctx, cfg.Clients.EventBus.NATS, cfg.Clients.EventBus.SubscriptionID, fileWatcher, logger, &h) + rs, err := service.NewResourceSubscriber(ctx, cfg.Clients.EventBus.NATS, cfg.Clients.EventBus.SubscriptionID, fileWatcher, logger, noop.NewTracerProvider(), &h) require.NoError(t, err) defer rs.Close() diff --git a/snippet-service/service/service.go b/snippet-service/service/service.go index 18af09662..d4c5d9cb0 100644 --- a/snippet-service/service/service.go +++ b/snippet-service/service/service.go @@ -130,7 +130,7 @@ func New(ctx context.Context, config Config, fileWatcher *fsnotify.Watcher, logg } }) - resourceSubscriber, err := NewResourceSubscriber(ctx, config.Clients.EventBus.NATS, config.Clients.EventBus.SubscriptionID, fileWatcher, logger, resourceUpdater.Load()) + resourceSubscriber, err := NewResourceSubscriber(ctx, config.Clients.EventBus.NATS, config.Clients.EventBus.SubscriptionID, fileWatcher, logger, tracerProvider, resourceUpdater.Load()) if err != nil { closerFn.Execute() return nil, fmt.Errorf("cannot create resource subscriber: %w", err) diff --git a/snippet-service/store/mongodb/store.go b/snippet-service/store/mongodb/store.go index 901266a50..cb1d4c98a 100644 --- a/snippet-service/store/mongodb/store.go +++ b/snippet-service/store/mongodb/store.go @@ -42,7 +42,7 @@ var deviceIDConfigurationIDUniqueIndex = mongo.IndexModel{ } func New(ctx context.Context, cfg *Config, fileWatcher *fsnotify.Watcher, logger log.Logger, tracerProvider trace.TracerProvider) (*Store, error) { - certManager, err := client.New(cfg.Mongo.TLS, fileWatcher, logger) + certManager, err := client.New(cfg.Mongo.TLS, fileWatcher, logger, tracerProvider) if err != nil { return nil, fmt.Errorf("could not create cert manager: %w", err) } diff --git a/test/config/config.go b/test/config/config.go index 507de2c58..8ba900223 100644 --- a/test/config/config.go +++ b/test/config/config.go @@ -19,7 +19,6 @@ import ( pkgMongo "github.com/plgd-dev/hub/v2/pkg/mongodb" grpcClient "github.com/plgd-dev/hub/v2/pkg/net/grpc/client" grpcServer "github.com/plgd-dev/hub/v2/pkg/net/grpc/server" - httpClient "github.com/plgd-dev/hub/v2/pkg/net/http/client" httpServer "github.com/plgd-dev/hub/v2/pkg/net/http/server" "github.com/plgd-dev/hub/v2/pkg/net/listener" otelClient "github.com/plgd-dev/hub/v2/pkg/opentelemetry/collector/client" @@ -28,6 +27,7 @@ import ( "github.com/plgd-dev/hub/v2/pkg/security/jwt/validator" "github.com/plgd-dev/hub/v2/pkg/security/oauth2" "github.com/plgd-dev/hub/v2/pkg/security/oauth2/oauth" + pkgTls "github.com/plgd-dev/hub/v2/pkg/security/tls" natsClient "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventbus/nats/client" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventstore/cqldb" "github.com/plgd-dev/hub/v2/resource-aggregate/cqrs/eventstore/mongodb" @@ -182,8 +182,8 @@ func MakeListenerConfig(address string) listener.Config { } } -func MakeHttpClientConfig() httpClient.Config { - return httpClient.Config{ +func MakeHttpClientConfig() pkgTls.HTTPConfig { + return pkgTls.HTTPConfig{ MaxIdleConns: 16, MaxConnsPerHost: 32, MaxIdleConnsPerHost: 16, diff --git a/test/sdk/client.go b/test/sdk/client.go index 598511952..02c35998c 100644 --- a/test/sdk/client.go +++ b/test/sdk/client.go @@ -54,8 +54,9 @@ type sdkConfig struct { key []byte } // TODO: replace by notBefore and notAfter - validFrom string // RFC3339, or relative time such as now-1m - validFor string // string parsable by time.ParseDuration + validFrom string // RFC3339, or relative time such as now-1m + validFor string // string parsable by time.ParseDuration + crlDistributionPoints []string useDeviceIDInQuery bool } @@ -101,6 +102,12 @@ func WithUseDeviceIDInQuery(useDeviceIDInQuery bool) Option { }) } +func WithCRLDistributionPoints(crlDistributionPoints []string) Option { + return optionFunc(func(cfg *sdkConfig) { + cfg.crlDistributionPoints = crlDistributionPoints + }) +} + func getSDKConfig(opts ...Option) (*sdkConfig, error) { c := &sdkConfig{ id: CertIdentity, @@ -167,13 +174,15 @@ func NewClient(opts ...Option) (*client.Client, error) { } devCfg := &client.DeviceOwnershipSDKConfig{ - ID: c.id, - Cert: string(identityIntermediateCA), - CertKey: string(identityIntermediateCAKey), - CreateSignerFunc: func(caCert []*x509.Certificate, caKey crypto.PrivateKey, validNotBefore time.Time, validNotAfter time.Time) core.CertificateSigner { - return certificateSigner.NewIdentityCertificateSigner(caCert, caKey, certificateSigner.WithNotBefore(validNotBefore), certificateSigner.WithNotAfter(validNotAfter)) + ID: c.id, + Cert: string(identityIntermediateCA), + CertKey: string(identityIntermediateCAKey), + ValidFrom: c.validFrom, + CRLDistributionPoints: c.crlDistributionPoints, + CreateSignerFunc: func(caCert []*x509.Certificate, caKey crypto.PrivateKey, validNotBefore, validNotAfter time.Time, crlDistributionPoints []string) core.CertificateSigner { + return certificateSigner.NewIdentityCertificateSigner(caCert, caKey, certificateSigner.WithNotBefore(validNotBefore), certificateSigner.WithNotAfter(validNotAfter), + certificateSigner.WithCRLDistributionPoints(crlDistributionPoints)) }, - ValidFrom: c.validFrom, } if c.validFor != "" { devCfg.CertExpiry = &c.validFor diff --git a/test/service/service.go b/test/service/service.go index 2dee6e893..7baa3769f 100644 --- a/test/service/service.go +++ b/test/service/service.go @@ -92,7 +92,7 @@ func ClearDB(ctx context.Context, t require.TestingT) { err = fileWatcher.Close() require.NoError(t, err) }() - certManager, err := cmClient.New(tlsConfig, fileWatcher, logger) + certManager, err := cmClient.New(tlsConfig, fileWatcher, logger, noop.NewTracerProvider()) require.NoError(t, err) defer certManager.Close() diff --git a/tools/mongodb/standby-tool/main.go b/tools/mongodb/standby-tool/main.go index c7062086f..0654cb8c7 100644 --- a/tools/mongodb/standby-tool/main.go +++ b/tools/mongodb/standby-tool/main.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.opentelemetry.io/otel/trace/noop" ) type StandbyConfig struct { @@ -144,7 +145,7 @@ func main() { } var certClient *client.CertManager if cfg.Clients.Storage.MongoDB.TLS.Enabled { - certClient, err = client.New(cfg.Clients.Storage.MongoDB.TLS.TLS, fileWatcher, logger) + certClient, err = client.New(cfg.Clients.Storage.MongoDB.TLS.TLS, fileWatcher, logger, noop.NewTracerProvider()) if err != nil { logger.Fatalf("cannot create cert client: %v", err) }