diff --git a/comet_core/data_store.py b/comet_core/data_store.py index 4c57535..f9c73f8 100644 --- a/comet_core/data_store.py +++ b/comet_core/data_store.py @@ -15,25 +15,23 @@ """Data Store module - interface to database.""" from datetime import datetime, timedelta +from typing import Dict, List, Optional -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.sql.expression import func +import sqlalchemy +import sqlalchemy.orm from comet_core.model import BaseRecord, EventRecord, IgnoreFingerprintRecord -Session = sessionmaker(autocommit=True) +def remove_duplicate_events(event_record_list: List[EventRecord]) -> List[EventRecord]: + """Removes duplicates based on fingerprint and chooses the newest issue. -def remove_duplicate_events(event_record_list): - """ - This removes duplicates based on fingerprint and chooses the newest issue Args: - event_record_list (list): list of EventRecords + event_record_list: list of EventRecords Returns: list: of EventRecords with extra fingerprints removed """ - events_hash_table = {} + events_hash_table: Dict[Optional[str], EventRecord] = {} for e in event_record_list: if e.fingerprint in events_hash_table: if events_hash_table[e.fingerprint].received_at < e.received_at: @@ -43,137 +41,145 @@ def remove_duplicate_events(event_record_list): return list(events_hash_table.values()) -class DataStore: # pylint: disable=too-many-public-methods - """Abstraction of the comet storage layer. +class DataStore: + """Abstraction of the Comet storage layer. Args: - database_uri (str): the database to use + database_uri: Database URL to connect to. Will be passed to sqlalchemy.create_engine, refer to that + documentation for formats. """ - def __init__(self, database_uri): - self.engine = create_engine(database_uri) - self.connection = self.engine.connect() - - Session.configure(bind=self.engine) - self.session = Session() - - BaseRecord.metadata.create_all(self.engine) - - def add_event(self, event): - """Add an event to the data store. + def __init__(self, database_uri: str) -> None: + """Creates a new DataStore instance Args: - event (PluginBase): a typed event to add to the database. + database_uri (str): Database URL to connect to. Will be passed to sqlalchemy.create_engine, refer to that + documentation for formats. """ - self.add_record(event.record) + # Setting "future" for 2.0 syntax + engine = sqlalchemy.create_engine(database_uri, future=True) + # expire_on_commit needs to be false due to https://docs.sqlalchemy.org/en/14/errors.html#error-bhk3 + self.session = sqlalchemy.orm.sessionmaker(engine, future=True, expire_on_commit=False) - def add_record(self, record): + BaseRecord.metadata.create_all(engine) + + def add_record(self, record: EventRecord) -> None: """Store a record in the data store. + Args: - record (EventRecord): the record object to store + record: the record object to store """ - self.session.add(record) + with self.session.begin() as session: + session.add(record) - def get_unprocessed_events_batch(self, wait_for_more, max_wait, source_type): + def get_unprocessed_events_batch( + self, wait_for_more: timedelta, max_wait: timedelta, source_type: str + ) -> List[EventRecord]: """Get all unprocessed events of the given source_type but only if the latest event is older than `wait_for_more` or the oldest event is older than `max_wait`. - Metrics emitted: - events-hit-max-wait - Args: - wait_for_more (datetime.timedelta): the amount of time to wait since the latest event - max_wait (datetime.timedelta): the amount of time to wait since the earliest event - source_type (str): source type of the events to look for + wait_for_more: the amount of time to wait since the latest event + max_wait: the amount of time to wait since the earliest event + source_type: source type of the events to look for Returns: list: list of `EventRecord`s, or empty list if there is nothing to return """ # https://explainextended.com/2009/09/18/not-in-vs-not-exists-vs-left-join-is-null-mysql/ - events = ( - self.session.query(EventRecord) - .filter((EventRecord.processed_at.is_(None)) & (EventRecord.source_type == source_type)) - .order_by(EventRecord.received_at.asc()) - .all() - ) + with self.session.begin() as session: + events: List[EventRecord] = ( + session.query(EventRecord) + .filter((EventRecord.processed_at.is_(None)) & (EventRecord.source_type == source_type)) + .order_by(EventRecord.received_at.asc()) + .all() + ) now = datetime.utcnow() if events and events[-1].received_at < now - wait_for_more: return events if events and events[0].received_at < now - max_wait: - # METRIC_RELAY.emit('events-hit-max-wait', 1, - # {'source-type': source_type}) return events return [] - def get_events_did_not_addressed(self, source_type): - """Get all events who we sent to the user + - the event haven't escalated already - and do not exist in IgnoreFingerprintRecord database. - That means that the user didn't addressed those events. + def get_events_did_not_addressed(self, source_type: str) -> List[EventRecord]: + """Get all non-escalated and non-ignored events sent to the user. + + Events that haven't escalated already and do not exist in IgnoreFingerprintRecord database are events that have + not been addressed by the user. + Args: - source_type (str): source type to filter the search by. + source_type: source type to filter the search by. Returns: - list: list of `EventRecord`s not addressed, - or empty list if there is nothing to return + list: list of records not addressed, or empty list if there is nothing to return """ - non_addressed_events = ( - self.session.query(EventRecord) - .filter( - (EventRecord.sent_at.isnot(None)) - & (EventRecord.escalated_at.is_(None)) - & (EventRecord.source_type == source_type) + with self.session.begin() as session: + non_addressed_events: List[EventRecord] = ( + session.query(EventRecord) + .filter( + (EventRecord.sent_at.isnot(None)) + & (EventRecord.escalated_at.is_(None)) + & (EventRecord.source_type == source_type) + ) + .outerjoin(IgnoreFingerprintRecord, EventRecord.fingerprint == IgnoreFingerprintRecord.fingerprint) + .filter(IgnoreFingerprintRecord.fingerprint.is_(None)) + .all() ) - .outerjoin(IgnoreFingerprintRecord, EventRecord.fingerprint == IgnoreFingerprintRecord.fingerprint) - .filter(IgnoreFingerprintRecord.fingerprint.is_(None)) - .all() - ) return non_addressed_events - def check_any_issue_needs_reminder(self, search_timedelta, records): - """Checks if the issue among the provided ones with the most recent sent_at value has that value older than the + def check_any_issue_needs_reminder(self, search_timedelta: datetime, records: List[EventRecord]) -> bool: + """Checks if a reminder should be sent by issue by comparingsent_at value with search_timedelta. + + Check if the issue among the provided ones with the most recent sent_at value has that value older than the `search_timedelta`, that is, a reminder should be sent for the issue. + NOTE: if all database records for a fingerprint given in the `records` list have the sent_at values set to Null, - then this fingerprint will be treated as NOT needing a reminder, which might be unintuitive. + then this fingerprint will be treated as NOT needing a reminder, which might be unintuitive. + Args: - search_timedelta (datetime.timedelta): reminder interval - records (list): list of EventRecord objects to check + search_timedelta: reminder interval + records: records to look at. Returns: bool: True if any of the provided records represents an issue that needs to be reminded about """ fingerprints = [record.fingerprint for record in records] - timestamps = ( - self.session.query(func.max(EventRecord.sent_at)) - .filter(EventRecord.fingerprint.in_(fingerprints) & EventRecord.sent_at.isnot(None)) - .group_by(EventRecord.fingerprint) - .all() - ) + with self.session.begin() as session: + timestamps: List[datetime] = ( + session.query(sqlalchemy.sql.expression.func.max(EventRecord.sent_at)) + .filter(EventRecord.fingerprint.in_(fingerprints) & EventRecord.sent_at.isnot(None)) + .group_by(EventRecord.fingerprint) + .all() + ) if timestamps: return max(timestamps)[0] <= datetime.utcnow() - search_timedelta + return False - def get_any_issues_need_reminder(self, search_timedelta, records): + def get_any_issues_need_reminder(self, search_timedelta: timedelta, records: List[EventRecord]) -> List[str]: """Returns all the `fingerprints` having corresponding `event` table entries with the latest `sent_at` more then search_timedelta ago. NOTE: if all database records for a fingerprint given in the `records` list have the sent_at values set to Null, then this fingerprint will be treated as NOT needing a reminder, which might be unintuitive. Args: - search_timedelta (datetime.timedelta): reminder interval - records (list): list of EventRecord objects to check + search_timedelta: reminder interval + records: list of EventRecord objects to check Returns: list: list of fingerprints that represent issues that need to be reminded about """ fingerprints = [record.fingerprint for record in records] - fingerprints_to_remind = ( - self.session.query(func.max(EventRecord.sent_at).label("sent_at"), EventRecord.fingerprint) - .filter(EventRecord.fingerprint.in_(fingerprints) & EventRecord.sent_at.isnot(None)) - .group_by(EventRecord.fingerprint) - .all() - ) + with self.session.begin() as session: + fingerprints_to_remind = ( + session.query( + sqlalchemy.sql.expression.func.max(EventRecord.sent_at).label("sent_at"), EventRecord.fingerprint + ) + .filter(EventRecord.fingerprint.in_(fingerprints) & EventRecord.sent_at.isnot(None)) + .group_by(EventRecord.fingerprint) + .all() + ) result = [] deltat = datetime.utcnow() - search_timedelta for f in fingerprints_to_remind: @@ -182,81 +188,87 @@ def get_any_issues_need_reminder(self, search_timedelta, records): return result - def update_timestamp_column_to_now(self, records, column_name): - """Update the `column_name` of the provided `EventRecord`s to datetime now + def update_timestamp_column_to_now(self, records: List[EventRecord], column_name: str) -> None: + """Update the `column_name` of the provided records to now Args: - records (list): `EventRecord`s to update the `column_name` for - column_name (str): the name of the datebase column to update + records: records to update the `column_name` for + column_name: the name of the datebase column to update """ time_now = datetime.utcnow() updates = [{"id": r.id, column_name: time_now} for r in records] - self.session.bulk_update_mappings(EventRecord, updates) + with self.session.begin() as session: + session.bulk_update_mappings(EventRecord, updates) - def update_processed_at_timestamp_to_now(self, records): # pylint: disable=invalid-name - """Update the processed_at timestamp for the given records to now. + def update_processed_at_timestamp_to_now(self, records: List[EventRecord]) -> None: # pylint: disable=invalid-name + """Update the processed_at timestamp for to now. Args: - records (list): `EventRecord`s to update the processed for + records: records to update the processed_at field """ self.update_timestamp_column_to_now(records, "processed_at") - def update_sent_at_timestamp_to_now(self, records): - """Update the sent_at timestamp for the given records to now. + def update_sent_at_timestamp_to_now(self, records: List[EventRecord]) -> None: + """Update the sent_at timestamp to now. Args: - records (list): `EventRecord`s to update the sent_at for + records: records to update the sent_at field """ self.update_timestamp_column_to_now(records, "sent_at") - def update_event_escalated_at_to_now(self, records): # pylint: disable=invalid-name - """Update the escalated_at timestamp for the given records to now. + def update_event_escalated_at_to_now(self, records: List[EventRecord]) -> None: # pylint: disable=invalid-name + """Update the escalated_at timestamp to now. + Args: - records (list): `EventRecord`s to update + records: records to update the escalated_at field """ self.update_timestamp_column_to_now(records, "escalated_at") - def get_oldest_event_with_fingerprint(self, fingerprint): # pylint: disable=invalid-name + def get_oldest_event_with_fingerprint(self, fingerprint: str) -> EventRecord: # pylint: disable=invalid-name """ Returns the oldest (first occurrence) event with the provided fingerprint. Args: - fingerprint (str): fingerprint to look for + fingerprint: fingerprint to look for Returns: EventRecord: oldest EventRecord with the given fingerprint """ - return ( - self.session.query(EventRecord) - .filter(EventRecord.fingerprint == fingerprint) - .order_by(EventRecord.received_at.asc()) - .limit(1) - .one_or_none() - ) + with self.session.begin() as session: + return ( + session.query(EventRecord) + .filter(EventRecord.fingerprint == fingerprint) + .order_by(EventRecord.received_at.asc()) + .limit(1) + .one_or_none() + ) - def get_latest_event_with_fingerprint(self, fingerprint): # pylint: disable=invalid-name + def get_latest_event_with_fingerprint(self, fingerprint: str) -> EventRecord: # pylint: disable=invalid-name """ Returns the latest (in other words: the newest, closest to now) event with the provided fingerprint. Args: - fingerprint (str): fingerprint to look for + fingerprint: fingerprint to look for Returns: EventRecord: latest EventRecord with the given fingerprint """ - return ( - self.session.query(EventRecord) - .filter(EventRecord.fingerprint == fingerprint) - .order_by(EventRecord.received_at.desc()) - .limit(1) - .one_or_none() - ) + with self.session.begin() as session: + return ( + session.query(EventRecord) + .filter(EventRecord.fingerprint == fingerprint) + .order_by(EventRecord.received_at.desc()) + .limit(1) + .one_or_none() + ) + + def check_needs_escalation(self, escalation_time: timedelta, event: EventRecord) -> bool: + """Checks if the event needs to be escalated. + + Returns True if the first occurrence of an event with the same fingerprint is older than the escalation time. - def check_needs_escalation(self, escalation_time, event): - """Checks if the event needs to be escalated. Returns True if the first occurrence of an event with the same - fingerprint is older than the escalation time. Args: - escalation_time (datetime.timedelta): time to delay escalation - event (EventRecord): EventRecord to check + escalation_time: time to delay escalation + event: EventRecord to check Returns: bool: True if the event should be escalated """ @@ -268,15 +280,21 @@ def check_needs_escalation(self, escalation_time, event): return oldest_event.received_at <= datetime.utcnow() - escalation_time def ignore_event_fingerprint( - self, fingerprint, ignore_type, expires_at=None, reported_at=None, record_metadata=None - ): - """Add a fingerprint to the list of ignored events + self, + fingerprint: str, + ignore_type: str, + expires_at: Optional[datetime] = None, + reported_at: Optional[datetime] = None, + record_metadata: Optional[datetime] = None, + ) -> None: + """Add a fingerprint to the list of ignored events. + Args: - fingerprint (str): fingerprint of the event to ignore - ignore_type (str): the type (reason) for ignoring, for example IgnoreFingerprintRecord.SNOOZE - expires_at (datetime.datetime): specify the time of the ignore expiration - reported_at (datetime.datetime): specify the time of the reported date - record_metadata (dict): metadata to hydrate the record with. + fingerprint: fingerprint of the event to ignore + ignore_type: the type (reason) for ignoring, for example IgnoreFingerprintRecord.SNOOZE + expires_at: specify the time of the ignore expiration + reported_at: specify the time of the reported date + record_metadata: metadata to hydrate the record with. """ new_record = IgnoreFingerprintRecord( fingerprint=fingerprint, @@ -285,171 +303,181 @@ def ignore_event_fingerprint( reported_at=reported_at, record_metadata=record_metadata, ) - self.session.begin() - self.session.add(new_record) - self.session.commit() + with self.session.begin() as session: + session.add(new_record) + + def fingerprint_is_ignored(self, fingerprint: str) -> bool: + """Check if a fingerprint is marked as ignored. - def fingerprint_is_ignored(self, fingerprint): - """Check if a fingerprint is marked as ignored (whitelisted or snoozed) Args: fingerprint (str): fingerprint of the event Returns: - bool: True if whitelisted + bool: True if whitelisted or snoozed """ - return ( - self.session.query(IgnoreFingerprintRecord) - .filter(IgnoreFingerprintRecord.fingerprint == fingerprint) - .filter( - (IgnoreFingerprintRecord.expires_at > datetime.utcnow()) - | (IgnoreFingerprintRecord.expires_at.is_(None)) + with self.session.begin() as session: + return ( + session.query(IgnoreFingerprintRecord) + .filter(IgnoreFingerprintRecord.fingerprint == fingerprint) + .filter( + (IgnoreFingerprintRecord.expires_at > datetime.utcnow()) + | (IgnoreFingerprintRecord.expires_at.is_(None)) + ) + .count() + >= 1 ) - .count() - >= 1 - ) - def may_send_escalation(self, source_type, escalation_reminder_cadence): - """Check if we are allowed to send another esclation notification to the source_type escalation recipient. + def may_send_escalation(self, source_type: str, escalation_reminder_cadence: timedelta) -> bool: + """Check if another escalation notification is allowed to the source_type escalation recipient. + Returns false if there was an escalation sent to them within `escalation_reminder_cadence`. Args: - source_type (str): source type of the events - escalation_reminder_cadence (datetime.timedelta): time to wait before sending next escalation + source_type: source type of the events + escalation_reminder_cadence: time to wait before sending next escalation Returns: bool: True if an escalation may be sent, False otherwise """ - last_escalated = ( - self.session.query(EventRecord.escalated_at) - .filter(EventRecord.source_type == source_type) - .order_by(EventRecord.escalated_at.desc()) - .limit(1) - .one_or_none() - ) + with self.session.begin() as session: + last_escalated = ( + session.query(EventRecord.escalated_at) + .filter(EventRecord.source_type == source_type) + .order_by(EventRecord.escalated_at.desc()) + .limit(1) + .one_or_none() + ) if not last_escalated[0]: return True return last_escalated[0] <= datetime.utcnow() - escalation_reminder_cadence - def check_if_previously_escalated(self, event): - """Checks if the issue was escalated before. This looks for previous escalations sent for events with the same - fingerprint. + def check_if_previously_escalated(self, event: EventRecord) -> None: + """Checks if the issue was escalated before. + + This looks for previous escalations sent for events with the same fingerprint. Args: - event (EventRecord): one event of the issue to check + event: one event of the issue to check Returns: bool: True if any previous event with the same fingerprint was escalated, False otherwise """ - return ( - self.session.query(EventRecord) - .filter(EventRecord.fingerprint == event.fingerprint) - .filter(EventRecord.escalated_at.isnot(None)) - .count() - >= 1 - ) + with self.session.begin() as session: + return ( + session.query(EventRecord) + .filter(EventRecord.fingerprint == event.fingerprint) + .filter(EventRecord.escalated_at.isnot(None)) + .count() + >= 1 + ) - def get_open_issues(self, owners): + def get_open_issues(self, owners: List[str]) -> List[EventRecord]: """Return a list of open (newer than 24h), not whitelisted or snoozed issues for the given owners. + Args: - owners (list): list of strings, containing owners + owners: list of strings, containing owners Returns: list: list of EventRecord, representing open, non-ignored issues for the given owners """ - open_issues = ( - self.session.query(EventRecord) - .filter(EventRecord.owner.in_(owners)) - .filter(EventRecord.received_at >= datetime.utcnow() - timedelta(days=1)) - .all() - ) + with self.session.begin() as session: + open_issues = ( + session.query(EventRecord) + .filter(EventRecord.owner.in_(owners)) + .filter(EventRecord.received_at >= datetime.utcnow() - timedelta(days=1)) + .all() + ) - open_issues = remove_duplicate_events(open_issues) + open_issues = remove_duplicate_events(open_issues) - open_issues_fps = [issue.fingerprint for issue in open_issues] + open_issues_fps = [issue.fingerprint for issue in open_issues] - ignored_issues_fps_tuples = ( - self.session.query(IgnoreFingerprintRecord.fingerprint) - .filter(IgnoreFingerprintRecord.fingerprint.in_(open_issues_fps)) - .filter( - (IgnoreFingerprintRecord.expires_at > datetime.utcnow()) - | (IgnoreFingerprintRecord.expires_at.is_(None)) + ignored_issues_fps_tuples = ( + session.query(IgnoreFingerprintRecord.fingerprint) + .filter(IgnoreFingerprintRecord.fingerprint.in_(open_issues_fps)) + .filter( + (IgnoreFingerprintRecord.expires_at > datetime.utcnow()) + | (IgnoreFingerprintRecord.expires_at.is_(None)) + ) + .all() ) - .all() - ) ignored_issues_fps = [t[0] for t in ignored_issues_fps_tuples] return [issue for issue in open_issues if issue.fingerprint not in ignored_issues_fps] - def check_if_new(self, fingerprint, new_threshold): - """Check if an issue is new. An issue is treated as new if there are no events with the same fingerprint OR - if there are older events with the same fingerprint but the most recent one of them is older than - `new_threshold`. The idea with the second condition is to flag regressions as new issues, but allow for some - flakyness (eg a scanner not running a day should not flag all old open issues as new when it runs the day - after again). + def check_if_new(self, fingerprint: str, new_threshold: timedelta) -> bool: + """Check if an issue is new. + + An issue is treated as new if there are no events with the same fingerprint OR if there are older events with + the same fingerprint but the most recent one of them is older than `new_threshold`. + The idea with the second condition is to flag regressions as new issues, but allow for some flakyness (e.g., a + scanner not running a day should not flag all old open issues as new when it runs the day after again). Args: - fingerprint (str): fingerprint of the issue to evaluate - new_threshold (datetime.timedelta): time after which an issue should be considered new again, even if it was - seen before + fingerprint: fingerprint of the issue to evaluate + new_threshold: time after which an issue should be considered new again, even if it was seen before Returns: bool: True if the issue is new, False if it is old. """ - most_recent_processed = ( - self.session.query(EventRecord.received_at) - .filter(EventRecord.fingerprint == fingerprint) - .filter(EventRecord.processed_at.isnot(None)) - .order_by(EventRecord.received_at.desc()) - .limit(1) - .one_or_none() - ) + with self.session.begin() as session: + most_recent_processed = ( + session.query(EventRecord.received_at) + .filter(EventRecord.fingerprint == fingerprint) + .filter(EventRecord.processed_at.isnot(None)) + .order_by(EventRecord.received_at.desc()) + .limit(1) + .one_or_none() + ) if not most_recent_processed: return True return most_recent_processed[0] <= datetime.utcnow() - new_threshold - def get_events_need_escalation(self, source_type): - """ - Get all the events that the end user escalate manually - and weren't escalated already by comet. + def get_events_need_escalation(self, source_type: str) -> List[EventRecord]: + """Get all the events that the end user escalate manually and weren't escalated already by Comet. + Args: - source_type (str): source type to filter the search by. + source_type: source type to filter the search by. Returns: list: list of `EventRecord`s to escalate. """ - events_to_escalate = ( - self.session.query(EventRecord) - .filter( - (EventRecord.sent_at.isnot(None)) - & (EventRecord.escalated_at.is_(None)) - & (EventRecord.source_type == source_type) + with self.session.begin() as session: + events_to_escalate = ( + session.query(EventRecord) + .filter( + (EventRecord.sent_at.isnot(None)) + & (EventRecord.escalated_at.is_(None)) + & (EventRecord.source_type == source_type) + ) + .outerjoin(IgnoreFingerprintRecord, EventRecord.fingerprint == IgnoreFingerprintRecord.fingerprint) + .filter(IgnoreFingerprintRecord.ignore_type == IgnoreFingerprintRecord.ESCALATE_MANUALLY) + .all() ) - .outerjoin(IgnoreFingerprintRecord, EventRecord.fingerprint == IgnoreFingerprintRecord.fingerprint) - .filter(IgnoreFingerprintRecord.ignore_type == IgnoreFingerprintRecord.ESCALATE_MANUALLY) - .all() - ) - return events_to_escalate + return events_to_escalate - def get_interactions_fingerprint(self, fingerprint): + def get_interactions_fingerprint(self, fingerprint: str) -> List[IgnoreFingerprintRecord]: """Return the list of all interactions associated with a fingerprint. + Args: - fingerprint (str): the fingerprint of the issue + fingerprint: the fingerprint of the issue Returns: list: list of IgnoreFingerprintRecord for the specified fingerprint """ - interactions = ( - self.session.query(IgnoreFingerprintRecord).filter(IgnoreFingerprintRecord.fingerprint == fingerprint).all() - ) - return [ - { - "id": t.id, - "fingerprint": t.fingerprint, - "ignore_type": t.ignore_type, - "reported_at": t.reported_at, - "expires_at": t.expires_at, - } - for t in interactions - ] + with self.session.begin() as session: + interactions = ( + session.query(IgnoreFingerprintRecord).filter(IgnoreFingerprintRecord.fingerprint == fingerprint).all() + ) + return [ + { + "id": t.id, + "fingerprint": t.fingerprint, + "ignore_type": t.ignore_type, + "reported_at": t.reported_at, + "expires_at": t.expires_at, + } + for t in interactions + ] diff --git a/comet_core/model.py b/comet_core/model.py index d75239e..7e92f9e 100644 --- a/comet_core/model.py +++ b/comet_core/model.py @@ -16,34 +16,16 @@ import json from datetime import datetime -from sqlalchemy import JSON, Column, DateTime, Integer, String, UnicodeText, types -from sqlalchemy.ext.declarative import declarative_base +import sqlalchemy +import sqlalchemy.orm +BaseRecord = sqlalchemy.orm.declarative_base() -class BaseRecordRepr: - """ - This class can be used by declarative_base, to add an automatic - __repr__ method to *all* subclasses of BaseRecord. - """ - - def __repr__(self): - """Return a representation of this object as a string. - - Returns: - str: a representation of the object. - """ - return f"{self.__class__.__name__}: " + " ".join( - [f"{k}={self.__getattribute__(k)}" for k, v in self.__class__.__dict__.items() if hasattr(v, "__set__")] - ) - -BaseRecord = declarative_base(cls=BaseRecordRepr) - - -class JSONType(types.TypeDecorator): # pylint: disable=abstract-method +class JSONType(sqlalchemy.types.TypeDecorator): # pylint: disable=abstract-method """This is for testing purposes, to make the JSON type work with sqlite.""" - impl = UnicodeText + impl = sqlalchemy.UnicodeText cache_ok = True @@ -57,7 +39,7 @@ def load_dialect_impl(self, dialect): object: if dialect name is 'mysql' it will override the type descriptor to JSON() """ if dialect.name == "mysql": - return dialect.type_descriptor(JSON()) + return dialect.type_descriptor(sqlalchemy.JSON()) return dialect.type_descriptor(self.impl) def process_bind_param(self, value, dialect): @@ -101,17 +83,17 @@ class EventRecord(BaseRecord): """ __tablename__ = "event" - id = Column(Integer, primary_key=True) - source_type = Column(String(250), nullable=False) - fingerprint = Column(String(250)) - owner = Column(String(250)) - event_metadata = Column(JSONType()) - data = Column(JSONType()) - - received_at = Column(DateTime, default=datetime.utcnow) - sent_at = Column(DateTime, default=None) - escalated_at = Column(DateTime, default=None) - processed_at = Column(DateTime, default=None) + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) + source_type = sqlalchemy.Column(sqlalchemy.String(250), nullable=False) + fingerprint = sqlalchemy.Column(sqlalchemy.String(250)) + owner = sqlalchemy.Column(sqlalchemy.String(250)) + event_metadata = sqlalchemy.Column(JSONType()) + data = sqlalchemy.Column(JSONType()) + + received_at = sqlalchemy.Column(sqlalchemy.DateTime, default=datetime.utcnow) + sent_at = sqlalchemy.Column(sqlalchemy.DateTime, default=None) + escalated_at = sqlalchemy.Column(sqlalchemy.DateTime, default=None) + processed_at = sqlalchemy.Column(sqlalchemy.DateTime, default=None) def __init__(self, *args, **kwargs): self.new = False @@ -129,17 +111,20 @@ def update_metadata(self, metadata): else: self.event_metadata = metadata + def __repr__(self): + return f"EventRecord(id={self.id!r}, source_type={self.source_type!r}, fingerprint={self.fingerprint!r})" + class IgnoreFingerprintRecord(BaseRecord): """Acceptedrisk model.""" __tablename__ = "ignore_fingerprint" - id = Column(Integer, primary_key=True) - fingerprint = Column(String(250)) - ignore_type = Column(String(50)) - reported_at = Column(DateTime, default=datetime.utcnow) - expires_at = Column(DateTime, default=None) - record_metadata = Column(JSONType()) + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) + fingerprint = sqlalchemy.Column(sqlalchemy.String(250)) + ignore_type = sqlalchemy.Column(sqlalchemy.String(50)) + reported_at = sqlalchemy.Column(sqlalchemy.DateTime, default=datetime.utcnow) + expires_at = sqlalchemy.Column(sqlalchemy.DateTime, default=None) + record_metadata = sqlalchemy.Column(JSONType()) SNOOZE = "snooze" ACCEPT_RISK = "acceptrisk" @@ -147,3 +132,6 @@ class IgnoreFingerprintRecord(BaseRecord): ACKNOWLEDGE = "acknowledge" ESCALATE_MANUALLY = "escalate_manually" RESOLVED = "resolved" + + def __repr__(self): + return f"IgnoreFingerPrintRecord(id={self.id!r}, fingerprint={self.fingerprint!r}, ignore_type={self.ignore_type!r})" diff --git a/pyproject.toml b/pyproject.toml index 2aff249..b125448 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,3 +8,33 @@ line_length = 120 [tool.black] line-length = 120 + +[tool.mypy] +ignore_missing_imports = true +warn_unused_configs = true +disallow_any_generics = true +disallow_subclassing_any = true +# disallow_untyped_calls = true # temp disabled +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_return_any = true +no_implicit_reexport = true +strict_equality = true +plugins = "sqlmypy" + +# Gradually roll out mypy by ignoring all files not explicitly opted in. +[[tool.mypy.overrides]] +module = "comet_core.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = [ + "comet_core.data_store", + "comet_core.models" +] +ignore_errors = false diff --git a/requirements-dev.txt b/requirements-dev.txt index bc42c7f..2b69552 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,18 @@ +# Install the package itself in dev mode +-e . + +tox==3.24.1 + +# Formatters and linters black==21.7b0 +pylint==2.9.6 +isort==5.8.0 + +# Testing pytest==6.2.4 pytest-cov==2.12.0 pytest-freezegun==0.4.2 -tox==3.24.1 -pylint==2.9.6 -isort==5.8.0 + +# Types +mypy==0.910 +sqlalchemy-stubs==0.4 diff --git a/setup.py b/setup.py index fb4650e..1e19001 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setuptools.setup( name="comet-core", - version="2.10.0", + version="2.11.0", url="https://github.com/spotify/comet-core", author="Spotify Platform Security", author_email="wasabi@spotify.com", diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 15edb86..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2018 Spotify AB. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/conftest.py b/tests/conftest.py index a4ad472..4e05be4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,29 +13,75 @@ # limitations under the License. """Global test fixtures""" -import os +from datetime import datetime +from typing import List import pytest from comet_core import Comet +from comet_core.app import EventContainer +from comet_core.data_store import DataStore +from comet_core.model import EventRecord + +# pylint: disable=redefined-outer-name @pytest.fixture -def app(): +def app() -> Comet: + """Returns a Comet app.""" yield Comet() @pytest.fixture -def test_db(): +def messages() -> List[EventContainer]: + """Get all test messages and their filenames as an iterator. + + Returns: + EventContainer: some test event + """ + event = EventContainer("test", {}) + event.set_owner("test@acme.org") + event.set_fingerprint("test") + return [event] + + +@pytest.fixture +def data_store() -> DataStore: + """Creates a SQLite backed datastore.""" + return DataStore("sqlite://") + + +@pytest.fixture +def test_db(messages, data_store) -> DataStore: """Setup a test database fixture Yields: DataStore: a sqlite backed datastore with all test data """ - from comet_core.data_store import DataStore - from tests.utils import get_all_test_messages - data_store = DataStore("sqlite://") - for event in get_all_test_messages(parsed=True): + for event in messages: data_store.add_record(event.get_record()) + + yield data_store + + +@pytest.fixture +def data_store_with_test_events(data_store) -> DataStore: + """Creates a populated data store.""" + one = EventRecord(received_at=datetime(2018, 7, 7, 9, 0, 0), source_type="datastoretest", owner="a", data={}) + one.fingerprint = "f1" + two = EventRecord(received_at=datetime(2018, 7, 7, 9, 30, 0), source_type="datastoretest", owner="a", data={}) + two.fingerprint = "f2" + three = EventRecord( + received_at=datetime(2018, 7, 7, 9, 0, 0), + source_type="datastoretest2", # Note that this is another source type! + owner="b", + data={}, + ) + three.fingerprint = "f3" + + data_store.add_record(one) + data_store.add_record(two) + data_store.add_record(three) + yield data_store diff --git a/tests/test_data_store.py b/tests/test_data_store.py index 287e26b..cf289b1 100644 --- a/tests/test_data_store.py +++ b/tests/test_data_store.py @@ -11,54 +11,85 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# pylint: disable=invalid-name,missing-docstring,redefined-outer-name """Tests the event_parser module""" +# pylint: disable=invalid-name,redefined-outer-name + from datetime import datetime, timedelta import pytest from freezegun import freeze_time -import comet_core.data_store from comet_core.data_store import remove_duplicate_events from comet_core.model import EventRecord, IgnoreFingerprintRecord -from tests.utils import get_all_test_messages +# Fixtures used in some of the data store tests. Additional generic fixtures are found in conftest.py +# There is no generic usage of the fixtures, some tests are defining their test data inline and some are using the +# test data from fixtures. @pytest.fixture -# pylint: disable=missing-yield-doc,missing-yield-type-doc -def data_store_with_test_events(): - data_store = comet_core.data_store.DataStore("sqlite://") - - one = EventRecord(received_at=datetime(2018, 7, 7, 9, 0, 0), source_type="datastoretest", owner="a", data={}) - one.fingerprint = "f1" - two = EventRecord(received_at=datetime(2018, 7, 7, 9, 30, 0), source_type="datastoretest", owner="a", data={}) - two.fingerprint = "f2" - three = EventRecord( +def non_addressed_event(): + """Event sent but missing in the ignore_event table to indicate that it wasn't addressed by the user""" + event = EventRecord( received_at=datetime(2018, 7, 7, 9, 0, 0), - source_type="datastoretest2", # Note that this is another source type! - owner="b", + source_type="datastoretest", + owner="a", + sent_at=datetime(2018, 7, 7, 9, 0, 0), data={}, ) - three.fingerprint = "f3" + event.fingerprint = "f1" + return event - data_store.add_record(one) - data_store.add_record(two) - data_store.add_record(three) - yield data_store +@pytest.fixture +def addressed_event(data_store): + """Event was sent and addressed (ack by the user).""" + ack_event = EventRecord( + received_at=datetime(2018, 7, 7, 9, 30, 0), + source_type="datastoretest", + owner="a", + sent_at=datetime(2018, 7, 7, 9, 30, 0), + data={}, + ) + ack_event.fingerprint = "f2" + data_store.ignore_event_fingerprint(ack_event.fingerprint, ignore_type=IgnoreFingerprintRecord.ACKNOWLEDGE) + return ack_event -def test_data_store(): - data_store = comet_core.data_store.DataStore("sqlite://") - for event in get_all_test_messages(): - data_store.add_record(event.get_record()) +@pytest.fixture +def event_to_escalate(data_store): + """Event was sent and addressed (escalated by the user).""" + escalated_event = EventRecord( + received_at=datetime(2018, 7, 7, 9, 30, 0), + source_type="datastoretest", + owner="a", + sent_at=datetime(2018, 7, 7, 9, 30, 0), + data={}, + ) + escalated_event.fingerprint = "f3" + data_store.ignore_event_fingerprint( + escalated_event.fingerprint, ignore_type=IgnoreFingerprintRecord.ESCALATE_MANUALLY + ) + return escalated_event + + +@pytest.fixture +def data_store_with_real_time_events(data_store, addressed_event, non_addressed_event, event_to_escalate): + """Data store with real time events added.""" + data_store.add_record(addressed_event) + data_store.add_record(non_addressed_event) + data_store.add_record(event_to_escalate) + return data_store -def test_date_sorting(): - data_store = comet_core.data_store.DataStore("sqlite://") +def test_data_store(data_store, messages): + """Test that the new events can be added to the data store.""" + for event in messages: + data_store.add_record(event.get_record()) + +def test_date_sorting(data_store): + """Test that the date sorting work by adding two events to the database and query for the oldest/latest.""" old = EventRecord( received_at=datetime(2018, 2, 19, 0, 0, 11), source_type="datastoretest", data={"fingerprint": "same"} ) @@ -79,8 +110,8 @@ def test_date_sorting(): @freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring, invalid-name def test_get_unprocessed_events_batch_will_wait(data_store_with_test_events): + """Test that no events are returned when asked to wait for a long time.""" val = data_store_with_test_events.get_unprocessed_events_batch( timedelta(days=1024 * 365), timedelta(days=1024 * 365), "datastoretest" ) @@ -88,8 +119,8 @@ def test_get_unprocessed_events_batch_will_wait(data_store_with_test_events): @freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring def test_get_unprocessed_events(data_store_with_test_events): + """Test that all unprocessed events are returned for a given source type.""" val = data_store_with_test_events.get_unprocessed_events_batch( timedelta(minutes=1), timedelta(minutes=1), "datastoretest" ) @@ -97,22 +128,11 @@ def test_get_unprocessed_events(data_store_with_test_events): @freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring, invalid-name -def test_get_unprocessed_events_max_wait(data_store_with_test_events): - val = data_store_with_test_events.get_unprocessed_events_batch( - timedelta(minutes=600), timedelta(minutes=60), "datastoretest" - ) - assert val == [] - - val = data_store_with_test_events.get_unprocessed_events_batch( - timedelta(minutes=600), timedelta(minutes=59), "datastoretest" - ) - assert len(val) == 2 - - -@freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring, invalid-name def test_get_unprocessed_events_wait_for_more(data_store_with_test_events): + """Test that unprocessed events are returned. + + TODO: Do this test add anyting compared to the above? + """ val = data_store_with_test_events.get_unprocessed_events_batch( timedelta(minutes=30), timedelta(minutes=120), "datastoretest" ) @@ -125,47 +145,61 @@ def test_get_unprocessed_events_wait_for_more(data_store_with_test_events): @freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring, invalid-name def test_update_sent_at_timestamp_to_now(data_store_with_test_events): + """Tests that updating the sent_at timestamp for events is working.""" val = data_store_with_test_events.get_unprocessed_events_batch( timedelta(minutes=1), timedelta(minutes=1), "datastoretest" ) assert len(val) == 2 + data_store_with_test_events.update_sent_at_timestamp_to_now(val) record = data_store_with_test_events.get_latest_event_with_fingerprint(val[0].fingerprint) - assert record.sent_at is not None assert isinstance(record.sent_at, datetime) @freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring, invalid-name def test_update_event_escalation_at_to_now(data_store_with_test_events): + """Tests that updating escalated_at for events is working.""" val = data_store_with_test_events.get_unprocessed_events_batch( timedelta(minutes=1), timedelta(minutes=1), "datastoretest" ) assert len(val) == 2 data_store_with_test_events.update_event_escalated_at_to_now(val) record = data_store_with_test_events.get_latest_event_with_fingerprint(val[0].fingerprint) - assert record.escalated_at is not None assert isinstance(record.escalated_at, datetime) @freeze_time("2018-07-07 10:00:00") -# pylint: disable=missing-docstring, invalid-name def test_update_processed_at_timestamp_to_now(data_store_with_test_events): + """Tests that updating processed_at for events is working.""" val = data_store_with_test_events.get_unprocessed_events_batch( timedelta(minutes=1), timedelta(minutes=1), "datastoretest" ) assert len(val) == 2 data_store_with_test_events.update_processed_at_timestamp_to_now(val) record = data_store_with_test_events.get_latest_event_with_fingerprint(val[0].fingerprint) - assert record.processed_at is not None assert isinstance(record.processed_at, datetime) -def test_get_any_issues_need_reminder(): - data_store = comet_core.data_store.DataStore("sqlite://") - +def test_get_any_issues_need_reminder(data_store): + """Tests events that needs reminders. + + Part 1: Add three events to the datastore and check that two are returned. + issue / time ---> + 1 --------a------|--------------> + 2 ----a-------b--|--------------> (2c sent_at == NULL) + 3 ---------------|--------------> (3a sent_at == NULL) + ^ + -7days + + Part 2: Add another issue and check that only one event is returned. + issue / time ---> + 1 --------a------|-----b--------> + 2 ----a-------b--|--------------> (2c sent_at == NULL) + 3 ---------------|--------------> (3a sent_at == NULL) + ^ + -7days + """ test_fingerprint1 = "f1" test_fingerprint2 = "f2" test_fingerprint3 = "f3" @@ -200,6 +234,7 @@ def test_get_any_issues_need_reminder(): # -7days result = data_store.get_any_issues_need_reminder(timedelta(days=7), [one_a, two_a, three_a]) + assert len(result) == 2 assert test_fingerprint2 in result assert test_fingerprint1 in result @@ -214,13 +249,30 @@ def test_get_any_issues_need_reminder(): # -7days result = data_store.get_any_issues_need_reminder(timedelta(days=7), [one_a, two_a, three_a]) + assert len(result) == 1 assert test_fingerprint2 in result -def test_check_any_issue_needs_reminder(): - data_store = comet_core.data_store.DataStore("sqlite://") - +def test_check_any_issue_needs_reminder(data_store): + """Checks if any event needs reminder. + + Part 1: Add three events to the datastore and check that two are returned. + issue / time ---> + 1 --------a------|--------------> + 2 ----a-------b--|--------------> (2c sent_at == NULL) + 3 ---------------|--------------> (3a sent_at == NULL) + ^ + -7days + + Part 2: Add another issue and check that only one event is returned. + issue / time ---> + 1 --------a------|-----b--------> + 2 ----a-------b--|--------------> (2c sent_at == NULL) + 3 ---------------|--------------> (3a sent_at == NULL) + ^ + -7days + """ test_fingerprint1 = "f1" test_fingerprint2 = "f2" test_fingerprint3 = "f3" @@ -266,8 +318,14 @@ def test_check_any_issue_needs_reminder(): assert not data_store.check_any_issue_needs_reminder(timedelta(days=7), [one_a, two_a, three_a]) -def test_check_needs_escalation(): - data_store = comet_core.data_store.DataStore("sqlite://") +def test_check_nonexisting_event(data_store): + """Test escalating an event that does not exist.""" + event = EventRecord(source_type="datastoretest") + assert not data_store.check_needs_escalation(timedelta(days=1), event) + + +def test_check_needs_escalation(data_store): + """Checks if events needs escalation by adding multiple events with the same fingerprint.""" test_fingerprint1 = "f1" test_fingerprint2 = "f2" @@ -303,9 +361,8 @@ def test_check_needs_escalation(): assert not data_store.check_needs_escalation(timedelta(days=1), five) -def test_check_acceptedrisk_event_fingerprint(): - data_store = comet_core.data_store.DataStore("sqlite://") - +def test_check_acceptedrisk_event_fingerprint(data_store): + """Check that ignored events are properly handled by their fingerprint.""" test_fingerprint1 = "f1" assert not data_store.fingerprint_is_ignored(test_fingerprint1) @@ -314,19 +371,10 @@ def test_check_acceptedrisk_event_fingerprint(): assert data_store.fingerprint_is_ignored(test_fingerprint1) -def test_check_snoozed_event_fingerprint(): - data_store = comet_core.data_store.DataStore("sqlite://") - - test_fingerprint1 = "f1" +def test_check_snoozed_event(data_store): + """Check that snoozed events are not ignored.""" test_fingerprint2 = "f2" - assert not data_store.fingerprint_is_ignored(test_fingerprint1) - - data_store.ignore_event_fingerprint( - test_fingerprint1, ignore_type=IgnoreFingerprintRecord.SNOOZE, expires_at=datetime.utcnow() + timedelta(days=30) - ) - assert data_store.fingerprint_is_ignored(test_fingerprint1) - test_snooze_record = IgnoreFingerprintRecord( fingerprint=test_fingerprint2, ignore_type=IgnoreFingerprintRecord.SNOOZE, @@ -337,28 +385,24 @@ def test_check_snoozed_event_fingerprint(): assert not data_store.fingerprint_is_ignored(test_fingerprint2) -def test_may_send_escalation(): - data_store = comet_core.data_store.DataStore("sqlite://") +def test_may_send_escalation(data_store): + """Test the escalation function from a datastore with both escalated and non-escalated events.""" data_store.add_record(EventRecord(source_type="type1", escalated_at=None)) - assert data_store.may_send_escalation("type1", timedelta(days=7)) data_store.add_record(EventRecord(source_type="type1", escalated_at=datetime.utcnow() - timedelta(days=8))) - assert data_store.may_send_escalation("type1", timedelta(days=7)) data_store.add_record(EventRecord(source_type="type1", escalated_at=datetime.utcnow() - timedelta(days=6))) - assert not data_store.may_send_escalation("type1", timedelta(days=7)) data_store.add_record(EventRecord(source_type="type2", escalated_at=None)) - assert data_store.may_send_escalation("type2", timedelta(days=7)) -def test_check_if_previously_escalated(): - data_store = comet_core.data_store.DataStore("sqlite://") +def test_check_if_previously_escalated(data_store): + """Test the 'previously escalated' function by adding events and then escalate them.""" one = EventRecord(source_type="test_type", fingerprint="f1", escalated_at=None) data_store.add_record(one) @@ -376,8 +420,8 @@ def test_check_if_previously_escalated(): assert data_store.check_if_previously_escalated(one) -def test_get_open_issues(*_): - data_store = comet_core.data_store.DataStore("sqlite://") +def test_get_open_issues(data_store): + """Tests getting open issues by adding events of different types and check how many are still open.""" one = EventRecord(source_type="test_type", fingerprint="f1", received_at=datetime.utcnow(), owner="test") data_store.add_record(one) @@ -424,8 +468,8 @@ def test_get_open_issues(*_): assert len(open_issues) == 1 -def test_check_if_new(*_): - data_store = comet_core.data_store.DataStore("sqlite://") +def test_check_if_new(data_store): + """Check if there are new issues by adding a variety of different events.""" timestamp = datetime.utcnow() one_a = EventRecord(source_type="test_type", fingerprint="f1", received_at=timestamp) @@ -449,7 +493,7 @@ def test_check_if_new(*_): def test_remove_duplicate_events(): - """Test the remove_duplicate_events function""" + """Test the remove_duplicate_events function by ensuring that duplicate events are removed.""" one = EventRecord(received_at=datetime(2018, 2, 19, 0, 0, 11), source_type="datastoretest", owner="a", data={}) one.fingerprint = "f1" two = EventRecord(received_at=datetime(2018, 2, 20, 0, 0, 11), source_type="datastoretest", owner="a", data={}) @@ -464,95 +508,53 @@ def test_remove_duplicate_events(): assert len(records) == 2 -@pytest.fixture -def ds_instance(): - yield comet_core.data_store.DataStore("sqlite://") - - -@pytest.fixture -def non_addressed_event(): - # event sent but missing in the ignore_event table to - # indicate that it wasn't addressed by the user - event = EventRecord( - received_at=datetime(2018, 7, 7, 9, 0, 0), - source_type="datastoretest", - owner="a", - sent_at=datetime(2018, 7, 7, 9, 0, 0), - data={}, - ) - event.fingerprint = "f1" - return event - - -@pytest.fixture -def addressed_event(ds_instance): - # event was sent and addressed (ack by the user) - ack_event = EventRecord( - received_at=datetime(2018, 7, 7, 9, 30, 0), - source_type="datastoretest", - owner="a", - sent_at=datetime(2018, 7, 7, 9, 30, 0), - data={}, - ) - ack_event.fingerprint = "f2" - ds_instance.ignore_event_fingerprint(ack_event.fingerprint, ignore_type=IgnoreFingerprintRecord.ACKNOWLEDGE) - return ack_event - +def test_get_real_time_events_did_not_addressed(data_store_with_real_time_events, non_addressed_event): + """Test getting realtime events that were not addressed. -@pytest.fixture -def event_to_escalate(ds_instance): - # event was sent and addressed (escalated by the user) - escalated_event = EventRecord( - received_at=datetime(2018, 7, 7, 9, 30, 0), - source_type="datastoretest", - owner="a", - sent_at=datetime(2018, 7, 7, 9, 30, 0), - data={}, - ) - escalated_event.fingerprint = "f3" - ds_instance.ignore_event_fingerprint( - escalated_event.fingerprint, ignore_type=IgnoreFingerprintRecord.ESCALATE_MANUALLY - ) - return escalated_event + Compare the events by their string representation as the objects are not identical. + """ - -@pytest.fixture -def ds_with_real_time_events(ds_instance, addressed_event, non_addressed_event, event_to_escalate): - ds_instance.add_record(addressed_event) - ds_instance.add_record(non_addressed_event) - ds_instance.add_record(event_to_escalate) - return ds_instance - - -def test_get_real_time_events_did_not_addressed(ds_with_real_time_events, non_addressed_event): source_type = "datastoretest" - non_addressed_events = ds_with_real_time_events.get_events_did_not_addressed(source_type) + non_addressed_events = data_store_with_real_time_events.get_events_did_not_addressed(source_type) - assert non_addressed_event in non_addressed_events + # Compare the id of the events as the event objects themselves are not equal. + assert non_addressed_event.id in [x.id for x in non_addressed_events] -def test_get_real_time_events_need_escalation(ds_with_real_time_events, event_to_escalate): +def test_get_real_time_events_need_escalation(data_store_with_real_time_events, event_to_escalate): + """Test getting realtime events that were not escalated. + + Compare the events by their string representation as the objects are not identical. + """ source_type = "datastoretest" - events_to_escalate = ds_with_real_time_events.get_events_need_escalation(source_type) + events_to_escalate = data_store_with_real_time_events.get_events_need_escalation(source_type) - assert event_to_escalate in events_to_escalate + # Compare the id of the events as the event objects themselves are not equal. + assert event_to_escalate.__repr__() in [x.__repr__() for x in events_to_escalate] -def test_ignore_event_fingerprint_with_metadata(ds_instance): +def test_ignore_event_fingerprint_with_metadata(data_store): + """Test the ignore events feature by querying the database directly to ensure the fingerprint is stored.""" fingerprint = "f1" record_metadata = {"slack_channel": "channel"} - ds_instance.ignore_event_fingerprint( + data_store.ignore_event_fingerprint( fingerprint, ignore_type=IgnoreFingerprintRecord.ESCALATE_MANUALLY, record_metadata=record_metadata ) - result = ( - ds_instance.session.query(IgnoreFingerprintRecord) - .filter(IgnoreFingerprintRecord.fingerprint == fingerprint) - .one_or_none() - ) - assert result.record_metadata == record_metadata + with data_store.session.begin() as session: + result = ( + session.query(IgnoreFingerprintRecord) + .filter(IgnoreFingerprintRecord.fingerprint == fingerprint) + .one_or_none() + ) + assert result.record_metadata == record_metadata + +def test_get_interactions_for_fingerprint(data_store): + """Test fingerprint interaction of the data store. -def test_get_interactions_for_fingerprint(ds_instance): + Create an IgnoreFingerprintRecord, store it in the data store and read it back to ensure it has been stored + properly. + """ one_a = IgnoreFingerprintRecord( id=1, fingerprint="f1", @@ -562,7 +564,7 @@ def test_get_interactions_for_fingerprint(ds_instance): record_metadata=None, ) fingerprint = "f1" - ds_instance.ignore_event_fingerprint( + data_store.ignore_event_fingerprint( fingerprint, ignore_type=one_a.ignore_type, record_metadata=one_a.record_metadata, @@ -579,5 +581,5 @@ def test_get_interactions_for_fingerprint(ds_instance): "expires_at": datetime(2019, 1, 7, 0, 0, 11), } ] - result = ds_instance.get_interactions_fingerprint(fingerprint) + result = data_store.get_interactions_fingerprint(fingerprint) assert result == expected diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 0bbb973..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2018 Spotify AB. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities and Mocked Data for tests""" - -from comet_core.app import EventContainer - - -def get_all_test_messages(parsed=False): - """Get all test messages and their filenames as an iterator. - - Args: - parsed (bool): returns Event objects if true otherwise strings - Yields: - EventContainer: some test event - """ - event = EventContainer("test", {}) - event.set_owner("test@acme.org") - event.set_fingerprint("test") - return [event] diff --git a/tox.ini b/tox.ini index 8cc9759..bd9a7a4 100644 --- a/tox.ini +++ b/tox.ini @@ -13,8 +13,7 @@ envlist= setenv = SQLALCHEMY_WARN_20=1 deps = -rrequirements-dev.txt commands = - coverage erase - pytest --junitxml=test-reports/junit.xml --cov={toxinidir}/comet_core/ --cov-report=term --cov-report=xml:test-reports/cobertura.xml {toxinidir}/tests/ + python3 -m pytest --junitxml=test-reports/junit.xml --cov={toxinidir}/comet_core --cov-report=term-missing --cov-report=xml:test-reports/cobertura.xml {toxinidir}/tests/ [testenv:format] basepython = python3.6 @@ -33,5 +32,13 @@ commands = [testenv:lint] basepython = python3.6 deps = -rrequirements-dev.txt +skip_install = true commands = python3 -m pylint --rcfile={toxinidir}/.pylintrc {toxinidir}/comet_core {toxinidir}/tests + +[testenv:types] +basepython = python3.6 +deps = -rrequirements-dev.txt +skip_install = true +commands = + python3 -m mypy --config {toxinidir}/pyproject.toml {toxinidir}/comet_core