Skip to content

Commit

Permalink
Use plain assertions where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Dec 12, 2023
1 parent 1974a50 commit 9e2529c
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 338 deletions.
27 changes: 12 additions & 15 deletions polymorphic/tests/admintestcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def admin_get_add(self, model, qs=""):
admin_instance = self.get_admin_instance(model)
request = self.create_admin_request("get", self.get_add_url(model) + qs)
response = admin_instance.add_view(request)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
return response

def admin_post_add(self, model, formdata, qs=""):
Expand All @@ -114,7 +114,7 @@ def admin_get_changelist(self, model):
admin_instance = self.get_admin_instance(model)
request = self.create_admin_request("get", self.get_changelist_url(model))
response = admin_instance.changelist_view(request)
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
return response

def admin_get_change(self, model, object_id, query=None, **extra):
Expand All @@ -126,7 +126,7 @@ def admin_get_change(self, model, object_id, query=None, **extra):
"get", self.get_change_url(model, object_id), data=query, **extra
)
response = admin_instance.change_view(request, str(object_id))
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
return response

def admin_post_change(self, model, object_id, formdata, **extra):
Expand All @@ -150,7 +150,7 @@ def admin_get_history(self, model, object_id, query=None, **extra):
"get", self.get_history_url(model, object_id), data=query, **extra
)
response = admin_instance.history_view(request, str(object_id))
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
return response

def admin_get_delete(self, model, object_id, query=None, **extra):
Expand All @@ -162,7 +162,7 @@ def admin_get_delete(self, model, object_id, query=None, **extra):
"get", self.get_delete_url(model, object_id), data=query, **extra
)
response = admin_instance.delete_view(request, str(object_id))
self.assertEqual(response.status_code, 200)
assert response.status_code == 200
return response

def admin_post_delete(self, model, object_id, **extra):
Expand All @@ -175,7 +175,7 @@ def admin_post_delete(self, model, object_id, **extra):
admin_instance = self.get_admin_instance(model)
request = self.create_admin_request("post", self.get_delete_url(model, object_id), **extra)
response = admin_instance.delete_view(request, str(object_id))
self.assertEqual(response.status_code, 302, f"Form errors in calling {request.path}")
assert response.status_code == 302, f"Form errors in calling {request.path}"
return response

def create_admin_request(self, method, url, data=None, **extra):
Expand Down Expand Up @@ -209,7 +209,7 @@ def assertFormSuccess(self, request_url, response):
"""
Assert that the response was a redirect, not a form error.
"""
self.assertIn(response.status_code, [200, 302])
assert response.status_code in [200, 302]
if response.status_code != 302:
context_data = response.context_data
if "errors" in context_data:
Expand All @@ -219,12 +219,9 @@ def assertFormSuccess(self, request_url, response):
else:
raise KeyError("Unknown field for errors in the TemplateResponse!")

self.assertEqual(
response.status_code,
302,
"Form errors in calling {}:\n{}".format(request_url, errors.as_text()),
assert response.status_code == 302, "Form errors in calling {}:\n{}".format(
request_url, errors.as_text()
)
self.assertTrue(
"/login/?next=" not in response["Location"],
f"Received login response for {request_url}",
)
assert (
"/login/?next=" not in response["Location"]
), f"Received login response for {request_url}"
29 changes: 15 additions & 14 deletions polymorphic/tests/test_admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from django.contrib import admin
from django.contrib.contenttypes.models import ContentType
from django.utils.html import escape
Expand Down Expand Up @@ -53,9 +54,9 @@ class Model2ChildAdmin(PolymorphicChildModelAdmin):
)

d_obj = Model2A.objects.all()[0]
self.assertEqual(d_obj.__class__, Model2D)
self.assertEqual(d_obj.field1, "A")
self.assertEqual(d_obj.field2, "B")
assert d_obj.__class__ == Model2D
assert d_obj.field1 == "A"
assert d_obj.field2 == "B"

# -- list page
self.admin_get_changelist(Model2A) # asserts 200
Expand All @@ -70,18 +71,18 @@ class Model2ChildAdmin(PolymorphicChildModelAdmin):
)

d_obj.refresh_from_db()
self.assertEqual(d_obj.field1, "A2")
self.assertEqual(d_obj.field2, "B2")
self.assertEqual(d_obj.field3, "C2")
self.assertEqual(d_obj.field4, "D2")
assert d_obj.field1 == "A2"
assert d_obj.field2 == "B2"
assert d_obj.field3 == "C2"
assert d_obj.field4 == "D2"

