lookups.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from copy import copy
  2. from itertools import repeat
  3. import inspect
  4. from django.conf import settings
  5. from django.utils import timezone
  6. from django.utils.functional import cached_property
  7. from django.utils.six.moves import xrange
  8. class RegisterLookupMixin(object):
  9. def _get_lookup(self, lookup_name):
  10. try:
  11. return self.class_lookups[lookup_name]
  12. except KeyError:
  13. # To allow for inheritance, check parent class' class_lookups.
  14. for parent in inspect.getmro(self.__class__):
  15. if 'class_lookups' not in parent.__dict__:
  16. continue
  17. if lookup_name in parent.class_lookups:
  18. return parent.class_lookups[lookup_name]
  19. except AttributeError:
  20. # This class didn't have any class_lookups
  21. pass
  22. return None
  23. def get_lookup(self, lookup_name):
  24. found = self._get_lookup(lookup_name)
  25. if found is None and hasattr(self, 'output_field'):
  26. return self.output_field.get_lookup(lookup_name)
  27. if found is not None and not issubclass(found, Lookup):
  28. return None
  29. return found
  30. def get_transform(self, lookup_name):
  31. found = self._get_lookup(lookup_name)
  32. if found is None and hasattr(self, 'output_field'):
  33. return self.output_field.get_transform(lookup_name)
  34. if found is not None and not issubclass(found, Transform):
  35. return None
  36. return found
  37. @classmethod
  38. def register_lookup(cls, lookup):
  39. if 'class_lookups' not in cls.__dict__:
  40. cls.class_lookups = {}
  41. cls.class_lookups[lookup.lookup_name] = lookup
  42. @classmethod
  43. def _unregister_lookup(cls, lookup):
  44. """
  45. Removes given lookup from cls lookups. Meant to be used in
  46. tests only.
  47. """
  48. del cls.class_lookups[lookup.lookup_name]
  49. class Transform(RegisterLookupMixin):
  50. def __init__(self, lhs, lookups):
  51. self.lhs = lhs
  52. self.init_lookups = lookups[:]
  53. def as_sql(self, qn, connection):
  54. raise NotImplementedError
  55. @cached_property
  56. def output_field(self):
  57. return self.lhs.output_field
  58. def relabeled_clone(self, relabels):
  59. return self.__class__(self.lhs.relabeled_clone(relabels))
  60. def get_group_by_cols(self):
  61. return self.lhs.get_group_by_cols()
  62. class Lookup(RegisterLookupMixin):
  63. lookup_name = None
  64. def __init__(self, lhs, rhs):
  65. self.lhs, self.rhs = lhs, rhs
  66. self.rhs = self.get_prep_lookup()
  67. def get_prep_lookup(self):
  68. return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
  69. def get_db_prep_lookup(self, value, connection):
  70. return (
  71. '%s', self.lhs.output_field.get_db_prep_lookup(
  72. self.lookup_name, value, connection, prepared=True))
  73. def process_lhs(self, qn, connection, lhs=None):
  74. lhs = lhs or self.lhs
  75. return qn.compile(lhs)
  76. def process_rhs(self, qn, connection):
  77. value = self.rhs
  78. # Due to historical reasons there are a couple of different
  79. # ways to produce sql here. get_compiler is likely a Query
  80. # instance, _as_sql QuerySet and as_sql just something with
  81. # as_sql. Finally the value can of course be just plain
  82. # Python value.
  83. if hasattr(value, 'get_compiler'):
  84. value = value.get_compiler(connection=connection)
  85. if hasattr(value, 'as_sql'):
  86. sql, params = qn.compile(value)
  87. return '(' + sql + ')', params
  88. if hasattr(value, '_as_sql'):
  89. sql, params = value._as_sql(connection=connection)
  90. return '(' + sql + ')', params
  91. else:
  92. return self.get_db_prep_lookup(value, connection)
  93. def rhs_is_direct_value(self):
  94. return not(
  95. hasattr(self.rhs, 'as_sql') or
  96. hasattr(self.rhs, '_as_sql') or
  97. hasattr(self.rhs, 'get_compiler'))
  98. def relabeled_clone(self, relabels):
  99. new = copy(self)
  100. new.lhs = new.lhs.relabeled_clone(relabels)
  101. if hasattr(new.rhs, 'relabeled_clone'):
  102. new.rhs = new.rhs.relabeled_clone(relabels)
  103. return new
  104. def get_group_by_cols(self):
  105. cols = self.lhs.get_group_by_cols()
  106. if hasattr(self.rhs, 'get_group_by_cols'):
  107. cols.extend(self.rhs.get_group_by_cols())
  108. return cols
  109. def as_sql(self, qn, connection):
  110. raise NotImplementedError
  111. class BuiltinLookup(Lookup):
  112. def process_lhs(self, qn, connection, lhs=None):
  113. lhs_sql, params = super(BuiltinLookup, self).process_lhs(
  114. qn, connection, lhs)
  115. field_internal_type = self.lhs.output_field.get_internal_type()
  116. db_type = self.lhs.output_field.db_type(connection=connection)
  117. lhs_sql = connection.ops.field_cast_sql(
  118. db_type, field_internal_type) % lhs_sql
  119. lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
  120. return lhs_sql, params
  121. def as_sql(self, qn, connection):
  122. lhs_sql, params = self.process_lhs(qn, connection)
  123. rhs_sql, rhs_params = self.process_rhs(qn, connection)
  124. params.extend(rhs_params)
  125. rhs_sql = self.get_rhs_op(connection, rhs_sql)
  126. return '%s %s' % (lhs_sql, rhs_sql), params
  127. def get_rhs_op(self, connection, rhs):
  128. return connection.operators[self.lookup_name] % rhs
  129. default_lookups = {}
  130. class Exact(BuiltinLookup):
  131. lookup_name = 'exact'
  132. default_lookups['exact'] = Exact
  133. class IExact(BuiltinLookup):
  134. lookup_name = 'iexact'
  135. default_lookups['iexact'] = IExact
  136. class Contains(BuiltinLookup):
  137. lookup_name = 'contains'
  138. default_lookups['contains'] = Contains
  139. class IContains(BuiltinLookup):
  140. lookup_name = 'icontains'
  141. default_lookups['icontains'] = IContains
  142. class GreaterThan(BuiltinLookup):
  143. lookup_name = 'gt'
  144. default_lookups['gt'] = GreaterThan
  145. class GreaterThanOrEqual(BuiltinLookup):
  146. lookup_name = 'gte'
  147. default_lookups['gte'] = GreaterThanOrEqual
  148. class LessThan(BuiltinLookup):
  149. lookup_name = 'lt'
  150. default_lookups['lt'] = LessThan
  151. class LessThanOrEqual(BuiltinLookup):
  152. lookup_name = 'lte'
  153. default_lookups['lte'] = LessThanOrEqual
  154. class In(BuiltinLookup):
  155. lookup_name = 'in'
  156. def get_db_prep_lookup(self, value, connection):
  157. params = self.lhs.output_field.get_db_prep_lookup(
  158. self.lookup_name, value, connection, prepared=True)
  159. if not params:
  160. # TODO: check why this leads to circular import
  161. from django.db.models.sql.datastructures import EmptyResultSet
  162. raise EmptyResultSet
  163. placeholder = '(' + ', '.join('%s' for p in params) + ')'
  164. return (placeholder, params)
  165. def get_rhs_op(self, connection, rhs):
  166. return 'IN %s' % rhs
  167. def as_sql(self, qn, connection):
  168. max_in_list_size = connection.ops.max_in_list_size()
  169. if self.rhs_is_direct_value() and (max_in_list_size and
  170. len(self.rhs) > max_in_list_size):
  171. rhs, rhs_params = self.process_rhs(qn, connection)
  172. lhs, lhs_params = self.process_lhs(qn, connection)
  173. in_clause_elements = ['(']
  174. params = []
  175. for offset in xrange(0, len(rhs_params), max_in_list_size):
  176. if offset > 0:
  177. in_clause_elements.append(' OR ')
  178. in_clause_elements.append('%s IN (' % lhs)
  179. params.extend(lhs_params)
  180. group_size = min(len(rhs_params) - offset, max_in_list_size)
  181. param_group = ', '.join(repeat('%s', group_size))
  182. in_clause_elements.append(param_group)
  183. in_clause_elements.append(')')
  184. params.extend(rhs_params[offset: offset + max_in_list_size])
  185. in_clause_elements.append(')')
  186. return ''.join(in_clause_elements), params
  187. else:
  188. return super(In, self).as_sql(qn, connection)
  189. default_lookups['in'] = In
  190. class PatternLookup(BuiltinLookup):
  191. def get_rhs_op(self, connection, rhs):
  192. # Assume we are in startswith. We need to produce SQL like:
  193. # col LIKE %s, ['thevalue%']
  194. # For python values we can (and should) do that directly in Python,
  195. # but if the value is for example reference to other column, then
  196. # we need to add the % pattern match to the lookup by something like
  197. # col LIKE othercol || '%%'
  198. # So, for Python values we don't need any special pattern, but for
  199. # SQL reference values we need the correct pattern added.
  200. value = self.rhs
  201. if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
  202. or hasattr(value, '_as_sql')):
  203. return connection.pattern_ops[self.lookup_name] % rhs
  204. else:
  205. return super(PatternLookup, self).get_rhs_op(connection, rhs)
  206. class StartsWith(PatternLookup):
  207. lookup_name = 'startswith'
  208. default_lookups['startswith'] = StartsWith
  209. class IStartsWith(PatternLookup):
  210. lookup_name = 'istartswith'
  211. default_lookups['istartswith'] = IStartsWith
  212. class EndsWith(BuiltinLookup):
  213. lookup_name = 'endswith'
  214. default_lookups['endswith'] = EndsWith
  215. class IEndsWith(BuiltinLookup):
  216. lookup_name = 'iendswith'
  217. default_lookups['iendswith'] = IEndsWith
  218. class Between(BuiltinLookup):
  219. def get_rhs_op(self, connection, rhs):
  220. return "BETWEEN %s AND %s" % (rhs, rhs)
  221. class Year(Between):
  222. lookup_name = 'year'
  223. default_lookups['year'] = Year
  224. class Range(Between):
  225. lookup_name = 'range'
  226. default_lookups['range'] = Range
  227. class DateLookup(BuiltinLookup):
  228. def process_lhs(self, qn, connection, lhs=None):
  229. from django.db.models import DateTimeField
  230. lhs, params = super(DateLookup, self).process_lhs(qn, connection, lhs)
  231. if isinstance(self.lhs.output_field, DateTimeField):
  232. tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
  233. sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
  234. return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
  235. else:
  236. return connection.ops.date_extract_sql(self.lookup_name, lhs), []
  237. def get_rhs_op(self, connection, rhs):
  238. return '= %s' % rhs
  239. class Month(DateLookup):
  240. lookup_name = 'month'
  241. extract_type = 'month'
  242. default_lookups['month'] = Month
  243. class Day(DateLookup):
  244. lookup_name = 'day'
  245. extract_type = 'day'
  246. default_lookups['day'] = Day
  247. class WeekDay(DateLookup):
  248. lookup_name = 'week_day'
  249. extract_type = 'week_day'
  250. default_lookups['week_day'] = WeekDay
  251. class Hour(DateLookup):
  252. lookup_name = 'hour'
  253. extract_type = 'hour'
  254. default_lookups['hour'] = Hour
  255. class Minute(DateLookup):
  256. lookup_name = 'minute'
  257. extract_type = 'minute'
  258. default_lookups['minute'] = Minute
  259. class Second(DateLookup):
  260. lookup_name = 'second'
  261. extract_type = 'second'
  262. default_lookups['second'] = Second
  263. class IsNull(BuiltinLookup):
  264. lookup_name = 'isnull'
  265. def as_sql(self, qn, connection):
  266. sql, params = qn.compile(self.lhs)
  267. if self.rhs:
  268. return "%s IS NULL" % sql, params
  269. else:
  270. return "%s IS NOT NULL" % sql, params
  271. default_lookups['isnull'] = IsNull
  272. class Search(BuiltinLookup):
  273. lookup_name = 'search'
  274. def as_sql(self, qn, connection):
  275. lhs, lhs_params = self.process_lhs(qn, connection)
  276. rhs, rhs_params = self.process_rhs(qn, connection)
  277. sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
  278. return sql_template, lhs_params + rhs_params
  279. default_lookups['search'] = Search
  280. class Regex(BuiltinLookup):
  281. lookup_name = 'regex'
  282. def as_sql(self, qn, connection):
  283. if self.lookup_name in connection.operators:
  284. return super(Regex, self).as_sql(qn, connection)
  285. else:
  286. lhs, lhs_params = self.process_lhs(qn, connection)
  287. rhs, rhs_params = self.process_rhs(qn, connection)
  288. sql_template = connection.ops.regex_lookup(self.lookup_name)
  289. return sql_template % (lhs, rhs), lhs_params + rhs_params
  290. default_lookups['regex'] = Regex
  291. class IRegex(Regex):
  292. lookup_name = 'iregex'
  293. default_lookups['iregex'] = IRegex