diff --git a/cmd/attested-get/main.go b/cmd/attested-get/main.go index 621d47a..4521587 100644 --- a/cmd/attested-get/main.go +++ b/cmd/attested-get/main.go @@ -14,10 +14,13 @@ package main // // go run cmd/attested-get/main.go --addr=https://instance_ip:port --out-measurements=measurements.json --out-response=response.txt // +// You can also compare the resulting measurements with a list of expected measurements: +// +// go run cmd/get-measurements/main.go --addr=https://instance_ip:port --expected-measurements=measurements.json +// import ( "encoding/asn1" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -33,6 +36,7 @@ import ( "github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements" "github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant" "github.com/flashbots/cvm-reverse-proxy/internal/config" + "github.com/flashbots/cvm-reverse-proxy/multimeasurements" "github.com/flashbots/cvm-reverse-proxy/proxy" "github.com/urfave/cli/v2" // imports as package "cli" ) @@ -58,6 +62,11 @@ var flags []cli.Flag = []cli.Flag{ Value: string(proxy.AttestationAzureTDX), Usage: "type of attestation to present (currently only azure-tdx)", }, + &cli.StringFlag{ + Name: "expected-measurements", + Value: "", + Usage: "File or URL with known measurements (to compare against)", + }, &cli.BoolFlag{ Name: "log-debug", Value: false, @@ -84,6 +93,7 @@ func runClient(cCtx *cli.Context) (err error) { outMeasurements := cCtx.String("out-measurements") outResponse := cCtx.String("out-response") attestationTypeStr := cCtx.String("attestation-type") + expectedMeasurementsPath := cCtx.String("expected-measurements") // Setup logging log := common.SetupLogger(&common.LoggingOpts{ @@ -93,6 +103,7 @@ func runClient(cCtx *cli.Context) (err error) { Version: common.Version, }) + // Sanity-check addr if !strings.HasPrefix(addr, "https://") { return errors.New("address needs to start with https://") } @@ -117,6 +128,16 @@ func runClient(cCtx *cli.Context) (err error) { return errors.New("currently only azure-tdx attestation is supported") } + // Load expected measurements from file or URL (if provided) + var expectedMeasurements *multimeasurements.MultiMeasurements + if expectedMeasurementsPath != "" { + log.Info("Loading expected measurements from " + expectedMeasurementsPath + " ...") + expectedMeasurements, err = multimeasurements.New(expectedMeasurementsPath) + if err != nil { + return err + } + } + // Prepare aTLS stuff issuer, err := proxy.CreateAttestationIssuer(log, proxy.AttestationAzureTDX) if err != nil { @@ -150,17 +171,17 @@ func runClient(cCtx *cli.Context) (err error) { return err } - measurementsInHeaderFormat := make(map[uint32]string, len(extractedMeasurements)) - for pcr, value := range extractedMeasurements { - measurementsInHeaderFormat[pcr] = hex.EncodeToString(value) + printableMeasurements := make(map[uint32]string) + for k, v := range extractedMeasurements { + printableMeasurements[k] = fmt.Sprintf("%x", v) } - marshaledPcrs, err := json.MarshalIndent(measurementsInHeaderFormat, "", " ") + marshaledPcrs, err := json.MarshalIndent(printableMeasurements, "", " ") if err != nil { return errors.New("could not marshal measurement extracted from tls extension") } - log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(measurementsInHeaderFormat))) + log.Info(fmt.Sprintf("Measurements for %s with %d entries:", atlsVariant.String(), len(extractedMeasurements))) fmt.Println(string(marshaledPcrs)) if outMeasurements != "" { if err := os.WriteFile(outMeasurements, marshaledPcrs, 0o644); err != nil { @@ -168,6 +189,16 @@ func runClient(cCtx *cli.Context) (err error) { } } + // Compare against expected measurements + if expectedMeasurements != nil { + found, foundMeasurement := expectedMeasurements.Contains(extractedMeasurements) + if found { + log.With("matchedMeasurements", foundMeasurement.MeasurementID).Info("Measurements match expected measurements ✅") + } else { + log.Error("Measurements do not match expected measurements! ❌") + } + } + // Print the response body msg, err := io.ReadAll(resp.Body) if err != nil { diff --git a/measurements.json b/measurements.json index f2f93de..bc31e54 100644 --- a/measurements.json +++ b/measurements.json @@ -1,42 +1,50 @@ -{ - "azure-tdx-example": { - "11": { - "expected": "efa43e0beff151b0f251c4abf48152382b1452b4414dbd737b4127de05ca31f7" - }, - "12": { - "expected": "0000000000000000000000000000000000000000000000000000000000000000" - }, - "13": { - "expected": "0000000000000000000000000000000000000000000000000000000000000000" - }, - "15": { - "expected": "0000000000000000000000000000000000000000000000000000000000000000" - }, - "4": { - "expected": "ea92ff762767eae6316794f1641c485d4846bc2b9df2eab6ba7f630ce6f4d66f" - }, - "8": { - "expected": "0000000000000000000000000000000000000000000000000000000000000000" - }, - "9": { - "expected": "c9f429296634072d1063a03fb287bed0b2d177b0a504755ad9194cffd90b2489" - } - }, - "dcap-tdx-example": { - "0": { - "expected": "5d56080eb9ef8ce0bbaf6bdcdadeeb06e7c5b0a4d1ec16be868a85a953babe0c5e54d01c8e050a54fe1ca078372530d2" - }, - "1": { - "expected": "4216e925f796f4e282cfa6e72d4c77a80560987afa29155a61fdc33adb80eab0d4112abd52387e5e25a60deefb8a5287" - }, - "2": { - "expected": "4274fefb79092c164000b571b64ecb432fa2357adb421fd1c77a867168d7d7f7fe82796d1eba092c7bab35cf43f5ec55" - }, - "3": { - "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - }, - "4": { - "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - } - } -} +[ + { + "measurement_id": "azure-tdx-example-01", + "attestation_type": "azure-tdx", + "measurements": { + "4": { + "expected": "ea92ff762767eae6316794f1641c485d4846bc2b9df2eab6ba7f630ce6f4d66f" + }, + "9": { + "expected": "c9f429296634072d1063a03fb287bed0b2d177b0a504755ad9194cffd90b2489" + }, + "11": { + "expected": "efa43e0beff151b0f251c4abf48152382b1452b4414dbd737b4127de05ca31f7" + } + } + }, + { + "measurement_id": "cvm-image-azure-tdx.rootfs-20241107200854.wic.vhd", + "attestation_type": "azure-tdx", + "measurements": { + "4": { + "expected": "1b8cd655f5ebdf50bedabfb5db6b896a0a7c56de54f318103a2de1e7cea57b6b" + }, + "9": { + "expected": "992465f922102234c196f596fdaba86ea16eaa4c264dc425ec26bc2d1c364472" + } + } + }, + { + "measurement_id": "dcap-tdx-example-02", + "attestation_type": "dcap-tdx", + "measurements": { + "0": { + "expected": "5d56080eb9ef8ce0bbaf6bdcdadeeb06e7c5b0a4d1ec16be868a85a953babe0c5e54d01c8e050a54fe1ca078372530d2" + }, + "1": { + "expected": "4216e925f796f4e282cfa6e72d4c77a80560987afa29155a61fdc33adb80eab0d4112abd52387e5e25a60deefb8a5287" + }, + "2": { + "expected": "4274fefb79092c164000b571b64ecb432fa2357adb421fd1c77a867168d7d7f7fe82796d1eba092c7bab35cf43f5ec55" + }, + "3": { + "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + }, + "4": { + "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + } + } + } +] \ No newline at end of file diff --git a/multimeasurements/multimeasurements.go b/multimeasurements/multimeasurements.go new file mode 100644 index 0000000..ad7f348 --- /dev/null +++ b/multimeasurements/multimeasurements.go @@ -0,0 +1,94 @@ +// Package multimeasurements contains a helper to load a file with multiple measurements +// and compare provided measurements against them. +// +// Compatible with measurements data schema v2 (see measurements.json) as well as the +// legacy v1 schema. +package multimeasurements + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "os" + "strings" + + "github.com/flashbots/cvm-reverse-proxy/internal/attestation/measurements" +) + +// MultiMeasurements holds several known measurements, and can check if +// given measurements match known ones. +type MultiMeasurements struct { + Measurements []MeasurementsContainer +} + +type MeasurementsContainer struct { + MeasurementID string `json:"measurement_id"` + AttestationType string `json:"attestation_type"` + Measurements measurements.M `json:"measurements"` +} + +type LegacyMultiMeasurements map[string]measurements.M + +// New returns a MultiMeasurements instance, with the measurements +// loaded from a file or URL. +func New(path string) (m *MultiMeasurements, err error) { + var data []byte + if strings.HasPrefix(path, "http") { + // load from URL + resp, err := http.Get(path) + if err != nil { + return nil, err + } + defer resp.Body.Close() + data, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + } else { + // load from file + data, err = os.ReadFile(path) + if err != nil { + return nil, err + } + } + + m = &MultiMeasurements{} + + // Try to load the v2 data schema, if that fails fall back to legacy v1 schema + if err = json.Unmarshal(data, &m.Measurements); err != nil { + var legacyData LegacyMultiMeasurements + err = json.Unmarshal(data, &legacyData) + for measurementID, measurements := range legacyData { + container := MeasurementsContainer{ + MeasurementID: measurementID, + AttestationType: "azure-tdx", + Measurements: measurements, + } + m.Measurements = append(m.Measurements, container) + } + } + + return m, err +} + +// Contains checks if the provided measurements match one of the known measurements. Any keys in the provided +// measurements which are not in the known measurements are ignored. +func (m *MultiMeasurements) Contains(measurements map[uint32][]byte) (found bool, foundMeasurement *MeasurementsContainer) { + // For every known container, all known measurements match (and additional ones are ignored) + for _, container := range m.Measurements { + allMatch := true + for key, value := range container.Measurements { + if !bytes.Equal(value.Expected, measurements[key]) { + allMatch = false + break + } + } + + if allMatch { + return true, &container + } + } + + return false, nil +} diff --git a/multimeasurements/multimeasurements_test.go b/multimeasurements/multimeasurements_test.go new file mode 100644 index 0000000..8cc6a50 --- /dev/null +++ b/multimeasurements/multimeasurements_test.go @@ -0,0 +1,94 @@ +package multimeasurements + +import ( + "encoding/hex" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestMeasurements is kept simple: map[pcr]measurement +type TestMeasurements map[uint32][]byte + +func mustBytesFromHex(hexValue string) []byte { + bytes, err := hex.DecodeString(hexValue) + if err != nil { + panic(err) + } + return bytes +} + +// Measurements V1 (legacy) JSON (from https://github.com/flashbots/cvm-reverse-proxy/blob/837588b9f87ee49d1bb6dca4712a1c2844eb1ecc/measurements.json) +var measurementsV1JSON = []byte(`{"azure-tdx-example":{"11":{"expected":"efa43e0beff151b0f251c4abf48152382b1452b4414dbd737b4127de05ca31f7"},"12":{"expected":"0000000000000000000000000000000000000000000000000000000000000000"},"13":{"expected":"0000000000000000000000000000000000000000000000000000000000000000"},"15":{"expected":"0000000000000000000000000000000000000000000000000000000000000000"},"4":{"expected":"ea92ff762767eae6316794f1641c485d4846bc2b9df2eab6ba7f630ce6f4d66f"},"8":{"expected":"0000000000000000000000000000000000000000000000000000000000000000"},"9":{"expected":"c9f429296634072d1063a03fb287bed0b2d177b0a504755ad9194cffd90b2489"}},"dcap-tdx-example":{"0":{"expected":"5d56080eb9ef8ce0bbaf6bdcdadeeb06e7c5b0a4d1ec16be868a85a953babe0c5e54d01c8e050a54fe1ca078372530d2"},"1":{"expected":"4216e925f796f4e282cfa6e72d4c77a80560987afa29155a61fdc33adb80eab0d4112abd52387e5e25a60deefb8a5287"},"2":{"expected":"4274fefb79092c164000b571b64ecb432fa2357adb421fd1c77a867168d7d7f7fe82796d1eba092c7bab35cf43f5ec55"},"3":{"expected":"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},"4":{"expected":"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}}}`) + +// TestMultiMeasurementsV2 tests the v2 data schema +func TestMultiMeasurementsV2(t *testing.T) { + // Load expected measurements from JSON file (in V2 format) + m, err := New("../measurements.json") + require.NoError(t, err) + require.Len(t, m.Measurements, 3) + + // Setup test measurements (matching cvm-image-azure-tdx.rootfs-20241107200854.wic.vhd) + testMeasurements := TestMeasurements{ + 4: mustBytesFromHex("1b8cd655f5ebdf50bedabfb5db6b896a0a7c56de54f318103a2de1e7cea57b6b"), + 9: mustBytesFromHex("992465f922102234c196f596fdaba86ea16eaa4c264dc425ec26bc2d1c364472"), + } + + // Ensure matching entries works, and that additional fields are ignored + testMeasurements[11] = testMeasurements[4] + exists, foundMeasurement := m.Contains(testMeasurements) + require.True(t, exists) + require.Equal(t, "cvm-image-azure-tdx.rootfs-20241107200854.wic.vhd", foundMeasurement.MeasurementID) + require.Equal(t, "azure-tdx", foundMeasurement.AttestationType) + + // Ensure check fails with a missing required key + delete(testMeasurements, 4) + exists, _ = m.Contains(testMeasurements) + require.False(t, exists) + + // Double-check it works again + testMeasurements[4] = testMeasurements[11] + exists, _ = m.Contains(testMeasurements) + require.True(t, exists) + + // Any changed value should make it fail + testMeasurements[4] = testMeasurements[9] + exists, _ = m.Contains(testMeasurements) + require.False(t, exists) + + // Check for another set of known measurements (dcap-tdx-example) + testMeasurements = TestMeasurements{ + 0: mustBytesFromHex("5d56080eb9ef8ce0bbaf6bdcdadeeb06e7c5b0a4d1ec16be868a85a953babe0c5e54d01c8e050a54fe1ca078372530d2"), + 1: mustBytesFromHex("4216e925f796f4e282cfa6e72d4c77a80560987afa29155a61fdc33adb80eab0d4112abd52387e5e25a60deefb8a5287"), + 2: mustBytesFromHex("4274fefb79092c164000b571b64ecb432fa2357adb421fd1c77a867168d7d7f7fe82796d1eba092c7bab35cf43f5ec55"), + 3: mustBytesFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), + 4: mustBytesFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), + } + exists, foundMeasurement = m.Contains(testMeasurements) + require.True(t, exists) + require.Equal(t, "dcap-tdx-example-02", foundMeasurement.MeasurementID) +} + +func TestMultiMeasurementsV1(t *testing.T) { + tempDir := t.TempDir() + err := os.WriteFile(filepath.Join(tempDir, "measurements.json"), measurementsV1JSON, 0644) + require.NoError(t, err) + + // Load expected measurements from JSON file + m, err := New(filepath.Join(tempDir, "measurements.json")) + require.NoError(t, err) + require.Len(t, m.Measurements, 2) + + testMeasurements := TestMeasurements{ + 0: mustBytesFromHex("5d56080eb9ef8ce0bbaf6bdcdadeeb06e7c5b0a4d1ec16be868a85a953babe0c5e54d01c8e050a54fe1ca078372530d2"), + 1: mustBytesFromHex("4216e925f796f4e282cfa6e72d4c77a80560987afa29155a61fdc33adb80eab0d4112abd52387e5e25a60deefb8a5287"), + 2: mustBytesFromHex("4274fefb79092c164000b571b64ecb432fa2357adb421fd1c77a867168d7d7f7fe82796d1eba092c7bab35cf43f5ec55"), + 3: mustBytesFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), + 4: mustBytesFromHex("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), + } + exists, foundMeasurement := m.Contains(testMeasurements) + require.True(t, exists) + require.Equal(t, "dcap-tdx-example", foundMeasurement.MeasurementID) +} diff --git a/proxy-server.dockerfile b/proxy-server.dockerfile index 771a699..a2d93fc 100644 --- a/proxy-server.dockerfile +++ b/proxy-server.dockerfile @@ -18,7 +18,7 @@ FROM alpine:latest WORKDIR /app COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=builder /build/proxy-server /app/proxy-server -RUN echo "{\"azure-tdx\": {}}" > /app/measurements-empty.json +RUN echo "[{}]" > /app/measurements-empty.json ENV LISTEN_ADDR=":8080" EXPOSE 8080 CMD ["/app/proxy-server"] diff --git a/proxy/atls_config.go b/proxy/atls_config.go index 2f73613..eee5fb2 100644 --- a/proxy/atls_config.go +++ b/proxy/atls_config.go @@ -16,6 +16,7 @@ import ( "github.com/flashbots/cvm-reverse-proxy/internal/attestation/variant" "github.com/flashbots/cvm-reverse-proxy/internal/cloud/cloudprovider" "github.com/flashbots/cvm-reverse-proxy/internal/config" + "github.com/flashbots/cvm-reverse-proxy/multimeasurements" dcap_tdx "github.com/flashbots/cvm-reverse-proxy/tdx" ) @@ -65,7 +66,7 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy return nil, err } - parsedMeasurements := make(map[string]measurements.M) + var parsedMeasurements []multimeasurements.MeasurementsContainer err = json.Unmarshal(jsonMeasurements, &parsedMeasurements) if err != nil { return nil, err @@ -76,7 +77,7 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy validators := []atls.Validator{} for _, measurement := range parsedMeasurements { attConfig := config.DefaultForAzureTDX() - attConfig.SetMeasurements(measurement) + attConfig.SetMeasurements(measurement.Measurements) validators = append(validators, azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log})) } return []atls.Validator{NewMultiValidator(validators)}, nil @@ -84,7 +85,7 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy validators := []atls.Validator{} for _, measurement := range parsedMeasurements { attConfig := &config.QEMUTDX{Measurements: measurements.DefaultsFor(cloudprovider.QEMU, variant.QEMUTDX{})} - attConfig.SetMeasurements(measurement) + attConfig.SetMeasurements(measurement.Measurements) validators = append(validators, dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log})) } return []atls.Validator{NewMultiValidator(validators)}, nil