# -- history
self.admin_get_history(Model2A, d_obj.pk)

# -- delete
self.admin_get_delete(Model2A, d_obj.pk)
self.admin_post_delete(Model2A, d_obj.pk)
self.assertRaises(Model2A.DoesNotExist, lambda: d_obj.refresh_from_db())
pytest.raises(Model2A.DoesNotExist, (lambda: d_obj.refresh_from_db()))

def test_admin_inlines(self):
"""
Expand All @@ -103,7 +104,7 @@ class InlineParentAdmin(PolymorphicInlineSupportMixin, admin.ModelAdmin):
inlines = (Inline,)

parent = InlineParent.objects.create(title="FOO")
self.assertEqual(parent.inline_children.count(), 0)
assert parent.inline_children.count() == 0

# -- get edit page
response = self.admin_get_change(InlineParent, parent.pk)
Expand Down Expand Up @@ -133,9 +134,9 @@ class InlineParentAdmin(PolymorphicInlineSupportMixin, admin.ModelAdmin):
)

parent.refresh_from_db()
self.assertEqual(parent.title, "FOO2")
self.assertEqual(parent.inline_children.count(), 1)
assert parent.title == "FOO2"
assert parent.inline_children.count() == 1
child = parent.inline_children.all()[0]
self.assertEqual(child.__class__, InlineModelB)
self.assertEqual(child.field1, "A2")
self.assertEqual(child.field2, "B2")
assert child.__class__ == InlineModelB
assert child.field1 == "A2"
assert child.field2 == "B2"
8 changes: 4 additions & 4 deletions polymorphic/tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_contrib_guardian(self):
# Regular Django inheritance should return the child model content type.
obj = PlainC()
ctype = get_polymorphic_base_content_type(obj)
self.assertEqual(ctype.name, "plain c")
assert ctype.name == "plain c"

ctype = get_polymorphic_base_content_type(PlainC)
self.assertEqual(ctype.name, "plain c")
assert ctype.name == "plain c"

# Polymorphic inheritance should return the parent model content type.
obj = Model2D()
ctype = get_polymorphic_base_content_type(obj)
self.assertEqual(ctype.name, "model2a")
assert ctype.name == "model2a"

ctype = get_polymorphic_base_content_type(Model2D)
self.assertEqual(ctype.name, "model2a")
assert ctype.name == "model2a"
8 changes: 4 additions & 4 deletions polymorphic/tests/test_multidb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def func():
entry = BlogEntry.objects.db_manager("secondary").create(blog=blog, text="Text")
ContentType.objects.clear_cache()
entry = BlogEntry.objects.db_manager("secondary").get(pk=entry.id)
self.assertEqual(blog, entry.blog)
assert blog == entry.blog

# Ensure no queries are made using the default database.
self.assertNumQueries(0, func)
Expand All @@ -89,7 +89,7 @@ def func():
entry = BlogEntry.objects.db_manager("secondary").create(blog=blog, text="Text")
ContentType.objects.clear_cache()
blog = BlogA.objects.db_manager("secondary").get(pk=blog.id)
self.assertEqual(entry, blog.blogentry_set.using("secondary").get())
assert entry == blog.blogentry_set.using("secondary").get()

# Ensure no queries are made using the default database.
self.assertNumQueries(0, func)
Expand All @@ -102,7 +102,7 @@ def func():
)
ContentType.objects.clear_cache()
m2a = Model2A.objects.db_manager("secondary").get(pk=m2a.id)
self.assertEqual(one2one, m2a.one2onerelatingmodel)
assert one2one == m2a.one2onerelatingmodel

# Ensure no queries are made using the default database.
self.assertNumQueries(0, func)
Expand All @@ -114,7 +114,7 @@ def func():
rm.many2many.add(m2a)
ContentType.objects.clear_cache()
m2a = Model2A.objects.db_manager("secondary").get(pk=m2a.id)
self.assertEqual(rm, m2a.relatingmodel_set.using("secondary").get())
assert rm == m2a.relatingmodel_set.using("secondary").get()

# Ensure no queries are made using the default database.
self.assertNumQueries(0, func)
Loading

0 comments on commit 9e2529c

Please sign in to comment.