diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index 0c849b3bb..31c2aac13 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -195,7 +195,8 @@ def add_arguments(self, parser): parser.add_argument( '--num_bigquery_write_shards', type=int, default=1, - help=('This flag is deprecated and may be removed in future releases.')) + help=('This flag is deprecated and will be removed in future ' + 'releases.')) parser.add_argument( '--null_numeric_value_replacement', type=int, diff --git a/gcp_variant_transforms/pipeline_common.py b/gcp_variant_transforms/pipeline_common.py index 71c767961..cbbbbbf8c 100644 --- a/gcp_variant_transforms/pipeline_common.py +++ b/gcp_variant_transforms/pipeline_common.py @@ -71,7 +71,7 @@ def parse_args(argv, command_line_options): known_args, pipeline_args = parser.parse_known_args(argv) for transform_options in options: transform_options.validate(known_args) - _raise_error_on_invalid_flags(pipeline_args) + _raise_error_on_invalid_flags(pipeline_args, known_args.output_table) if hasattr(known_args, 'input_pattern') or hasattr(known_args, 'input_file'): known_args.all_patterns = _get_all_patterns( known_args.input_pattern, known_args.input_file) @@ -301,8 +301,8 @@ def write_headers(merged_header, file_path): vcf_header_io.WriteVcfHeaders(file_path)) -def _raise_error_on_invalid_flags(pipeline_args): - # type: (List[str]) -> None +def _raise_error_on_invalid_flags(pipeline_args, output_table): + # type: (List[str], Any) -> None """Raises an error if there are unrecognized flags.""" parser = argparse.ArgumentParser() for cls in pipeline_options.PipelineOptions.__subclasses__(): @@ -315,6 +315,14 @@ def _raise_error_on_invalid_flags(pipeline_args): not known_pipeline_args.setup_file): raise ValueError('The --setup_file flag is required for DataflowRunner. ' 'Please provide a path to the setup.py file.') + if output_table: + if (not hasattr(known_pipeline_args, 'temp_location') or + not known_pipeline_args.temp_location): + raise ValueError('--temp_location is required for BigQuery imports.') + if not known_pipeline_args.temp_location.startswith('gs://'): + raise ValueError( + '--temp_location must be valid GCS location for BigQuery imports') + def is_pipeline_direct_runner(pipeline): diff --git a/gcp_variant_transforms/pipeline_common_test.py b/gcp_variant_transforms/pipeline_common_test.py index cd2d2ce58..1697908ff 100644 --- a/gcp_variant_transforms/pipeline_common_test.py +++ b/gcp_variant_transforms/pipeline_common_test.py @@ -95,21 +95,31 @@ def test_fail_on_invalid_flags(self): 'gcp-variant-transforms-test', '--staging_location', 'gs://integration_test_runs/staging'] - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, None) # Add Dataflow runner (requires --setup_file). pipeline_args.extend(['--runner', 'DataflowRunner']) with self.assertRaisesRegexp(ValueError, 'setup_file'): - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, None) # Add setup.py (required for Variant Transforms run). This is now valid. pipeline_args.extend(['--setup_file', 'setup.py']) - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, None) + + with self.assertRaisesRegexp(ValueError, '--temp_location is required*'): + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') + + pipeline_args.extend(['--temp_location', 'wrong_gcs']) + with self.assertRaisesRegexp(ValueError, '--temp_location must be valid*'): + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') + + pipeline_args = pipeline_args[:-1] + ['gs://valid_bucket/temp'] + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') # Add an unknown flag. pipeline_args.extend(['--unknown_flag', 'somevalue']) with self.assertRaisesRegexp(ValueError, 'Unrecognized.*unknown_flag'): - pipeline_common._raise_error_on_invalid_flags(pipeline_args) + pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output') def test_get_compression_type(self): vcf_metadata_list = [filesystem.FileMetadata(path, size) for diff --git a/gcp_variant_transforms/transforms/sample_info_to_bigquery.py b/gcp_variant_transforms/transforms/sample_info_to_bigquery.py index 0a6b90a1a..965bf2b28 100644 --- a/gcp_variant_transforms/transforms/sample_info_to_bigquery.py +++ b/gcp_variant_transforms/transforms/sample_info_to_bigquery.py @@ -50,7 +50,7 @@ def process(self, vcf_header): class SampleInfoToBigQuery(beam.PTransform): """Writes sample info to BigQuery.""" - def __init__(self, output_table_prefix, temp_location, append=False, + def __init__(self, output_table_prefix, append=False, samples_span_multiple_files=False): # type: (str, Dict[str, str], bool, bool) -> None """Initializes the transform. @@ -67,7 +67,6 @@ def __init__(self, output_table_prefix, temp_location, append=False, self._append = append self._samples_span_multiple_files = samples_span_multiple_files self._schema = sample_info_table_schema_generator.generate_schema() - self._temp_location = temp_location def expand(self, pcoll): return (pcoll @@ -82,5 +81,4 @@ def expand(self, pcoll): beam.io.BigQueryDisposition.WRITE_APPEND if self._append else beam.io.BigQueryDisposition.WRITE_TRUNCATE), - method=beam.io.WriteToBigQuery.Method.FILE_LOADS, - custom_gcs_temp_location=self._temp_location)) + method=beam.io.WriteToBigQuery.Method.FILE_LOADS)) diff --git a/gcp_variant_transforms/transforms/sample_info_to_bigquery_test.py b/gcp_variant_transforms/transforms/sample_info_to_bigquery_test.py index 3dd280fcf..d07835a55 100644 --- a/gcp_variant_transforms/transforms/sample_info_to_bigquery_test.py +++ b/gcp_variant_transforms/transforms/sample_info_to_bigquery_test.py @@ -53,7 +53,7 @@ def test_convert_sample_info_to_row(self): | transforms.Create([vcf_header_1, vcf_header_2]) | 'ConvertToRow' >> transforms.ParDo(sample_info_to_bigquery.ConvertSampleInfoToRow( - ), False)) + False), )) assert_that(bigquery_rows, equal_to(expected_rows)) pipeline.run() @@ -83,7 +83,7 @@ def test_convert_sample_info_to_row_without_file_in_hash(self): | transforms.Create([vcf_header_1, vcf_header_2]) | 'ConvertToRow' >> transforms.ParDo(sample_info_to_bigquery.ConvertSampleInfoToRow( - ), True)) + True), )) assert_that(bigquery_rows, equal_to(expected_rows)) pipeline.run() diff --git a/gcp_variant_transforms/transforms/variant_to_bigquery.py b/gcp_variant_transforms/transforms/variant_to_bigquery.py index 58f73d1c8..73b298482 100644 --- a/gcp_variant_transforms/transforms/variant_to_bigquery.py +++ b/gcp_variant_transforms/transforms/variant_to_bigquery.py @@ -59,7 +59,6 @@ def __init__( self, output_table, # type: str header_fields, # type: vcf_header_io.VcfHeader - temp_location, # type: str variant_merger=None, # type: variant_merge_strategy.VariantMergeStrategy proc_var_factory=None, # type: processed_variant.ProcessedVariantFactory # TODO(bashir2): proc_var_factory is a required argument and if `None` is @@ -99,7 +98,6 @@ def __init__( """ self._output_table = output_table self._header_fields = header_fields - self._temp_location = temp_location self._variant_merger = variant_merger self._proc_var_factory = proc_var_factory self._append = append @@ -137,5 +135,4 @@ def expand(self, pcoll): beam.io.BigQueryDisposition.WRITE_APPEND if self._append else beam.io.BigQueryDisposition.WRITE_TRUNCATE), - method=beam.io.WriteToBigQuery.Method.FILE_LOADS, - custom_gcs_temp_location=self._temp_location)) + method=beam.io.WriteToBigQuery.Method.FILE_LOADS)) diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index 7a5672c27..3df3bb428 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -384,7 +384,6 @@ def _run_annotation_pipeline(known_args, pipeline_args): def _create_sample_info_table(pipeline, # type: beam.Pipeline pipeline_mode, # type: PipelineModes known_args, # type: argparse.Namespace, - temp_directory, # str ): # type: (...) -> None headers = pipeline_common.read_headers( @@ -395,7 +394,6 @@ def _create_sample_info_table(pipeline, # type: beam.Pipeline _ = (headers | 'SampleInfoToBigQuery' >> sample_info_to_bigquery.SampleInfoToBigQuery( known_args.output_table, - temp_directory, known_args.append, known_args.samples_span_multiple_files)) @@ -406,8 +404,6 @@ def run(argv=None): logging.info('Command: %s', ' '.join(argv or sys.argv)) known_args, pipeline_args = pipeline_common.parse_args(argv, _COMMAND_LINE_OPTIONS) - if known_args.output_table and '--temp_location' not in pipeline_args: - raise ValueError('--temp_location is required for BigQuery imports.') if known_args.auto_flags_experiment: _get_input_dimensions(known_args, pipeline_args) @@ -483,10 +479,6 @@ def run(argv=None): num_partitions = 1 if known_args.output_table: - temp_directory = pipeline_options.PipelineOptions(pipeline_args).view_as( - pipeline_options.GoogleCloudOptions).temp_location - if not temp_directory: - raise ValueError('--temp_location must be set when writing to BigQuery.') for i in range(num_partitions): table_suffix = '' if partitioner and partitioner.get_partition_name(i): @@ -496,7 +488,6 @@ def run(argv=None): variant_to_bigquery.VariantToBigQuery( table_name, header_fields, - temp_directory, variant_merger, processed_variant_factory, append=known_args.append, @@ -507,7 +498,7 @@ def run(argv=None): known_args.null_numeric_value_replacement))) if known_args.generate_sample_info_table: _create_sample_info_table( - pipeline, pipeline_mode, known_args, temp_directory) + pipeline, pipeline_mode, known_args) if known_args.output_avro_path: # TODO(bashir2): Add an integration test that outputs to Avro files and