diff --git a/orator/orm/builder.py b/orator/orm/builder.py index 72a7441e..4bbe8ce3 100644 --- a/orator/orm/builder.py +++ b/orator/orm/builder.py @@ -827,9 +827,7 @@ def _add_has_where(self, has_query, relation, operator, count, boolean): if isinstance(count, basestring) and count.isdigit(): count = QueryExpression(count) - return self.where( - QueryExpression("(%s)" % has_query.to_sql()), operator, count, boolean - ) + return self._query.where_expression(has_query._query, operator, count, boolean) def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation): """ @@ -843,9 +841,7 @@ def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation) """ relation_query = relation.get_base_query() - has_query.merge_wheres(relation_query.wheres, relation_query.get_bindings()) - - self._query.add_binding(has_query.get_query().get_bindings(), "where") + has_query.merge_wheres(relation_query.wheres) def _get_has_relation_query(self, relation): """ diff --git a/orator/query/builder.py b/orator/query/builder.py index 6e8aaa6c..f0952ca4 100644 --- a/orator/query/builder.py +++ b/orator/query/builder.py @@ -66,6 +66,8 @@ def __init__(self, connection, grammar, processor): for type in ["select", "join", "where", "having", "order"]: self._bindings[type] = [] + self._settable_bindings = ["select", "from", "join", "having", "order"] + self.aggregate_ = None self.columns = [] self.distinct_ = False @@ -403,18 +405,18 @@ def where(self, column, operator=Null(), value=None, boolean="and"): type = "basic" - self.wheres.append( - { - "type": type, - "column": column, - "operator": operator, - "value": value, - "boolean": boolean, - } - ) + where = { + "type": type, + "column": column, + "operator": operator, + "value": value, + "boolean": boolean, + } if not isinstance(value, QueryExpression): - self.add_binding(value, "where") + where["bindings"] = value + + self.wheres.append(where) return self @@ -429,9 +431,9 @@ def _invalid_operator_and_value(self, operator, value): def where_raw(self, sql, bindings=None, boolean="and"): type = "raw" - self.wheres.append({"type": type, "sql": sql, "boolean": boolean}) - - self.add_binding(bindings, "where") + self.wheres.append( + {"type": type, "sql": sql, "boolean": boolean, "bindings": bindings} + ) return self @@ -442,11 +444,15 @@ def where_between(self, column, values, boolean="and", negate=False): type = "between" self.wheres.append( - {"column": column, "type": type, "boolean": boolean, "not": negate} + { + "column": column, + "type": type, + "boolean": boolean, + "not": negate, + "bindings": values, + } ) - self.add_binding(values, "where") - return self def or_where_between(self, column, values): @@ -500,6 +506,46 @@ def _where_sub(self, column, operator, query, boolean): return self + def where_expression( + self, left_expression, operator, right_expression, boolean="and" + ): + type = "expression" + + bindings = [] + + if isinstance(left_expression, QueryBuilder): + self.merge_bindings(left_expression) + bindings += left_expression.get_bindings() + left_expression = QueryExpression("(%s)" % left_expression.to_sql()) + elif not isinstance(left_expression, QueryExpression): + if not isinstance(left_expression, list): + bindings.append(left_expression) + else: + bindings += left_expression + + if isinstance(right_expression, QueryBuilder): + self.merge_bindings(right_expression) + bindings += right_expression.get_bindings() + right_expression = QueryExpression("(%s)" % right_expression.to_sql()) + elif not isinstance(right_expression, QueryExpression): + if not isinstance(right_expression, list): + bindings.append(right_expression) + else: + bindings += right_expression + + self.wheres.append( + { + "type": type, + "lhs": left_expression, + "operator": operator, + "rhs": right_expression, + "boolean": boolean, + "bindings": bindings, + } + ) + + return self + def where_exists(self, query, boolean="and", negate=False): """ Add an exists clause to the query. @@ -574,11 +620,15 @@ def where_in(self, column, values, boolean="and", negate=False): values = values.all() self.wheres.append( - {"type": type, "column": column, "values": values, "boolean": boolean} + { + "type": type, + "column": column, + "values": values, + "boolean": boolean, + "bindings": values, + } ) - self.add_binding(values, "where") - return self def or_where_in(self, column, values): @@ -661,11 +711,10 @@ def _add_date_based_where(self, type, column, operator, value, boolean="and"): "boolean": boolean, "operator": operator, "value": value, + "bindings": value, } ) - self.add_binding(value, "where") - def dynamic_where(self, method): finder = method[6:] @@ -691,6 +740,38 @@ def dynamic_where(*parameters): def _add_dynamic(self, segment, connector, parameters, index): self.where(segment, "=", parameters[index], connector) + def remove_where(self, column, operator=None): + """ + Remove where clauses referencing a specific column + + :param column: The column for which to remove the clauses + :type column: str + + :param operator: If specified, will only remove a where clause matching the operator + :type operator: str + + :return: The current QueryBuilder instance + :rtype: QueryBuilder + """ + remove_indexes = [] + + for i, w in enumerate(self.wheres): + if w.get("column") != column: + continue + if operator and ( + w.get("operator") != operator and w.get("type") != operator + ): + continue + remove_indexes.append(i) + + removed = 0 + for i in remove_indexes: + idx = i - removed + del self.wheres[idx] + removed = removed + 1 + + return self + def group_by(self, *columns): """ Add a "group by" clause to the query @@ -1507,20 +1588,16 @@ def new_query(self): """ return QueryBuilder(self._connection, self._grammar, self._processor) - def merge_wheres(self, wheres, bindings): + def merge_wheres(self, wheres): """ Merge a list of where clauses and bindings :param wheres: A list of where clauses :type wheres: list - :param bindings: A list of bindings - :type bindings: list - :rtype: None """ self.wheres = self.wheres + wheres - self._bindings["where"] = self._bindings["where"] + bindings def _clean_bindings(self, bindings): """ @@ -1548,6 +1625,8 @@ def raw(self, value): def get_bindings(self): bindings = [] + self._bindings["where"] = self.get_where_bindings() + for value in chain(*self._bindings.values()): if isinstance(value, datetime.date): value = value.strftime(self._grammar.get_date_format()) @@ -1556,22 +1635,40 @@ def get_bindings(self): return bindings + def get_where_bindings(self): + bindings = [] + for where in self.wheres: + if "bindings" in where: + value = where.get("bindings") + + if isinstance(value, (list, tuple)): + bindings += value + else: + bindings.append(value) + elif isinstance(where.get("query"), QueryBuilder): + bindings += where["query"].get_where_bindings() + + for union in self.unions: + bindings += union["query"].get_where_bindings() + + return bindings + def get_raw_bindings(self): return self._bindings - def set_bindings(self, bindings, type="where"): - if type not in self._bindings: + def set_bindings(self, bindings, type): + if type not in self._settable_bindings: raise ArgumentError("Invalid binding type: %s" % type) self._bindings[type] = bindings return self - def add_binding(self, value, type="where"): + def add_binding(self, value, type): if value is None: return self - if type not in self._bindings: + if type not in self._settable_bindings: raise ArgumentError("Invalid binding type: %s" % type) if isinstance(value, (list, tuple)): diff --git a/orator/query/grammars/grammar.py b/orator/query/grammars/grammar.py index ffd1d74d..43d4c704 100644 --- a/orator/query/grammars/grammar.py +++ b/orator/query/grammars/grammar.py @@ -157,6 +157,21 @@ def _where_sub(self, query, where): return "%s %s (%s)" % (self.wrap(where["column"]), where["operator"], select) + def _where_expression(self, query, where): + lhs = where["lhs"] + rhs = where["rhs"] + if isinstance(lhs, list): + lhs = "(%s)" % self.parameterize(lhs) + else: + lhs = self.parameter(lhs) + + if isinstance(rhs, list): + rhs = "(%s)" % self.parameterize(rhs) + else: + rhs = self.parameter(rhs) + + return "%s %s %s" % (lhs, where["operator"], rhs) + def _where_basic(self, query, where): value = self.parameter(where["value"]) diff --git a/tests/query/test_query_builder.py b/tests/query/test_query_builder.py index 26c1c5e3..5f921c8f 100644 --- a/tests/query/test_query_builder.py +++ b/tests/query/test_query_builder.py @@ -1462,10 +1462,16 @@ def test_mysql_wrapping(self): def test_merge_wheres_can_merge_wheres_and_bindings(self): builder = self.get_builder() - builder.wheres = ["foo"] - builder.merge_wheres(["wheres"], ["foo", "bar"]) - self.assertEqual(["foo", "wheres"], builder.wheres) - self.assertEqual(["foo", "bar"], builder.get_bindings()) + builder.wheres = [{"column": "foo", "bindings": ["foo"]}] + builder.merge_wheres([{"column": "wheres", "bindings": ["bar"]}]) + self.assertEqual( + [ + {"column": "foo", "bindings": ["foo"]}, + {"column": "wheres", "bindings": ["bar"]}, + ], + builder.wheres, + ) + self.assertEqual(["foo", "bar"], builder.get_where_bindings()) def test_where_with_null_second_parameter(self): builder = self.get_builder() @@ -1579,22 +1585,22 @@ def test_binding_order(self): def test_add_binding_with_list_merges_bindings(self): builder = self.get_builder() - builder.add_binding(["foo", "bar"]) - builder.add_binding(["baz"]) + builder.add_binding(["foo", "bar"], "select") + builder.add_binding(["baz"], "select") self.assertEqual(["foo", "bar", "baz"], builder.get_bindings()) def test_add_binding_with_list_merges_bindings_in_correct_order(self): builder = self.get_builder() builder.add_binding(["bar", "baz"], "having") - builder.add_binding(["foo"], "where") + builder.add_binding(["foo"], "select") self.assertEqual(["foo", "bar", "baz"], builder.get_bindings()) def test_merge_builders(self): builder = self.get_builder() - builder.add_binding("foo", "where") + builder.add_binding("foo", "select") builder.add_binding("baz", "having") other_builder = self.get_builder() - other_builder.add_binding("bar", "where") + other_builder.add_binding("bar", "select") builder.merge_bindings(other_builder) self.assertEqual(["foo", "bar", "baz"], builder.get_bindings()) @@ -1676,6 +1682,71 @@ def test_merge(self): self.assertEqual(["boom", "bar"], b1.get_bindings()) + def test_remove_where(self): + builder = self.get_builder().from_("a") + marker = builder.get_grammar().get_marker() + builder.where("not_removed_1", "1") + builder.where("col_1", "a").or_where("col_1", "b").or_where("col_1", "!=", "c") + builder.where("not_removed_2", "2") + builder.or_where_in("col_1", [1, 2, 3]) + builder.or_where_between("col_1", (1, 2)) + builder.or_where_null("col_1") + + builder.remove_where("col_1") + + self.assertEqual( + 'SELECT * FROM "a" WHERE "not_removed_1" = %s AND "not_removed_2" = %s' + % (marker, marker), + builder.to_sql(), + ) + + self.assertEqual(["1", "2"], builder.get_bindings()) + + def test_remove_where_specific_operator(self): + builder = self.get_builder().from_("a") + marker = builder.get_grammar().get_marker() + builder.where("not_removed_1", "1") + builder.where("col_1", "a").or_where("col_1", "b").or_where("col_1", "!=", "c") + builder.where("not_removed_2", "2") + builder.or_where_in("col_1", [1, 2, 3]) + builder.or_where_between("col_1", (100, 101)) + builder.or_where_null("col_1") + + builder.remove_where("col_1", "=") + + expected_sql = ( + 'SELECT * FROM "a" WHERE "not_removed_1" = %s OR "col_1" != %s AND "not_removed_2" = %s' + ' OR "col_1" IN (%s, %s, %s)' + ' OR "col_1" BETWEEN %s AND %s OR "col_1" IS NULL' + % (marker, marker, marker, marker, marker, marker, marker, marker) + ) + + self.assertEqual(expected_sql, builder.to_sql()) + self.assertEqual(["1", "c", "2", 1, 2, 3, 100, 101], builder.get_bindings()) + + builder.remove_where("col_1", "null") + + expected_sql = ( + 'SELECT * FROM "a" WHERE "not_removed_1" = %s OR "col_1" != %s AND "not_removed_2" = %s' + ' OR "col_1" IN (%s, %s, %s)' + ' OR "col_1" BETWEEN %s AND %s' + % (marker, marker, marker, marker, marker, marker, marker, marker) + ) + + self.assertEqual(expected_sql, builder.to_sql()) + self.assertEqual(["1", "c", "2", 1, 2, 3, 100, 101], builder.get_bindings()) + + builder.remove_where("col_1", "between") + + expected_sql = ( + 'SELECT * FROM "a" WHERE "not_removed_1" = %s OR "col_1" != %s AND "not_removed_2" = %s' + ' OR "col_1" IN (%s, %s, %s)' + % (marker, marker, marker, marker, marker, marker) + ) + + self.assertEqual(expected_sql, builder.to_sql()) + self.assertEqual(["1", "c", "2", 1, 2, 3], builder.get_bindings()) + def get_mysql_builder(self): grammar = MySQLQueryGrammar() processor = MockProcessor().prepare_mock()