diff --git a/polymorphic/tests.py b/polymorphic/tests.py index 9c2cb3d4..114db05b 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -167,6 +167,22 @@ class ModelUnderRelChild(PolymorphicModel): _private2 = models.CharField(max_length=10) +class Participant(PolymorphicModel): + pass + + +class UserProfile(Participant): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +class Team(models.Model): + team_name = models.CharField(max_length=100) + user_profiles = models.ManyToManyField(UserProfile, related_name='user_teams') + + class MyManagerQuerySet(PolymorphicQuerySet): def my_queryset_foo(self): @@ -1325,6 +1341,44 @@ def func(): # Ensure no queries are made using the default database. self.assertNumQueries(0, func) + def test_unknown_issue(self): + user_a = UserProfile.objects.create(name='a') + user_b = UserProfile.objects.create(name='b') + user_c = UserProfile.objects.create(name='c') + + team1 = Team.objects.create(team_name='team1') + team1.user_profiles.add(user_a, user_b, user_c) + team1.save() + + team2 = Team.objects.create(team_name='team2') + team2.user_profiles.add(user_c) + team2.save() + + # without prefetch_related, the test passes + my_teams = Team.objects.filter(user_profiles=user_c).prefetch_related('user_profiles').distinct() + + print(my_teams[0].user_profiles.all()) + print(my_teams[1].user_profiles.all()) + self.assertEqual(len(my_teams[0].user_profiles.all()), 3) + self.assertEqual(len(my_teams[1].user_profiles.all()), 1) + + print(my_teams[0].user_profiles.all()) + print(my_teams[1].user_profiles.all()) + self.assertEqual(len(my_teams[0].user_profiles.all()), 3) + self.assertEqual(len(my_teams[1].user_profiles.all()), 1) + + # without this "for" loop, the test passes + for _ in my_teams: + pass + + print(my_teams[0].user_profiles.all()) + print(my_teams[1].user_profiles.all()) + # This time, test fails. + # with sqlite: 4 != 3 + # with postgresql: 2 != 3 + self.assertEqual(len(my_teams[0].user_profiles.all()), 3) + self.assertEqual(len(my_teams[1].user_profiles.all()), 1) + def qrepr(data): """