diff --git a/go.mod b/go.mod index 62a871a..478330a 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gtank/ristretto255 v0.1.2 github.com/optable/match v1.4.0 - github.com/optable/match-api/v2 v2.6.0 + github.com/optable/match-api/v2 v2.7.0 github.com/rs/zerolog v1.33.0 gocloud.dev v0.39.0 golang.org/x/oauth2 v0.22.0 diff --git a/go.sum b/go.sum index 18cbc3b..d5db78d 100644 --- a/go.sum +++ b/go.sum @@ -138,8 +138,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/optable/match v1.4.0 h1:kyj1ty6qFIRVFsB6zTJab0RF3Duq9xqPIdld7+4IDa4= github.com/optable/match v1.4.0/go.mod h1:l8DT0v6TfmIT53vBbEAp+W0EFAxJ22NIEeJDz0z3WDM= -github.com/optable/match-api/v2 v2.6.0 h1:MHZD5JWjwu7evHong9m3deyHi2hDzk77odMNtG33xbk= -github.com/optable/match-api/v2 v2.6.0/go.mod h1:b4eo6B06BE4goiWwhJ3bNl1BTuMF6hIZdGEhbRgdEkI= +github.com/optable/match-api/v2 v2.7.0 h1:fn4Qhrg9CoapikvrfpXhphoe03HipPnwju47c/89UpM= +github.com/optable/match-api/v2 v2.7.0/go.mod h1:b4eo6B06BE4goiWwhJ3bNl1BTuMF6hIZdGEhbRgdEkI= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/pkg/bucket/bucket.go b/pkg/bucket/bucket.go index 95b2fd0..9575df7 100644 --- a/pkg/bucket/bucket.go +++ b/pkg/bucket/bucket.go @@ -118,6 +118,16 @@ func (b *BucketCompleter) Complete(ctx context.Context) error { return b.client.Close() } +// Checks if the .Completed file exists in the destination bucket. +func (b *BucketCompleter) HasCompleted(ctx context.Context) (bool, error) { + dstBucket := b.client.Bucket(b.dstPrefixedBucket.bucket) + _, err := dstBucket.Object(fmt.Sprintf("%s/%s", b.dstPrefixedBucket.prefix, CompletedFile)).Attrs(ctx) + if errors.Is(err, storage.ErrObjectNotExist) { + return false, nil + } + return err == nil, err +} + // NewBucketReadWriter creates a new Bucket object and opens readers and writers for the specified source and destination URLs. // Caller needs to call Close() on the returned Bucket object to release resources. func NewBucketReadWriter(ctx context.Context, downscopedToken string, dstURL string, opts ...BucketOption) (*BucketReadWriter, error) { diff --git a/pkg/cmd/cli/base.go b/pkg/cmd/cli/base.go index 04a401b..4fadfda 100644 --- a/pkg/cmd/cli/base.go +++ b/pkg/cmd/cli/base.go @@ -14,12 +14,9 @@ type CliContext struct { type ( CleanroomCmd struct { - Get GetCmd `cmd:"" help:"Get the current status and configuration associated with the specified Optable PAIR clean room."` - Participate ParticipateCmd `cmd:"" hidden:"" help:"Participate in the PAIR operation by contributing advertiser hashed and encrypted data."` - ReEncrypt ReEncryptCmd `cmd:"" hidden:"" help:"Re-encrypt publisher's PAIR IDs with the advertiser key."` - Match MatchCmd `cmd:"" hidden:"" help:"Match publisher's PAIR IDs with advertiser's PAIR IDs."` - Run RunCmd `cmd:"" help:"As the advertiser clean room, run the PAIR match protocol with the publisher that has invited you to the specified Optable PAIR clean room."` - Decrypt DecryptCmd `cmd:"" help:"Decrypt a list of previously matched triple encrypted PAIR IDs using the advertiser clean room's private key."` + Get GetCmd `cmd:"" help:"Get the current status and configuration associated with the specified Optable PAIR clean room."` + Run RunCmd `cmd:"" help:"As the advertiser clean room, run the PAIR match protocol with the publisher that has invited you to the specified Optable PAIR clean room."` + Decrypt DecryptCmd `cmd:"" help:"Decrypt a list of previously matched triple encrypted PAIR IDs using the advertiser clean room's private key."` } KeyCmd struct { diff --git a/pkg/cmd/cli/match.go b/pkg/cmd/cli/match.go deleted file mode 100644 index f644950..0000000 --- a/pkg/cmd/cli/match.go +++ /dev/null @@ -1,79 +0,0 @@ -package cli - -import ( - "fmt" - "os" - - "optable-pair-cli/pkg/io" - "optable-pair-cli/pkg/pair" -) - -type ( - MatchCmd struct { - PairCleanroomToken string `arg:"" help:"The PAIR clean room token to use for the operation."` - AdvertiserInput string `cmd:"" short:"a" help:"If given a file path, it will read from the file. If not provided, it will read from the GCS path specified from the token."` - PublisherInput string `cmd:"" short:"p" help:"If given a file path, it will read from the file. If not provided, it will read from the GCS path specified from the token."` - OutputDir string `cmd:"" short:"o" help:"The output directory path to write the decrypted and matched double encrypted PAIR IDs. Each thread will write one single file in the given directory path. If none are provided, all matched and decrypted PAIR IDs will be written to stdout."` - NumThreads int `cmd:"" short:"n" help:"The number of threads to use for the operation. Defaults to the number of the available cores on the machine."` - PublisherPAIRIDs string `cmd:"" short:"s" name:"publisher-pair-ids" help:"Use the publisher's PAIR IDs from a path."` - } -) - -func (c *MatchCmd) Help() string { - return ` -This operation produces the match rate of this PAIR clean room operation, -and output the list of decrypted and matched PAIR IDs. -` -} - -func (c *MatchCmd) Run(cli *CliContext) error { - ctx := cli.Context() - - advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config) - if err != nil { - return fmt.Errorf("ReadKeyConfig: %w", err) - } - if c.NumThreads <= 0 { - c.NumThreads = defaultThreadCount - } - pairCfg, err := NewPAIRConfig(ctx, c.PairCleanroomToken, c.NumThreads, advertiserKey) - if err != nil { - return err - } - - // Allow testing with local files. - if c.AdvertiserInput != "" && c.PublisherInput != "" && !io.IsGCSBucketURL(c.AdvertiserInput) && !io.IsGCSBucketURL(c.PublisherInput) { - adv, err := io.FileReaders(c.AdvertiserInput) - if err != nil { - return fmt.Errorf("fileReaders: %w", err) - } - - pub, err := io.FileReaders(c.PublisherInput) - if err != nil { - return fmt.Errorf("fileWriters: %w", err) - } - - if c.OutputDir != "" { - if err := os.MkdirAll(c.OutputDir, os.ModePerm); err != nil { - return fmt.Errorf("os.MkdirAll: %w", err) - } - } - - matcher, err := pair.NewMatcher(adv, pub, c.OutputDir) - if err != nil { - return fmt.Errorf("pair.NewMatcher: %w", err) - } - - return matcher.Match(ctx, c.NumThreads, pairCfg.salt, pairCfg.key) - } - - return pairCfg.match(ctx, c.OutputDir, c.PublisherPAIRIDs) -} - -func readersFromReadClosers(rs []io.ReadCloser) []io.Reader { - readers := make([]io.Reader, len(rs)) - for i, r := range rs { - readers[i] = r - } - return readers -} diff --git a/pkg/cmd/cli/pair.go b/pkg/cmd/cli/pair.go index 65d8144..49bf29f 100644 --- a/pkg/cmd/cli/pair.go +++ b/pkg/cmd/cli/pair.go @@ -84,6 +84,15 @@ func (c *pairConfig) hashEncryt(ctx context.Context, input string) (err error) { if err != nil { return fmt.Errorf("bucket.NewBucketCompleter: %w", err) } + hasCompleted, err := bucketCompleter.HasCompleted(ctx) + if err != nil { + return fmt.Errorf("bucketCompleter.HasCompleted: %w", err) + } + if hasCompleted { + // nothing to do if the advertiser data has pushed the data + return nil + } + defer func() { // don't complete the bucket if there was an error to prevent writing // unwanted files. @@ -95,6 +104,7 @@ func (c *pairConfig) hashEncryt(ctx context.Context, input string) (err error) { logger.Error().Err(err).Msg("failed to write .Completed file to bucket") return } + }() b, err := bucket.NewBucketReadWriter(ctx, c.downscopedToken, c.advTwicePath, bucket.WithReader(in)) @@ -141,6 +151,15 @@ func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string) if err != nil { return fmt.Errorf("bucket.NewBucketCompleter: %w", err) } + + hasCompleted, err := bucketCompleter.HasCompleted(ctx) + if err != nil { + return fmt.Errorf("bucketCompleter.HasCompleted: %w", err) + } + if hasCompleted { + // nothing to do if the advertiser data has pushed the data + return nil + } defer func() { // don't complete the bucket if there was an error to prevent writing // unwanted files. @@ -251,3 +270,11 @@ func (c *pairConfig) match(ctx context.Context, outputPath string, publisherPAIR return nil } + +func readersFromReadClosers(rs []io.ReadCloser) []io.Reader { + readers := make([]io.Reader, len(rs)) + for i, r := range rs { + readers[i] = r + } + return readers +} diff --git a/pkg/cmd/cli/participate.go b/pkg/cmd/cli/participate.go deleted file mode 100644 index c75cfcf..0000000 --- a/pkg/cmd/cli/participate.go +++ /dev/null @@ -1,57 +0,0 @@ -package cli - -import ( - "fmt" - - "optable-pair-cli/pkg/io" - "optable-pair-cli/pkg/pair" -) - -type ( - ParticipateCmd struct { - PairCleanroomToken string `arg:"" help:"The PAIR clean room token to use for the operation."` - Input string `cmd:"" short:"i" help:"The input file containing the advertiser data to be hashed and encrypted. If given a directory, all files in the directory will be processed."` - Output string `cmd:"" short:"o" help:"The output file to write the advertiser data to, default to stdout."` - NumThreads int `cmd:"" short:"n" help:"The number of threads to use for the operation. Defaults to the number of the available cores on the machine."` - } -) - -func (c *ParticipateCmd) Run(cli *CliContext) error { - ctx := cli.Context() - - advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config) - if err != nil { - return fmt.Errorf("ReadKeyConfig: %w", err) - } - if c.NumThreads <= 0 { - c.NumThreads = defaultThreadCount - } - // instantiate pair config - pairCfg, err := NewPAIRConfig(ctx, c.PairCleanroomToken, c.NumThreads, advertiserKey) - if err != nil { - return err - } - - fs, err := io.FileReaders(c.Input) - if err != nil { - return fmt.Errorf("io.FileReaders: %w", err) - } - in := io.MultiReader(fs...) - - // Allow testing with local files. - if !io.IsGCSBucketURL(c.Output) { - out, err := io.FileWriter(c.Output) - if err != nil { - return fmt.Errorf("io.FileWriter: %w", err) - } - - rw, err := pair.NewPAIRIDReadWriter(in, out) - if err != nil { - return fmt.Errorf("NewPAIRIDReadWriter: %w", err) - } - - return rw.HashEncrypt(ctx, c.NumThreads, pairCfg.salt, pairCfg.key) - } - - return pairCfg.hashEncryt(ctx, c.Input) -} diff --git a/pkg/cmd/cli/re-encrypt.go b/pkg/cmd/cli/re-encrypt.go deleted file mode 100644 index 02052dc..0000000 --- a/pkg/cmd/cli/re-encrypt.go +++ /dev/null @@ -1,73 +0,0 @@ -package cli - -import ( - "fmt" - "os" - - "optable-pair-cli/pkg/io" - "optable-pair-cli/pkg/pair" -) - -type ( - ReEncryptCmd struct { - PairCleanroomToken string `arg:"" help:"The PAIR clean room token to use for the operation."` - Input string `cmd:"" short:"i" help:"The GCS bucket URL containing objects of publisher's encrypted PAIR IDs. If given a file path, it will read from the file instead. If not provided, it will read from stdin."` - Output string `cmd:"" short:"o" help:"The GCS bucket URL to write the re-encrypted publisher PAIR IDs to. If given a file path, it will write to the file instead. If not provided, it will write to stdout."` - NumThreads int `cmd:"" short:"n" help:"The number of threads to use for the operation. Defaults to the number of the available cores on the machine."` - PublisherPAIRIDs string `cmd:"" short:"s" name:"publisher-pair-ids" help:"Save the publisher's PAIR IDs in the provided directory, to be used later. If not provided, the publisher's PAIR IDs will not be saved."` - } -) - -func (c *ReEncryptCmd) Run(cli *CliContext) error { - ctx := cli.Context() - - advertiserKey, err := ReadKeyConfig(cli.keyContext, cli.config) - if err != nil { - return fmt.Errorf("ReadKeyConfig: %w", err) - } - if c.NumThreads <= 0 { - c.NumThreads = defaultThreadCount - } - // instantiate pair config - pairCfg, err := NewPAIRConfig(ctx, c.PairCleanroomToken, c.NumThreads, advertiserKey) - if err != nil { - return err - } - - // Allow testing with local files. - if !io.IsGCSBucketURL(c.Input) && !io.IsGCSBucketURL(c.Output) { - in, err := io.FileReaders(c.Input) - if err != nil { - return fmt.Errorf("fileReaders: %w", err) - } - - out, err := io.FileWriter(c.Output) - if err != nil { - return fmt.Errorf("fileWriters: %w", err) - } - - opts := []pair.ReadWriterOption{} - if c.PublisherPAIRIDs != "" { - // create the publisher data directory if it does not exist - if err := os.MkdirAll(c.PublisherPAIRIDs, os.ModePerm); err != nil { - return fmt.Errorf("io.CreateDirectory: %w", err) - } - - w, err := io.FileWriter(fmt.Sprintf("%s/pair_ids.csv", c.PublisherPAIRIDs)) - if err != nil { - return fmt.Errorf("fileWriters: %w", err) - } - - opts = append(opts, pair.WithSecondaryWriter(w)) - } - - rw, err := pair.NewPAIRIDReadWriter(io.MultiReader(in...), out, opts...) - if err != nil { - return fmt.Errorf("pair.NewPAIRIDReadWriter: %w", err) - } - - return rw.ReEncrypt(ctx, c.NumThreads, pairCfg.salt, pairCfg.key) - } - - return pairCfg.reEncrypt(ctx, c.PublisherPAIRIDs) -} diff --git a/pkg/cmd/cli/run.go b/pkg/cmd/cli/run.go index 0c99f14..b26da7b 100644 --- a/pkg/cmd/cli/run.go +++ b/pkg/cmd/cli/run.go @@ -163,11 +163,19 @@ func startFromStepOne(ctx context.Context, pairCfg *pairConfig, input, output st return fmt.Errorf("hashEncryt: %w", err) } + if _, err := pairCfg.cleanroomClient.AdvanceAdvertiserState(ctx); err != nil { + return fmt.Errorf("failed to advance advertiser state: %w", err) + } + // Step 2. Re-encrypt the publisher's hashed and encrypted PAIR IDs and output to pubTriplePath. if err := pairCfg.reEncrypt(ctx, publisherData); err != nil { return fmt.Errorf("reEncrypt: %w", err) } + if _, err := pairCfg.cleanroomClient.AdvanceAdvertiserState(ctx); err != nil { + return fmt.Errorf("failed to advance advertiser state: %w", err) + } + if output == "" { return nil } @@ -182,6 +190,10 @@ func startFromStepTwo(ctx context.Context, pairCfg *pairConfig, output string, p return fmt.Errorf("reEncrypt: %w", err) } + if _, err := pairCfg.cleanroomClient.AdvanceAdvertiserState(ctx); err != nil { + return fmt.Errorf("failed to advance advertiser state: %w", err) + } + if output == "" { return nil } diff --git a/pkg/internal/client.go b/pkg/internal/client.go index d3afa87..1cd41df 100644 --- a/pkg/internal/client.go +++ b/pkg/internal/client.go @@ -65,6 +65,14 @@ func (c *CleanroomClient) RefreshToken(ctx context.Context) (*v1.Cleanroom, erro return c.do(ctx, req) } +func (c *CleanroomClient) AdvanceAdvertiserState(ctx context.Context) (*v1.Cleanroom, error) { + req := &v1.AdvanceCleanroomAdvertiserStateRequest{ + Name: c.cleanroomName, + } + + return c.do(ctx, req) +} + func (c *CleanroomClient) GetDownScopedToken(ctx context.Context) (string, error) { cleanroom, err := c.GetCleanroom(ctx, true) if err != nil { @@ -159,6 +167,8 @@ func (c *CleanroomClient) do(ctx context.Context, req proto.Message) (*v1.Cleanr path = "/admin/api/external/v1/cleanroom/get" case *v1.RefreshTokenRequest: path = "/admin/api/external/v1/cleanroom/refresh-token" + case *v1.AdvanceCleanroomAdvertiserStateRequest: + path = "/admin/api/external/v1/cleanroom/advance-advertiser-state" default: return nil, fmt.Errorf("unknown request type") }