Skip to content

Commit

Permalink
Address second iteration of comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tneymanov committed Feb 10, 2020
1 parent 5dd0a66 commit d7c4271
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 19 deletions.
3 changes: 2 additions & 1 deletion gcp_variant_transforms/options/variant_transform_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def add_arguments(self, parser):
parser.add_argument(
'--num_bigquery_write_shards',
type=int, default=1,
help=('This flag is deprecated and may be removed in future releases.'))
help=('This flag is deprecated and will be removed in future '
'releases.'))
parser.add_argument(
'--null_numeric_value_replacement',
type=int,
Expand Down
16 changes: 13 additions & 3 deletions gcp_variant_transforms/pipeline_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def parse_args(argv, command_line_options):
known_args, pipeline_args = parser.parse_known_args(argv)
for transform_options in options:
transform_options.validate(known_args)
_raise_error_on_invalid_flags(pipeline_args)
_raise_error_on_invalid_flags(
pipeline_args,
known_args.output_table if hasattr(known_args, 'output_table') else None)
if hasattr(known_args, 'input_pattern') or hasattr(known_args, 'input_file'):
known_args.all_patterns = _get_all_patterns(
known_args.input_pattern, known_args.input_file)
Expand Down Expand Up @@ -304,8 +306,8 @@ def write_headers(merged_header, file_path):
vcf_header_io.WriteVcfHeaders(file_path))


def _raise_error_on_invalid_flags(pipeline_args):
# type: (List[str]) -> None
def _raise_error_on_invalid_flags(pipeline_args, output_table):
# type: (List[str], Any) -> None
"""Raises an error if there are unrecognized flags."""
parser = argparse.ArgumentParser()
for cls in pipeline_options.PipelineOptions.__subclasses__():
Expand All @@ -318,6 +320,14 @@ def _raise_error_on_invalid_flags(pipeline_args):
not known_pipeline_args.setup_file):
raise ValueError('The --setup_file flag is required for DataflowRunner. '
'Please provide a path to the setup.py file.')
if output_table:
if (not hasattr(known_pipeline_args, 'temp_location') or
not known_pipeline_args.temp_location):
raise ValueError('--temp_location is required for BigQuery imports.')
if not known_pipeline_args.temp_location.startswith('gs://'):
raise ValueError(
'--temp_location must be valid GCS location for BigQuery imports')



def is_pipeline_direct_runner(pipeline):
Expand Down
18 changes: 14 additions & 4 deletions gcp_variant_transforms/pipeline_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,31 @@ def test_fail_on_invalid_flags(self):
'gcp-variant-transforms-test',
'--staging_location',
'gs://integration_test_runs/staging']
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)

# Add Dataflow runner (requires --setup_file).
pipeline_args.extend(['--runner', 'DataflowRunner'])
with self.assertRaisesRegexp(ValueError, 'setup_file'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)

# Add setup.py (required for Variant Transforms run). This is now valid.
pipeline_args.extend(['--setup_file', 'setup.py'])
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)

with self.assertRaisesRegexp(ValueError, '--temp_location is required*'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

pipeline_args.extend(['--temp_location', 'wrong_gcs'])
with self.assertRaisesRegexp(ValueError, '--temp_location must be valid*'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

pipeline_args = pipeline_args[:-1] + ['gs://valid_bucket/temp']
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

# Add an unknown flag.
pipeline_args.extend(['--unknown_flag', 'somevalue'])
with self.assertRaisesRegexp(ValueError, 'Unrecognized.*unknown_flag'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

def test_get_compression_type(self):
vcf_metadata_list = [filesystem.FileMetadata(path, size) for
Expand Down
4 changes: 1 addition & 3 deletions gcp_variant_transforms/transforms/sample_info_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(self, output_table_prefix, sample_name_encoding, append=False):
self._append = append
self._sample_name_encoding = sample_name_encoding
self._schema = sample_info_table_schema_generator.generate_schema()
self._temp_location = temp_location

def expand(self, pcoll):
return (pcoll
Expand All @@ -84,5 +83,4 @@ def expand(self, pcoll):
beam.io.BigQueryDisposition.WRITE_APPEND
if self._append
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
custom_gcs_temp_location=self._temp_location))
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
3 changes: 1 addition & 2 deletions gcp_variant_transforms/transforms/variant_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,4 @@ def expand(self, pcoll):
beam.io.BigQueryDisposition.WRITE_APPEND
if self._append
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
custom_gcs_temp_location=self._temp_location))
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
9 changes: 3 additions & 6 deletions gcp_variant_transforms/vcf_to_bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def _run_annotation_pipeline(known_args, pipeline_args):
def _create_sample_info_table(pipeline, # type: beam.Pipeline
pipeline_mode, # type: PipelineModes
known_args, # type: argparse.Namespace,
temp_directory, # str
):
# type: (...) -> None
headers = pipeline_common.read_headers(
Expand All @@ -410,8 +409,6 @@ def run(argv=None):
logging.info('Command: %s', ' '.join(argv or sys.argv))
known_args, pipeline_args = pipeline_common.parse_args(argv,
_COMMAND_LINE_OPTIONS)
if known_args.output_table and '--temp_location' not in pipeline_args:
raise ValueError('--temp_location is required for BigQuery imports.')
if known_args.auto_flags_experiment:
_get_input_dimensions(known_args, pipeline_args)

Expand Down Expand Up @@ -486,9 +483,9 @@ def run(argv=None):

for i in range(num_shards):
table_suffix = ''
if sharding and sharding.get_shard_name(i):
table_suffix = '_' + sharding.get_shard_name(i)
table_name = known_args.output_table + table_suffix
table_suffix = sharding.get_output_table_suffix(i)
table_name = sample_info_table_schema_generator.compose_table_name(
known_args.output_table, table_suffix)
_ = (variants[i] | 'VariantToBigQuery' + table_suffix >>
variant_to_bigquery.VariantToBigQuery(
table_name,
Expand Down

0 comments on commit d7c4271

Please sign in to comment.