From 7c4b8b53d206c4a4348146705f7045b1d84a7ee7 Mon Sep 17 00:00:00 2001 From: Carise F Date: Tue, 25 Apr 2017 17:54:17 -0700 Subject: [PATCH] Import scanner violations into Cloud SQL (#239) --- .gitignore | 4 + .../cloud/security/common/data_access/dao.py | 5 +- .../security/common/data_access/group_dao.py | 3 - .../common/data_access/organization_dao.py | 3 - .../common/data_access/project_dao.py | 3 - .../data_access/sql_queries/create_tables.py | 22 +- .../data_access/sql_queries/load_data.py | 7 + .../common/data_access/violation_dao.py | 123 ++++++++++ .../email_templates/scanner_summary.jinja | 110 +++------ .../security/common/gcp_type/iam_policy.py | 2 +- .../cloud/security/common/gcp_type/project.py | 1 - .../cloud/security/common/util/file_loader.py | 4 +- .../load_org_iam_policies_pipeline.py | 15 -- .../inventory/pipelines/load_orgs_pipeline.py | 15 -- .../load_projects_iam_policies_pipeline.py | 15 -- .../pipelines/load_projects_pipeline.py | 15 -- .../scanner/audit/base_rules_engine.py | 34 --- .../scanner/audit/group_rules_engine.py | 11 +- .../scanner/audit/org_rules_engine.py | 136 +++--------- google/cloud/security/scanner/audit/rules.py | 129 +++++++++++ google/cloud/security/scanner/scanner.py | 77 +++++-- tests/common/data_access/dao_test.py | 1 + tests/common/data_access/group_dao_test.py | 1 + .../data_access/organization_dao_test.py | 1 + tests/common/data_access/project_dao_test.py | 1 + .../common/data_access/violation_dao_test.py | 210 ++++++++++++++++++ tests/scanner/audit/base_rules_engine_test.py | 23 +- tests/scanner/audit/org_rules_engine_test.py | 84 ++++--- tests/scanner/scanner_test.py | 56 +++-- 29 files changed, 723 insertions(+), 388 deletions(-) create mode 100644 google/cloud/security/common/data_access/violation_dao.py create mode 100644 google/cloud/security/scanner/audit/rules.py create mode 100644 tests/common/data_access/violation_dao_test.py diff --git a/.gitignore b/.gitignore index 05b53944ba..0b6ffcd1c5 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,7 @@ deployment-templates/*.yaml build/ dist/ out/ + +# Coverage +.coverage +htmlcov/ diff --git a/google/cloud/security/common/data_access/dao.py b/google/cloud/security/common/data_access/dao.py index 583dfa64fe..acd8b8cd4e 100755 --- a/google/cloud/security/common/data_access/dao.py +++ b/google/cloud/security/common/data_access/dao.py @@ -41,6 +41,8 @@ 'raw_project_iam_policies': create_tables.CREATE_RAW_PROJECT_IAM_POLICIES_TABLE, 'raw_org_iam_policies': create_tables.CREATE_RAW_ORG_IAM_POLICIES_TABLE, + + 'violations': create_tables.CREATE_VIOLATIONS_TABLE, } SNAPSHOT_FILTER_CLAUSE = ' where status in ({})' @@ -49,9 +51,6 @@ class Dao(_db_connector.DbConnector): """Data access object (DAO).""" - def __init__(self): - super(Dao, self).__init__() - def _create_snapshot_table(self, resource_name, timestamp): """Creates a snapshot table. diff --git a/google/cloud/security/common/data_access/group_dao.py b/google/cloud/security/common/data_access/group_dao.py index 3a5bedb6b6..075734476e 100755 --- a/google/cloud/security/common/data_access/group_dao.py +++ b/google/cloud/security/common/data_access/group_dao.py @@ -25,9 +25,6 @@ class GroupDao(dao.Dao): """Data access object (DAO) for Groups.""" - def __init__(self): - super(GroupDao, self).__init__() - def get_group_users(self, resource_name, timestamp): """Get the group members who are users. diff --git a/google/cloud/security/common/data_access/organization_dao.py b/google/cloud/security/common/data_access/organization_dao.py index 7a69af53ca..bfa0c40c23 100755 --- a/google/cloud/security/common/data_access/organization_dao.py +++ b/google/cloud/security/common/data_access/organization_dao.py @@ -35,9 +35,6 @@ class OrganizationDao(dao.Dao): """Data access object (DAO) for Organizations.""" - def __init__(self): - super(OrganizationDao, self).__init__() - def get_organizations(self, resource_name, timestamp): """Get organizations from snapshot table. diff --git a/google/cloud/security/common/data_access/project_dao.py b/google/cloud/security/common/data_access/project_dao.py index 6e6c433bd3..c132823bdd 100755 --- a/google/cloud/security/common/data_access/project_dao.py +++ b/google/cloud/security/common/data_access/project_dao.py @@ -36,9 +36,6 @@ class ProjectDao(dao.Dao): """Data access object (DAO).""" - def __init__(self): - super(ProjectDao, self).__init__() - def get_project_numbers(self, resource_name, timestamp): """Select the project numbers from a projects snapshot table. diff --git a/google/cloud/security/common/data_access/sql_queries/create_tables.py b/google/cloud/security/common/data_access/sql_queries/create_tables.py index a0495b2cc1..53f0d425f1 100755 --- a/google/cloud/security/common/data_access/sql_queries/create_tables.py +++ b/google/cloud/security/common/data_access/sql_queries/create_tables.py @@ -20,8 +20,8 @@ `project_number` bigint(20) NOT NULL, `project_id` varchar(255) NOT NULL, `project_name` varchar(255) DEFAULT NULL, - `lifecycle_state` enum('ACTIVE','DELETE_REQUESTED', - 'DELETE_IN_PROGRESS','DELETED') DEFAULT NULL, + `lifecycle_state` enum('LIFECYCLE_STATE_UNSPECIFIED','ACTIVE', + 'DELETE_REQUESTED','DELETED') NOT NULL, `parent_type` varchar(255) DEFAULT NULL, `parent_id` varchar(255) DEFAULT NULL, `raw_project` json DEFAULT NULL, @@ -58,8 +58,8 @@ `org_id` bigint(20) unsigned NOT NULL, `name` varchar(255) NOT NULL, `display_name` varchar(255) DEFAULT NULL, - `lifecycle_state` enum('ACTIVE','DELETE_REQUESTED', - 'DELETED','LIFECYCLE_STATE_UNSPECIFIED') DEFAULT NULL, + `lifecycle_state` enum('LIFECYCLE_STATE_UNSPECIFIED','ACTIVE', + 'DELETE_REQUESTED', 'DELETED') NOT NULL, `raw_org` json DEFAULT NULL, `creation_time` datetime DEFAULT NULL, PRIMARY KEY (`org_id`) @@ -115,3 +115,17 @@ """ # TODO: Add a RAW_GROUP_MEMBERS_TABLE. + +CREATE_VIOLATIONS_TABLE = """ + CREATE TABLE `{0}` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, + `resource_type` varchar(255) NOT NULL, + `resource_id` varchar(255) NOT NULL, + `rule_name` varchar(255) DEFAULT NULL, + `rule_index` int DEFAULT NULL, + `violation_type` enum('UNSPECIFIED','ADDED','REMOVED') NOT NULL, + `role` varchar(255) DEFAULT NULL, + `member` varchar(255) DEFAULT NULL, + PRIMARY KEY (`id`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +""" diff --git a/google/cloud/security/common/data_access/sql_queries/load_data.py b/google/cloud/security/common/data_access/sql_queries/load_data.py index f4b017a6a5..373c8033ce 100755 --- a/google/cloud/security/common/data_access/sql_queries/load_data.py +++ b/google/cloud/security/common/data_access/sql_queries/load_data.py @@ -19,3 +19,10 @@ INTO TABLE {1} FIELDS TERMINATED BY ',' ({2}); """ + +INSERT_VIOLATION = """ + INSERT INTO {0} + (resource_type, resource_id, rule_name, rule_index, + violation_type, role, member) + VALUES (%s, %s, %s, %s, %s, %s, %s) +""" diff --git a/google/cloud/security/common/data_access/violation_dao.py b/google/cloud/security/common/data_access/violation_dao.py new file mode 100644 index 0000000000..5ad373f4d7 --- /dev/null +++ b/google/cloud/security/common/data_access/violation_dao.py @@ -0,0 +1,123 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides the data access object (DAO) for Organizations.""" + +import MySQLdb + +from google.cloud.security.common.data_access import dao +from google.cloud.security.common.data_access import errors as db_errors +from google.cloud.security.common.data_access.sql_queries import load_data +from google.cloud.security.common.util import log_util + +LOGGER = log_util.get_logger(__name__) + + +class ViolationDao(dao.Dao): + """Data access object (DAO) for rule violations.""" + + RESOURCE_NAME = 'violations' + + def insert_violations(self, violations, snapshot_timestamp=None): + """Import violations into database. + + Args: + violations: An iterator of RuleViolations. + snapshot_timestamp: The snapshot timestamp to associate these + violations with. + + Return: + A tuple of (int, list) containing the count of inserted rows and + a list of violations that encountered an error during insert. + + Raise: + MySQLError if snapshot table could not be created. + """ + + try: + # Make sure to have a reasonable timestamp to use. + if not snapshot_timestamp: + snapshot_timestamp = self.get_latest_snapshot_timestamp( + ('PARTIAL_SUCCESS', 'SUCCESS')) + + # Create the violations snapshot table. + snapshot_table = self._create_snapshot_table( + self.RESOURCE_NAME, snapshot_timestamp) + except MySQLdb.Error, e: + raise db_errors.MySQLError(self.RESOURCE_NAME, e) + + inserted_rows = 0 + violation_errors = [] + for violation in violations: + for formatted_violation in _format_violation(violation): + try: + self.execute_sql_with_commit( + self.RESOURCE_NAME, + load_data.INSERT_VIOLATION.format(snapshot_table), + formatted_violation) + inserted_rows += 1 + except MySQLdb.Error, e: + LOGGER.error('Unable to insert violation %s due to %s', + formatted_violation, e) + violation_errors.append(formatted_violation) + + return (inserted_rows, violation_errors) + + +def _format_violation(violation): + """Format the violation data into a tuple. + + Also flattens the RuleViolation, since it consists of the resource, + rule, and members that don't meet the rule criteria. + + Various properties of RuleViolation may also have values that exceed the + declared column length, so truncate as necessary to prevent MySQL errors. + + Args: + violation: The RuleViolation. + + Yields: + A tuple of the rule violation properties. + """ + + resource_type = violation.resource_type + if resource_type: + resource_type = resource_type[:255] + + resource_id = violation.resource_id + if resource_id: + resource_id = str(resource_id)[:255] + + rule_name = violation.rule_name + if rule_name: + rule_name = rule_name[:255] + + role = violation.role + if role: + role = role[:255] + + iam_members = violation.members + if iam_members: + members = [str(iam_member)[:255] for iam_member in iam_members] + else: + members = [] + + for member in members: + yield (resource_type, + resource_id, + rule_name, + violation.rule_index, + violation.violation_type, + role, + member) diff --git a/google/cloud/security/common/email_templates/scanner_summary.jinja b/google/cloud/security/common/email_templates/scanner_summary.jinja index e9679fc332..3a535d0d2e 100644 --- a/google/cloud/security/common/email_templates/scanner_summary.jinja +++ b/google/cloud/security/common/email_templates/scanner_summary.jinja @@ -21,59 +21,10 @@ line-height: 16px; } -body { - background-color: #eee; -} - -.content { - background-color: #fff; - margin-top: 40px; - padding: 20px; - margin-left: 20%; - margin-right: 20%; -} - -.message { - margin: 10px 10px 15px 10px; -} - -.summary .header { - font-weight: bold; - font-size: 16px; - margin-bottom: 5px; -} - -.summary .footer { - font-style: italic; - font-size: 14px; - margin-bottom: 5px; -} - -.diff-summary { - font-size: 14px; -} - -.diff-summary > * { - margin: 5px 0; -} - a, a:visited { color: #1082d9; } -.summary { - margin: 10px 10px; -} - -.resource { - margin: 20px 10px; -} - -.resource-violations { - border-collapse: collapse; - border-spacing: 0 -} - .resource-violations tr > td { border: 1px solid #ddd; border: 1px solid #ddd; @@ -83,45 +34,29 @@ a, a:visited { padding: 4px; } -.resource-violations tr:first-child > th { - border-bottom: 1px solid #ddd; -} - -.resource-violations tr:last-child > td { - border-bottom: 1px solid #ddd; -} - -.resource-violations th { +th { font-size: 16px; font-weight: bold; padding: 4px; text-align: left; } -.resource-violations th.left { - width: 80%; -} - -.resource-violations th.left > a { - font-weight: normal; -} - -.resource-violations td.numeric { - padding-left: 18px; +td { + padding: 4px; } - -
-
+ +
+
Forseti Security found some issues during the scan on {{ scan_date }}.
-
-
+
+
Resource Violations:
-
+
{% for (resource_type, summary) in resource_summaries.iteritems() %}
{{ summary['pluralized_resource_type'] }}: {{ summary['violations']|length }} @@ -130,23 +65,34 @@ a, a:visited {
{% endfor %}
- -
+
+ +
{% for (resource_type, summary) in resource_summaries.iteritems() %} -
- +
+
- - + + {% for resource_id, count in summary['violations'].iteritems() %} - - + + {% endfor %}
{{ resource_type.title() }}# Violations + {{ resource_type.title() }} + + # Violations +
{{ resource_id}}{{ count }}{{ resource_id}}{{ count }}
diff --git a/google/cloud/security/common/gcp_type/iam_policy.py b/google/cloud/security/common/gcp_type/iam_policy.py index eaab5296d1..eeabe43dad 100755 --- a/google/cloud/security/common/gcp_type/iam_policy.py +++ b/google/cloud/security/common/gcp_type/iam_policy.py @@ -199,7 +199,7 @@ def __hash__(self): def __repr__(self): """String representation of IamPolicyMember.""" - return 'IamMember '.format(self.type, self.name) + return '%s:%s' % (self.type, self.name) def _member_type_exists(self, member_type): """Determine if the member type exists in valid member types.""" diff --git a/google/cloud/security/common/gcp_type/project.py b/google/cloud/security/common/gcp_type/project.py index eeba6bd016..c2af2c8aec 100755 --- a/google/cloud/security/common/gcp_type/project.py +++ b/google/cloud/security/common/gcp_type/project.py @@ -25,7 +25,6 @@ class ProjectLifecycleState(resource.LifecycleState): """Project lifecycle state.""" DELETE_REQUESTED = 'DELETE_REQUESTED' - DELETE_IN_PROGRESS = 'DELETE_IN_PROGRESS' class Project(resource.Resource): diff --git a/google/cloud/security/common/util/file_loader.py b/google/cloud/security/common/util/file_loader.py index f4cfcbf1b4..e861cb73d8 100755 --- a/google/cloud/security/common/util/file_loader.py +++ b/google/cloud/security/common/util/file_loader.py @@ -36,8 +36,8 @@ def read_and_parse_file(file_path): if file_path.startswith('gs://'): return _read_file_from_gcs(file_path) - else: - return _read_file_from_local(file_path) + + return _read_file_from_local(file_path) def _get_filetype_parser(file_path, parser_type): diff --git a/google/cloud/security/inventory/pipelines/load_org_iam_policies_pipeline.py b/google/cloud/security/inventory/pipelines/load_org_iam_policies_pipeline.py index 280ca62258..d30dcdebda 100755 --- a/google/cloud/security/inventory/pipelines/load_org_iam_policies_pipeline.py +++ b/google/cloud/security/inventory/pipelines/load_org_iam_policies_pipeline.py @@ -35,21 +35,6 @@ class LoadOrgIamPoliciesPipeline(base_pipeline.BasePipeline): RESOURCE_NAME = 'org_iam_policies' RAW_RESOURCE_NAME = 'raw_org_iam_policies' - def __init__(self, cycle_timestamp, configs, crm_client, dao): - """Constructor for the data pipeline. - - Args: - cycle_timestamp: String of timestamp, formatted as YYYYMMDDTHHMMSSZ. - configs: Dictionary of configurations. - crm_client: CRM API client. - dao: Data access object. - - Returns: - None - """ - super(LoadOrgIamPoliciesPipeline, self).__init__( - cycle_timestamp, configs, crm_client, dao) - def _transform(self, iam_policies): """Yield an iterator of loadable iam policies. diff --git a/google/cloud/security/inventory/pipelines/load_orgs_pipeline.py b/google/cloud/security/inventory/pipelines/load_orgs_pipeline.py index 3c3e6f3506..a30f02f3a5 100755 --- a/google/cloud/security/inventory/pipelines/load_orgs_pipeline.py +++ b/google/cloud/security/inventory/pipelines/load_orgs_pipeline.py @@ -33,21 +33,6 @@ class LoadOrgsPipeline(base_pipeline.BasePipeline): MYSQL_DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S' - def __init__(self, cycle_timestamp, configs, crm_client, dao): - """Constructor for the data pipeline. - - Args: - cycle_timestamp: String of timestamp, formatted as YYYYMMDDTHHMMSSZ. - configs: Dictionary of configurations. - crm_client: CRM API client. - dao: Data access object. - - Returns: - None - """ - super(LoadOrgsPipeline, self).__init__( - cycle_timestamp, configs, crm_client, dao) - def _transform(self, orgs): """Yield an iterator of loadable iam policies. diff --git a/google/cloud/security/inventory/pipelines/load_projects_iam_policies_pipeline.py b/google/cloud/security/inventory/pipelines/load_projects_iam_policies_pipeline.py index 8f7be2505a..690c46389f 100755 --- a/google/cloud/security/inventory/pipelines/load_projects_iam_policies_pipeline.py +++ b/google/cloud/security/inventory/pipelines/load_projects_iam_policies_pipeline.py @@ -36,21 +36,6 @@ class LoadProjectsIamPoliciesPipeline(base_pipeline.BasePipeline): RESOURCE_NAME = 'project_iam_policies' RAW_RESOURCE_NAME = 'raw_project_iam_policies' - def __init__(self, cycle_timestamp, configs, crm_client, dao): - """Constructor for the data pipeline. - - Args: - cycle_timestamp: String of timestamp, formatted as YYYYMMDDTHHMMSSZ. - configs: Dictionary of configurations. - crm_client: CRM API client. - dao: Data access object. - - Returns: - None - """ - super(LoadProjectsIamPoliciesPipeline, self).__init__( - cycle_timestamp, configs, crm_client, dao) - def _transform(self, iam_policy_maps): """Yield an iterator of loadable iam policies. diff --git a/google/cloud/security/inventory/pipelines/load_projects_pipeline.py b/google/cloud/security/inventory/pipelines/load_projects_pipeline.py index 47720b7958..b340bf8e42 100755 --- a/google/cloud/security/inventory/pipelines/load_projects_pipeline.py +++ b/google/cloud/security/inventory/pipelines/load_projects_pipeline.py @@ -34,21 +34,6 @@ class LoadProjectsPipeline(base_pipeline.BasePipeline): MYSQL_DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S' - def __init__(self, cycle_timestamp, configs, crm_client, dao): - """Constructor for the data pipeline. - - Args: - cycle_timestamp: String of timestamp, formatted as YYYYMMDDTHHMMSSZ. - configs: Dictionary of configurations. - crm_client: CRM API client. - dao: Data access object. - - Returns: - None - """ - super(LoadProjectsPipeline, self).__init__( - cycle_timestamp, configs, crm_client, dao) - def _transform(self, projects): """Yield an iterator of loadable iam policies. diff --git a/google/cloud/security/scanner/audit/base_rules_engine.py b/google/cloud/security/scanner/audit/base_rules_engine.py index 3847f1bdde..02f1e15c74 100644 --- a/google/cloud/security/scanner/audit/base_rules_engine.py +++ b/google/cloud/security/scanner/audit/base_rules_engine.py @@ -71,37 +71,3 @@ class BaseRuleBook(object): def add_rule(self, rule_def, rule_index): """Add rule to rule book.""" raise NotImplementedError('Implement add_rule() in subclass') - -class RuleAppliesTo(object): - """What the rule applies to. (Default: SELF) """ - - SELF = 'self' - CHILDREN = 'children' - SELF_AND_CHILDREN = 'self_and_children' - apply_types = frozenset([SELF, CHILDREN, SELF_AND_CHILDREN]) - - @classmethod - def verify(cls, applies_to): - """Verify whether the applies_to is valid.""" - if applies_to not in cls.apply_types: - raise audit_errors.InvalidRulesSchemaError( - 'Invalid applies_to: {}'.format(applies_to)) - return applies_to - - -class RuleMode(object): - """The rule mode.""" - - WHITELIST = 'whitelist' - BLACKLIST = 'blacklist' - REQUIRED = 'required' - - modes = frozenset([WHITELIST, BLACKLIST, REQUIRED]) - - @classmethod - def verify(cls, mode): - """Verify whether the mode is valid.""" - if mode not in cls.modes: - raise audit_errors.InvalidRulesSchemaError( - 'Invalid rule mode: {}'.format(mode)) - return mode diff --git a/google/cloud/security/scanner/audit/group_rules_engine.py b/google/cloud/security/scanner/audit/group_rules_engine.py index 0b19db2acb..8e9bb7897b 100644 --- a/google/cloud/security/scanner/audit/group_rules_engine.py +++ b/google/cloud/security/scanner/audit/group_rules_engine.py @@ -32,6 +32,7 @@ from google.cloud.security.common.gcp_type.resource_util import ResourceUtil from google.cloud.security.common.util import log_util from google.cloud.security.scanner.audit import base_rules_engine as bre +from google.cloud.security.scanner.audit import rules as audit_rules from google.cloud.security.scanner.audit import errors as audit_errors @@ -214,7 +215,7 @@ class ResourceRules(object): def __init__(self, resource=None, rules=None, - applies_to=bre.RuleAppliesTo.SELF, + applies_to=audit_rules.RuleAppliesTo.SELF, inherit_from_parents=False): """Initialize. @@ -230,11 +231,11 @@ def __init__(self, rules = set([]) self.resource = resource self.rules = rules - self.applies_to = bre.RuleAppliesTo.verify(applies_to) + self.applies_to = audit_rules.RuleAppliesTo.verify(applies_to) self.inherit_from_parents = inherit_from_parents self._rule_mode_methods = { - bre.RuleMode.WHITELIST: _check_whitelist_members, - bre.RuleMode.BLACKLIST: _check_blacklist_members, - bre.RuleMode.REQUIRED: _check_required_members, + audit_rules.RuleMode.WHITELIST: _check_whitelist_members, + audit_rules.RuleMode.BLACKLIST: _check_blacklist_members, + audit_rules.RuleMode.REQUIRED: _check_required_members, } diff --git a/google/cloud/security/scanner/audit/org_rules_engine.py b/google/cloud/security/scanner/audit/org_rules_engine.py index 24f438afd5..1be810e6ef 100644 --- a/google/cloud/security/scanner/audit/org_rules_engine.py +++ b/google/cloud/security/scanner/audit/org_rules_engine.py @@ -22,13 +22,13 @@ import itertools import threading -from collections import namedtuple from google.cloud.security.common.gcp_type import errors as resource_errors from google.cloud.security.common.gcp_type.iam_policy import IamPolicyBinding from google.cloud.security.common.gcp_type.resource import ResourceType from google.cloud.security.common.gcp_type.resource_util import ResourceUtil from google.cloud.security.common.util import log_util from google.cloud.security.scanner.audit import base_rules_engine as bre +from google.cloud.security.scanner.audit import rules as scanner_rules from google.cloud.security.scanner.audit import errors as audit_errors LOGGER = log_util.get_logger(__name__) @@ -291,10 +291,10 @@ def add_rule(self, rule_def, rule_index): rule_bindings = [ IamPolicyBinding.create_from(b) for b in rule_def.get('bindings')] - rule = Rule(rule_name=rule_def.get('name'), - rule_index=rule_index, - bindings=rule_bindings, - mode=rule_def.get('mode')) + rule = scanner_rules.Rule(rule_name=rule_def.get('name'), + rule_index=rule_index, + bindings=rule_bindings, + mode=rule_def.get('mode')) rule_applies_to = resource.get('applies_to') rule_key = (gcp_resource, rule_applies_to) @@ -329,7 +329,7 @@ def _get_resource_rules(self, resource): """ resource_rules = [] - for rule_applies_to in bre.RuleAppliesTo.apply_types: + for rule_applies_to in scanner_rules.RuleAppliesTo.apply_types: if (resource, rule_applies_to) in self.resource_rules_map: resource_rules.append(self.resource_rules_map.get( (resource, rule_applies_to))) @@ -363,13 +363,22 @@ def find_violations(self, resource, policy_binding): # SELF: check rules if the starting resource == current resource # CHILDREN: check rules if starting resource != current resource # SELF_AND_CHILDREN: always check rules + applies_to_self = ( + resource_rule.applies_to == + scanner_rules.RuleAppliesTo.SELF and + resource == curr_resource) + applies_to_children = ( + resource_rule.applies_to == + scanner_rules.RuleAppliesTo.CHILDREN and + resource != curr_resource) + applies_to_both = ( + resource_rule.applies_to == + scanner_rules.RuleAppliesTo.SELF_AND_CHILDREN) + rule_applies_to_resource = ( - (resource_rule.applies_to == bre.RuleAppliesTo.SELF and - resource == curr_resource) or - (resource_rule.applies_to == bre.RuleAppliesTo.CHILDREN and - resource != curr_resource) or - (resource_rule.applies_to == - bre.RuleAppliesTo.SELF_AND_CHILDREN)) + applies_to_self or + applies_to_children or + applies_to_both) if not rule_applies_to_resource: continue @@ -397,7 +406,7 @@ class ResourceRules(object): def __init__(self, resource=None, rules=None, - applies_to=bre.RuleAppliesTo.SELF, + applies_to=scanner_rules.RuleAppliesTo.SELF, inherit_from_parents=False): """Initialize. @@ -413,13 +422,13 @@ def __init__(self, rules = set([]) self.resource = resource self.rules = rules - self.applies_to = bre.RuleAppliesTo.verify(applies_to) + self.applies_to = scanner_rules.RuleAppliesTo.verify(applies_to) self.inherit_from_parents = inherit_from_parents self._rule_mode_methods = { - bre.RuleMode.WHITELIST: _check_whitelist_members, - bre.RuleMode.BLACKLIST: _check_blacklist_members, - bre.RuleMode.REQUIRED: _check_required_members, + scanner_rules.RuleMode.WHITELIST: _check_whitelist_members, + scanner_rules.RuleMode.BLACKLIST: _check_blacklist_members, + scanner_rules.RuleMode.REQUIRED: _check_required_members, } def __eq__(self, other): @@ -471,7 +480,7 @@ def find_mismatches(self, policy_resource, binding_to_match): # pattern, then check the members to see whether they match, # according to the rule mode. if binding.role_pattern.match(policy_role_name): - if rule.mode == bre.RuleMode.REQUIRED: + if rule.mode == scanner_rules.RuleMode.REQUIRED: role_name = binding.role_name else: role_name = policy_role_name @@ -481,26 +490,28 @@ def find_mismatches(self, policy_resource, binding_to_match): rule_members=binding.members, policy_members=policy_binding.members)) if violating_members: - yield RuleViolation( + yield scanner_rules.RuleViolation( resource_type=policy_resource.type, resource_id=policy_resource.id, rule_name=rule.rule_name, rule_index=rule.rule_index, - violation_type=RULE_VIOLATION_TYPE.get( - rule.mode, RULE_VIOLATION_TYPE['UNSPECIFIED']), + violation_type=scanner_rules.VIOLATION_TYPE.get( + rule.mode, + scanner_rules.VIOLATION_TYPE['UNSPECIFIED']), role=role_name, members=tuple(violating_members)) # Extra check if the role did not match in the REQUIRED case. - if not found_role and rule.mode == bre.RuleMode.REQUIRED: + if not found_role and rule.mode == scanner_rules.RuleMode.REQUIRED: for binding in rule.bindings: - yield RuleViolation( + yield scanner_rules.RuleViolation( resource_type=policy_resource.type, resource_id=policy_resource.id, rule_name=rule.rule_name, rule_index=rule.rule_index, - violation_type=RULE_VIOLATION_TYPE.get( - rule.mode, RULE_VIOLATION_TYPE['UNSPECIFIED']), + violation_type=scanner_rules.VIOLATION_TYPE.get( + rule.mode, + scanner_rules.VIOLATION_TYPE['UNSPECIFIED']), role=binding.role_name, members=tuple(binding.members)) @@ -518,78 +529,3 @@ def _dispatch_rule_mode_check(self, mode, rule_members=None, return self._rule_mode_methods[mode]( rule_members=rule_members, policy_members=policy_members) - - -# pylint: disable=too-few-public-methods -class Rule(object): - """Encapsulate Rule properties from the rule definition file. - - The reason this is not a named tuple is that it needs to be hashable. - The ResourceRules class has a set of Rules. - """ - - def __init__(self, rule_name, rule_index, bindings, mode=None): - """Initialize. - - Args: - rule_name: The string name of the rule. - rule_index: The rule's index in the rules file. - bindings: The list of IamPolicyBindings for this rule. - mode: The RulesMode for this rule. - """ - self.rule_name = rule_name - self.rule_index = rule_index - self.bindings = bindings - self.mode = bre.RuleMode.verify(mode) - - def __eq__(self, other): - """Test whether Rule equals other Rule.""" - if not isinstance(other, type(self)): - return NotImplemented - return (self.rule_name == other.rule_name and - self.rule_index == other.rule_index and - self.bindings == other.bindings and - self.mode == other.mode) - - def __ne__(self, other): - """Test whether Rule is not equal to another Rule.""" - return not self == other - - def __hash__(self): - """Make a hash of the rule index. - - For now, this will suffice since the rule index is assigned - automatically when the rules map is built, and the scanner - only handles one rule file at a time. Later on, we'll need to - revisit this hash method when we process multiple rule files. - - Returns: - The hash of the rule index. - """ - return hash(self.rule_index) - - def __repr__(self): - """Returns the string representation of this Rule.""" - return 'Rule <{}, name={}, mode={}, bindings={}>'.format( - self.rule_index, self.rule_name, self.mode, self.bindings) - - -# Rule violation. -# resource_type: string -# resource_id: string -# rule_name: string -# rule_index: int -# violation_type: RULE_VIOLATION_TYPE -# role: string -# members: tuple of IamPolicyBindings -RuleViolation = namedtuple('RuleViolation', - ['resource_type', 'resource_id', 'rule_name', - 'rule_index', 'violation_type', 'role', 'members']) - -# Rule violation types. -RULE_VIOLATION_TYPE = { - 'whitelist': 'ADDED', - 'blacklist': 'ADDED', - 'required': 'REMOVED', - 'UNSPECIFIED': 'UNSPECIFIED' -} diff --git a/google/cloud/security/scanner/audit/rules.py b/google/cloud/security/scanner/audit/rules.py new file mode 100644 index 0000000000..6ea5085a80 --- /dev/null +++ b/google/cloud/security/scanner/audit/rules.py @@ -0,0 +1,129 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rules-related classes.""" + +from collections import namedtuple + +from google.cloud.security.scanner.audit import errors as audit_errors + + +# pylint: disable=too-few-public-methods +class Rule(object): + """Encapsulate Rule properties from the rule definition file. + + The reason this is not a named tuple is that it needs to be hashable. + The ResourceRules class has a set of Rules. + """ + + def __init__(self, rule_name, rule_index, bindings, mode=None): + """Initialize. + + Args: + rule_name: The string name of the rule. + rule_index: The rule's index in the rules file. + bindings: The list of IamPolicyBindings for this rule. + mode: The RulesMode for this rule. + """ + self.rule_name = rule_name + self.rule_index = rule_index + self.bindings = bindings + self.mode = RuleMode.verify(mode) + + def __eq__(self, other): + """Test whether Rule equals other Rule.""" + if not isinstance(other, type(self)): + return NotImplemented + return (self.rule_name == other.rule_name and + self.rule_index == other.rule_index and + self.bindings == other.bindings and + self.mode == other.mode) + + def __ne__(self, other): + """Test whether Rule is not equal to another Rule.""" + return not self == other + + def __hash__(self): + """Make a hash of the rule index. + + For now, this will suffice since the rule index is assigned + automatically when the rules map is built, and the scanner + only handles one rule file at a time. Later on, we'll need to + revisit this hash method when we process multiple rule files. + + Returns: + The hash of the rule index. + """ + return hash(self.rule_index) + + def __repr__(self): + """Returns the string representation of this Rule.""" + return 'Rule <{}, name={}, mode={}, bindings={}>'.format( + self.rule_index, self.rule_name, self.mode, self.bindings) + + +class RuleAppliesTo(object): + """What the rule applies to. (Default: SELF) """ + + SELF = 'self' + CHILDREN = 'children' + SELF_AND_CHILDREN = 'self_and_children' + apply_types = frozenset([SELF, CHILDREN, SELF_AND_CHILDREN]) + + @classmethod + def verify(cls, applies_to): + """Verify whether the applies_to is valid.""" + if applies_to not in cls.apply_types: + raise audit_errors.InvalidRulesSchemaError( + 'Invalid applies_to: {}'.format(applies_to)) + return applies_to + + +class RuleMode(object): + """The rule mode.""" + + WHITELIST = 'whitelist' + BLACKLIST = 'blacklist' + REQUIRED = 'required' + + modes = frozenset([WHITELIST, BLACKLIST, REQUIRED]) + + @classmethod + def verify(cls, mode): + """Verify whether the mode is valid.""" + if mode not in cls.modes: + raise audit_errors.InvalidRulesSchemaError( + 'Invalid rule mode: {}'.format(mode)) + return mode + + +# Rule violation. +# resource_type: string +# resource_id: string +# rule_name: string +# rule_index: int +# violation_type: VIOLATION_TYPE +# role: string +# members: tuple of IamPolicyBindings +RuleViolation = namedtuple('RuleViolation', + ['resource_type', 'resource_id', 'rule_name', + 'rule_index', 'violation_type', 'role', 'members']) + +# Rule violation types. +VIOLATION_TYPE = { + 'whitelist': 'ADDED', + 'blacklist': 'ADDED', + 'required': 'REMOVED', + 'UNSPECIFIED': 'UNSPECIFIED' +} diff --git a/google/cloud/security/scanner/scanner.py b/google/cloud/security/scanner/scanner.py index cd6831cc63..73bc92b858 100644 --- a/google/cloud/security/scanner/scanner.py +++ b/google/cloud/security/scanner/scanner.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Organization resource scanner. +"""GCP Resource scanner. Usage: @@ -37,10 +37,11 @@ from google.apputils import app from google.cloud.security.common.data_access import csv_writer +from google.cloud.security.common.data_access import dao from google.cloud.security.common.data_access import organization_dao from google.cloud.security.common.data_access import project_dao -from google.cloud.security.common.data_access.dao import Dao -from google.cloud.security.common.data_access.errors import MySQLError +from google.cloud.security.common.data_access import violation_dao +from google.cloud.security.common.data_access import errors as db_errors from google.cloud.security.common.gcp_type.resource import ResourceType from google.cloud.security.common.gcp_type.resource_util import ResourceUtil from google.cloud.security.common.util import log_util @@ -75,6 +76,7 @@ def main(_): """Run the scanner.""" + LOGGER.info('Initializing the rules engine:\nUsing rules: %s', FLAGS.rules) if not FLAGS.rules: @@ -110,7 +112,9 @@ def main(_): ResourceType.ORGANIZATION: len(org_policies), ResourceType.PROJECT: len(project_policies), } - _output_results(all_violations, resource_counts=resource_counts) + _output_results(all_violations, + snapshot_timestamp, + resource_counts=resource_counts) LOGGER.info('Done!') @@ -124,6 +128,7 @@ def _find_violations(policies, rules_engine): Returns: A list of violations. """ + all_violations = [] LOGGER.info('Finding policy violations...') for (resource, policy) in policies: @@ -143,6 +148,7 @@ def _get_output_filename(now_utc): Returns: The output filename for the csv, formatted with the now_utc timestamp. """ + output_timestamp = now_utc.strftime(OUTPUT_TIMESTAMP_FMT) output_filename = SCANNER_OUTPUT_CSV_FMT.format(output_timestamp) return output_filename @@ -153,12 +159,11 @@ def _get_timestamp(statuses=('SUCCESS', 'PARTIAL_SUCCESS')): Returns: The latest snapshot timestamp string. """ - dao = None + latest_timestamp = None try: - dao = Dao() - latest_timestamp = dao.get_latest_snapshot_timestamp(statuses) - except MySQLError as err: + latest_timestamp = dao.Dao().get_latest_snapshot_timestamp(statuses) + except db_errors.MySQLError as err: LOGGER.error('Error getting latest snapshot timestamp: %s', err) return latest_timestamp @@ -172,6 +177,7 @@ def _get_org_policies(timestamp): Returns: The org policies. """ + org_policies = {} org_dao = organization_dao.OrganizationDao() org_policies = org_dao.get_org_iam_policies('organizations', timestamp) @@ -186,17 +192,22 @@ def _get_project_policies(timestamp): Returns: The project policies. """ + project_policies = {} - dao = project_dao.ProjectDao() - project_policies = dao.get_project_policies('projects', timestamp) + project_policies = ( + project_dao.ProjectDao().get_project_policies('projects', timestamp)) return project_policies -def _write_violations_output(violations): - """Write violations to csv output file and store in output bucket. +def _flatten_violations(violations): + """Flatten RuleViolations into a dict for each RuleViolation member. Args: - violations: The violations to write to the csv. + violations: The RuleViolations to flatten. + + Yield: + Iterator of RuleViolations as a dict per member. """ + LOGGER.info('Writing violations to csv...') for violation in violations: for member in violation.members: @@ -210,19 +221,36 @@ def _write_violations_output(violations): 'member': '{}:{}'.format(member.type, member.name) } -def _output_results(all_violations, **kwargs): +def _output_results(all_violations, snapshot_timestamp, **kwargs): """Send the output results. Args: all_violations: The list of violations to report. **kwargs: The rest of the args. """ - # Write the CSV. + + # Write violations to database. + (inserted_row_count, violation_errors) = (0, []) + try: + vdao = violation_dao.ViolationDao() + (inserted_row_count, violation_errors) = vdao.insert_violations( + all_violations, snapshot_timestamp=snapshot_timestamp) + except db_errors.MySQLError as err: + LOGGER.error('Error importing violations to database: %s', err) + + # TODO: figure out what to do with the errors. For now, just log it. + LOGGER.debug('Inserted %s rows with %s errors', + inserted_row_count, len(violation_errors)) + + output_csv_name = None + + # Write the CSV for all the violations. with csv_writer.write_csv( resource_name='policy_violations', - data=_write_violations_output(all_violations), + data=_flatten_violations(all_violations), write_header=True) as csv_file: - LOGGER.info('CSV filename: %s', csv_file.name) + output_csv_name = csv_file.name + LOGGER.info('CSV filename: %s', output_csv_name) # Scanner timestamp for output file and email. now_utc = datetime.utcnow() @@ -234,12 +262,14 @@ def _output_results(all_violations, **kwargs): if not os.path.exists(FLAGS.output_path): os.makedirs(output_path) output_path = os.path.abspath(output_path) - _upload_csv(output_path, now_utc, csv_file.name) + _upload_csv(output_path, now_utc, output_csv_name) # Send summary email. if FLAGS.email_recipient is not None: resource_counts = kwargs.get('resource_counts', {}) - _send_email(csv_file.name, now_utc, all_violations, resource_counts) + _send_email(output_csv_name, now_utc, + all_violations, resource_counts, + violation_errors) def _upload_csv(output_path, now_utc, csv_name): """Upload CSV to Cloud Storage. @@ -249,6 +279,7 @@ def _upload_csv(output_path, now_utc, csv_name): now_utc: The UTC timestamp of "now". csv_name: The csv_name. """ + from google.cloud.security.common.gcp_api import storage output_filename = _get_output_filename(now_utc) @@ -268,7 +299,8 @@ def _upload_csv(output_path, now_utc, csv_name): # Otherwise, just copy it to the output path. shutil.copy(csv_name, full_output_path) -def _send_email(csv_name, now_utc, all_violations, total_resources): +def _send_email(csv_name, now_utc, all_violations, + total_resources, violation_errors): """Send a summary email of the scan. Args: @@ -276,7 +308,9 @@ def _send_email(csv_name, now_utc, all_violations, total_resources): now_utc: The UTC datetime right now. all_violations: The list of violations. total_resources: A dict of the resources and their count. + violation_errors: Iterable of violation errors. """ + mail_util = EmailUtil(FLAGS.sendgrid_api_key) total_violations, resource_summaries = _build_scan_summary( all_violations, total_resources) @@ -287,6 +321,7 @@ def _send_email(csv_name, now_utc, all_violations, total_resources): 'scanner_summary.jinja', { 'scan_date': scan_date, 'resource_summaries': resource_summaries, + 'violation_errors': violation_errors, }) # Create an attachment out of the csv file and base64 encode the content. @@ -316,6 +351,7 @@ def _build_scan_summary(all_violations, total_resources): Returns: Total counts and summaries. """ + resource_summaries = {} total_violations = 0 # Build a summary of the violations and counts for the email. @@ -349,5 +385,6 @@ def _build_scan_summary(all_violations, total_resources): return total_violations, resource_summaries + if __name__ == '__main__': app.run() diff --git a/tests/common/data_access/dao_test.py b/tests/common/data_access/dao_test.py index fcd967e3d2..eeebb1229b 100644 --- a/tests/common/data_access/dao_test.py +++ b/tests/common/data_access/dao_test.py @@ -30,6 +30,7 @@ class DaoTest(basetest.TestCase): @mock.patch.object(_db_connector.DbConnector, '__init__', autospec=True) def setUp(self, mock_db_connector): + mock_db_connector.return_value = None self.dao = dao.Dao() self.fake_timestamp = '12345' self.resource_projects = 'projects' diff --git a/tests/common/data_access/group_dao_test.py b/tests/common/data_access/group_dao_test.py index 33a0fae6d5..281fcccef7 100644 --- a/tests/common/data_access/group_dao_test.py +++ b/tests/common/data_access/group_dao_test.py @@ -50,6 +50,7 @@ class GroupDaoTest(basetest.TestCase): @mock.patch.object(dao.Dao, '__init__', autospec=True) def setUp(self, mock_dao): + mock_dao.return_value = None #self.group_dao = mock.create_autospec(group_dao.GroupDao) self.group_dao = group_dao.GroupDao() self.resource_name = 'groups' diff --git a/tests/common/data_access/organization_dao_test.py b/tests/common/data_access/organization_dao_test.py index 953647b71a..1fad4a99a8 100644 --- a/tests/common/data_access/organization_dao_test.py +++ b/tests/common/data_access/organization_dao_test.py @@ -34,6 +34,7 @@ class OrgDaoTest(basetest.TestCase): @mock.patch.object(_db_connector.DbConnector, '__init__', autospec=True) def setUp(self, mock_db_connector): + mock_db_connector.return_value = None self.org_dao = organization_dao.OrganizationDao() self.resource_name = 'organizations' self.fake_timestamp = '12345' diff --git a/tests/common/data_access/project_dao_test.py b/tests/common/data_access/project_dao_test.py index 4e64ee9369..25058cd8e3 100644 --- a/tests/common/data_access/project_dao_test.py +++ b/tests/common/data_access/project_dao_test.py @@ -34,6 +34,7 @@ class ProjectDaoTest(basetest.TestCase): @mock.patch.object(_db_connector.DbConnector, '__init__', autospec=True) def setUp(self, mock_db_connector): + mock_db_connector.return_value = None self.project_dao = project_dao.ProjectDao() self.resource_name = 'projects' self.fake_timestamp = '12345' diff --git a/tests/common/data_access/violation_dao_test.py b/tests/common/data_access/violation_dao_test.py new file mode 100644 index 0000000000..c64b0d453b --- /dev/null +++ b/tests/common/data_access/violation_dao_test.py @@ -0,0 +1,210 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests the Dao.""" + +from google.apputils import basetest +import mock +import MySQLdb + +from google.cloud.security.common.data_access import _db_connector +from google.cloud.security.common.data_access import errors +from google.cloud.security.common.data_access import violation_dao +from google.cloud.security.common.data_access.sql_queries import load_data +from google.cloud.security.common.gcp_type import iam_policy as iam +from google.cloud.security.scanner.audit import rules + + +class ViolationDaoTest(basetest.TestCase): + """Tests for the Dao.""" + + @mock.patch.object(_db_connector.DbConnector, '__init__', autospec=True) + def setUp(self, mock_db_connector): + mock_db_connector.return_value = None + self.dao = violation_dao.ViolationDao() + self.fake_snapshot_timestamp = '12345' + self.fake_table_name = ('%s_%s' % + (self.dao.RESOURCE_NAME, self.fake_snapshot_timestamp)) + self.fake_violations = [ + rules.RuleViolation( + resource_type='x', + resource_id='1', + rule_name='rule name', + rule_index=0, + violation_type='ADDED', + role='roles/editor', + members=[iam.IamPolicyMember.create_from(m) + for m in ['user:a@foo.com', 'user:b@foo.com']], + ), + rules.RuleViolation( + resource_type='%sb' % ('a'*300), + resource_id='1', + rule_name='%sd' % ('c'*300), + rule_index=1, + violation_type='REMOVED', + role='%s' % ('e'*300), + members=[iam.IamPolicyMember.create_from( + 'user:%sh' % ('g'*300))], + ), + ] + + self.expected_fake_violations = [ + ('x', '1', 'rule name', 0, 'ADDED', + 'roles/editor', 'user:a@foo.com'), + ('x', '1', 'rule name', 0, 'ADDED', + 'roles/editor', 'user:b@foo.com'), + ('a'*255, '1', 'c'*255, 1, 'REMOVED', + 'e'*255, ('user:%s' % ('g'*300))[:255]), + ] + + def test_format_violation(self): + """Test that a RuleViolation is formatted and flattened properly. + + Setup: + Create some rule violations: + * With multiple members. + * With really long text values for properties. + + Expect: + _format_violation() will flatten the violation and truncate the + property values accordingly. + """ + + actual = [f for v in self.fake_violations + for f in violation_dao._format_violation(v)] + + self.assertEquals(self.expected_fake_violations, actual) + + def test_insert_violations_no_timestamp(self): + """Test that insert_violations() is properly called. + + Setup: + Create mocks: + * self.dao.conn + * self.dao.conn.commit + * self.dao.get_latest_snapshot_timestamp + * self.dao._create_snapshot_table + + Expect: + * Assert that get_latest_snapshot_timestamp() gets called. + * Assert that _create_snapshot_table() gets called. + * Assert that conn.commit() is called 3x. + was called == # of formatted/flattened RuleViolations). + """ + + conn_mock = mock.MagicMock() + commit_mock = mock.MagicMock() + + self.dao.get_latest_snapshot_timestamp = mock.MagicMock( + return_value = self.fake_snapshot_timestamp) + self.dao._create_snapshot_table = mock.MagicMock( + return_value=self.fake_table_name) + self.dao.conn = conn_mock + self.dao.execute_sql_with_commit = commit_mock + + self.dao.insert_violations(self.fake_violations) + + # Assert snapshot is retrieved because no snapshot timestamp was + # provided to the method call. + self.dao.get_latest_snapshot_timestamp.assert_called_once_with( + ('PARTIAL_SUCCESS', 'SUCCESS')) + + # Assert that the snapshot table was created. + self.dao._create_snapshot_table.assert_called_once_with( + self.dao.RESOURCE_NAME, self.fake_snapshot_timestamp) + + # Assert that conn.commit() was called. + self.assertEqual(3, commit_mock.call_count) + + def test_insert_violations_with_timestamp(self): + """Test that insert_violations() is properly called with timestamp. + + Setup: + * Create fake custom timestamp. + * Create mocks: + * self.dao._create_snapshot_table + * self.dao.get_latest_snapshot_timestamp + * self.dao.conn + + Expect: + * Assert that get_latest_snapshot_timestamp() doesn't get called. + * Assert that _create_snapshot_table() gets called once. + """ + + fake_custom_timestamp = '11111' + self.dao.conn = mock.MagicMock() + self.dao._create_snapshot_table = mock.MagicMock() + self.dao.get_latest_snapshot_timestamp = mock.MagicMock() + self.dao.insert_violations(self.fake_violations, fake_custom_timestamp) + + self.dao.get_latest_snapshot_timestamp.assert_not_called() + self.dao._create_snapshot_table.assert_called_once_with( + self.dao.RESOURCE_NAME, fake_custom_timestamp) + + def test_insert_violations_raises_error_on_create(self): + """Test raises MySQLError when getting a create table error. + + Expect: + Raise MySQLError when create_snapshot_table() raises an error. + """ + + self.dao.get_latest_snapshot_timestamp = mock.MagicMock( + return_value=self.fake_snapshot_timestamp) + self.dao._create_snapshot_table = mock.MagicMock( + side_effect=MySQLdb.DataError) + + with self.assertRaises(errors.MySQLError): + self.dao.insert_violations([]) + + def test_insert_violations_with_error(self): + """Test insert_violations handles errors during insert. + + Setup: + * Create mocks: + * self.dao.conn + * self.dao.get_latest_snapshot_timestamp + * self.dao._create_snapshot_table + * Create side effect for one violation to raise an error. + + Expect: + * Log MySQLError when table insert error occurs and return list + of errors. + * Return a tuple of (num_violations-1, [violation]) + """ + + self.dao.get_latest_snapshot_timestamp = mock.MagicMock( + return_value=self.fake_snapshot_timestamp) + self.dao._create_snapshot_table = mock.MagicMock( + return_value=self.fake_table_name) + violation_dao.LOGGER = mock.MagicMock() + + def insert_violation_side_effect(*args, **kwargs): + if args[2] == self.expected_fake_violations[1]: + raise MySQLdb.DataError( + self.dao.RESOURCE_NAME, mock.MagicMock()) + else: + return mock.DEFAULT + + self.dao.execute_sql_with_commit = mock.MagicMock( + side_effect=insert_violation_side_effect) + + actual = self.dao.insert_violations(self.fake_violations) + expected = (2, [self.expected_fake_violations[1]]) + + self.assertEqual(expected, actual) + self.assertEquals(1, violation_dao.LOGGER.error.call_count) + + +if __name__ == '__main__': + basetest.main() diff --git a/tests/scanner/audit/base_rules_engine_test.py b/tests/scanner/audit/base_rules_engine_test.py index 543b4e5f69..ba8e8952b8 100644 --- a/tests/scanner/audit/base_rules_engine_test.py +++ b/tests/scanner/audit/base_rules_engine_test.py @@ -18,6 +18,7 @@ from google.apputils import basetest from google.cloud.security.scanner.audit import base_rules_engine as bre +from google.cloud.security.scanner.audit import rules as audit_rules from google.cloud.security.scanner.audit import errors as audit_errors @@ -60,18 +61,24 @@ class RuleAppliesToTest(basetest.TestCase): def test_rule_applies_is_verified(self): """Test valid RuleAppliesTo.""" - self.assertEqual(bre.RuleAppliesTo.SELF, - bre.RuleAppliesTo.verify(bre.RuleAppliesTo.SELF)) - self.assertEqual(bre.RuleAppliesTo.CHILDREN, - bre.RuleAppliesTo.verify(bre.RuleAppliesTo.CHILDREN)) - self.assertEqual(bre.RuleAppliesTo.SELF_AND_CHILDREN, - bre.RuleAppliesTo.verify( - bre.RuleAppliesTo.SELF_AND_CHILDREN)) + + self.assertEqual( + audit_rules.RuleAppliesTo.SELF, + audit_rules.RuleAppliesTo.verify(audit_rules.RuleAppliesTo.SELF)) + + self.assertEqual( + audit_rules.RuleAppliesTo.CHILDREN, + audit_rules.RuleAppliesTo.verify(audit_rules.RuleAppliesTo.CHILDREN)) + + self.assertEqual( + audit_rules.RuleAppliesTo.SELF_AND_CHILDREN, + audit_rules.RuleAppliesTo.verify( + audit_rules.RuleAppliesTo.SELF_AND_CHILDREN)) def test_invalid_rule_applies_raises_error(self): """Test invalid RuleAppliesTo raises error.""" with self.assertRaises(audit_errors.InvalidRulesSchemaError): - bre.RuleAppliesTo.verify('invalid') + audit_rules.RuleAppliesTo.verify('invalid') if __name__ == '__main__': diff --git a/tests/scanner/audit/org_rules_engine_test.py b/tests/scanner/audit/org_rules_engine_test.py index 85bc1201b2..7c7bb7ebf6 100644 --- a/tests/scanner/audit/org_rules_engine_test.py +++ b/tests/scanner/audit/org_rules_engine_test.py @@ -28,9 +28,7 @@ from google.cloud.security.scanner.audit.errors import InvalidRulesSchemaError from google.cloud.security.scanner.audit import base_rules_engine as bre from google.cloud.security.scanner.audit import org_rules_engine as ore -from google.cloud.security.scanner.audit.org_rules_engine import ResourceRules -from google.cloud.security.scanner.audit.org_rules_engine import RuleViolation -from google.cloud.security.scanner.audit.org_rules_engine import RULE_VIOLATION_TYPE +from google.cloud.security.scanner.audit import rules as scanner_rules from tests.unittest_utils import get_datafile_path from tests.scanner.audit.data import test_rules @@ -109,18 +107,18 @@ def test_add_single_rule_builds_correct_map(self): rule_bindings = [{ 'role': 'roles/*', 'members': ['user:*@company.com'] }] - rule = ore.Rule('my rule', 0, + rule = scanner_rules.Rule('my rule', 0, [IamPolicyBinding.create_from(b) for b in rule_bindings], mode='whitelist') - expected_org_rules = ResourceRules(self.org789, - rules=set([rule]), - applies_to='self_and_children') - expected_proj1_rules = ResourceRules(self.project1, - rules=set([rule]), - applies_to='self') - expected_proj2_rules = ResourceRules(self.project2, - rules=set([rule]), - applies_to='self') + expected_org_rules = ore.ResourceRules(self.org789, + rules=set([rule]), + applies_to='self_and_children') + expected_proj1_rules = ore.ResourceRules(self.project1, + rules=set([rule]), + applies_to='self') + expected_proj2_rules = ore.ResourceRules(self.project2, + rules=set([rule]), + applies_to='self') expected_rules = { (self.org789, 'self_and_children'): expected_org_rules, (self.project1, 'self'): expected_proj1_rules, @@ -132,12 +130,12 @@ def test_add_single_rule_builds_correct_map(self): def test_invalid_rule_mode_raises_when_verify_mode(self): """Test that an invalid rule mode raises error.""" with self.assertRaises(InvalidRulesSchemaError): - bre.RuleMode.verify('nonexistent mode') + scanner_rules.RuleMode.verify('nonexistent mode') def test_invalid_rule_mode_raises_when_create_rule(self): """Test that creating a Rule with invalid rule mode raises error.""" with self.assertRaises(InvalidRulesSchemaError): - ore.Rule('exception', 0, []) + scanner_rules.Rule('exception', 0, []) def test_policy_binding_matches_whitelist_rules(self): """Test that a policy binding matches the whitelist rules. @@ -171,10 +169,10 @@ def test_policy_binding_matches_whitelist_rules(self): } ] - rule = ore.Rule('test rule', 0, + rule = scanner_rules.Rule('test rule', 0, [IamPolicyBinding.create_from(b) for b in rule_bindings], mode='whitelist') - resource_rule = ResourceRules(rules=[rule]) + resource_rule = ore.ResourceRules(rules=[rule]) results = list(resource_rule.find_mismatches( self.project1, test_binding)) @@ -209,10 +207,10 @@ def test_policy_binding_does_not_match_blacklist_rules(self): } ] - rule = ore.Rule('test rule', 0, + rule = scanner_rules.Rule('test rule', 0, [IamPolicyBinding.create_from(b) for b in rule_bindings], mode='blacklist') - resource_rule = ResourceRules(rules=[rule]) + resource_rule = ore.ResourceRules(rules=[rule]) results = list(resource_rule.find_mismatches( self.project1, test_binding)) @@ -248,10 +246,10 @@ def test_policy_binding_matches_required_rules(self): } ] - rule = ore.Rule('test rule', 0, + rule = scanner_rules.Rule('test rule', 0, [IamPolicyBinding.create_from(b) for b in rule_bindings], mode='required') - resource_rule = ResourceRules(rules=[rule]) + resource_rule = ore.ResourceRules(rules=[rule]) results = list(resource_rule.find_mismatches( self.project1, test_binding)) @@ -286,10 +284,10 @@ def test_policy_binding_mismatches_required_rules(self): } ] - rule = ore.Rule('test rule', 0, + rule = scanner_rules.Rule('test rule', 0, [IamPolicyBinding.create_from(b) for b in rule_bindings], mode='required') - resource_rule = ResourceRules(resource=self.project1) + resource_rule = ore.ResourceRules(resource=self.project1) resource_rule.rules.add(rule) results = list(resource_rule.find_mismatches( self.project1, test_binding)) @@ -327,7 +325,7 @@ def test_one_member_mismatch(self): 'role': 'roles/*', 'members': ['user:*@company.com'] }] - rule = ore.Rule('my rule', 0, + rule = scanner_rules.Rule('my rule', 0, [IamPolicyBinding.create_from(b) for b in rule_bindings], mode='whitelist') expected_outstanding = { @@ -336,13 +334,13 @@ def test_one_member_mismatch(self): ] } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( resource_type=self.project1.type, resource_id=self.project1.id, rule_name=rule.rule_name, rule_index=rule.rule_index, role='roles/editor', - violation_type=RULE_VIOLATION_TYPE.get(rule.mode), + violation_type=scanner_rules.VIOLATION_TYPE.get(rule.mode), members=tuple(expected_outstanding['roles/editor'])) ]) @@ -477,7 +475,7 @@ def test_whitelist_blacklist_rules_vs_policy_has_violations(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=0, rule_name='my rule', resource_id=self.project1.id, @@ -485,7 +483,7 @@ def test_whitelist_blacklist_rules_vs_policy_has_violations(self): violation_type='ADDED', role=policy['bindings'][0]['role'], members=tuple(expected_outstanding1['roles/editor'])), - RuleViolation( + scanner_rules.RuleViolation( rule_index=0, rule_name='my rule', resource_type=self.project2.type, @@ -493,7 +491,7 @@ def test_whitelist_blacklist_rules_vs_policy_has_violations(self): violation_type='ADDED', role=policy['bindings'][0]['role'], members=tuple(expected_outstanding1['roles/editor'])), - RuleViolation( + scanner_rules.RuleViolation( rule_index=1, rule_name='my other rule', resource_type=self.project2.type, @@ -501,7 +499,7 @@ def test_whitelist_blacklist_rules_vs_policy_has_violations(self): violation_type='ADDED', role=policy['bindings'][0]['role'], members=tuple(expected_outstanding2['roles/editor'])), - RuleViolation( + scanner_rules.RuleViolation( rule_index=2, rule_name='required rule', resource_id=self.project1.id, @@ -609,7 +607,7 @@ def test_org_proj_rules_vs_policy_has_violations(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=1, rule_name='my blacklist rule', resource_id=self.org789.id, @@ -617,7 +615,7 @@ def test_org_proj_rules_vs_policy_has_violations(self): violation_type='ADDED', role=org_policy['bindings'][0]['role'], members=tuple(expected_outstanding_org['roles/editor'])), - RuleViolation( + scanner_rules.RuleViolation( rule_index=0, rule_name='my whitelist rule', resource_id=self.project1.id, @@ -625,7 +623,7 @@ def test_org_proj_rules_vs_policy_has_violations(self): violation_type='ADDED', role=project_policy['bindings'][0]['role'], members=tuple(expected_outstanding_project['roles/editor'])), - RuleViolation( + scanner_rules.RuleViolation( rule_index=2, rule_name='my required rule', resource_id=self.project1.id, @@ -695,7 +693,7 @@ def test_org_self_rules_work_with_org_child_rules(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=0, rule_name='org whitelist', resource_id=self.org789.id, @@ -703,7 +701,7 @@ def test_org_self_rules_work_with_org_child_rules(self): violation_type='ADDED', role=org_policy['bindings'][0]['role'], members=tuple(expected_outstanding_org['roles/owner'])), - RuleViolation( + scanner_rules.RuleViolation( rule_index=1, rule_name='project whitelist', resource_id=self.project1.id, @@ -744,9 +742,9 @@ def test_org_project_noinherit_project_overrides_org_rule(self): ] } - actual_violations = set(itertools.chain( + actual_violations = set( rules_engine.find_policy_violations(self.project1, project_policy) - )) + ) # expected expected_outstanding_proj = { @@ -800,7 +798,7 @@ def test_org_2_child_rules_report_violation(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=1, rule_name='project blacklist', resource_id=self.project1.id, @@ -855,7 +853,7 @@ def test_org_project_inherit_org_rule_violation(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=0, rule_name='org blacklist', resource_id=self.project1.id, @@ -962,7 +960,7 @@ def test_org_self_wl_proj_noinherit_bl_has_violation(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=1, rule_name='project blacklist', resource_id=self.project1.id, @@ -1003,9 +1001,9 @@ def test_ignore_case_works(self): ] } - actual_violations = set(itertools.chain( + actual_violations = set( rules_engine.find_policy_violations(self.project1, project_policy) - )) + ) # expected expected_outstanding_proj = { @@ -1015,7 +1013,7 @@ def test_ignore_case_works(self): } expected_violations = set([ - RuleViolation( + scanner_rules.RuleViolation( rule_index=1, rule_name='project blacklist', resource_id=self.project1.id, diff --git a/tests/scanner/scanner_test.py b/tests/scanner/scanner_test.py index 1924d3070b..9894fed302 100644 --- a/tests/scanner/scanner_test.py +++ b/tests/scanner/scanner_test.py @@ -22,14 +22,16 @@ from google.apputils import basetest from google.cloud.security.common.data_access import csv_writer -from google.cloud.security.common.data_access import dao +from google.cloud.security.common.data_access import _db_connector from google.cloud.security.common.data_access import errors +from google.cloud.security.common.data_access import violation_dao as vdao from google.cloud.security.common.gcp_type import iam_policy from google.cloud.security.common.gcp_type import organization from google.cloud.security.common.gcp_type import project from google.cloud.security.common.gcp_type import resource from google.cloud.security.scanner import scanner from google.cloud.security.scanner.audit import org_rules_engine as ore +from google.cloud.security.scanner.audit import rules as audit_rules from tests.inventory.pipelines.test_data import fake_iam_policies @@ -40,7 +42,8 @@ def setUp(self): year=1900, month=1, day=1, hour=0, minute=0, second=0, microsecond=0) self.fake_utcnow = fake_utcnow - self.fake_utcnow_str = self.fake_utcnow.strftime(scanner.OUTPUT_TIMESTAMP_FMT) + self.fake_utcnow_str = self.fake_utcnow.strftime( + scanner.OUTPUT_TIMESTAMP_FMT) self.fake_timestamp = '123456' self.scanner = scanner self.scanner.LOGGER = mock.MagicMock() @@ -48,7 +51,8 @@ def setUp(self): self.scanner.FLAGS.rules = 'fake/path/to/rules.yaml' self.fake_main_argv = [] self.fake_org_policies = fake_iam_policies.FAKE_ORG_IAM_POLICY_MAP - self.fake_project_policies = fake_iam_policies.FAKE_PROJECT_IAM_POLICY_MAP + self.fake_project_policies = \ + fake_iam_policies.FAKE_PROJECT_IAM_POLICY_MAP def test_missing_rules_flag_raises_systemexit(self): """Test that missing the `rules` flag raises SystemExit/calls sys.exit().""" @@ -219,23 +223,28 @@ def test_get_timestamp_db_errors(self, mock_get_ss_timestamp, mock_conn): self.assertEqual(1, scanner.LOGGER.error.call_count) self.assertIsNone(actual) + @mock.patch.object(MySQLdb, 'connect') @mock.patch.object(csv_writer, 'write_csv', autospec=True) @mock.patch.object(os, 'path', autospec=True) @mock.patch.object(scanner, '_upload_csv') @mock.patch.object(scanner, '_send_email') @mock.patch('google.cloud.security.scanner.scanner.datetime') + @mock.patch.object(vdao.ViolationDao, 'insert_violations') def test_output_results_local_no_email( self, + mock_violation_dao, mock_datetime, mock_send_email, mock_upload, mock_path, - mock_write_csv): + mock_write_csv, + mock_conn): """Test output results for local output, and don't send email. Setup: * Create fake csv filename. * Create fake file path. + * Mock out the ViolationDao. * Set FLAGS values. * Mock the context manager and the csv file name. * Mock the timestamp for the email. @@ -259,23 +268,30 @@ def test_output_results_local_no_email( mock_path.abspath = mock.MagicMock() mock_path.abspath.return_value = fake_full_path - self.scanner._output_results(['a']) + mock_violation_dao.return_value = (1, []) + + self.scanner._output_results(['a'], self.fake_timestamp) - mock_upload.assert_called_once_with(fake_full_path, self.fake_utcnow, fake_csv_name) + mock_upload.assert_called_once_with( + fake_full_path, self.fake_utcnow, fake_csv_name) self.assertEquals(0, mock_send_email.call_count) + @mock.patch.object(MySQLdb, 'connect') @mock.patch.object(csv_writer, 'write_csv', autospec=True) @mock.patch.object(os, 'path', autospec=True) @mock.patch.object(scanner, '_upload_csv') @mock.patch.object(scanner, '_send_email') @mock.patch('google.cloud.security.scanner.scanner.datetime') + @mock.patch.object(vdao.ViolationDao, 'insert_violations') def test_output_results_gcs_email( self, + mock_violation_dao, mock_datetime, mock_send_email, mock_upload, mock_path, - mock_write_csv): + mock_write_csv, + mock_conn): """Test output results for GCS upload and send email. Setup: @@ -283,6 +299,7 @@ def test_output_results_gcs_email( * Create fake counts. * Create fake csv filename. * Create fake file path. + * Mock out the ViolationDao. * Set FLAGS values. * Mock the context manager and the csv file name. * Mock the timestamp for the email. @@ -302,17 +319,23 @@ def test_output_results_gcs_email( mock_write_csv.return_value = mock.MagicMock() mock_write_csv.return_value.__enter__ = mock.MagicMock() - type(mock_write_csv.return_value.__enter__.return_value).name = fake_csv_name + type(mock_write_csv.return_value \ + .__enter__.return_value).name = fake_csv_name mock_datetime.utcnow = mock.MagicMock() mock_datetime.utcnow.return_value = self.fake_utcnow mock_path.abspath = mock.MagicMock() mock_path.abspath.return_value = fake_full_path - self.scanner._output_results(fake_violations, resource_counts=fake_counts) + mock_violation_dao.return_value = (1, []) + + self.scanner._output_results(fake_violations, + self.fake_timestamp, + resource_counts=fake_counts) - mock_upload.assert_called_once_with(fake_full_path, self.fake_utcnow, fake_csv_name) + mock_upload.assert_called_once_with( + fake_full_path, self.fake_utcnow, fake_csv_name) mock_send_email.assert_called_once_with( - fake_csv_name, self.fake_utcnow, fake_violations, fake_counts) + fake_csv_name, self.fake_utcnow, fake_violations, fake_counts, []) def test_build_scan_summary(self): """Test that the scan summary is built correctly.""" @@ -320,20 +343,20 @@ def test_build_scan_summary(self): for u in ['user:a@b.c', 'group:g@h.i', 'serviceAccount:x@y.z'] ] all_violations = [ - ore.RuleViolation( + audit_rules.RuleViolation( resource_type='organization', resource_id='abc111', rule_name='Abc 111', rule_index=0, - violation_type=ore.RULE_VIOLATION_TYPE['whitelist'], + violation_type=audit_rules.VIOLATION_TYPE['whitelist'], role='role1', members=tuple(members)), - ore.RuleViolation( + audit_rules.RuleViolation( resource_type='project', resource_id='def222', rule_name='Def 123', rule_index=1, - violation_type=ore.RULE_VIOLATION_TYPE['blacklist'], + violation_type=audit_rules.VIOLATION_TYPE['blacklist'], role='role2', members=tuple(members)), ] @@ -342,7 +365,8 @@ def test_build_scan_summary(self): resource.ResourceType.PROJECT: 1, } - actual = self.scanner._build_scan_summary(all_violations, total_resources) + actual = self.scanner._build_scan_summary( + all_violations, total_resources) expected_summaries = { resource.ResourceType.ORGANIZATION: {