diff --git a/.gitignore b/.gitignore index 1a442fb..5dc581b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /measurements.json /build/ /quotes/ +/builder-cert.pem \ No newline at end of file diff --git a/cmd/attested-get/main.go b/cmd/attested-get/main.go index 60005fa..621d47a 100644 --- a/cmd/attested-get/main.go +++ b/cmd/attested-get/main.go @@ -29,7 +29,10 @@ import ( "github.com/flashbots/cvm-reverse-proxy/common" "github.com/flashbots/cvm-reverse-proxy/internal/atls" + azure_tdx "github.com/flashbots/cvm-reverse-proxy/internal/attestation/azure/tdx" + "github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements" "github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant" + "github.com/flashbots/cvm-reverse-proxy/internal/config" "github.com/flashbots/cvm-reverse-proxy/proxy" "github.com/urfave/cli/v2" // imports as package "cli" ) @@ -50,6 +53,11 @@ var flags []cli.Flag = []cli.Flag{ Value: "", Usage: "Output file for the response payload", }, + &cli.StringFlag{ + Name: "attestation-type", // TODO: Add support for other attestation types + Value: string(proxy.AttestationAzureTDX), + Usage: "type of attestation to present (currently only azure-tdx)", + }, &cli.BoolFlag{ Name: "log-debug", Value: false, @@ -75,6 +83,7 @@ func runClient(cCtx *cli.Context) (err error) { addr := cCtx.String("addr") outMeasurements := cCtx.String("out-measurements") outResponse := cCtx.String("out-response") + attestationTypeStr := cCtx.String("attestation-type") // Setup logging log := common.SetupLogger(&common.LoggingOpts{ @@ -88,7 +97,25 @@ func runClient(cCtx *cli.Context) (err error) { return errors.New("address needs to start with https://") } - log.Info("Getting verified measurements from " + addr + " ...") + // Create validators based on the attestation type + attestationType, err := proxy.ParseAttestationType(attestationTypeStr) + if err != nil { + log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help") + return err + } + + var validators []atls.Validator + switch attestationType { + case proxy.AttestationAzureTDX: + // Prepare an azure-tdx validator without any required measurements + attConfig := config.DefaultForAzureTDX() + attConfig.SetMeasurements(measurements.M{}) + validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log}) + validators = append(validators, validator) + default: + log.Error("currently only azure-tdx attestation is supported") + return errors.New("currently only azure-tdx attestation is supported") + } // Prepare aTLS stuff issuer, err := proxy.CreateAttestationIssuer(log, proxy.AttestationAzureTDX) @@ -97,24 +124,27 @@ func runClient(cCtx *cli.Context) (err error) { return err } - tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, []atls.Validator{}) + // Create the (a)TLS config + tlsConfig, err := atls.CreateAttestationClientTLSConfig(issuer, validators) if err != nil { log.Error("could not create atls config", "err", err) return err } - tr := &http.Transport{ + // Prepare the client + client := &http.Client{Transport: &http.Transport{ TLSClientConfig: tlsConfig, - } - client := &http.Client{Transport: tr} + }} + + // Execute the GET request + log.Info("Executing attested GET request to " + addr + " ...") resp, err := client.Get(addr) if err != nil { return err } - certs := resp.TLS.PeerCertificates // Extract the aTLS variant and measurements from the TLS connection - atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(certs, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()}) + atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(resp.TLS.PeerCertificates, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()}) if err != nil { log.Error("Error in getMeasurementsFromTLS", "err", err) return err @@ -130,11 +160,10 @@ func runClient(cCtx *cli.Context) (err error) { return errors.New("could not marshal measurement extracted from tls extension") } - log.Info("Variant: " + atlsVariant.String()) log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat))) fmt.Println(string(marshaledPcrs)) if outMeasurements != "" { - if err := os.WriteFile(outMeasurements, marshaledPcrs, 0644); err != nil { + if err := os.WriteFile(outMeasurements, marshaledPcrs, 0o644); err != nil { return err } } @@ -148,7 +177,7 @@ func runClient(cCtx *cli.Context) (err error) { log.Info(fmt.Sprintf("Response body with %d bytes:", len(msg))) fmt.Println(string(msg)) if outResponse != "" { - if err := os.WriteFile(outResponse, msg, 0644); err != nil { + if err := os.WriteFile(outResponse, msg, 0o644); err != nil { return err } }