diff --git a/nise/__init__.py b/nise/__init__.py index d3efa6a0..e3c44d1d 100644 --- a/nise/__init__.py +++ b/nise/__init__.py @@ -1,3 +1,3 @@ -__version__ = "4.4.11" +__version__ = "4.4.12" VERSION = __version__.split(".") diff --git a/nise/__main__.py b/nise/__main__.py index 4db5ac22..d9cbb3fb 100644 --- a/nise/__main__.py +++ b/nise/__main__.py @@ -742,7 +742,7 @@ def calculate_start_date(start_date): elif start_date == "today": generated_start_date = today().replace(hour=0, minute=0, second=0) elif start_date and isinstance(start_date, datetime.date): - generated_start_date = start_date + generated_start_date = datetime.datetime(start_date.year, start_date.month, start_date.day) elif start_date: generated_start_date = date_parser.parse(start_date) else: @@ -760,7 +760,7 @@ def calculate_end_date(start_date, end_date): elif end_date == "today": generated_end_date = today().replace(hour=0, minute=0, second=0) elif end_date and isinstance(end_date, datetime.date): - generated_end_date = end_date + generated_end_date = datetime.datetime(end_date.year, end_date.month, end_date.day) else: generated_end_date = date_parser.parse(end_date) except TypeError: diff --git a/nise/generators/aws/aws_generator.py b/nise/generators/aws/aws_generator.py index e8dd61ee..e39cf926 100644 --- a/nise/generators/aws/aws_generator.py +++ b/nise/generators/aws/aws_generator.py @@ -333,6 +333,12 @@ def _get_location(self): location = choice(REGIONS) return location + def _get_legal_entity(self): + """Pick legal entity.""" + if self.attributes and self.attributes.get("legal_entity"): + return self.attributes.get("legal_entity") + return "Amazon Web Services, Inc." + def _add_common_usage_info(self, row, start, end, **kwargs): """Add common usage information.""" row["lineItem/UsageAccountId"] = choice(self.usage_accounts) @@ -340,7 +346,7 @@ def _add_common_usage_info(self, row, start, end, **kwargs): row["lineItem/UsageStartDate"] = start row["lineItem/UsageEndDate"] = end row["lineItem/CurrencyCode"] = self.currency - row["lineItem/LegalEntity"] = "Amazon Web Services, Inc." + row["lineItem/LegalEntity"] = self._get_legal_entity() return row def _add_tag_data(self, row): diff --git a/tests/test_aws_generator.py b/tests/test_aws_generator.py index d1c34e70..470f2459 100644 --- a/tests/test_aws_generator.py +++ b/tests/test_aws_generator.py @@ -156,6 +156,21 @@ def test_get_location(self): location = generator._get_location() self.assertIn("us-west-1", location) + def test_get_legal_entity(self): + """Test the _get_legal_entity method.""" + two_hours_ago = (self.now - self.one_hour) - self.one_hour + generator = TestGenerator(two_hours_ago, self.now, self.currency, self.payer_account, self.usage_accounts) + legal_entity = generator._get_legal_entity() + self.assertEqual(legal_entity, "Amazon Web Services, Inc.") + + attributes = {} + attributes["legal_entity"] = "Corey" + generator = TestGenerator( + two_hours_ago, self.now, self.currency, self.payer_account, self.usage_accounts, attributes + ) + legal_entity = generator._get_legal_entity() + self.assertEqual(legal_entity, "Corey") + class AWSGeneratorTestCase(TestCase): """Test Base for specific generator classes."""