Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimum set size threshold #36

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions pkg/bucket/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"context"
"errors"
"fmt"
"io"
"math/rand/v2"
"net/url"
"strings"

"optable-pair-cli/pkg/io"

"cloud.google.com/go/storage"
"github.com/rs/zerolog"
"gocloud.dev/blob/gcsblob"
Expand All @@ -17,7 +18,9 @@ import (
"google.golang.org/api/option"
)

const CompletedFile = ".Completed"
const (
CompletedFile = ".Completed"
)

type (
// BucketReadWriter contains the storage client and read writers for the source and destination buckets.
Expand Down Expand Up @@ -286,3 +289,81 @@ func blobFromObjectName(objectName string) string {
func shortHex() string {
return fmt.Sprintf("%08x", rand.Int32())
}

func BucketObjectAboveMinimumIDCount(ctx context.Context, downscopedToken string, srcBucketURL string, minThreshold int) (bool, error) {
if downscopedToken == "" {
return false, ErrTokenRequired
}

client, err := storage.NewClient(
ctx,
option.WithTokenSource(
oauth2.StaticTokenSource(
&oauth2.Token{
AccessToken: downscopedToken,
},
),
),
)
if err != nil {
return false, fmt.Errorf("failed to create storage client: %w", err)
}

srcPrefixedBucket, err := bucketFromObjectURL(srcBucketURL)
if err != nil {
return false, fmt.Errorf("failed to parse source URL: %w", err)
}

rc, err := newObjectReadCloser(ctx, client, srcPrefixedBucket)
if err != nil {
return false, fmt.Errorf("failed to create read writers: %w", err)
}

defer func() {
for _, c := range rc {
_ = c.Close()
}
}()

rs := make([]io.Reader, len(rc))
for i, r := range rc {
rs[i] = r
}

return io.ReadAboveCount(io.MultiReader(rs...), minThreshold)
}

// newObjectReadCloser lists the objects specified by the srcPrefixedBucket and opens a reader for each object,
// except for the .Completed file.
func newObjectReadCloser(ctx context.Context, client *storage.Client, srcPrefixedBucket *prefixedBucket) ([]io.ReadCloser, error) {
logger := zerolog.Ctx(ctx) //
query := &storage.Query{Prefix: srcPrefixedBucket.prefix + "/"}

srcBucket := client.Bucket(srcPrefixedBucket.bucket)

it := srcBucket.Objects(ctx, query)
var rc []io.ReadCloser

for {
obj, err := it.Next()
if errors.Is(err, iterator.Done) {
break
} else if err != nil {
logger.Debug().Err(err).Msgf("failed to list objects from source bucket %s", srcPrefixedBucket.prefix)
return nil, err
}

if strings.HasSuffix(obj.Name, CompletedFile) || strings.HasSuffix(obj.Name, "/") || obj.Size == 0 {
continue
}

reader, err := srcBucket.Object(obj.Name).NewReader(ctx)
if err != nil {
return nil, err
}

rc = append(rc, reader)
}

return rc, nil
}
22 changes: 22 additions & 0 deletions pkg/cmd/cli/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"github.com/rs/zerolog"
)

const PAIRIDMinimumThreshold = 1000

var ErrInputBelowThreshold = errors.New("input file does not contain enough IDs for a secure PAIR ID match")

type pairConfig struct {
downscopedToken string
threads int
Expand Down Expand Up @@ -73,6 +77,15 @@ func (c *pairConfig) hashEncryt(ctx context.Context, input string) error {
logger := zerolog.Ctx(ctx)
logger.Info().Msg("Step 1: Hash and encrypt the advertiser data.")

// check if the advertiser's PAIR IDs are above the threshold
ok, err := io.IsInputFileAboveCount(input, PAIRIDMinimumThreshold)
if err != nil {
return fmt.Errorf("io.IsInputFileAboveCount: %w", err)
}
if !ok {
return ErrInputBelowThreshold
}

fs, err := io.FileReaders(input)
if err != nil {
return fmt.Errorf("io.FileReaders: %w", err)
Expand Down Expand Up @@ -124,6 +137,15 @@ func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string)
logger := zerolog.Ctx(ctx)
logger.Info().Msg("Step 2: Re-encrypt the publisher's hashed and encrypted PAIR IDs.")

// check if the publisher's PAIR IDs are above the threshold
ok, err := bucket.BucketObjectAboveMinimumIDCount(ctx, c.downscopedToken, c.pubTwicePath, PAIRIDMinimumThreshold)
if err != nil {
return fmt.Errorf("bucket.NewBucket: %w", err)
}
if !ok {
return ErrInputBelowThreshold
}

// defer statements are executed in Last In First Out order, so we will write the completed file last.
bucketCompleter, err := bucket.NewBucketCompleter(ctx, c.downscopedToken, c.pubTriplePath)
if err != nil {
Expand Down
21 changes: 15 additions & 6 deletions pkg/cmd/cli/participate.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,28 @@ func (c *ParticipateCmd) Run(cli *CliContext) error {
return err
}

fs, err := io.FileReaders(c.Input)
if err != nil {
return fmt.Errorf("io.FileReaders: %w", err)
}
in := io.MultiReader(fs...)

// Allow testing with local files.
if !io.IsGCSBucketURL(c.Output) {
out, err := io.FileWriter(c.Output)
if err != nil {
return fmt.Errorf("io.FileWriter: %w", err)
}

// check if the advertiser's PAIR IDs are above the threshold
ok, err := io.IsInputFileAboveCount(c.Input, PAIRIDMinimumThreshold)
if err != nil {
return fmt.Errorf("io.IsInputFileAboveCount: %w", err)
}
if !ok {
return ErrInputBelowThreshold
}

fs, err := io.FileReaders(c.Input)
if err != nil {
return fmt.Errorf("io.FileReaders: %w", err)
}
in := io.MultiReader(fs...)

rw, err := pair.NewPAIRIDReadWriter(in, out)
if err != nil {
return fmt.Errorf("NewPAIRIDReadWriter: %w", err)
Expand Down
8 changes: 8 additions & 0 deletions pkg/cmd/cli/re-encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ func (c *ReEncryptCmd) Run(cli *CliContext) error {

// Allow testing with local files.
if !io.IsGCSBucketURL(c.Input) && !io.IsGCSBucketURL(c.Output) {
ok, err := io.IsInputFileAboveCount(c.Input, PAIRIDMinimumThreshold)
if err != nil {
return fmt.Errorf("io.IsInputFileAboveCount: %w", err)
}
if !ok {
return ErrInputBelowThreshold
}

in, err := io.FileReaders(c.Input)
if err != nil {
return fmt.Errorf("fileReaders: %w", err)
Expand Down
35 changes: 35 additions & 0 deletions pkg/io/io.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io

import (
"encoding/csv"
"errors"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -87,3 +89,36 @@ func IsGCSBucketURL(path string) bool {

return url.Scheme == gcsblob.Scheme
}

func IsInputFileAboveCount(path string, threshold int) (bool, error) {
fs, err := FileReaders(path)
if err != nil {
return false, fmt.Errorf("FileReaders: %w", err)
}

in := MultiReader(fs...)

return ReadAboveCount(in, threshold)
}

func ReadAboveCount(r io.Reader, threshold int) (bool, error) {
csvReader := csv.NewReader(r)

count := 0
for {
_, err := csvReader.Read()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return false, err
}

count++

if count > threshold {
return true, nil
}
}

return false, nil
}