Skip to content

Commit

Permalink
export GetMeasurementsFromTLS and make independent of proxy (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
metachris authored Nov 18, 2024
1 parent 837588b commit 04abe8c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion proxy/atls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (

"github.com/flashbots/cvm-reverse-proxy/internal/atls"
azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx"
dcap_tdx "github.com/flashbots/cvm-reverse-proxy/tdx"
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements"
"github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant"
"github.com/flashbots/cvm-reverse-proxy/internal/cloud/cloudprovider"
"github.com/flashbots/cvm-reverse-proxy/internal/config"
dcap_tdx "github.com/flashbots/cvm-reverse-proxy/tdx"
)

type AttestationType string
Expand Down
10 changes: 6 additions & 4 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
Expand Down Expand Up @@ -108,12 +109,12 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.log.With("duration", duration).Info("[proxy-request] proxying complete")
}

func (p *Proxy) getMeasurementsFromTLS(conn *tls.ConnectionState) (atlsVariant variant.Variant, measurements map[uint32][]byte, err error) {
func GetMeasurementsFromTLS(certs []*x509.Certificate, validatorOIDs []asn1.ObjectIdentifier) (atlsVariant variant.Variant, measurements map[uint32][]byte, err error) {
// In verifyEmbeddedReport which is used to validate the extensions, only the first matching extension is validated! Refuse to accept multiple
var ATLSExtension *pkix.Extension = nil
for _, cert := range conn.PeerCertificates {
for _, cert := range certs {
for _, ext := range cert.Extensions {
for _, validatorOID := range p.validatorOIDs {
for _, validatorOID := range validatorOIDs {
if ext.Id.Equal(validatorOID) {
if ATLSExtension != nil {
return nil, nil, errors.New("more than one ATLS extension provided, refusing to continue")
Expand Down Expand Up @@ -142,7 +143,8 @@ func (p *Proxy) getMeasurementsFromTLS(conn *tls.ConnectionState) (atlsVariant v
}

func (p *Proxy) copyMeasurementsToHeader(conn *tls.ConnectionState, header *http.Header) (int, error) {
atlsVariant, extractedMeasurements, err := p.getMeasurementsFromTLS(conn)
certs := conn.PeerCertificates
atlsVariant, extractedMeasurements, err := GetMeasurementsFromTLS(certs, p.validatorOIDs)
if err != nil {
return http.StatusTeapot, err
} else if extractedMeasurements == nil {
Expand Down

0 comments on commit 04abe8c

Please sign in to comment.