Skip to content
This repository has been archived by the owner on May 14, 2024. It is now read-only.

Added method QueryBuilder().remove_where(column) #311

Open
wants to merge 2 commits into
base: 0.9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions orator/orm/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,7 @@ def _add_has_where(self, has_query, relation, operator, count, boolean):
if isinstance(count, basestring) and count.isdigit():
count = QueryExpression(count)

return self.where(
QueryExpression("(%s)" % has_query.to_sql()), operator, count, boolean
)
return self._query.where_expression(has_query._query, operator, count, boolean)

def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation):
"""
Expand All @@ -843,9 +841,7 @@ def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation)
"""
relation_query = relation.get_base_query()

has_query.merge_wheres(relation_query.wheres, relation_query.get_bindings())

self._query.add_binding(has_query.get_query().get_bindings(), "where")
has_query.merge_wheres(relation_query.wheres)

def _get_has_relation_query(self, relation):
"""
Expand Down
157 changes: 127 additions & 30 deletions orator/query/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self, connection, grammar, processor):
for type in ["select", "join", "where", "having", "order"]:
self._bindings[type] = []

self._settable_bindings = ["select", "from", "join", "having", "order"]

self.aggregate_ = None
self.columns = []
self.distinct_ = False
Expand Down Expand Up @@ -403,18 +405,18 @@ def where(self, column, operator=Null(), value=None, boolean="and"):

type = "basic"

self.wheres.append(
{
"type": type,
"column": column,
"operator": operator,
"value": value,
"boolean": boolean,
}
)
where = {
"type": type,
"column": column,
"operator": operator,
"value": value,
"boolean": boolean,
}

if not isinstance(value, QueryExpression):
self.add_binding(value, "where")
where["bindings"] = value

self.wheres.append(where)

return self

Expand All @@ -429,9 +431,9 @@ def _invalid_operator_and_value(self, operator, value):
def where_raw(self, sql, bindings=None, boolean="and"):
type = "raw"

self.wheres.append({"type": type, "sql": sql, "boolean": boolean})

self.add_binding(bindings, "where")
self.wheres.append(
{"type": type, "sql": sql, "boolean": boolean, "bindings": bindings}
)

return self

Expand All @@ -442,11 +444,15 @@ def where_between(self, column, values, boolean="and", negate=False):
type = "between"

self.wheres.append(
{"column": column, "type": type, "boolean": boolean, "not": negate}
{
"column": column,
"type": type,
"boolean": boolean,
"not": negate,
"bindings": values,
}
)

self.add_binding(values, "where")

return self

def or_where_between(self, column, values):
Expand Down Expand Up @@ -500,6 +506,46 @@ def _where_sub(self, column, operator, query, boolean):

return self

def where_expression(
self, left_expression, operator, right_expression, boolean="and"
):
type = "expression"

bindings = []

if isinstance(left_expression, QueryBuilder):
self.merge_bindings(left_expression)
bindings += left_expression.get_bindings()
left_expression = QueryExpression("(%s)" % left_expression.to_sql())
elif not isinstance(left_expression, QueryExpression):
if not isinstance(left_expression, list):
bindings.append(left_expression)
else:
bindings += left_expression

if isinstance(right_expression, QueryBuilder):
self.merge_bindings(right_expression)
bindings += right_expression.get_bindings()
right_expression = QueryExpression("(%s)" % right_expression.to_sql())
elif not isinstance(right_expression, QueryExpression):
if not isinstance(right_expression, list):
bindings.append(right_expression)
else:
bindings += right_expression

self.wheres.append(
{
"type": type,
"lhs": left_expression,
"operator": operator,
"rhs": right_expression,
"boolean": boolean,
"bindings": bindings,
}
)

return self

def where_exists(self, query, boolean="and", negate=False):
"""
Add an exists clause to the query.
Expand Down Expand Up @@ -574,11 +620,15 @@ def where_in(self, column, values, boolean="and", negate=False):
values = values.all()

