Skip to content

Commit

Permalink
Merge pull request #64 from City-of-Helsinki/MAAS-101-fix-api
Browse files Browse the repository at this point in the history
MAAS-101 | Fix StopSerializer.departures not using distinct StopTimes
  • Loading branch information
Pekka Lampila authored Jun 1, 2021
2 parents 7c184d0 + 335d1c2 commit 4e195b3
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 37 deletions.
56 changes: 27 additions & 29 deletions gtfs/api/stops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class CoordinateSerializer(serializers.Serializer):
longitude = serializers.FloatField(source="point.x", read_only=True)


class DepartureSerializer(serializers.ModelSerializer):
id = serializers.UUIDField(source="api_id")
class StopTimeSerializer(serializers.ModelSerializer):
id = serializers.SerializerMethodField()
short_name = serializers.CharField(source="trip.short_name")
arrival_time = serializers.SerializerMethodField()
departure_time = serializers.SerializerMethodField()
direction_id = serializers.IntegerField(source="trip.direction_id")
departure_headsign = serializers.CharField(source="trip.headsign")
stop_headsign = serializers.SerializerMethodField()
stop_sequence = serializers.SerializerMethodField()
stop_headsign = serializers.CharField()
stop_sequence = serializers.IntegerField()
wheelchair_accessible = serializers.IntegerField(
source="trip.wheelchair_accessible"
)
Expand All @@ -33,10 +33,10 @@ class DepartureSerializer(serializers.ModelSerializer):
source="trip.route", slug_field="api_id", read_only=True
)
block_id = serializers.CharField(source="trip.block_id")
timepoint = serializers.SerializerMethodField()
timepoint = serializers.IntegerField()

