diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index 40b92dde6..086ab6904 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -185,7 +185,8 @@ def add_arguments(self, parser): parser.add_argument( '--num_bigquery_write_shards', type=int, default=1, - help=('This flag is deprecated and may be removed in future releases.')) + help=('This flag is deprecated and will be removed in future ' + 'releases.')) parser.add_argument( '--null_numeric_value_replacement', type=int, diff --git a/gcp_variant_transforms/pipeline_common.py b/gcp_variant_transforms/pipeline_common.py index 378738af6..fcb8308c2 100644 --- a/gcp_variant_transforms/pipeline_common.py +++ b/gcp_variant_transforms/pipeline_common.py @@ -73,7 +73,9 @@ def parse_args(argv, command_line_options): known_args, pipeline_args = parser.parse_known_args(argv) for transform_options in options: transform_options.validate(known_args) - _raise_error_on_invalid_flags(pipeline_args) + _raise_error_on_invalid_flags( + pipeline_args, + known_args.output_table if hasattr(known_args, 'output_table') else None) if hasattr(known_args, 'input_pattern') or hasattr(known_args, 'input_file'): known_args.all_patterns = _get_all_patterns( known_args.input_pattern, known_args.input_file) @@ -304,8 +306,8 @@ def write_headers(merged_header, file_path): vcf_header_io.WriteVcfHeaders(file_path)) -def _raise_error_on_invalid_flags(pipeline_args): - # type: (List[str]) -> None +def _raise_error_on_invalid_flags(pipeline_args, output_table): + # type: (List[str], Any) -> None """Raises an error if there are unrecognized flags.""" parser = argparse.ArgumentParser() for cls in pipeline_options.PipelineOptions.__subclasses__(): @@ -318,6 +320,14 @@ def _raise_error_on_invalid_flags(pipeline_args): not known_pipeline_args.setup_file): raise ValueError('The --setup_file flag is required for DataflowRunner. ' 'Please provide a path to the setup.py file.') + if output_table: + if (not hasattr(known_pipeline_args, 'temp_location') or + not known_pipeline_args.temp_location): + raise ValueError('--temp_location is required for BigQuery imports.') + if not known_pipeline_args.temp_location.startswith('gs://'): + raise ValueError( + '--temp_location must be valid GCS location for BigQuery imports') + def is_pipeline_direct_runner(pipeline): diff --git a/gcp_variant_transforms/pipeline_common_test.py b/gcp_variant_transforms/pipeline_common_test.py index a5633b615..181c1b851 100644 --- a/gcp_variant_transforms/pipeline_common_test.py +++ b/gcp_variant_transforms/pipeline_common_test.py @@ -94,21 +94,31 @@ def test_fail_on_invalid_flags(self): 'gcp-variant-transforms-test', '--staging_location', 'gs://integration_test_runs/staging'] - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, None) # Add Dataflow runner (requires --setup_file). pipeline_args.extend(['--runner', 'DataflowRunner']) with self.assertRaisesRegexp(ValueError, 'setup_file'): - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, None) # Add setup.py (required for Variant Transforms run). This is now valid. pipeline_args.extend(['--setup_file', 'setup.py']) - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, None) + + with self.assertRaisesRegexp(ValueError, '--temp_location is required*'): + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') + + pipeline_args.extend(['--temp_location', 'wrong_gcs']) + with self.assertRaisesRegexp(ValueError, '--temp_location must be valid*'): + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') + + pipeline_args = pipeline_args[:-1] + ['gs://valid_bucket/temp'] + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') # Add an unknown flag. pipeline_args.extend(['--unknown_flag', 'somevalue']) with self.assertRaisesRegexp(ValueError, 'Unrecognized.*unknown_flag'): - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') def test_get_compression_type(self): vcf_metadata_list = [filesystem.FileMetadata(path, size) for diff --git a/gcp_variant_transforms/transforms/sample_info_to_bigquery.py b/gcp_variant_transforms/transforms/sample_info_to_bigquery.py index fde467de1..fe63c71e1 100644 --- a/gcp_variant_transforms/transforms/sample_info_to_bigquery.py +++ b/gcp_variant_transforms/transforms/sample_info_to_bigquery.py @@ -69,7 +69,6 @@ def __init__(self, output_table_prefix, sample_name_encoding, append=False): self._append = append self._sample_name_encoding = sample_name_encoding self._schema = sample_info_table_schema_generator.generate_schema() - self._temp_location = temp_location def expand(self, pcoll): return (pcoll @@ -84,5 +83,4 @@ def expand(self, pcoll): beam.io.BigQueryDisposition.WRITE_APPEND if self._append else beam.io.BigQueryDisposition.WRITE_TRUNCATE), - method=beam.io.WriteToBigQuery.Method.FILE_LOADS, - custom_gcs_temp_location=self._temp_location)) + method=beam.io.WriteToBigQuery.Method.FILE_LOADS)) diff --git a/gcp_variant_transforms/transforms/variant_to_bigquery.py b/gcp_variant_transforms/transforms/variant_to_bigquery.py index 5a954e502..2f17ba358 100644 --- a/gcp_variant_transforms/transforms/variant_to_bigquery.py +++ b/gcp_variant_transforms/transforms/variant_to_bigquery.py @@ -119,5 +119,4 @@ def expand(self, pcoll): beam.io.BigQueryDisposition.WRITE_APPEND if self._append else beam.io.BigQueryDisposition.WRITE_TRUNCATE), - method=beam.io.WriteToBigQuery.Method.FILE_LOADS, - custom_gcs_temp_location=self._temp_location)) + method=beam.io.WriteToBigQuery.Method.FILE_LOADS)) diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index 5b98a79f0..72ede7621 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -390,7 +390,6 @@ def _run_annotation_pipeline(known_args, pipeline_args): def _create_sample_info_table(pipeline, # type: beam.Pipeline pipeline_mode, # type: PipelineModes known_args, # type: argparse.Namespace, - temp_directory, # str ): # type: (...) -> None headers = pipeline_common.read_headers( @@ -410,8 +409,6 @@ def run(argv=None): logging.info('Command: %s', ' '.join(argv or sys.argv)) known_args, pipeline_args = pipeline_common.parse_args(argv, _COMMAND_LINE_OPTIONS) - if known_args.output_table and '--temp_location' not in pipeline_args: - raise ValueError('--temp_location is required for BigQuery imports.') if known_args.auto_flags_experiment: _get_input_dimensions(known_args, pipeline_args) @@ -486,9 +483,9 @@ def run(argv=None): for i in range(num_shards): table_suffix = '' - if sharding and sharding.get_shard_name(i): - table_suffix = '_' + sharding.get_shard_name(i) - table_name = known_args.output_table + table_suffix + table_suffix = sharding.get_output_table_suffix(i) + table_name = sample_info_table_schema_generator.compose_table_name( + known_args.output_table, table_suffix) _ = (variants[i] | 'VariantToBigQuery' + table_suffix >> variant_to_bigquery.VariantToBigQuery( table_name,