self.wheres.append(
{"type": type, "column": column, "values": values, "boolean": boolean}
{
"type": type,
"column": column,
"values": values,
"boolean": boolean,
"bindings": values,
}
)

self.add_binding(values, "where")

return self

def or_where_in(self, column, values):
Expand Down Expand Up @@ -661,11 +711,10 @@ def _add_date_based_where(self, type, column, operator, value, boolean="and"):
"boolean": boolean,
"operator": operator,
"value": value,
"bindings": value,
}
)

self.add_binding(value, "where")

def dynamic_where(self, method):
finder = method[6:]

Expand All @@ -691,6 +740,38 @@ def dynamic_where(*parameters):
def _add_dynamic(self, segment, connector, parameters, index):
self.where(segment, "=", parameters[index], connector)

def remove_where(self, column, operator=None):
"""
Remove where clauses referencing a specific column

:param column: The column for which to remove the clauses
:type column: str

:param operator: If specified, will only remove a where clause matching the operator
:type operator: str

:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
remove_indexes = []

for i, w in enumerate(self.wheres):
if w.get("column") != column:
continue
if operator and (
w.get("operator") != operator and w.get("type") != operator
):
continue
remove_indexes.append(i)

removed = 0
for i in remove_indexes:
idx = i - removed
del self.wheres[idx]
removed = removed + 1

return self

def group_by(self, *columns):
"""
Add a "group by" clause to the query
Expand Down Expand Up @@ -1507,20 +1588,16 @@ def new_query(self):
"""
return QueryBuilder(self._connection, self._grammar, self._processor)

def merge_wheres(self, wheres, bindings):
def merge_wheres(self, wheres):
"""
Merge a list of where clauses and bindings

:param wheres: A list of where clauses
:type wheres: list

:param bindings: A list of bindings
:type bindings: list

:rtype: None
"""
self.wheres = self.wheres + wheres
self._bindings["where"] = self._bindings["where"] + bindings

def _clean_bindings(self, bindings):
"""
Expand Down Expand Up @@ -1548,6 +1625,8 @@ def raw(self, value):

def get_bindings(self):
bindings = []
self._bindings["where"] = self.get_where_bindings()

for value in chain(*self._bindings.values()):
if isinstance(value, datetime.date):
value = value.strftime(self._grammar.get_date_format())
Expand All @@ -1556,22 +1635,40 @@ def get_bindings(self):

return bindings

def get_where_bindings(self):
bindings = []
for where in self.wheres:
if "bindings" in where:
value = where.get("bindings")

if isinstance(value, (list, tuple)):
bindings += value
else:
bindings.append(value)
elif isinstance(where.get("query"), QueryBuilder):
bindings += where["query"].get_where_bindings()

for union in self.unions:
bindings += union["query"].get_where_bindings()

return bindings

def get_raw_bindings(self):
return self._bindings

def set_bindings(self, bindings, type="where"):
if type not in self._bindings:
def set_bindings(self, bindings, type):
if type not in self._settable_bindings:
raise ArgumentError("Invalid binding type: %s" % type)

self._bindings[type] = bindings

return self

def add_binding(self, value, type="where"):
def add_binding(self, value, type):
if value is None:
return self

if type not in self._bindings:
if type not in self._settable_bindings:
raise ArgumentError("Invalid binding type: %s" % type)

if isinstance(value, (list, tuple)):
Expand Down
15 changes: 15 additions & 0 deletions orator/query/grammars/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ def _where_sub(self, query, where):

return "%s %s (%s)" % (self.wrap(where["column"]), where["operator"], select)

def _where_expression(self, query, where):
lhs = where["lhs"]
rhs = where["rhs"]
if isinstance(lhs, list):
lhs = "(%s)" % self.parameterize(lhs)
else:
lhs = self.parameter(lhs)

if isinstance(rhs, list):
rhs = "(%s)" % self.parameterize(rhs)
else:
rhs = self.parameter(rhs)

return "%s %s %s" % (lhs, where["operator"], rhs)

def _where_basic(self, query, where):
value = self.parameter(where["value"])

Expand Down
Loading