Current File : //usr/local/apps/python3/lib/python3.11/site-packages/django/db/models/fields/tuple_lookups.py |
import itertools
from django.core.exceptions import EmptyResultSet
from django.db.models import Field
from django.db.models.expressions import (
ColPairs,
Func,
ResolvedOuterRef,
Subquery,
Value,
)
from django.db.models.lookups import (
Exact,
GreaterThan,
GreaterThanOrEqual,
In,
IsNull,
LessThan,
LessThanOrEqual,
)
from django.db.models.sql import Query
from django.db.models.sql.where import AND, OR, WhereNode
class Tuple(Func):
allows_composite_expressions = True
function = ""
output_field = Field()
def __len__(self):
return len(self.source_expressions)
def __iter__(self):
return iter(self.source_expressions)
class TupleLookupMixin:
allows_composite_expressions = True
def get_prep_lookup(self):
if self.rhs_is_direct_value():
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length()
else:
self.check_rhs_is_supported_expression()
super().get_prep_lookup()
return self.rhs
def check_rhs_is_tuple_or_list(self):
if not isinstance(self.rhs, (tuple, list)):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
)
def check_rhs_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
if len_lhs != len(self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
)
def check_rhs_is_supported_expression(self):
if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
lhs_str = self.get_lhs_str()
rhs_cls = self.rhs.__class__.__name__
raise ValueError(
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
)
def get_lhs_str(self):
if isinstance(self.lhs, ColPairs):
return repr(self.lhs.field.name)
else:
names = ", ".join(repr(f.name) for f in self.lhs)
return f"({names})"
def get_prep_lhs(self):
if isinstance(self.lhs, (tuple, list)):
return Tuple(*self.lhs)
return super().get_prep_lhs()
def process_lhs(self, compiler, connection, lhs=None):
sql, params = super().process_lhs(compiler, connection, lhs)
if not isinstance(self.lhs, Tuple):
sql = f"({sql})"
return sql, params
def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value():
args = [
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, self.rhs)
]
return compiler.compile(Tuple(*args))
else:
sql, params = compiler.compile(self.rhs)
if isinstance(self.rhs, ColPairs):
return "(%s)" % sql, params
elif isinstance(self.rhs, Query):
return super().process_rhs(compiler, connection)
else:
raise ValueError(
"Composite field lookups only work with composite expressions."
)
def get_fallback_sql(self, compiler, connection):
raise NotImplementedError(
f"{self.__class__.__name__}.get_fallback_sql() must be implemented "
f"for backends that don't have the supports_tuple_lookups feature enabled."
)
def as_sql(self, compiler, connection):
if not connection.features.supports_tuple_lookups:
return self.get_fallback_sql(compiler, connection)
return super().as_sql(compiler, connection)
class TupleExact(TupleLookupMixin, Exact):
def get_fallback_sql(self, compiler, connection):
if isinstance(self.rhs, Query):
return super(TupleLookupMixin, self).as_sql(compiler, connection)
# Process right-hand-side to trigger sanitization.
self.process_rhs(compiler, connection)
# e.g.: (a, b, c) == (x, y, z) as SQL:
# WHERE a = x AND b = y AND c = z
lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
root = WhereNode(lookups, connector=AND)
return root.as_sql(compiler, connection)
class TupleIsNull(TupleLookupMixin, IsNull):
def get_prep_lookup(self):
rhs = self.rhs
if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
rhs = rhs[0]
if isinstance(rhs, bool):
return rhs
raise ValueError(
"The QuerySet value for an isnull lookup must be True or False."
)
def as_sql(self, compiler, connection):
# e.g.: (a, b, c) is None as SQL:
# WHERE a IS NULL OR b IS NULL OR c IS NULL
# e.g.: (a, b, c) is not None as SQL:
# WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
rhs = self.rhs
lookups = [IsNull(col, rhs) for col in self.lhs]
root = WhereNode(lookups, connector=OR if rhs else AND)
return root.as_sql(compiler, connection)
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
def get_fallback_sql(self, compiler, connection):
# Process right-hand-side to trigger sanitization.
self.process_rhs(compiler, connection)
# e.g.: (a, b, c) > (x, y, z) as SQL:
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
lookups = itertools.cycle([GreaterThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list[:-1])
vals_iter = iter(vals_list[:-1])
col = next(cols_iter)
val = next(vals_iter)
lookup = next(lookups)
connector = next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup = next(lookups)
connector = next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
def get_fallback_sql(self, compiler, connection):
# Process right-hand-side to trigger sanitization.
self.process_rhs(compiler, connection)
# e.g.: (a, b, c) >= (x, y, z) as SQL:
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
lookups = itertools.cycle([GreaterThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list)
vals_iter = iter(vals_list)
col = next(cols_iter)
val = next(vals_iter)
lookup = next(lookups)
connector = next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup = next(lookups)
connector = next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleLessThan(TupleLookupMixin, LessThan):
def get_fallback_sql(self, compiler, connection):
# Process right-hand-side to trigger sanitization.
self.process_rhs(compiler, connection)
# e.g.: (a, b, c) < (x, y, z) as SQL:
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
lookups = itertools.cycle([LessThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list[:-1])
vals_iter = iter(vals_list[:-1])
col = next(cols_iter)
val = next(vals_iter)
lookup = next(lookups)
connector = next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup = next(lookups)
connector = next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
def get_fallback_sql(self, compiler, connection):
# Process right-hand-side to trigger sanitization.
self.process_rhs(compiler, connection)
# e.g.: (a, b, c) <= (x, y, z) as SQL:
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
lookups = itertools.cycle([LessThan, Exact])
connectors = itertools.cycle([OR, AND])
cols_list = [col for col in self.lhs for _ in range(2)]
vals_list = [val for val in self.rhs for _ in range(2)]
cols_iter = iter(cols_list)
vals_iter = iter(vals_list)
col = next(cols_iter)
val = next(vals_iter)
lookup = next(lookups)
connector = next(connectors)
root = node = WhereNode([lookup(col, val)], connector=connector)
for col, val in zip(cols_iter, vals_iter):
lookup = next(lookups)
connector = next(connectors)
child = WhereNode([lookup(col, val)], connector=connector)
node.children.append(child)
node = child
return root.as_sql(compiler, connection)
class TupleIn(TupleLookupMixin, In):
def get_prep_lookup(self):
if self.rhs_is_direct_value():
self.check_rhs_is_tuple_or_list()
self.check_rhs_is_collection_of_tuples_or_lists()
self.check_rhs_elements_length_equals_lhs_length()
else:
self.check_rhs_is_query()
super(TupleLookupMixin, self).get_prep_lookup()
return self.rhs # skip checks from mixin
def check_rhs_is_collection_of_tuples_or_lists(self):
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} "
"must be a collection of tuples or lists"
)
def check_rhs_elements_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
if not all(len_lhs == len(vals) for vals in self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} "
f"must have {len_lhs} elements each"
)
def check_rhs_is_query(self):
if not isinstance(self.rhs, (Query, Subquery)):
lhs_str = self.get_lhs_str()
rhs_cls = self.rhs.__class__.__name__
raise ValueError(
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
f"must be a Query object (received {rhs_cls!r})"
)
def process_rhs(self, compiler, connection):
if not self.rhs_is_direct_value():
return super(TupleLookupMixin, self).process_rhs(compiler, connection)
rhs = self.rhs
if not rhs:
raise EmptyResultSet
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
result = []
lhs = self.lhs
for vals in rhs:
# Remove any tuple containing None from the list as NULL is never
# equal to anything.
if any(val is None for val in vals):
continue
result.append(
Tuple(
*[
Value(val, output_field=col.output_field)
for col, val in zip(lhs, vals)
]
)
)
if not result:
raise EmptyResultSet
return compiler.compile(Tuple(*result))
def get_fallback_sql(self, compiler, connection):
rhs = self.rhs
if not rhs:
raise EmptyResultSet
if not self.rhs_is_direct_value():
return super(TupleLookupMixin, self).as_sql(compiler, connection)
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
root = WhereNode([], connector=OR)
lhs = self.lhs
for vals in rhs:
# Remove any tuple containing None from the list as NULL is never
# equal to anything.
if any(val is None for val in vals):
continue
lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
root.children.append(WhereNode(lookups, connector=AND))
if not root.children:
raise EmptyResultSet
return root.as_sql(compiler, connection)
tuple_lookups = {
"exact": TupleExact,
"gt": TupleGreaterThan,
"gte": TupleGreaterThanOrEqual,
"lt": TupleLessThan,
"lte": TupleLessThanOrEqual,
"in": TupleIn,
"isnull": TupleIsNull,
}