class Meta:
model = Departure
model = StopTime
fields = (
"id",
"short_name",
Expand All @@ -59,25 +59,20 @@ def get_fields(self):
del fields["route_id"]
return fields

@extend_schema_field(OpenApiTypes.UUID)
def get_id(self, obj):
departure = obj.trip.dates_departure[0]
return departure.api_id

@extend_schema_field(OpenApiTypes.DATETIME)
def get_arrival_time(self, obj):
return obj.trip.stops_stop_times[0].get_arrival_time_datetime(obj)
departure = obj.trip.dates_departure[0]
return obj.get_arrival_time_datetime(departure)

@extend_schema_field(OpenApiTypes.DATETIME)
def get_departure_time(self, obj):
return obj.trip.stops_stop_times[0].get_departure_time_datetime(obj)

@extend_schema_field(OpenApiTypes.STR)
def get_stop_headsign(self, obj):
return obj.trip.stops_stop_times[0].stop_headsign

@extend_schema_field(OpenApiTypes.INT)
def get_stop_sequence(self, obj):
return obj.trip.stops_stop_times[0].stop_sequence

@extend_schema_field(OpenApiTypes.INT)
def get_timepoint(self, obj):
return obj.trip.stops_stop_times[0].timepoint
departure = obj.trip.dates_departure[0]
return obj.get_departure_time_datetime(departure)


class StopSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -109,30 +104,33 @@ def get_coordinates(self, obj):
if obj.point:
return CoordinateSerializer(obj).data

@extend_schema_field(DepartureSerializer(many=True))
@extend_schema_field(StopTimeSerializer(many=True))
def get_departures(self, obj):
if "date" not in self.context:
return None

queryset = (
Departure.objects.filter(
trip__stop_times__stop=obj, date=self.context["date"]
StopTime.objects.filter(
stop=obj, trip__departures__date=self.context["date"]
)
.select_related("trip", "trip__route", "trip__route__agency")
.prefetch_related(
Prefetch(
"trip__stop_times",
queryset=StopTime.objects.filter(stop=obj),
to_attr="stops_stop_times",
),
"trip__departures",
queryset=Departure.objects.filter(
date=self.context["date"], trip__stop_times__stop=obj
),
to_attr="dates_departure",
)
)
.order_by("trip__stop_times__departure_time")
.order_by("departure_time")
)
if "direction_id" in self.context:
queryset = queryset.filter(trip__direction_id=self.context["direction_id"])
if "route_id" in self.context:
queryset = queryset.filter(trip__route_id=self.context["route_id"])

return DepartureSerializer(queryset, many=True, context=self.context).data
return StopTimeSerializer(queryset, many=True, context=self.context).data


class RadiusToLocationFilter(DistanceToPointFilter):
Expand Down
21 changes: 15 additions & 6 deletions gtfs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ def api_id_generator():


@pytest.fixture
def route_with_departures(maas_operator, api_id_generator):
"""
A route with
* 2 trips each having 2 stops and 2 stop times
* 3 departures (first trip having two departures on separate days)
"""
def route_for_maas_operator(maas_operator, api_id_generator):
feed = get_feed_for_maas_operator(maas_operator, True)

agency = baker.make(
Expand All @@ -63,6 +58,20 @@ def route_with_departures(maas_operator, api_id_generator):
url="url of test route ",
capacity_sales=Route.CapacitySales.DISABLED,
)

return route


@pytest.fixture
def route_with_departures(api_id_generator, route_for_maas_operator):
"""
A route with
* 2 trips each having 2 stops and 2 stop times
* 3 departures (first trip having two departures on separate days)
"""
route = route_for_maas_operator
feed = route_for_maas_operator.feed

trips = baker.make(
Trip,
route=route,
Expand Down
33 changes: 33 additions & 0 deletions gtfs/tests/snapshots/snap_test_stops_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,36 @@
]

snapshots["test_stops_departures[filters6] 1"] = []

snapshots["test_stops_departures__stop_appears_multiple_times_in_trip 1"] = [
{
"arrival_time": "2021-02-18T07:00:00Z",
"bikes_allowed": 0,
"block_id": "block_id of test trip 1",
"departure_headsign": "headsign of test trip ",
"departure_time": "2021-02-18T08:00:00Z",
"direction_id": 0,
"id": "00000000-0000-0000-0000-000000000004",
"route_id": "00000000-0000-0000-0000-000000000000",
"short_name": "short_name of test trip ",
"stop_headsign": "stop_headsign of test stop time ",
"stop_sequence": 2,
"timepoint": 1,
"wheelchair_accessible": 0,
},
{
"arrival_time": "2021-02-18T09:00:00Z",
"bikes_allowed": 0,
"block_id": "block_id of test trip 1",
"departure_headsign": "headsign of test trip ",
"departure_time": "2021-02-18T10:00:00Z",
"direction_id": 0,
"id": "00000000-0000-0000-0000-000000000004",
"route_id": "00000000-0000-0000-0000-000000000000",
"short_name": "short_name of test trip ",
"stop_headsign": "stop_headsign of test stop time ",
"stop_sequence": 4,
"timepoint": 1,
"wheelchair_accessible": 0,
},
]
86 changes: 84 additions & 2 deletions gtfs/tests/test_stops_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import itertools
import json
from datetime import date, timedelta
from uuid import UUID

import pytest
from django.contrib.gis.geos import Point
from model_bakery import baker
from model_bakery import baker, seq

from gtfs.models import Stop
from gtfs.models import Departure, Stop, StopTime, Trip
from gtfs.tests.utils import clean_stops_for_snapshot, get_feed_for_maas_operator

ENDPOINT = "/v1/stops/"
Expand Down Expand Up @@ -85,3 +86,84 @@ def test_stops_departures(maas_api_client, snapshot, filters, route_with_departu
snapshot.assert_match(content["departures"])
else:
assert "departures" not in content


@pytest.mark.django_db
def test_stops_departures__stop_appears_multiple_times_in_trip(
maas_api_client,
route_for_maas_operator,
api_id_generator,
snapshot,
django_assert_max_num_queries,
):
"""Same stop appears twice (or more) in a trip.
There is more than one StopTime with the same Stop and Trip, but with different
arrival and departure times. Stop serializer should return departures with separate
stop times.
"""
route = route_for_maas_operator
feed = route_for_maas_operator.feed

trip = baker.make(
Trip,
route=route,
feed=feed,
source_id="source_id of test trip ",
short_name="short_name of test trip ",
headsign="headsign of test trip ",
direction_id=0,
block_id=seq("block_id of test trip "),
)
stops = baker.make(
Stop,
feed=feed,
api_id=api_id_generator,
name="stop ",
tts_name="tts_name of stop ",
code=seq("code of stop"),
desc="desc of test stop ",
_quantity=3,
)
baker.make(
StopTime,
trip=trip,
stop=iter([stops[0], stops[1], stops[2], stops[1]]),
feed=feed,
# -2 hours in Helsinki time
arrival_time=iter(
[
timedelta(hours=8),
timedelta(hours=9),
timedelta(hours=10),
timedelta(hours=11),
]
),
# -2 hours in Helsinki time
departure_time=iter(
[
timedelta(hours=9),
timedelta(hours=10),
timedelta(hours=11),
timedelta(hours=12),
]
),
stop_headsign="stop_headsign of test stop time ",
stop_sequence=seq(0),
timepoint=StopTime.Timepoint.EXACT,
_quantity=4,
)
baker.make(
Departure,
api_id=api_id_generator,
trip=trip,
date=date(2021, 2, 18),
)

with django_assert_max_num_queries(5):
response = maas_api_client.get(
ENDPOINT + f"{stops[1].api_id}/", {"date": "2021-02-18"}
)
content = response.json()

snapshot.assert_match(content["departures"])

0 comments on commit 4e195b3

Please sign in to comment.