diff --git a/custom_components/nysse/const.py b/custom_components/nysse/const.py index 289d81b..964db73 100644 --- a/custom_components/nysse/const.py +++ b/custom_components/nysse/const.py @@ -1,3 +1,5 @@ +"""Constants for the Nysse component.""" + DOMAIN = "nysse" PLATFORM_NAME = "Nysse" @@ -11,29 +13,10 @@ DEFAULT_ICON = "mdi:bus-clock" TRAM_LINES = ["1", "3"] -WEEKDAYS = [ - "Monday", - "Tuesday", - "Wednesday", - "Thursday", - "Friday", - "Saturday", - "Sunday", -] - -DEFAULT_TIME_ZONE = "Europe/Helsinki" - -JOURNEY = "journey" -DEPARTURE = "departure" -AIMED_ARRIVAL_TIME = "aimedArrivalTime" AIMED_DEPARTURE_TIME = "aimedDepartureTime" -EXPECTED_ARRIVAL_TIME = "expectedArrivalTime" EXPECTED_DEPARTURE_TIME = "expectedDepartureTime" STOP_URL = "https://data.itsfactory.fi/journeys/api/1/stop-monitoring?stops={0}" -STOP_POINTS_URL = "http://data.itsfactory.fi/journeys/api/1/stop-points/" -JOURNEYS_URL = "http://data.itsfactory.fi/journeys/api/1/journeys?stopPointId={0}&dayTypes={1}&startIndex={2}" -LINES_URL = "https://data.itsfactory.fi/journeys/api/1/lines?stopPointId={0}" SERVICE_ALERTS_URL = ( "https://data.itsfactory.fi/journeys/api/1/gtfs-rt/service-alerts/json" ) diff --git a/custom_components/nysse/fetch_api.py b/custom_components/nysse/fetch_api.py index d22e7da..e214e76 100644 --- a/custom_components/nysse/fetch_api.py +++ b/custom_components/nysse/fetch_api.py @@ -1,5 +1,7 @@ +"""Fetches data from the Nysse GTFS API.""" + import csv -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta import logging import os import pathlib @@ -39,12 +41,12 @@ async def _fetch_gtfs(): try: path = _get_dir_path() filename = "extended_gtfs_tampere.zip" - timestamp = _get_file_modified_time(path + filename) - if timestamp != datetime(1970, 1, 1) or datetime.now().minute != 0: + if os.path.isfile(path + filename) and datetime.now().minute != 0: _LOGGER.debug("Skipped fetching GTFS data") return # Skip fetching if the file exists or it's not the top of the hour + timestamp = _get_file_modified_time(path + filename) - _LOGGER.debug("Fetching GTFS data") + _LOGGER.debug("Fetching GTFS data from %s", GTFS_URL) timeout = aiohttp.ClientTimeout(total=30) async with ( aiohttp.ClientSession(timeout=timeout) as session, @@ -72,7 +74,14 @@ async def _read_csv_to_db(): # Stops cursor.execute( - "CREATE TABLE IF NOT EXISTS stops (stop_id TEXT PRIMARY KEY, stop_name TEXT, stop_lat TEXT, stop_lon TEXT)" + """ + CREATE TABLE IF NOT EXISTS stops ( + stop_id TEXT PRIMARY KEY, + stop_name TEXT, + stop_lat TEXT, + stop_lon TEXT + ) + """ ) stops = _parse_csv_file(_get_dir_path() + "stops.txt") to_db = [ @@ -85,7 +94,15 @@ async def _read_csv_to_db(): # Routes cursor.execute( - "CREATE TABLE IF NOT EXISTS trips (trip_id TEXT PRIMARY KEY, route_id TEXT, service_id TEXT, trip_headsign TEXT, direction_id TEXT)" + """ + CREATE TABLE IF NOT EXISTS trips ( + trip_id TEXT PRIMARY KEY, + route_id TEXT, + service_id TEXT, + trip_headsign TEXT, + direction_id TEXT + ) + """ ) trips = _parse_csv_file(_get_dir_path() + "trips.txt") to_db = [ @@ -99,13 +116,30 @@ async def _read_csv_to_db(): for i in trips ] cursor.executemany( - "INSERT OR REPLACE INTO trips (trip_id, route_id, service_id, trip_headsign, direction_id) VALUES (?, ?, ?, ?, ?)", + """ + INSERT OR REPLACE INTO trips + (trip_id, route_id, service_id, trip_headsign, direction_id) + VALUES (?, ?, ?, ?, ?) + """, to_db, ) # Calendar cursor.execute( - "CREATE TABLE IF NOT EXISTS calendar (service_id TEXT PRIMARY KEY, monday TEXT, tuesday TEXT, wednesday TEXT, thursday TEXT, friday TEXT, saturday TEXT, sunday TEXT)" + """ + CREATE TABLE IF NOT EXISTS calendar ( + service_id TEXT PRIMARY KEY, + monday TEXT, + tuesday TEXT, + wednesday TEXT, + thursday TEXT, + friday TEXT, + saturday TEXT, + sunday TEXT, + start_date TEXT, + end_date TEXT + ) + """ ) calendar = _parse_csv_file(_get_dir_path() + "calendar.txt") to_db = [ @@ -118,17 +152,32 @@ async def _read_csv_to_db(): i["friday"], i["saturday"], i["sunday"], + i["start_date"], + i["end_date"], ) for i in calendar ] cursor.executemany( - "INSERT OR REPLACE INTO calendar (service_id, monday, tuesday, wednesday, thursday, friday, saturday, sunday) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + """ + INSERT OR REPLACE INTO calendar + (service_id, monday, tuesday, wednesday, thursday, friday, saturday, sunday, start_date, end_date) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, to_db, ) # Stop times cursor.execute( - "CREATE TABLE IF NOT EXISTS stop_times (trip_id TEXT, arrival_time TIME, departure_time TIME, stop_id TEXT, stop_sequence TEXT, PRIMARY KEY(trip_id, arrival_time))" + """ + CREATE TABLE IF NOT EXISTS stop_times ( + trip_id TEXT, + arrival_time TIME, + departure_time TIME, + stop_id TEXT, + stop_sequence TEXT, + PRIMARY KEY(trip_id, arrival_time) + ) + """ ) stop_times = _parse_csv_file(_get_dir_path() + "stop_times.txt") to_db = [ @@ -143,7 +192,11 @@ async def _read_csv_to_db(): ] to_db.sort(key=lambda x: x[2]) # Sort by departure_time cursor.executemany( - "INSERT OR REPLACE INTO stop_times (trip_id, arrival_time, departure_time, stop_id, stop_sequence) VALUES (?, ?, ?, ?, ?)", + """ + INSERT OR REPLACE INTO stop_times + (trip_id, arrival_time, departure_time, stop_id, stop_sequence) + VALUES (?, ?, ?, ?, ?) + """, to_db, ) @@ -198,7 +251,15 @@ async def get_route_ids(stop_id): await _fetch_gtfs() conn, cursor = _get_database() cursor.execute( - "SELECT DISTINCT route_id FROM trips WHERE trip_id IN (SELECT trip_id FROM stop_times WHERE stop_id = ?)", + """ + SELECT DISTINCT route_id + FROM trips + WHERE trip_id IN ( + SELECT trip_id + FROM stop_times + WHERE stop_id = ? + ) + """, (stop_id,), ) route_ids = [row[0] for row in cursor.fetchall()] @@ -221,18 +282,38 @@ async def get_stop_times(stop_id, route_ids, amount, from_time): """ await _fetch_gtfs() conn, cursor = _get_database() - today = datetime.now().strftime("%Y-%m-%d") - weekday = datetime.strptime(today, "%Y-%m-%d").strftime("%A").lower() - if from_time: - cursor.execute( - f"SELECT route_id, trip_headsign, departure_time FROM stop_times JOIN trips ON stop_times.trip_id = trips.trip_id JOIN calendar ON trips.service_id = calendar.service_id WHERE stop_id = ? AND trips.route_id IN ({','.join(['?']*len(route_ids))}) AND calendar.{weekday} = '1' AND departure_time > ? LIMIT ?", - [stop_id, *route_ids, from_time.strftime("%H:%M:%S"), amount], - ) - else: + today = datetime.now().strftime("%Y%m%d") + weekday = datetime.strptime(today, "%Y%m%d").strftime("%A").lower() + stop_times = [] + delta_days = 0 + while len(stop_times) < amount: cursor.execute( - f"SELECT route_id, trip_headsign, departure_time FROM stop_times JOIN trips ON stop_times.trip_id = trips.trip_id JOIN calendar ON trips.service_id = calendar.service_id WHERE stop_id = ? AND trips.route_id IN ({','.join(['?']*len(route_ids))}) AND calendar.{weekday} = '1' LIMIT ?", - [stop_id, *route_ids, amount], + f""" + SELECT stop_times.trip_id, route_id, trip_headsign, departure_time, {delta_days} as delta_days + FROM stop_times + JOIN trips ON stop_times.trip_id = trips.trip_id + JOIN calendar ON trips.service_id = calendar.service_id + WHERE stop_id = ? + AND trips.route_id IN ({','.join(['?']*len(route_ids))}) + AND calendar.{weekday} = '1' + AND calendar.start_date < ? + AND departure_time > ? + LIMIT ? + """, + [stop_id, *route_ids, today, from_time.strftime("%H:%M:%S"), amount], ) - stop_times = cursor.fetchall() + stop_times += cursor.fetchall() + if len(stop_times) >= amount: + break + # If there are no more stop times for today, move to the next day + delta_days += 1 + if delta_days == 7: + _LOGGER.debug( + "Not enough departures found. Consider decreasing the amount of requested departures" + ) + break + next_day = datetime.strptime(today, "%Y%m%d") + timedelta(days=1) + today = next_day.strftime("%Y%m%d") + weekday = next_day.strftime("%A").lower() conn.close() - return stop_times + return stop_times[:amount] diff --git a/custom_components/nysse/sensor.py b/custom_components/nysse/sensor.py index 10cf1f2..38c4d92 100644 --- a/custom_components/nysse/sensor.py +++ b/custom_components/nysse/sensor.py @@ -14,16 +14,12 @@ import homeassistant.util.dt as dt_util from .const import ( - AIMED_ARRIVAL_TIME, AIMED_DEPARTURE_TIME, DEFAULT_ICON, DEFAULT_MAX, DEFAULT_TIMELIMIT, - DEPARTURE, DOMAIN, - EXPECTED_ARRIVAL_TIME, EXPECTED_DEPARTURE_TIME, - JOURNEY, PLATFORM_NAME, SERVICE_ALERTS_URL, STOP_URL, @@ -34,7 +30,6 @@ _LOGGER = logging.getLogger(__name__) SCAN_INTERVAL = timedelta(seconds=30) -PAGE_SIZE = 100 async def async_setup_entry( @@ -76,27 +71,18 @@ class NysseSensor(SensorEntity): def __init__(self, stop_code, maximum, timelimit, lines) -> None: """Initialize the sensor.""" - self._unique_id = PLATFORM_NAME + "_" + stop_code - self.stop_code = stop_code - self.max_items = int(maximum) - self.timelimit = int(timelimit) - self.lines = lines + self._stop_code = stop_code + self._max_items = int(maximum) + self._timelimit = int(timelimit) + self._lines = lines self._journeys = [] self._stops = [] self._all_data = [] - self._current_weekday_int = -1 self._last_update_time = None - async def fetch_stops(self, force=False): - """Fetch stops if not fetched already.""" - if len(self._stops) == 0 or force: - _LOGGER.debug("Fetching stops") - self._stops = await get_stops() - - def remove_unwanted_data(self, departures): - """Remove stale and unwanted data.""" + def _remove_unwanted_departures(self, departures): removed_departures_count = 0 # Remove unwanted departures based on departure time and line number @@ -106,8 +92,8 @@ def remove_unwanted_data(self, departures): ) if ( departure_local - < self._last_update_time + timedelta(minutes=self.timelimit) - or departure["route_id"] not in self.lines + < self._last_update_time + timedelta(minutes=self._timelimit) + or departure["route_id"] not in self._lines ): departures.remove(departure) removed_departures_count += 1 @@ -115,41 +101,39 @@ def remove_unwanted_data(self, departures): if removed_departures_count > 0: _LOGGER.debug( "%s: Removed %s stale or unwanted departures", - self.stop_code, + self._stop_code, removed_departures_count, ) - return departures[: self.max_items] + return departures[: self._max_items] - async def fetch_departures(self): - """Fetch live stop monitoring data.""" - url = STOP_URL.format(self.stop_code) + async def _fetch_departures(self): + url = STOP_URL.format(self._stop_code) _LOGGER.debug( "%s: Fectching departures from %s", - self.stop_code, + self._stop_code, url + "&indent=yes", ) data = await get(url) if not data: _LOGGER.warning( "%s: Nysse API error: failed to fetch realtime data: no data received from %s", - self.stop_code, + self._stop_code, url, ) return unformatted_departures = json.loads(data) - return self.format_departures(unformatted_departures) + return self._format_departures(unformatted_departures) - def format_departures(self, departures): - """Format live stop monitoring data.""" + def _format_departures(self, departures): try: - body = departures["body"][self.stop_code] + body = departures["body"][self._stop_code] formatted_data = [] for departure in body: try: formatted_departure = { "route_id": departure["lineRef"], - "trip_headsign": self.get_stop_name( + "trip_headsign": self._get_stop_name( departure["destinationShortName"] ), "departure_time": departure["call"][EXPECTED_DEPARTURE_TIME], @@ -164,7 +148,7 @@ def format_departures(self, departures): except KeyError as err: _LOGGER.info( "%s: Failed to process realtime departure: %s", - self.stop_code, + self._stop_code, err, ) continue @@ -172,106 +156,30 @@ def format_departures(self, departures): except KeyError as err: _LOGGER.info( "%s: Nysse API error: failed to process realtime data: %s", - self.stop_code, + self._stop_code, err, ) return [] - def format_journeys(self, journeys, weekday): - """Format static timetable data.""" - formatted_data = [] - - if weekday == self._current_weekday_int: - delta = timedelta(seconds=0) - elif weekday > self._current_weekday_int: - delta = timedelta(days=weekday - self._current_weekday_int) - else: - delta = timedelta(days=7 - self._current_weekday_int + weekday) - - try: - for journey in journeys["body"]: - for call in journey["calls"]: - try: - formatted_journey = { - "line": journey["lineUrl"].split("/")[7], - "stopCode": call["stopPoint"]["shortName"], - "destinationCode": journey["calls"][-1]["stopPoint"][ - "shortName" - ], - "departureTime": self.get_departure_time( - call, JOURNEY, delta - ), - "realtime": False, - } - if formatted_journey["departureTime"] is not None: - formatted_data.append(formatted_journey) - except KeyError as err: - _LOGGER.info( - "%s: Failed to process timetable departure: %s", - self.stop_code, - err, - ) - continue - except KeyError as err: - _LOGGER.info( - "%s: Nysse API error: failed to fetch timetable data: %s", - self.stop_code, - err, - ) - return formatted_data - - def get_departure_time( - self, item, item_type, delta=timedelta(seconds=0), time_type="" - ): - """Calculate departure time.""" - try: - if item_type == DEPARTURE: - if time_type != "": - parsed_time = parser.parse(item["call"][time_type]) - return parsed_time - try: - time_fields = [ - item["call"][EXPECTED_ARRIVAL_TIME], - item["call"][EXPECTED_DEPARTURE_TIME], - item["call"][AIMED_ARRIVAL_TIME], - item["call"][AIMED_DEPARTURE_TIME], - ] - for field in time_fields: - parsed_time = parser.parse(field) - return parsed_time - except (ValueError, KeyError): - pass - return None - except (ValueError, KeyError): - return None - async def async_update(self) -> None: """Fetch new state data for the sensor.""" self._last_update_time = dt_util.now() - self._current_weekday_int = self._last_update_time.weekday() try: - await self.fetch_stops() if len(self._stops) == 0: - return + _LOGGER.debug("Getting stops") + self._stops = await get_stops() - departures = await self.fetch_departures() - departures = self.remove_unwanted_data(departures) - if len(departures) < self.max_items: + departures = await self._fetch_departures() + departures = self._remove_unwanted_departures(departures) + if len(departures) < self._max_items: self._journeys = await get_stop_times( - self.stop_code, - self.lines, - self.max_items - len(departures), - self._last_update_time, + self._stop_code, + self._lines, + self._max_items - len(departures), + self._last_update_time + timedelta(minutes=self._timelimit), ) for journey in self._journeys[:]: - print( - journey["route_id"] - + " - " - + journey["trip_headsign"] - + " - " - + journey["departure_time"] - ) for departure in departures: departure_time = parser.parse(departure["aimed_departure_time"]) journey_time = parser.parse(journey["departure_time"]) @@ -283,47 +191,44 @@ async def async_update(self) -> None: else: self._journeys.clear() - self._all_data = self.data_to_display_format(departures + self._journeys) + self._all_data = self._data_to_display_format(departures + self._journeys) _LOGGER.debug( "%s: Got %s valid departures and %s valid journeys", - self.stop_code, + self._stop_code, len(departures), len(self._journeys), ) - _LOGGER.debug("%s: Data fetching complete", self.stop_code) except OSError as err: - _LOGGER.error("%s: Failed to update sensor: %s", self.stop_code, err) + _LOGGER.error("%s: Failed to update sensor: %s", self._stop_code, err) - def data_to_display_format(self, data): - """Format data to be displayed in sensor attributes.""" + def _data_to_display_format(self, data): formatted_data = [] for item in data: departure = { "destination": item["trip_headsign"], "line": item["route_id"], "departure": parser.parse(item["departure_time"]).strftime("%H:%M"), - "time_to_station": self.time_to_station(item), - "icon": self.get_line_icon(item["route_id"]), + "time_to_station": self._time_to_station(item), + "icon": self._get_line_icon(item["route_id"]), "realtime": item["realtime"] if "realtime" in item else False, } formatted_data.append(departure) return formatted_data - def get_line_icon(self, line_no): - """Get line icon based on operating vehicle type.""" + def _get_line_icon(self, line_no): if line_no in TRAM_LINES: return "mdi:tram" return "mdi:bus" - def time_to_station(self, item): - """Get time until departure.""" + def _time_to_station(self, item): departure_local = dt_util.as_local(parser.parse(item["departure_time"])) + if "delta_days" in item: + departure_local += timedelta(days=item["delta_days"]) next_departure_time = (departure_local - self._last_update_time).seconds return int(next_departure_time / 60) - def get_stop_name(self, stop_id): - """Get the name of the stop.""" + def _get_stop_name(self, stop_id): return next( (stop["stop_name"] for stop in self._stops if stop["stop_id"] == stop_id), "unknown stop", @@ -332,13 +237,13 @@ def get_stop_name(self, stop_id): @property def unique_id(self) -> str: """Unique id for the sensor.""" - return PLATFORM_NAME + "_" + self.stop_code + return PLATFORM_NAME + "_" + self._stop_code @property def name(self) -> str: """Return the name of the sensor.""" - stop_name = self.get_stop_name(self.stop_code) - return f"{stop_name} ({self.stop_code})" + stop_name = self._get_stop_name(self._stop_code) + return f"{stop_name} ({self._stop_code})" @property def icon(self) -> str: @@ -358,7 +263,7 @@ def extra_state_attributes(self): attributes = { "last_refresh": self._last_update_time, "departures": self._all_data, - "station_name": self.get_stop_name(self.stop_code), + "station_name": self._get_stop_name(self._stop_code), } return attributes @@ -372,20 +277,17 @@ def __init__(self) -> None: self._alerts = [] self._empty_response_counter = 0 - def timestamp_to_local(self, timestamp): - """Convert timestamp to local datetime.""" + def _timestamp_to_local(self, timestamp): utc = dt_util.utc_from_timestamp(int(str(timestamp)[:10])) return dt_util.as_local(utc) - def conditionally_clear_alerts(self): - """Clear alerts if none received in 20 tries.""" + def _conditionally_clear_alerts(self): # TODO: Individual alerts may never be removed if self._empty_response_counter >= 20: self._empty_response_counter = 0 self._alerts.clear() - async def fetch_service_alerts(self): - """Fetch service alerts.""" + async def _fetch_service_alerts(self): try: alerts = [] data = await get(SERVICE_ALERTS_URL) @@ -397,15 +299,15 @@ async def fetch_service_alerts(self): return json_data = json.loads(data) - self._last_update = self.timestamp_to_local( + self._last_update = self._timestamp_to_local( json_data["header"]["timestamp"] ) for item in json_data["entity"]: - start_time = self.timestamp_to_local( + start_time = self._timestamp_to_local( item["alert"]["active_period"][0]["start"] ) - end_time = self.timestamp_to_local( + end_time = self._timestamp_to_local( item["alert"]["active_period"][0]["end"] ) description = item["alert"]["description_text"]["translation"][0][ @@ -423,7 +325,7 @@ async def fetch_service_alerts(self): except KeyError: self._empty_response_counter += 1 - self.conditionally_clear_alerts() + self._conditionally_clear_alerts() return self._alerts except OSError as err: _LOGGER.error("Failed to fetch service alerts: %s", err) @@ -431,7 +333,7 @@ async def fetch_service_alerts(self): async def async_update(self) -> None: """Fetch new state data for the sensor.""" - self._alerts = await self.fetch_service_alerts() + self._alerts = await self._fetch_service_alerts() @property def unique_id(self) -> str: