Skip to content

Commit

Permalink
ScopingMixin fix, project should be done. Now to make good examples a…
Browse files Browse the repository at this point in the history
…nd update the docs
  • Loading branch information
robdox committed Jul 11, 2018
1 parent 99a3807 commit 06e5656
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 107 deletions.
64 changes: 24 additions & 40 deletions easy_scoping/DjangoEasyScoping/ScopingMixin.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,24 @@
from django.db import models
from django.db.models import Count, Case, When
class ScopingQuerySet(models.QuerySet):
def __getattr__(self, name):
if name in self.model.scopes():
def scoped_query(*args, **kwargs):
return self.model.scopes()[name](self, *args, **kwargs)
return scoped_query
elif name not in self.model.aggs():
raise AttributeError('Queryset for %s has no attribute %s' %(self.model, name))

if name in self.model.aggs():
def aggregate_query(*args, **kwargs):
return self.model.aggs()[name](self, *args, **kwargs)
return aggregate_query
elif name not in self.model.scopes():
raise AttributeError('Aggregate for %s has no attribute %s' %(self.model, name))

def a(self, *args, **kwargs):
return self.all()

def f(self, *args, **kwargs):
return self.filter(*args, **kwargs)

def e(self, *args, **kwargs):
return self.exclude(*args, **kwargs)

def g(self, operation, *args, **kwargs):
if operation.lower() == 'count':
return self.aggregate(ret=Count(Case(When(then=1, **kwargs))))['ret']
def __getattr__(self, attr):
for plugin in ['scopes', 'aggregates']:
if attr in getattr(self.model, '__%s__'%plugin):
def scoped_query(*args, **kwargs):
return getattr(self.model, '__%s__'%plugin)[attr](self, *args, **kwargs)
return scoped_query
raise AttributeError('Queryset for %s has no attribute %s' %(self.model, attr))

class ScopingMixin(object):

@classmethod
def a(cls):
return cls.objects.all()
def check_names(cls, name):
if name in cls.scopes():
print('CATCH THIS STATEMENT', name)
raise AttributeError('%s already has a scope named %s' %(cls, name))

if name in cls.aggregates():
print('CATCH THIS STATEMENT', name)
raise AttributeError('%s already has an aggregate named %s' %(cls, name))

@classmethod
def scopes(cls):
Expand All @@ -42,11 +27,10 @@ def scopes(cls):
return cls.__scopes__

@classmethod
def scope(cls, name, func):
def register_scope(cls, name, func):
from types import MethodType
if name in cls.scopes():
name = '_%s'%(name)

cls.check_names(name)
cls.__scopes__[name] = func

def scoped_query_classmethod(klss, *args, **kwargs):
Expand All @@ -60,18 +44,17 @@ def instance_in_scope(self, *args, **kwargs):
setattr(cls, 'in_scope_%s'%name, MethodType(instance_in_scope, cls))

@classmethod
def aggs(cls):
if not getattr(cls, '__aggs__', None):
setattr(cls, '__aggs__', dict())
return cls.__aggs__
def aggregates(cls):
if not getattr(cls, '__aggregates__', None):
setattr(cls, '__aggregates__', dict())
return cls.__aggregates__

@classmethod
def register_aggregate(cls, name, func):
from types import MethodType
if name in cls.aggs():
name = '_%s'%(name)

cls.__aggs__[name] = func
cls.check_names(name)
cls.__aggregates__[name] = func

def aggregate_classmethod(klss, *args, **kwargs):
return getattr(klss.a(), name)(*args, **kwargs)
Expand All @@ -82,3 +65,4 @@ def instance_in_agg(self, *args, **kwargs):
return bool(func(self.a(), *args, **kwargs).g(pk=self.pk))

setattr(cls, 'in_agg_%s'%name, MethodType(instance_in_agg, cls))

33 changes: 15 additions & 18 deletions easy_scoping/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,63 @@ def test_db_data_loaded(self):
obj = Widget.objects.all()
self.assertEqual(obj.count(), 2500)

obj = obj.a()
self.assertEqual(obj.count(), 2500)

def test_num_blue(self):
obj1 = Widget.a().aggregate(ret=Count(Case(When(color='blue',
obj1 = Widget.objects.all().aggregate(ret=Count(Case(When(color='blue',
then=1))))['ret']
obj2 = Widget.a().num_blue()
obj2 = Widget.objects.all().num_blue()

self.assertEqual(obj1, obj2)

def test_num_blue_filtered(self):
obj1 = Widget.a().blue().not_small() \
obj1 = Widget.objects.all().blue().not_small() \
.aggregate(ret=Count(Case(When(color='blue', then=1))))['ret']
obj2 = Widget.a().blue().not_small().num_blue()
obj2 = Widget.objects.all().blue().not_small().num_blue()

self.assertEqual(obj1, obj2)

def test_no_blue_filtered(self):
obj1 = Widget.a().not_blue() \
obj1 = Widget.objects.all().not_blue() \
.aggregate(ret=Count(Case(When(color='blue', then=1))))['ret']

obj2 = Widget.a().not_blue().num_blue()
obj2 = Widget.objects.all().not_blue().num_blue()

self.assertEqual(obj1, obj2)

def test_multiple_aggs(self):
obj1 = Widget.a().basic_query_widget() \
obj1 = Widget.objects.all().basic_query_widget() \
.aggregate(ret=Count(Case(When(color='blue',
size='small',
then=1))))['ret']

obj2 = Widget.a().basic_query_widget().num_blue_small()
obj2 = Widget.objects.all().basic_query_widget().num_blue_small()

self.assertEqual(obj1, obj2)

obj3 = Widget.a().not_circle().before_y2k() \
obj3 = Widget.objects.all().not_circle().before_y2k() \
.aggregate(ret=Count(Case(When(color='blue',
size='small',
then=1))))['ret']

obj4 = Widget.a().not_circle().before_y2k().num_blue_small()
obj4 = Widget.objects.all().not_circle().before_y2k().num_blue_small()

self.assertEqual(obj3, obj4)

def test_passing_kwags(self):
obj1 = Widget.a().basic_query_widget() \
obj1 = Widget.objects.all().basic_query_widget() \
.aggregate(ret=Count(Case(When(color='blue',
size='small',
then=1))))['ret']

obj2 = Widget.a().basic_query_widget().num_kwargs(color='blue',
size='small')
obj2 = Widget.objects.all().basic_query_widget().num_kwargs(color='blue',
size='small')

self.assertEqual(obj1, obj2)

obj3 = Widget.a().not_circle().before_y2k() \
obj3 = Widget.objects.all().not_circle().before_y2k() \
.aggregate(ret=Count(Case(When(color='blue',
size='small',
then=1))))['ret']

obj4 = Widget.a().not_circle().before_y2k().num_blue_small()
obj4 = Widget.objects.all().not_circle().before_y2k().num_blue_small()

self.assertEqual(obj3, obj4)
39 changes: 18 additions & 21 deletions easy_scoping/tests/test_scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@ def test_db_data_loaded(self):
obj = Widget.objects.all()
self.assertEqual(obj.count(), 336)

obj = obj.a()
self.assertEqual(obj.count(), 336)

def test_passing_many_kwargs(self):
obj1 = Widget.objects.filter(color='blue',
size='small')
obj2 = Widget.a().take_kwargs(color='blue', size='small')
obj2 = Widget.objects.all().take_kwargs(color='blue', size='small')

self.assertQuerysetEqual(obj1,
obj2,
transform=lambda x: x,
ordered=False)
self.assertEqual(obj1.count(), obj2.count())

obj3 = Widget.a().take_kwargs(size='small', color='blue')
obj3 = Widget.objects.all().take_kwargs(size='small', color='blue')

self.assertQuerysetEqual(obj1,
obj3,
Expand All @@ -37,7 +34,7 @@ def test_passing_many_kwargs(self):

def test_passing_kwargs(self):
obj1 = Widget.objects.filter(color='blue')
obj2 = Widget.a().take_kwargs(color='blue')
obj2 = Widget.objects.all().take_kwargs(color='blue')

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -48,7 +45,7 @@ def test_passing_kwargs(self):
def test_passing_many_args(self):
obj1 = Widget.objects.filter(color='blue',
size='small')
obj2 = Widget.a().take_more_args('blue', 'small')
obj2 = Widget.objects.all().take_more_args('blue', 'small')

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -58,7 +55,7 @@ def test_passing_many_args(self):

def test_passing_args(self):
obj1 = Widget.objects.filter(color='blue')
obj2 = Widget.a().take_args('blue')
obj2 = Widget.objects.all().take_args('blue')

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -68,10 +65,10 @@ def test_passing_args(self):

def test_no_scope_registered(self):
with self.assertRaises(AttributeError):
Widget.a().not_a_scope()
Widget.objects.all().not_a_scope()

def test_redundant_chain(self):
obj1 = Widget.a().basic_query_widget()
obj1 = Widget.objects.all().basic_query_widget()
obj2 = obj1.basic_query_widget()

self.assertQuerysetEqual(obj1,
Expand All @@ -81,7 +78,7 @@ def test_redundant_chain(self):
self.assertEqual(obj1.count(), obj2.count())
self.assertEqual(obj1.get(), obj2.get())

not_obj1 = Widget.a().not_basic_query_widget()
not_obj1 = Widget.objects.all().not_basic_query_widget()
not_obj2 = not_obj1.not_basic_query_widget()

self.assertQuerysetEqual(not_obj1,
Expand All @@ -94,7 +91,7 @@ def test_query_widget(self):
obj1 = Widget.objects.filter(color='blue',
size='small',
shape='circle')
obj2 = Widget.a().basic_query_widget()
obj2 = Widget.objects.all().basic_query_widget()

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -106,7 +103,7 @@ def test_query_widget(self):
not_obj1 = Widget.objects.exclude(color='blue',
size='small',
shape='circle')
not_obj2 = Widget.a().not_basic_query_widget()
not_obj2 = Widget.objects.all().not_basic_query_widget()

self.assertQuerysetEqual(not_obj1,
not_obj2,
Expand All @@ -118,7 +115,7 @@ def test_query_chaining(self):
obj1 = Widget.objects.filter(color='blue') \
.filter(size='small') \
.filter(shape='circle')
obj2 = Widget.a().blue().small().circle()
obj2 = Widget.objects.all().blue().small().circle()

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -130,7 +127,7 @@ def test_query_chaining(self):
not_obj1 = Widget.objects.exclude(color='blue') \
.exclude(size='small') \
.exclude(shape='circle')
not_obj2 = Widget.a().not_blue().not_small().not_circle()
not_obj2 = Widget.objects.all().not_blue().not_small().not_circle()

self.assertQuerysetEqual(not_obj1,
not_obj2,
Expand All @@ -140,7 +137,7 @@ def test_query_chaining(self):

def test_query_blue(self):
obj1 = Widget.objects.filter(color='blue')
obj2 = Widget.a().blue()
obj2 = Widget.objects.all().blue()

self.assertEqual(obj1.count(), obj2.count())
self.assertQuerysetEqual(obj1,
Expand All @@ -149,7 +146,7 @@ def test_query_blue(self):
ordered=False)

not_obj1 = Widget.objects.exclude(color='blue')
not_obj2 = Widget.a().not_blue()
not_obj2 = Widget.objects.all().not_blue()

self.assertEqual(not_obj1.count(), not_obj2.count())
self.assertQuerysetEqual(not_obj1,
Expand All @@ -161,7 +158,7 @@ def test_before_y2k(self):
import datetime

obj1 = Widget.objects.filter(used_on__lte=datetime.date(2000, 1, 1))
obj2 = Widget.a().before_y2k()
obj2 = Widget.objects.all().before_y2k()

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -170,7 +167,7 @@ def test_before_y2k(self):
self.assertEqual(obj1.count(), obj2.count())

not_obj1 = Widget.objects.exclude(used_on__lte=datetime.date(2000, 1, 1))
not_obj2 = Widget.a().not_before_y2k()
not_obj2 = Widget.objects.all().not_before_y2k()

self.assertQuerysetEqual(not_obj1,
not_obj2,
Expand All @@ -182,7 +179,7 @@ def test_after_y2k(self):
import datetime

obj1 = Widget.objects.filter(used_on__gte=datetime.date(2000, 1, 1))
obj2 = Widget.a().after_y2k()
obj2 = Widget.objects.all().after_y2k()

self.assertQuerysetEqual(obj1,
obj2,
Expand All @@ -191,7 +188,7 @@ def test_after_y2k(self):
self.assertEqual(obj1.count(), obj2.count())

not_obj1 = Widget.objects.exclude(used_on__gte=datetime.date(2000, 1, 1))
not_obj2 = Widget.a().not_after_y2k()
not_obj2 = Widget.objects.all().not_after_y2k()

self.assertQuerysetEqual(not_obj1,
not_obj2,
Expand Down
Loading

0 comments on commit 06e5656

Please sign in to comment.