diff --git a/dev/scripts/load_test_customer_data.sh b/dev/scripts/load_test_customer_data.sh index abfde7e561..8c5f65b077 100755 --- a/dev/scripts/load_test_customer_data.sh +++ b/dev/scripts/load_test_customer_data.sh @@ -333,32 +333,34 @@ build_all(){ } # ---execute--- -case ${1} in - "AWS"|"aws") +provider_arg=`echo ${1} |tr [a-z] [A-Z]` + +case ${provider_arg} in + "AWS") check-api-status "Koku" "${KOKU_URL_PREFIX}/v1/status/" check-api-status "Masu" "${MASU_URL_PREFIX}/v1/status/" build_aws_data enable_ocp_tags ;; - "AZURE"|"azure"|"Azure") + "AZURE") check-api-status "Koku" "${KOKU_URL_PREFIX}/v1/status/" check-api-status "Masu" "${MASU_URL_PREFIX}/v1/status/" build_azure_data enable_ocp_tags ;; - "GCP"|"gcp") + "GCP") check-api-status "Koku" "${KOKU_URL_PREFIX}/v1/status/" check-api-status "Masu" "${MASU_URL_PREFIX}/v1/status/" build_gcp_data enable_ocp_tags ;; - "ONPREM"|"onprem") + "ONPREM") check-api-status "Koku" "${KOKU_URL_PREFIX}/v1/status/" check-api-status "Masu" "${MASU_URL_PREFIX}/v1/status/" build_onprem_data enable_ocp_tags ;; - "all") + "ALL") check-api-status "Koku" "${KOKU_URL_PREFIX}/v1/status/" check-api-status "Masu" "${MASU_URL_PREFIX}/v1/status/" build_all enable_ocp_tags ;; - "HELP"|"help") usage;; + "HELP") usage;; *) usage;; esac diff --git a/koku/hcs/csv_file_handler.py b/koku/hcs/csv_file_handler.py index 1ff9291a2a..9bd5f2a4c7 100644 --- a/koku/hcs/csv_file_handler.py +++ b/koku/hcs/csv_file_handler.py @@ -4,7 +4,7 @@ import pandas as pd from api.common import log_json -from masu.util.aws.common import copy_local_report_file_to_s3_bucket +from masu.util.aws.common import copy_local_hcs_report_file_to_s3_bucket LOG = logging.getLogger(__name__) @@ -18,12 +18,13 @@ def __init__(self, schema_name, provider, provider_uuid): self._provider = provider self._provider_uuid = provider_uuid - def write_csv_to_s3(self, date, data, cols, tracing_id=None): + def write_csv_to_s3(self, date, data, cols, finalize=False, tracing_id=None): """ Generates an HCS CSV from the specified schema and provider. :param date :param data :param cols + :param finalize :param tracing_id :return none @@ -38,5 +39,5 @@ def write_csv_to_s3(self, date, data, cols, tracing_id=None): LOG.info(log_json(tracing_id, "preparing to write file to object storage")) my_df.to_csv(filename, header=cols, index=False) - copy_local_report_file_to_s3_bucket(tracing_id, s3_csv_path, filename, filename, "", date) + copy_local_hcs_report_file_to_s3_bucket(tracing_id, s3_csv_path, filename, filename, finalize, date) os.remove(filename) diff --git a/koku/hcs/daily_report.py b/koku/hcs/daily_report.py index f854031c2f..1ea273f4d8 100644 --- a/koku/hcs/daily_report.py +++ b/koku/hcs/daily_report.py @@ -23,15 +23,18 @@ def __init__(self, schema_name, provider, provider_uuid, tracing_id): self._date_accessor = DateAccessor() self._tracing_id = tracing_id - def generate_report(self, start_date, end_date): + def generate_report(self, start_date, end_date, finalize=False): """Generate HCS daily report :param start_date (str) The date to start populating the table :param end_date (str) The date to end on + :param finalize (bool) Set to True when report is final(default=False) - :returns (none) + returns (none) """ sql_file = f"sql/reporting_{self._provider.lower()}_hcs_daily_summary.sql" with HCSReportDBAccessor(self._schema_name) as accessor: for date in date_range(start_date, end_date, step=1): - accessor.get_hcs_daily_summary(date, self._provider, self._provider_uuid, sql_file, self._tracing_id) + accessor.get_hcs_daily_summary( + date, self._provider, self._provider_uuid, sql_file, self._tracing_id, finalize + ) diff --git a/koku/hcs/database/report_db_accessor.py b/koku/hcs/database/report_db_accessor.py index 1c7b53eb4d..cd9093e561 100644 --- a/koku/hcs/database/report_db_accessor.py +++ b/koku/hcs/database/report_db_accessor.py @@ -37,13 +37,14 @@ def __init__(self, schema): self.date_accessor = DateAccessor() self.jinja_sql = JinjaSql() - def get_hcs_daily_summary(self, date, provider, provider_uuid, sql_summary_file, tracing_id): + def get_hcs_daily_summary(self, date, provider, provider_uuid, sql_summary_file, tracing_id, finalize=False): """Build HCS daily report. :param date (datetime.date) The date to process :param provider (str) The provider name :param provider_uuid (uuid) ID for cost source :param sql_summary_file (str) The sql file used for processing :param tracing_id (id) Logging identifier + :param finalize (bool) Set True when report is finalized(default=False) :returns (None) """ @@ -53,6 +54,11 @@ def get_hcs_daily_summary(self, date, provider, provider_uuid, sql_summary_file, try: sql = pkgutil.get_data("hcs.database", sql_summary_file) sql = sql.decode("utf-8") + table = HCS_TABLE_MAP.get(provider.strip("-local")) + + if not self.table_exists_trino(table): + LOG.info(log_json(tracing_id, f"{table} does not exist, skipping...")) + return {} sql_params = { "provider_uuid": provider_uuid, @@ -61,8 +67,9 @@ def get_hcs_daily_summary(self, date, provider, provider_uuid, sql_summary_file, "date": date, "schema": self.schema, "ebs_acct_num": self._ebs_acct_num, - "table": HCS_TABLE_MAP.get(provider), + "table": table, } + LOG.debug(log_json(tracing_id, f"SQL params: {sql_params}")) sql, sql_params = self.jinja_sql.prepare_query(sql, sql_params) @@ -77,9 +84,14 @@ def get_hcs_daily_summary(self, date, provider, provider_uuid, sql_summary_file, if len(data) > 0: LOG.info(log_json(tracing_id, f"data found for date: {date}")) csv_handler = CSVFileHandler(self.schema, provider, provider_uuid) - csv_handler.write_csv_to_s3(date, data, cols, tracing_id) + csv_handler.write_csv_to_s3(date, data, cols, finalize, tracing_id) else: - LOG.info(log_json(tracing_id, f"no data found for date: {date}")) + LOG.info( + log_json( + tracing_id, + f"no data found for date: {date}, " f"provider: {provider}, provider_uuid: {provider_uuid}", + ) + ) except FileNotFoundError: LOG.error(log_json(tracing_id, f"unable to locate SQL file: {sql_summary_file}")) diff --git a/koku/hcs/database/sql/reporting_aws-local_hcs_daily_summary.sql b/koku/hcs/database/sql/reporting_aws-local_hcs_daily_summary.sql new file mode 100644 index 0000000000..454cbca3b2 --- /dev/null +++ b/koku/hcs/database/sql/reporting_aws-local_hcs_daily_summary.sql @@ -0,0 +1,9 @@ +SELECT *, '{{ebs_acct_num | sqlsafe}}' as ebs_account_id +FROM hive.{{schema | sqlsafe}}.{{table | sqlsafe}} +WHERE source = '{{provider_uuid | sqlsafe}}' + AND year = '{{year | sqlsafe}}' + AND month = '{{month | sqlsafe}}' + AND bill_billingentity = 'AWS Marketplace' + AND lineitem_legalentity like '%Red Hat%' + AND lineitem_usagestartdate >= TIMESTAMP '{{date | sqlsafe}}' + AND lineitem_usagestartdate < date_add('day', 1, TIMESTAMP '{{date | sqlsafe}}') diff --git a/koku/hcs/database/sql/reporting_azure-local_hcs_daily_summary.sql b/koku/hcs/database/sql/reporting_azure-local_hcs_daily_summary.sql new file mode 100644 index 0000000000..04a4d07889 --- /dev/null +++ b/koku/hcs/database/sql/reporting_azure-local_hcs_daily_summary.sql @@ -0,0 +1,9 @@ +SELECT *, '{{ebs_acct_num | sqlsafe}}' as ebs_account_id +FROM hive.{{schema | sqlsafe}}.{{table | sqlsafe}} +WHERE source = '{{provider_uuid | sqlsafe}}' + AND year = '{{year | sqlsafe}}' + AND month = '{{month | sqlsafe}}' + AND publishertype = 'Marketplace' + AND publishername like '%Red Hat%' + AND coalesce(date, usagedatetime) >= TIMESTAMP '{{date | sqlsafe}}' + AND coalesce(date, usagedatetime) < date_add('day', 1, TIMESTAMP '{{date | sqlsafe}}') diff --git a/koku/hcs/tasks.py b/koku/hcs/tasks.py index 8128abbce6..65c6f52cee 100644 --- a/koku/hcs/tasks.py +++ b/koku/hcs/tasks.py @@ -19,6 +19,12 @@ LOG = logging.getLogger(__name__) HCS_QUEUE = "hcs" +HCS_EXCEPTED_PROVIDERS = ( + Provider.PROVIDER_AWS, + Provider.PROVIDER_AWS_LOCAL, + Provider.PROVIDER_AZURE, + Provider.PROVIDER_AZURE_LOCAL, +) # any additional queues should be added to this list QUEUE_LIST = [HCS_QUEUE] @@ -44,7 +50,6 @@ def collect_hcs_report_data_from_manifest(reports_to_hcs_summarize): Returns: None - """ reports = [report for report in reports_to_hcs_summarize if report] reports_deduplicated = [dict(t) for t in {tuple(d.items()) for d in reports}] @@ -53,7 +58,7 @@ def collect_hcs_report_data_from_manifest(reports_to_hcs_summarize): start_date = None end_date = None if report.get("start") and report.get("end"): - LOG.info("using start and end dates from the manifest") + LOG.info("using start and end dates from the manifest for HCS processing") start_date = parser.parse(report.get("start")).strftime("%Y-%m-%d") end_date = parser.parse(report.get("end")).strftime("%Y-%m-%d") @@ -63,13 +68,14 @@ def collect_hcs_report_data_from_manifest(reports_to_hcs_summarize): tracing_id = report.get("tracing_id", report.get("manifest_uuid", str(uuid.uuid4()))) stmt = ( - f"[collect_hcs_report_data_from_manifest] schema_name: {schema_name}," + f"[collect_hcs_report_data_from_manifest]:" + f" schema_name: {schema_name}," f"provider_type: {provider_type}," f"provider_uuid: {provider_uuid}," f"start: {start_date}," f"end: {end_date}" ) - LOG.debug(log_json(tracing_id, stmt)) + LOG.info(log_json(tracing_id, stmt)) collect_hcs_report_data.s( schema_name, provider_type, provider_uuid, start_date, end_date, tracing_id @@ -77,23 +83,20 @@ def collect_hcs_report_data_from_manifest(reports_to_hcs_summarize): @celery_app.task(name="hcs.tasks.collect_hcs_report_data", queue=HCS_QUEUE) -def collect_hcs_report_data(schema_name, provider, provider_uuid, start_date=None, end_date=None, tracing_id=None): +def collect_hcs_report_data( + schema_name, provider, provider_uuid, start_date=None, end_date=None, tracing_id=None, finalize=False +): """Update Hybrid Committed Spend report. :param provider: (str) The provider type - :param provider_uuid: (str) The provider type + :param provider_uuid: (str) The provider unique identification number :param start_date: The date to start populating the table (default: (Today - 2 days)) :param end_date: The date to end on (default: Today) :param schema_name: (Str) db schema name :param tracing_id: (uuid) for log tracing + :param finalize: (boolean) If True run report finalization process for previous month(default: False) :returns None """ - - # drop "-local" from provider name when in development environment - if "-local" in provider: - LOG.debug(log_json(tracing_id, "dropping '-local' from provider name")) - provider = provider.strip("-local") - if schema_name and not schema_name.startswith("acct"): schema_name = f"acct{schema_name}" @@ -106,24 +109,63 @@ def collect_hcs_report_data(schema_name, provider, provider_uuid, start_date=Non if tracing_id is None: tracing_id = str(uuid.uuid4()) - if enable_hcs_processing(schema_name) and provider in (Provider.PROVIDER_AWS, Provider.PROVIDER_AZURE): + if enable_hcs_processing(schema_name) and provider in HCS_EXCEPTED_PROVIDERS: stmt = ( - f"Running HCS data collection: " + f"[collect_hcs_report_data]: " f"schema_name: {schema_name}, " f"provider_uuid: {provider_uuid}, " - f"provider: {provider}, " + f"provider_type: {provider}, " f"dates {start_date} - {end_date}" ) LOG.info(log_json(tracing_id, stmt)) reporter = ReportHCS(schema_name, provider, provider_uuid, tracing_id) - reporter.generate_report(start_date, end_date) + reporter.generate_report(start_date, end_date, finalize) else: stmt = ( f"[SKIPPED] HCS report generation: " - f"Schema-name: {schema_name}, " - f"provider: {provider}, " + f"Schema_name: {schema_name}, " + f"provider_type: {provider}, " f"provider_uuid: {provider_uuid}, " f"dates {start_date} - {end_date}" ) LOG.info(log_json(tracing_id, stmt)) + + +@celery_app.task(name="hcs.tasks.collect_hcs_report_finalization", queue=HCS_QUEUE) +def collect_hcs_report_finalization(tracing_id=None): + if tracing_id is None: + tracing_id = str(uuid.uuid4()) + + today = DateAccessor().today() + + for excepted_provider in HCS_EXCEPTED_PROVIDERS: + LOG.debug(log_json(tracing_id, f"excepted_provider: {excepted_provider}")) + + providers = Provider.objects.filter(type=excepted_provider).all() + + for provider in providers: + schema_name = provider.customer.schema_name + provider_uuid = provider.uuid + provider_type = provider.type + end_date_prev_month = today.replace(day=1) - datetime.timedelta(days=1) + start_date_prev_month = today.replace(day=1) - datetime.timedelta(days=end_date_prev_month.day) + + stmt = ( + f"[collect_hcs_report_finalization]: " + f"schema_name: {schema_name}, " + f"provider_type: {provider_type}, " + f"provider_uuid: {provider_uuid}, " + f"dates: {start_date_prev_month} - {end_date_prev_month}" + ) + LOG.info(log_json(tracing_id, stmt)) + + collect_hcs_report_data.s( + schema_name, + provider_type, + provider_uuid, + start_date_prev_month, + end_date_prev_month, + tracing_id, + True, + ).apply_async() diff --git a/koku/hcs/test/test_tasks.py b/koku/hcs/test/test_tasks.py index 8b5d0e8144..69f3228dd3 100644 --- a/koku/hcs/test/test_tasks.py +++ b/koku/hcs/test/test_tasks.py @@ -43,7 +43,7 @@ def test_get_report_dates(self, mock_report): end_date = self.today collect_hcs_report_data(self.schema, self.provider, self.provider_uuid, start_date, end_date) - self.assertIn("Running HCS data collection", _logs.output[0]) + self.assertIn("[collect_hcs_report_data]", _logs.output[0]) def test_get_report_no_start_date(self, mock_report): """Test no start or end dates provided""" @@ -52,7 +52,7 @@ def test_get_report_no_start_date(self, mock_report): with self.assertLogs("hcs.tasks", "INFO") as _logs: collect_hcs_report_data(self.schema, self.provider, self.provider_uuid) - self.assertIn("Running HCS data collection", _logs.output[0]) + self.assertIn("[collect_hcs_report_data]", _logs.output[0]) def test_get_report_no_end_date(self, mock_report): """Test no start end provided""" @@ -62,7 +62,7 @@ def test_get_report_no_end_date(self, mock_report): start_date = self.yesterday collect_hcs_report_data(self.schema, self.provider, self.provider_uuid, start_date) - self.assertIn("Running HCS data collection", _logs.output[0]) + self.assertIn("[collect_hcs_report_data]", _logs.output[0]) def test_get_report_invalid_provider(self, mock_report): """Test invalid provider""" @@ -98,7 +98,7 @@ def test_get_report_with_manifest(self, mock_report, rd): } ] - with self.assertLogs("hcs.tasks", "DEBUG") as _logs: + with self.assertLogs("hcs.tasks", "INFO") as _logs: collect_hcs_report_data_from_manifest(manifests) self.assertIn("[collect_hcs_report_data_from_manifest]", _logs.output[0]) @@ -110,7 +110,7 @@ def test_get_report_with_manifest(self, mock_report, rd): @patch("hcs.tasks.collect_hcs_report_data") def test_get_report_with_manifest_and_dates(self, mock_report, rd): - """Test invalid provider""" + """Test HCS reports using manifest""" from hcs.tasks import collect_hcs_report_data_from_manifest manifests = [ @@ -125,5 +125,21 @@ def test_get_report_with_manifest_and_dates(self, mock_report, rd): with self.assertLogs("hcs.tasks", "INFO") as _logs: collect_hcs_report_data_from_manifest(manifests) + self.assertIn("using start and end dates from the manifest for HCS processing", _logs.output[0]) + + @patch("hcs.tasks.collect_hcs_report_data") + @patch("api.provider.models") + def test_get_collect_hcs_report_finalization(self, mock_report, rd, provider): + """Test hcs finalization""" + from hcs.tasks import collect_hcs_report_finalization + + provider.customer.schema_name.return_value = provider(side_effect=Provider.objects.filter(type="AWS")) + + with self.assertLogs("hcs.tasks", "INFO") as _logs: + collect_hcs_report_finalization() - self.assertIn("using start and end dates from the manifest", _logs.output[0]) + self.assertIn("[collect_hcs_report_finalization]:", _logs.output[0]) + self.assertIn("schema_name:", _logs.output[0]) + self.assertIn("provider_type:", _logs.output[0]) + self.assertIn("provider_uuid:", _logs.output[0]) + self.assertIn("dates:", _logs.output[0]) diff --git a/koku/koku/celery.py b/koku/koku/celery.py index 1469a1d3a3..86672e3ce1 100644 --- a/koku/koku/celery.py +++ b/koku/koku/celery.py @@ -215,8 +215,13 @@ def readiness_check(self): "schedule": crontab(hour=0, minute=0), } +# Beat used for HCS report finalization +app.conf.beat_schedule["finalize_hcs_reports"] = { + "task": "hcs.tasks.collect_hcs_report_finalization", + "schedule": crontab(0, 0, day_of_month="15"), +} -# Celery timeout if broker is unavaiable to avoid blocking indefintely +# Celery timeout if broker is unavailable to avoid blocking indefinitely app.conf.broker_transport_options = {"max_retries": 4, "interval_start": 0, "interval_step": 0.5, "interval_max": 3} app.autodiscover_tasks() diff --git a/koku/masu/api/hcs_report_finalization.py b/koku/masu/api/hcs_report_finalization.py new file mode 100644 index 0000000000..bead10325b --- /dev/null +++ b/koku/masu/api/hcs_report_finalization.py @@ -0,0 +1,43 @@ +# +# Copyright 2021 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""View for running_celery_task collect_hcs_report_finalization endpoint.""" +import datetime +import logging +import uuid + +from django.views.decorators.cache import never_cache +from rest_framework.decorators import api_view +from rest_framework.decorators import permission_classes +from rest_framework.decorators import renderer_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework.settings import api_settings + +from hcs.tasks import collect_hcs_report_finalization +from hcs.tasks import HCS_QUEUE + + +LOG = logging.getLogger(__name__) + + +@never_cache +@api_view(http_method_names=["GET"]) +@permission_classes((AllowAny,)) +@renderer_classes(tuple(api_settings.DEFAULT_RENDERER_CLASSES)) +def hcs_report_finalization(request): + """Generate HCS finalized for last month(based on 'datetime.date.today')reports.""" + tracing_id = str(uuid.uuid4()) + + report_data_msg_key = "HCS Report Finalization Task ID" + async_results = [] + + today = datetime.date.today() + first = today.replace(day=1) + last_month = first - datetime.timedelta(days=1) + + if request.method == "GET": + async_result = collect_hcs_report_finalization.s(tracing_id).apply_async(queue=HCS_QUEUE) + async_results.append({last_month.strftime("%Y-%m"): str(async_result)}) + return Response({report_data_msg_key: async_results}) diff --git a/koku/masu/api/urls.py b/koku/masu/api/urls.py index 61ce5d2a7a..b2a4f97c21 100644 --- a/koku/masu/api/urls.py +++ b/koku/masu/api/urls.py @@ -20,6 +20,7 @@ from masu.api.views import gcp_invoice_monthly_cost from masu.api.views import get_status from masu.api.views import hcs_report_data +from masu.api.views import hcs_report_finalization from masu.api.views import lockinfo from masu.api.views import pg_engine_version from masu.api.views import report_data @@ -39,6 +40,7 @@ path("enabled_tags/", enabled_tags, name="enabled_tags"), path("expired_data/", expired_data, name="expired_data"), path("hcs_report_data/", hcs_report_data, name="hcs_report_data"), + path("hcs_report_finalization/", hcs_report_finalization, name="hcs_report_finalization"), path("report_data/", report_data, name="report_data"), path("source_cleanup/", cleanup, name="cleanup"), path("update_cost_model_costs/", update_cost_model_costs, name="update_cost_model_costs"), diff --git a/koku/masu/api/views.py b/koku/masu/api/views.py index 50880f334d..78b37231ee 100644 --- a/koku/masu/api/views.py +++ b/koku/masu/api/views.py @@ -17,6 +17,7 @@ from masu.api.expired_data import expired_data from masu.api.gcp_invoice_monthly_cost import gcp_invoice_monthly_cost from masu.api.hcs_report_data import hcs_report_data +from masu.api.hcs_report_finalization import hcs_report_finalization from masu.api.manifest.views import ManifestView from masu.api.report_data import report_data from masu.api.running_celery_tasks import celery_queue_lengths diff --git a/koku/masu/openapi.json b/koku/masu/openapi.json index 40217ea13e..b5d0c434bc 100644 --- a/koku/masu/openapi.json +++ b/koku/masu/openapi.json @@ -30,6 +30,10 @@ "name": "HCS Report Data", "description": "Operations about HCS report data" }, + { + "name": "HCS Report Finalization", + "description": "Operations about HCS report finalization" + }, { "name": "Update Cost Model Cost", "description": "Operations about updating cost model costs" @@ -276,7 +280,7 @@ ], "responses": { "200": { - "description": "Data summary task has been queued", + "description": "HCS report task has been queued", "content": { "application/json": { "schema": { @@ -291,6 +295,40 @@ ] } }, + "/hcs_report_finalization/": { + "get": { + "operationId": "getHCSReportFinalization", + "description": "Finalize HCS report data.", + "parameters": [ + { + "name": "tracing_id", + "in": "query", + "description": "The tracing UUID", + "required": false, + "schema": { + "type": "string", + "format": "uuid", + "example": "83ee048e-3c1d-43ef-b945-108225ae52f4" + } + } + ], + "responses": { + "200": { + "description": "HCS Finalization task has been queued", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HCSDataGetResponse" + } + } + } + } + }, + "tags": [ + "HCS Report Finalization" + ] + } + }, "/report_data/": { "get": { "operationId": "getReportData", diff --git a/koku/masu/test/api/test_hcs_report_finalization.py b/koku/masu/test/api/test_hcs_report_finalization.py new file mode 100644 index 0000000000..f594719d0d --- /dev/null +++ b/koku/masu/test/api/test_hcs_report_finalization.py @@ -0,0 +1,34 @@ +# +# Copyright 2021 Red Hat Inc. +# SPDX-License-Identifier: Apache-2.0 +# +"""Test the hcs_report_data endpoint view.""" +import uuid +from unittest.mock import patch + +from django.test import TestCase +from django.test.utils import override_settings +from django.urls import reverse + + +@override_settings(ROOT_URLCONF="masu.urls") +class HCSFinalizationTests(TestCase): + """Test Cases for the hcs_report_finalization endpoint.""" + + ENDPOINT = "hcs_report_finalization" + + @patch("koku.middleware.MASU", return_value=True) + @patch("masu.api.hcs_report_finalization.collect_hcs_report_finalization") + def test_get_report_data(self, mock_celery, _): + """Test the GET report_data endpoint.""" + + params = { + "tracing_id": str(uuid.uuid4()), + } + expected_key = "HCS Report Finalization Task ID" + + response = self.client.get(reverse(self.ENDPOINT), params) + body = response.json() + self.assertEqual(response.status_code, 200) + self.assertIn(expected_key, body) + mock_celery.s.assert_called() diff --git a/koku/masu/test/util/aws/test_common.py b/koku/masu/test/util/aws/test_common.py index f855a6664a..98eb264d78 100644 --- a/koku/masu/test/util/aws/test_common.py +++ b/koku/masu/test/util/aws/test_common.py @@ -360,6 +360,22 @@ def test_copy_data_to_s3_bucket(self): upload = utils.copy_data_to_s3_bucket("request_id", "path", "filename", "data", "manifest_id") self.assertEqual(upload, None) + def test_copy_hcs_data_to_s3_bucket(self): + """Test copy_hcs_data_to_s3_bucket.""" + upload = utils.copy_hcs_data_to_s3_bucket("request_id", "path", "filename", "data") + self.assertEqual(upload, None) + + with patch("masu.util.aws.common.settings", ENABLE_S3_ARCHIVING=True): + with patch("masu.util.aws.common.get_s3_resource") as mock_s3: + upload = utils.copy_hcs_data_to_s3_bucket("request_id", "path", "filename", "data") + self.assertIsNotNone(upload) + + with patch("masu.util.aws.common.settings", ENABLE_S3_ARCHIVING=True): + with patch("masu.util.aws.common.get_s3_resource") as mock_s3: + mock_s3.side_effect = ClientError({}, "Error") + upload = utils.copy_hcs_data_to_s3_bucket("request_id", "path", "filename", "data") + self.assertEqual(upload, None) + def test_aws_post_processor(self): """Test that missing columns in a report end up in the data frame.""" column_one = "column_one" diff --git a/koku/masu/util/aws/common.py b/koku/masu/util/aws/common.py index 8a8ae95902..7625561ee2 100644 --- a/koku/masu/util/aws/common.py +++ b/koku/masu/util/aws/common.py @@ -351,6 +351,43 @@ def copy_local_report_file_to_s3_bucket( copy_data_to_s3_bucket(request_id, s3_path, local_filename, fin, manifest_id, context) +def copy_hcs_data_to_s3_bucket(request_id, path, filename, data, finalize=False, context={}): + """ + Copies HCS data to s3 bucket location + """ + if not ( + settings.ENABLE_S3_ARCHIVING + or enable_trino_processing(context.get("provider_uuid"), context.get("provider_type"), context.get("account")) + ): + return None + + upload = None + upload_key = f"{path}/{filename}" + extra_args = {"Metadata": {"finalized": str(finalize)}} + + try: + s3_resource = get_s3_resource() + s3_obj = {"bucket_name": settings.S3_BUCKET_NAME, "key": upload_key} + upload = s3_resource.Object(**s3_obj) + upload.upload_fileobj(data, ExtraArgs=extra_args) + except (EndpointConnectionError, ClientError) as err: + msg = f"Unable to copy data to {upload_key} in bucket {settings.S3_BUCKET_NAME}. Reason: {str(err)}" + LOG.info(log_json(request_id, msg, context)) + return upload + + +def copy_local_hcs_report_file_to_s3_bucket( + request_id, s3_path, full_file_path, local_filename, finalize=False, context={} +): + """ + Copies local report file to s3 bucket + """ + if s3_path and settings.ENABLE_S3_ARCHIVING: + LOG.info(f"copy_local_HCS_report_file_to_s3_bucket: {s3_path} {full_file_path}") + with open(full_file_path, "rb") as fin: + copy_hcs_data_to_s3_bucket(request_id, s3_path, local_filename, fin, finalize, context) + + def remove_files_not_in_set_from_s3_bucket(request_id, s3_path, manifest_id, context={}): """ Removes all files in a given prefix if they are not within the given set.