Skip to content

Commit

Permalink
use HTTP GET request, and save to file
Browse files Browse the repository at this point in the history
  • Loading branch information
metachris committed Nov 15, 2024
1 parent 8bc40f3 commit e982121
Showing 1 changed file with 58 additions and 15 deletions.
73 changes: 58 additions & 15 deletions cmd/get-measurements/main.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
package main

//
// CLI tool to get and print verified measurements from an aTLS server.
// Make a HTTP GET request over a TEE-attested connection (to a server with aTLS support),
// and print the verified measurements and the response payload.
//
// Currently only works for Azure TDX but should be easy to expand.
//
// Usage:
//
// go run cmd/get-measurements/main.go instance_ip:port
// go run cmd/get-measurements/main.go --addr=https://instance_ip:port
//
// Can also save the verified measurements and the response body to files:
//
// go run cmd/get-measurements/main.go --addr=https://instance_ip:port --out-measurements=measurements.json --out-response=response.txt
//

import (
"crypto/tls"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"

"github.com/flashbots/cvm-reverse-proxy/common"
"github.com/flashbots/cvm-reverse-proxy/internal/atls"
Expand All @@ -28,6 +35,21 @@ import (
)

var flags []cli.Flag = []cli.Flag{
&cli.StringFlag{
Name: "addr",
Value: "https://localhost:7936",
Usage: "TEE server address",
},
&cli.StringFlag{
Name: "out-measurements",
Value: "",
Usage: "Output file for the measurements",
},
&cli.StringFlag{
Name: "out-response",
Value: "",
Usage: "Output file for the response payload",
},
&cli.BoolFlag{
Name: "log-debug",
Value: false,
Expand All @@ -48,8 +70,11 @@ func main() {
}
}

func runClient(cCtx *cli.Context) error {
func runClient(cCtx *cli.Context) (err error) {
logDebug := cCtx.Bool("log-debug")
addr := cCtx.String("addr")
outMeasurements := cCtx.String("out-measurements")
outResponse := cCtx.String("out-response")

// Setup logging
log := common.SetupLogger(&common.LoggingOpts{
Expand All @@ -59,14 +84,11 @@ func runClient(cCtx *cli.Context) error {
Version: common.Version,
})

addr := cCtx.Args().Get(0)
if addr == "" {
log.Error("Please provide an address as cli argument")
return errors.New("provide an address as argument")
if !strings.HasPrefix(addr, "https://") {
return errors.New("address needs to start with https://")
}

log.Info("Getting verified measurements from " + addr + " ...")

// Prepare aTLS stuff
serverAttestationType := proxy.AttestationAzureTDX
issuer, err := proxy.CreateAttestationIssuer(log, serverAttestationType)
Expand All @@ -87,16 +109,18 @@ func runClient(cCtx *cli.Context) error {
return err
}

// Open connection to the TDX server and verify the aTLS attestation
conn, err := tls.Dial("tcp", addr, tlsConfig)
tr := &http.Transport{
TLSClientConfig: tlsConfig,
}
client := &http.Client{Transport: tr}
resp, err := client.Get(addr)
if err != nil {
log.Error("Error in Dial", "err", err)
return err
}
defer conn.Close()
certs := resp.TLS.PeerCertificates

// Extract the aTLS variant and measurements from the TLS connection
certs := conn.ConnectionState().PeerCertificates
// certs := conn.ConnectionState().PeerCertificates
atlsVariant, extractedMeasurements, err := proxy.GetMeasurementsFromTLS(certs, []asn1.ObjectIdentifier{variant.AzureTDX{}.OID()})
if err != nil {
log.Error("Error in getMeasurementsFromTLS", "err", err)
Expand All @@ -114,8 +138,27 @@ func runClient(cCtx *cli.Context) error {
}

log.Info("Variant: " + atlsVariant.String())
// log.Info("Measurements", "measurements", string(marshaledPcrs))
log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat)))
fmt.Println(string(marshaledPcrs))
if outMeasurements != "" {
if err := os.WriteFile(outMeasurements, marshaledPcrs, 0644); err != nil {
return err
}
}

// Print the response body
msg, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

log.Info(fmt.Sprintf("Response body with %d bytes:", len(msg)))
fmt.Println(string(msg))
if outResponse != "" {
if err := os.WriteFile(outResponse, msg, 0644); err != nil {
return err
}
}

return nil
}

0 comments on commit e982121

Please sign in to comment.