__init__.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # -*- coding: utf-8 -*-
  2. """
  3. Django Extensions additional model fields
  4. Some fields might require additional dependencies to be installed.
  5. """
  6. import re
  7. import six
  8. import string
  9. try:
  10. import uuid
  11. HAS_UUID = True
  12. except ImportError:
  13. HAS_UUID = False
  14. try:
  15. import shortuuid
  16. HAS_SHORT_UUID = True
  17. except ImportError:
  18. HAS_SHORT_UUID = False
  19. from django.conf import settings
  20. from django.core.exceptions import ImproperlyConfigured
  21. from django.db.models import DateTimeField, CharField, SlugField, Q
  22. from django.db.models.constants import LOOKUP_SEP
  23. from django.template.defaultfilters import slugify
  24. from django.utils.crypto import get_random_string
  25. from django.utils.encoding import force_str
  26. MAX_UNIQUE_QUERY_ATTEMPTS = getattr(settings, 'EXTENSIONS_MAX_UNIQUE_QUERY_ATTEMPTS', 100)
  27. class UniqueFieldMixin:
  28. def check_is_bool(self, attrname):
  29. if not isinstance(getattr(self, attrname), bool):
  30. raise ValueError("'{}' argument must be True or False".format(attrname))
  31. @staticmethod
  32. def _get_fields(model_cls):
  33. return [
  34. (f, f.model if f.model != model_cls else None) for f in model_cls._meta.get_fields()
  35. if not f.is_relation or f.one_to_one or (f.many_to_one and f.related_model)
  36. ]
  37. def get_queryset(self, model_cls, slug_field):
  38. for field, model in self._get_fields(model_cls):
  39. if model and field == slug_field:
  40. return model._default_manager.all()
  41. return model_cls._default_manager.all()
  42. def find_unique(self, model_instance, field, iterator, *args):
  43. # exclude the current model instance from the queryset used in finding
  44. # next valid hash
  45. queryset = self.get_queryset(model_instance.__class__, field)
  46. if model_instance.pk:
  47. queryset = queryset.exclude(pk=model_instance.pk)
  48. # form a kwarg dict used to implement any unique_together constraints
  49. kwargs = {}
  50. for params in model_instance._meta.unique_together:
  51. if self.attname in params:
  52. for param in params:
  53. kwargs[param] = getattr(model_instance, param, None)
  54. # for support django 2.2+
  55. query = Q()
  56. constraints = getattr(model_instance._meta, 'constraints', None)
  57. if constraints:
  58. for constraint in constraints:
  59. if self.attname in constraint.fields:
  60. condition = {
  61. field: getattr(model_instance, field, None)
  62. for field in constraint.fields
  63. if field != self.attname
  64. }
  65. query &= Q(**condition)
  66. new = six.next(iterator)
  67. kwargs[self.attname] = new
  68. while not new or queryset.filter(query, **kwargs):
  69. new = six.next(iterator)
  70. kwargs[self.attname] = new
  71. setattr(model_instance, self.attname, new)
  72. return new
  73. class AutoSlugField(UniqueFieldMixin, SlugField):
  74. """
  75. AutoSlugField
  76. By default, sets editable=False, blank=True.
  77. Required arguments:
  78. populate_from
  79. Specifies which field, list of fields, or model method
  80. the slug will be populated from.
  81. populate_from can traverse a ForeignKey relationship
  82. by using Django ORM syntax:
  83. populate_from = 'related_model__field'
  84. Optional arguments:
  85. separator
  86. Defines the used separator (default: '-')
  87. overwrite
  88. If set to True, overwrites the slug on every save (default: False)
  89. slugify_function
  90. Defines the function which will be used to "slugify" a content
  91. (default: :py:func:`~django.template.defaultfilters.slugify` )
  92. It is possible to provide custom "slugify" function with
  93. the ``slugify_function`` function in a model class.
  94. ``slugify_function`` function in a model class takes priority over
  95. ``slugify_function`` given as an argument to :py:class:`~AutoSlugField`.
  96. Example
  97. .. code-block:: python
  98. # models.py
  99. from django.db import models
  100. from django_extensions.db.fields import AutoSlugField
  101. class MyModel(models.Model):
  102. def slugify_function(self, content):
  103. return content.replace('_', '-').lower()
  104. title = models.CharField(max_length=42)
  105. slug = AutoSlugField(populate_from='title')
  106. Inspired by SmileyChris' Unique Slugify snippet:
  107. http://www.djangosnippets.org/snippets/690/
  108. """
  109. def __init__(self, *args, **kwargs):
  110. kwargs.setdefault('blank', True)
  111. kwargs.setdefault('editable', False)
  112. populate_from = kwargs.pop('populate_from', None)
  113. if populate_from is None:
  114. raise ValueError("missing 'populate_from' argument")
  115. else:
  116. self._populate_from = populate_from
  117. if not callable(populate_from):
  118. if not isinstance(populate_from, (list, tuple)):
  119. populate_from = (populate_from, )
  120. if not all(isinstance(e, six.string_types) for e in populate_from):
  121. raise TypeError("'populate_from' must be str or list[str] or tuple[str], found `%s`" % populate_from)
  122. self.slugify_function = kwargs.pop('slugify_function', slugify)
  123. self.separator = kwargs.pop('separator', six.u('-'))
  124. self.overwrite = kwargs.pop('overwrite', False)
  125. self.check_is_bool('overwrite')
  126. self.overwrite_on_add = kwargs.pop('overwrite_on_add', True)
  127. self.check_is_bool('overwrite_on_add')
  128. self.allow_duplicates = kwargs.pop('allow_duplicates', False)
  129. self.check_is_bool('allow_duplicates')
  130. self.max_unique_query_attempts = kwargs.pop('max_unique_query_attempts', MAX_UNIQUE_QUERY_ATTEMPTS)
  131. super().__init__(*args, **kwargs)
  132. def _slug_strip(self, value):
  133. """
  134. Clean up a slug by removing slug separator characters that occur at
  135. the beginning or end of a slug.
  136. If an alternate separator is used, it will also replace any instances
  137. of the default '-' separator with the new separator.
  138. """
  139. re_sep = '(?:-|%s)' % re.escape(self.separator)
  140. value = re.sub('%s+' % re_sep, self.separator, value)
  141. return re.sub(r'^%s+|%s+$' % (re_sep, re_sep), '', value)
  142. @staticmethod
  143. def slugify_func(content, slugify_function):
  144. if content:
  145. return slugify_function(content)
  146. return ''
  147. def slug_generator(self, original_slug, start):
  148. yield original_slug
  149. for i in range(start, self.max_unique_query_attempts):
  150. slug = original_slug
  151. end = '%s%s' % (self.separator, i)
  152. end_len = len(end)
  153. if self.slug_len and len(slug) + end_len > self.slug_len:
  154. slug = slug[:self.slug_len - end_len]
  155. slug = self._slug_strip(slug)
  156. slug = '%s%s' % (slug, end)
  157. yield slug
  158. raise RuntimeError('max slug attempts for %s exceeded (%s)' % (original_slug, self.max_unique_query_attempts))
  159. def create_slug(self, model_instance, add):
  160. slug = getattr(model_instance, self.attname)
  161. use_existing_slug = False
  162. if slug and not self.overwrite:
  163. # Existing slug and not configured to overwrite - Short-circuit
  164. # here to prevent slug generation when not required.
  165. use_existing_slug = True
  166. if self.overwrite_on_add and add:
  167. use_existing_slug = False
  168. if use_existing_slug:
  169. return slug
  170. # get fields to populate from and slug field to set
  171. populate_from = self._populate_from
  172. if not isinstance(populate_from, (list, tuple)):
  173. populate_from = (populate_from, )
  174. slug_field = model_instance._meta.get_field(self.attname)
  175. slugify_function = getattr(model_instance, 'slugify_function', self.slugify_function)
  176. # slugify the original field content and set next step to 2
  177. slug_for_field = lambda lookup_value: self.slugify_func(
  178. self.get_slug_fields(model_instance, lookup_value),
  179. slugify_function=slugify_function
  180. )
  181. slug = self.separator.join(map(slug_for_field, populate_from))
  182. start = 2
  183. # strip slug depending on max_length attribute of the slug field
  184. # and clean-up
  185. self.slug_len = slug_field.max_length
  186. if self.slug_len:
  187. slug = slug[:self.slug_len]
  188. slug = self._slug_strip(slug)
  189. original_slug = slug
  190. if self.allow_duplicates:
  191. setattr(model_instance, self.attname, slug)
  192. return slug
  193. return self.find_unique(
  194. model_instance, slug_field, self.slug_generator(original_slug, start))
  195. def get_slug_fields(self, model_instance, lookup_value):
  196. if callable(lookup_value):
  197. # A function has been provided
  198. return "%s" % lookup_value(model_instance)
  199. lookup_value_path = lookup_value.split(LOOKUP_SEP)
  200. attr = model_instance
  201. for elem in lookup_value_path:
  202. try:
  203. attr = getattr(attr, elem)
  204. except AttributeError:
  205. raise AttributeError(
  206. "value {} in AutoSlugField's 'populate_from' argument {} returned an error - {} has no attribute {}".format(
  207. elem, lookup_value, attr, elem))
  208. if callable(attr):
  209. return "%s" % attr()
  210. return attr
  211. def pre_save(self, model_instance, add):
  212. value = force_str(self.create_slug(model_instance, add))
  213. return value
  214. def get_internal_type(self):
  215. return "SlugField"
  216. def deconstruct(self):
  217. name, path, args, kwargs = super().deconstruct()
  218. kwargs['populate_from'] = self._populate_from
  219. if not self.separator == six.u('-'):
  220. kwargs['separator'] = self.separator
  221. if self.overwrite is not False:
  222. kwargs['overwrite'] = True
  223. if self.allow_duplicates is not False:
  224. kwargs['allow_duplicates'] = True
  225. return name, path, args, kwargs
  226. class RandomCharField(UniqueFieldMixin, CharField):
  227. """
  228. RandomCharField
  229. By default, sets editable=False, blank=True, unique=False.
  230. Required arguments:
  231. length
  232. Specifies the length of the field
  233. Optional arguments:
  234. unique
  235. If set to True, duplicate entries are not allowed (default: False)
  236. lowercase
  237. If set to True, lowercase the alpha characters (default: False)
  238. uppercase
  239. If set to True, uppercase the alpha characters (default: False)
  240. include_alpha
  241. If set to True, include alpha characters (default: True)
  242. include_digits
  243. If set to True, include digit characters (default: True)
  244. include_punctuation
  245. If set to True, include punctuation characters (default: False)
  246. """
  247. def __init__(self, *args, **kwargs):
  248. kwargs.setdefault('blank', True)
  249. kwargs.setdefault('editable', False)
  250. self.length = kwargs.pop('length', None)
  251. if self.length is None:
  252. raise ValueError("missing 'length' argument")
  253. kwargs['max_length'] = self.length
  254. self.lowercase = kwargs.pop('lowercase', False)
  255. self.check_is_bool('lowercase')
  256. self.uppercase = kwargs.pop('uppercase', False)
  257. self.check_is_bool('uppercase')
  258. if self.uppercase and self.lowercase:
  259. raise ValueError("the 'lowercase' and 'uppercase' arguments are mutually exclusive")
  260. self.include_digits = kwargs.pop('include_digits', True)
  261. self.check_is_bool('include_digits')
  262. self.include_alpha = kwargs.pop('include_alpha', True)
  263. self.check_is_bool('include_alpha')
  264. self.include_punctuation = kwargs.pop('include_punctuation', False)
  265. self.check_is_bool('include_punctuation')
  266. self.max_unique_query_attempts = kwargs.pop('max_unique_query_attempts', MAX_UNIQUE_QUERY_ATTEMPTS)
  267. # Set unique=False unless it's been set manually.
  268. if 'unique' not in kwargs:
  269. kwargs['unique'] = False
  270. super().__init__(*args, **kwargs)
  271. def random_char_generator(self, chars):
  272. for i in range(self.max_unique_query_attempts):
  273. yield ''.join(get_random_string(self.length, chars))
  274. raise RuntimeError('max random character attempts exceeded (%s)' % self.max_unique_query_attempts)
  275. def in_unique_together(self, model_instance):
  276. for params in model_instance._meta.unique_together:
  277. if self.attname in params:
  278. return True
  279. return False
  280. def pre_save(self, model_instance, add):
  281. if not add and getattr(model_instance, self.attname) != '':
  282. return getattr(model_instance, self.attname)
  283. population = ''
  284. if self.include_alpha:
  285. if self.lowercase:
  286. population += string.ascii_lowercase
  287. elif self.uppercase:
  288. population += string.ascii_uppercase
  289. else:
  290. population += string.ascii_letters
  291. if self.include_digits:
  292. population += string.digits
  293. if self.include_punctuation:
  294. population += string.punctuation
  295. random_chars = self.random_char_generator(population)
  296. if not self.unique and not self.in_unique_together(model_instance):
  297. new = six.next(random_chars)
  298. setattr(model_instance, self.attname, new)
  299. return new
  300. return self.find_unique(
  301. model_instance,
  302. model_instance._meta.get_field(self.attname),
  303. random_chars,
  304. )
  305. def internal_type(self):
  306. return "CharField"
  307. def deconstruct(self):
  308. name, path, args, kwargs = super().deconstruct()
  309. kwargs['length'] = self.length
  310. del kwargs['max_length']
  311. if self.lowercase is True:
  312. kwargs['lowercase'] = self.lowercase
  313. if self.uppercase is True:
  314. kwargs['uppercase'] = self.uppercase
  315. if self.include_alpha is False:
  316. kwargs['include_alpha'] = self.include_alpha
  317. if self.include_digits is False:
  318. kwargs['include_digits'] = self.include_digits
  319. if self.include_punctuation is True:
  320. kwargs['include_punctuation'] = self.include_punctuation
  321. if self.unique is True:
  322. kwargs['unique'] = self.unique
  323. return name, path, args, kwargs
  324. class CreationDateTimeField(DateTimeField):
  325. """
  326. CreationDateTimeField
  327. By default, sets editable=False, blank=True, auto_now_add=True
  328. """
  329. def __init__(self, *args, **kwargs):
  330. kwargs.setdefault('editable', False)
  331. kwargs.setdefault('blank', True)
  332. kwargs.setdefault('auto_now_add', True)
  333. DateTimeField.__init__(self, *args, **kwargs)
  334. def get_internal_type(self):
  335. return "DateTimeField"
  336. def deconstruct(self):
  337. name, path, args, kwargs = super().deconstruct()
  338. if self.editable is not False:
  339. kwargs['editable'] = True
  340. if self.blank is not True:
  341. kwargs['blank'] = False
  342. if self.auto_now_add is not False:
  343. kwargs['auto_now_add'] = True
  344. return name, path, args, kwargs
  345. class ModificationDateTimeField(CreationDateTimeField):
  346. """
  347. ModificationDateTimeField
  348. By default, sets editable=False, blank=True, auto_now=True
  349. Sets value to now every time the object is saved.
  350. """
  351. def __init__(self, *args, **kwargs):
  352. kwargs.setdefault('auto_now', True)
  353. DateTimeField.__init__(self, *args, **kwargs)
  354. def get_internal_type(self):
  355. return "DateTimeField"
  356. def deconstruct(self):
  357. name, path, args, kwargs = super().deconstruct()
  358. if self.auto_now is not False:
  359. kwargs['auto_now'] = True
  360. return name, path, args, kwargs
  361. def pre_save(self, model_instance, add):
  362. if not getattr(model_instance, 'update_modified', True):
  363. return getattr(model_instance, self.attname)
  364. return super().pre_save(model_instance, add)
  365. class UUIDVersionError(Exception):
  366. pass
  367. class UUIDFieldMixin:
  368. """
  369. UUIDFieldMixin
  370. By default uses UUID version 4 (randomly generated UUID).
  371. The field support all uuid versions which are natively supported by the uuid python module, except version 2.
  372. For more information see: http://docs.python.org/lib/module-uuid.html
  373. """
  374. DEFAULT_MAX_LENGTH = 36
  375. def __init__(self, verbose_name=None, name=None, auto=True, version=4,
  376. node=None, clock_seq=None, namespace=None, uuid_name=None, *args,
  377. **kwargs):
  378. if not HAS_UUID:
  379. raise ImproperlyConfigured("'uuid' module is required for UUIDField. (Do you have Python 2.5 or higher installed ?)")
  380. kwargs.setdefault('max_length', self.DEFAULT_MAX_LENGTH)
  381. if auto:
  382. self.empty_strings_allowed = False
  383. kwargs['blank'] = True
  384. kwargs.setdefault('editable', False)
  385. self.auto = auto
  386. self.version = version
  387. self.node = node
  388. self.clock_seq = clock_seq
  389. self.namespace = namespace
  390. self.uuid_name = uuid_name or name
  391. super().__init__(verbose_name=verbose_name, *args, **kwargs)
  392. def create_uuid(self):
  393. if not self.version or self.version == 4:
  394. return uuid.uuid4()
  395. elif self.version == 1:
  396. return uuid.uuid1(self.node, self.clock_seq)
  397. elif self.version == 2:
  398. raise UUIDVersionError("UUID version 2 is not supported.")
  399. elif self.version == 3:
  400. return uuid.uuid3(self.namespace, self.uuid_name)
  401. elif self.version == 5:
  402. return uuid.uuid5(self.namespace, self.uuid_name)
  403. else:
  404. raise UUIDVersionError("UUID version %s is not valid." % self.version)
  405. def pre_save(self, model_instance, add):
  406. value = super().pre_save(model_instance, add)
  407. if self.auto and add and value is None:
  408. value = force_str(self.create_uuid())
  409. setattr(model_instance, self.attname, value)
  410. return value
  411. else:
  412. if self.auto and not value:
  413. value = force_str(self.create_uuid())
  414. setattr(model_instance, self.attname, value)
  415. return value
  416. def formfield(self, **kwargs):
  417. if self.auto:
  418. return None
  419. return super().formfield(**kwargs)
  420. def deconstruct(self):
  421. name, path, args, kwargs = super().deconstruct()
  422. if kwargs.get('max_length', None) == self.DEFAULT_MAX_LENGTH:
  423. del kwargs['max_length']
  424. if self.auto is not True:
  425. kwargs['auto'] = self.auto
  426. if self.version != 4:
  427. kwargs['version'] = self.version
  428. if self.node is not None:
  429. kwargs['node'] = self.node
  430. if self.clock_seq is not None:
  431. kwargs['clock_seq'] = self.clock_seq
  432. if self.namespace is not None:
  433. kwargs['namespace'] = self.namespace
  434. if self.uuid_name is not None:
  435. kwargs['uuid_name'] = self.name
  436. return name, path, args, kwargs
  437. class ShortUUIDField(UUIDFieldMixin, CharField):
  438. """
  439. ShortUUIDFied
  440. Generates concise (22 characters instead of 36), unambiguous, URL-safe UUIDs.
  441. Based on `shortuuid`: https://github.com/stochastic-technologies/shortuuid
  442. """
  443. DEFAULT_MAX_LENGTH = 22
  444. def __init__(self, *args, **kwargs):
  445. super().__init__(*args, **kwargs)
  446. if not HAS_SHORT_UUID:
  447. raise ImproperlyConfigured("'shortuuid' module is required for ShortUUIDField. (Do you have Python 2.5 or higher installed ?)")
  448. kwargs.setdefault('max_length', self.DEFAULT_MAX_LENGTH)
  449. def create_uuid(self):
  450. if not self.version or self.version == 4:
  451. return shortuuid.uuid()
  452. elif self.version == 1:
  453. return shortuuid.uuid()
  454. elif self.version == 2:
  455. raise UUIDVersionError("UUID version 2 is not supported.")
  456. elif self.version == 3:
  457. raise UUIDVersionError("UUID version 3 is not supported.")
  458. elif self.version == 5:
  459. return shortuuid.uuid(name=self.namespace)
  460. else:
  461. raise UUIDVersionError("UUID version %s is not valid." % self.version)