fields.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. import operator
  2. import warnings
  3. import weakref
  4. from bson import DBRef, ObjectId, SON
  5. import pymongo
  6. import six
  7. from mongoengine.base.common import UPDATE_OPERATORS
  8. from mongoengine.base.datastructures import (BaseDict, BaseList,
  9. EmbeddedDocumentList)
  10. from mongoengine.common import _import_class
  11. from mongoengine.errors import ValidationError
  12. __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
  13. 'GeoJsonBaseField')
  14. class BaseField(object):
  15. """A base class for fields in a MongoDB document. Instances of this class
  16. may be added to subclasses of `Document` to define a document's schema.
  17. .. versionchanged:: 0.5 - added verbose and help text
  18. """
  19. name = None
  20. _geo_index = False
  21. _auto_gen = False # Call `generate` to generate a value
  22. _auto_dereference = True
  23. # These track each time a Field instance is created. Used to retain order.
  24. # The auto_creation_counter is used for fields that MongoEngine implicitly
  25. # creates, creation_counter is used for all user-specified fields.
  26. creation_counter = 0
  27. auto_creation_counter = -1
  28. def __init__(self, db_field=None, name=None, required=False, default=None,
  29. unique=False, unique_with=None, primary_key=False,
  30. validation=None, choices=None, null=False, sparse=False,
  31. **kwargs):
  32. """
  33. :param db_field: The database field to store this field in
  34. (defaults to the name of the field)
  35. :param name: Deprecated - use db_field
  36. :param required: If the field is required. Whether it has to have a
  37. value or not. Defaults to False.
  38. :param default: (optional) The default value for this field if no value
  39. has been set (or if the value has been unset). It can be a
  40. callable.
  41. :param unique: Is the field value unique or not. Defaults to False.
  42. :param unique_with: (optional) The other field this field should be
  43. unique with.
  44. :param primary_key: Mark this field as the primary key. Defaults to False.
  45. :param validation: (optional) A callable to validate the value of the
  46. field. Generally this is deprecated in favour of the
  47. `FIELD.validate` method
  48. :param choices: (optional) The valid choices
  49. :param null: (optional) If the field value can be null. If no and there is a default value
  50. then the default value is set
  51. :param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False`
  52. means that uniqueness won't be enforced for `None` values
  53. :param **kwargs: (optional) Arbitrary indirection-free metadata for
  54. this field can be supplied as additional keyword arguments and
  55. accessed as attributes of the field. Must not conflict with any
  56. existing attributes. Common metadata includes `verbose_name` and
  57. `help_text`.
  58. """
  59. self.db_field = (db_field or name) if not primary_key else '_id'
  60. if name:
  61. msg = 'Field\'s "name" attribute deprecated in favour of "db_field"'
  62. warnings.warn(msg, DeprecationWarning)
  63. self.required = required or primary_key
  64. self.default = default
  65. self.unique = bool(unique or unique_with)
  66. self.unique_with = unique_with
  67. self.primary_key = primary_key
  68. self.validation = validation
  69. self.choices = choices
  70. self.null = null
  71. self.sparse = sparse
  72. self._owner_document = None
  73. # Make sure db_field is a string (if it's explicitly defined).
  74. if (
  75. self.db_field is not None and
  76. not isinstance(self.db_field, six.string_types)
  77. ):
  78. raise TypeError('db_field should be a string.')
  79. # Make sure db_field doesn't contain any forbidden characters.
  80. if isinstance(self.db_field, six.string_types) and (
  81. '.' in self.db_field or
  82. '\0' in self.db_field or
  83. self.db_field.startswith('$')
  84. ):
  85. raise ValueError(
  86. 'field names cannot contain dots (".") or null characters '
  87. '("\\0"), and they must not start with a dollar sign ("$").'
  88. )
  89. # Detect and report conflicts between metadata and base properties.
  90. conflicts = set(dir(self)) & set(kwargs)
  91. if conflicts:
  92. raise TypeError('%s already has attribute(s): %s' % (
  93. self.__class__.__name__, ', '.join(conflicts)))
  94. # Assign metadata to the instance
  95. # This efficient method is available because no __slots__ are defined.
  96. self.__dict__.update(kwargs)
  97. # Adjust the appropriate creation counter, and save our local copy.
  98. if self.db_field == '_id':
  99. self.creation_counter = BaseField.auto_creation_counter
  100. BaseField.auto_creation_counter -= 1
  101. else:
  102. self.creation_counter = BaseField.creation_counter
  103. BaseField.creation_counter += 1
  104. def __get__(self, instance, owner):
  105. """Descriptor for retrieving a value from a field in a document.
  106. """
  107. if instance is None:
  108. # Document class being used rather than a document object
  109. return self
  110. # Get value from document instance if available
  111. return instance._data.get(self.name)
  112. def __set__(self, instance, value):
  113. """Descriptor for assigning a value to a field in a document.
  114. """
  115. # If setting to None and there is a default
  116. # Then set the value to the default value
  117. if value is None:
  118. if self.null:
  119. value = None
  120. elif self.default is not None:
  121. value = self.default
  122. if callable(value):
  123. value = value()
  124. if instance._initialised:
  125. try:
  126. if (self.name not in instance._data or
  127. instance._data[self.name] != value):
  128. instance._mark_as_changed(self.name)
  129. except Exception:
  130. # Values cant be compared eg: naive and tz datetimes
  131. # So mark it as changed
  132. instance._mark_as_changed(self.name)
  133. EmbeddedDocument = _import_class('EmbeddedDocument')
  134. if isinstance(value, EmbeddedDocument):
  135. value._instance = weakref.proxy(instance)
  136. elif isinstance(value, (list, tuple)):
  137. for v in value:
  138. if isinstance(v, EmbeddedDocument):
  139. v._instance = weakref.proxy(instance)
  140. instance._data[self.name] = value
  141. def error(self, message='', errors=None, field_name=None):
  142. """Raise a ValidationError."""
  143. field_name = field_name if field_name else self.name
  144. raise ValidationError(message, errors=errors, field_name=field_name)
  145. def to_python(self, value):
  146. """Convert a MongoDB-compatible type to a Python type."""
  147. return value
  148. def to_mongo(self, value):
  149. """Convert a Python type to a MongoDB-compatible type."""
  150. return self.to_python(value)
  151. def _to_mongo_safe_call(self, value, use_db_field=True, fields=None):
  152. """Helper method to call to_mongo with proper inputs."""
  153. f_inputs = self.to_mongo.__code__.co_varnames
  154. ex_vars = {}
  155. if 'fields' in f_inputs:
  156. ex_vars['fields'] = fields
  157. if 'use_db_field' in f_inputs:
  158. ex_vars['use_db_field'] = use_db_field
  159. return self.to_mongo(value, **ex_vars)
  160. def prepare_query_value(self, op, value):
  161. """Prepare a value that is being used in a query for PyMongo."""
  162. if op in UPDATE_OPERATORS:
  163. self.validate(value)
  164. return value
  165. def validate(self, value, clean=True):
  166. """Perform validation on a value."""
  167. pass
  168. def _validate_choices(self, value):
  169. Document = _import_class('Document')
  170. EmbeddedDocument = _import_class('EmbeddedDocument')
  171. choice_list = self.choices
  172. if isinstance(next(iter(choice_list)), (list, tuple)):
  173. # next(iter) is useful for sets
  174. choice_list = [k for k, _ in choice_list]
  175. # Choices which are other types of Documents
  176. if isinstance(value, (Document, EmbeddedDocument)):
  177. if not any(isinstance(value, c) for c in choice_list):
  178. self.error(
  179. 'Value must be an instance of %s' % (
  180. six.text_type(choice_list)
  181. )
  182. )
  183. # Choices which are types other than Documents
  184. else:
  185. values = value if isinstance(value, (list, tuple)) else [value]
  186. if len(set(values) - set(choice_list)):
  187. self.error('Value must be one of %s' % six.text_type(choice_list))
  188. def _validate(self, value, **kwargs):
  189. # Check the Choices Constraint
  190. if self.choices:
  191. self._validate_choices(value)
  192. # check validation argument
  193. if self.validation is not None:
  194. if callable(self.validation):
  195. if not self.validation(value):
  196. self.error('Value does not match custom validation method')
  197. else:
  198. raise ValueError('validation argument for "%s" must be a '
  199. 'callable.' % self.name)
  200. self.validate(value, **kwargs)
  201. @property
  202. def owner_document(self):
  203. return self._owner_document
  204. def _set_owner_document(self, owner_document):
  205. self._owner_document = owner_document
  206. @owner_document.setter
  207. def owner_document(self, owner_document):
  208. self._set_owner_document(owner_document)
  209. class ComplexBaseField(BaseField):
  210. """Handles complex fields, such as lists / dictionaries.
  211. Allows for nesting of embedded documents inside complex types.
  212. Handles the lazy dereferencing of a queryset by lazily dereferencing all
  213. items in a list / dict rather than one at a time.
  214. .. versionadded:: 0.5
  215. """
  216. field = None
  217. def __get__(self, instance, owner):
  218. """Descriptor to automatically dereference references."""
  219. if instance is None:
  220. # Document class being used rather than a document object
  221. return self
  222. ReferenceField = _import_class('ReferenceField')
  223. GenericReferenceField = _import_class('GenericReferenceField')
  224. EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
  225. auto_dereference = instance._fields[self.name]._auto_dereference
  226. dereference = (auto_dereference and
  227. (self.field is None or isinstance(self.field,
  228. (GenericReferenceField, ReferenceField))))
  229. _dereference = _import_class('DeReference')()
  230. if instance._initialised and dereference and instance._data.get(self.name):
  231. instance._data[self.name] = _dereference(
  232. instance._data.get(self.name), max_depth=1, instance=instance,
  233. name=self.name
  234. )
  235. value = super(ComplexBaseField, self).__get__(instance, owner)
  236. # Convert lists / values so we can watch for any changes on them
  237. if isinstance(value, (list, tuple)):
  238. if (issubclass(type(self), EmbeddedDocumentListField) and
  239. not isinstance(value, EmbeddedDocumentList)):
  240. value = EmbeddedDocumentList(value, instance, self.name)
  241. elif not isinstance(value, BaseList):
  242. value = BaseList(value, instance, self.name)
  243. instance._data[self.name] = value
  244. elif isinstance(value, dict) and not isinstance(value, BaseDict):
  245. value = BaseDict(value, instance, self.name)
  246. instance._data[self.name] = value
  247. if (auto_dereference and instance._initialised and
  248. isinstance(value, (BaseList, BaseDict)) and
  249. not value._dereferenced):
  250. value = _dereference(
  251. value, max_depth=1, instance=instance, name=self.name
  252. )
  253. value._dereferenced = True
  254. instance._data[self.name] = value
  255. return value
  256. def to_python(self, value):
  257. """Convert a MongoDB-compatible type to a Python type."""
  258. if isinstance(value, six.string_types):
  259. return value
  260. if hasattr(value, 'to_python'):
  261. return value.to_python()
  262. BaseDocument = _import_class('BaseDocument')
  263. if isinstance(value, BaseDocument):
  264. # Something is wrong, return the value as it is
  265. return value
  266. is_list = False
  267. if not hasattr(value, 'items'):
  268. try:
  269. is_list = True
  270. value = {idx: v for idx, v in enumerate(value)}
  271. except TypeError: # Not iterable return the value
  272. return value
  273. if self.field:
  274. self.field._auto_dereference = self._auto_dereference
  275. for key, item in value.items():
  276. self.field.to_python(item)
  277. value_dict = {key: self.field.to_python(item)
  278. for key, item in value.items()}
  279. else:
  280. Document = _import_class('Document')
  281. value_dict = {}
  282. for k, v in value.items():
  283. if isinstance(v, Document):
  284. # We need the id from the saved object to create the DBRef
  285. if v.pk is None:
  286. self.error('You can only reference documents once they'
  287. ' have been saved to the database')
  288. collection = v._get_collection_name()
  289. value_dict[k] = DBRef(collection, v.pk)
  290. elif hasattr(v, 'to_python'):
  291. value_dict[k] = v.to_python()
  292. else:
  293. value_dict[k] = self.to_python(v)
  294. if is_list: # Convert back to a list
  295. return [v for _, v in sorted(value_dict.items(),
  296. key=operator.itemgetter(0))]
  297. return value_dict
  298. def to_mongo(self, value, use_db_field=True, fields=None):
  299. """Convert a Python type to a MongoDB-compatible type."""
  300. Document = _import_class('Document')
  301. EmbeddedDocument = _import_class('EmbeddedDocument')
  302. GenericReferenceField = _import_class('GenericReferenceField')
  303. if isinstance(value, six.string_types):
  304. return value
  305. if hasattr(value, 'to_mongo'):
  306. if isinstance(value, Document):
  307. return GenericReferenceField().to_mongo(value)
  308. cls = value.__class__
  309. val = value.to_mongo(use_db_field, fields)
  310. # If it's a document that is not inherited add _cls
  311. if isinstance(value, EmbeddedDocument):
  312. val['_cls'] = cls.__name__
  313. return val
  314. is_list = False
  315. if not hasattr(value, 'items'):
  316. try:
  317. is_list = True
  318. value = {k: v for k, v in enumerate(value)}
  319. except TypeError: # Not iterable return the value
  320. return value
  321. if self.field:
  322. value_dict = {
  323. key: self.field._to_mongo_safe_call(item, use_db_field, fields)
  324. for key, item in value.iteritems()
  325. }
  326. else:
  327. value_dict = {}
  328. for k, v in value.iteritems():
  329. if isinstance(v, Document):
  330. # We need the id from the saved object to create the DBRef
  331. if v.pk is None:
  332. self.error('You can only reference documents once they'
  333. ' have been saved to the database')
  334. # If its a document that is not inheritable it won't have
  335. # any _cls data so make it a generic reference allows
  336. # us to dereference
  337. meta = getattr(v, '_meta', {})
  338. allow_inheritance = meta.get('allow_inheritance')
  339. if not allow_inheritance and not self.field:
  340. value_dict[k] = GenericReferenceField().to_mongo(v)
  341. else:
  342. collection = v._get_collection_name()
  343. value_dict[k] = DBRef(collection, v.pk)
  344. elif hasattr(v, 'to_mongo'):
  345. cls = v.__class__
  346. val = v.to_mongo(use_db_field, fields)
  347. # If it's a document that is not inherited add _cls
  348. if isinstance(v, (Document, EmbeddedDocument)):
  349. val['_cls'] = cls.__name__
  350. value_dict[k] = val
  351. else:
  352. value_dict[k] = self.to_mongo(v, use_db_field, fields)
  353. if is_list: # Convert back to a list
  354. return [v for _, v in sorted(value_dict.items(),
  355. key=operator.itemgetter(0))]
  356. return value_dict
  357. def validate(self, value):
  358. """If field is provided ensure the value is valid."""
  359. errors = {}
  360. if self.field:
  361. if hasattr(value, 'iteritems') or hasattr(value, 'items'):
  362. sequence = value.iteritems()
  363. else:
  364. sequence = enumerate(value)
  365. for k, v in sequence:
  366. try:
  367. self.field._validate(v)
  368. except ValidationError as error:
  369. errors[k] = error.errors or error
  370. except (ValueError, AssertionError) as error:
  371. errors[k] = error
  372. if errors:
  373. field_class = self.field.__class__.__name__
  374. self.error('Invalid %s item (%s)' % (field_class, value),
  375. errors=errors)
  376. # Don't allow empty values if required
  377. if self.required and not value:
  378. self.error('Field is required and cannot be empty')
  379. def prepare_query_value(self, op, value):
  380. return self.to_mongo(value)
  381. def lookup_member(self, member_name):
  382. if self.field:
  383. return self.field.lookup_member(member_name)
  384. return None
  385. def _set_owner_document(self, owner_document):
  386. if self.field:
  387. self.field.owner_document = owner_document
  388. self._owner_document = owner_document
  389. class ObjectIdField(BaseField):
  390. """A field wrapper around MongoDB's ObjectIds."""
  391. def to_python(self, value):
  392. try:
  393. if not isinstance(value, ObjectId):
  394. value = ObjectId(value)
  395. except Exception:
  396. pass
  397. return value
  398. def to_mongo(self, value):
  399. if not isinstance(value, ObjectId):
  400. try:
  401. return ObjectId(six.text_type(value))
  402. except Exception as e:
  403. # e.message attribute has been deprecated since Python 2.6
  404. self.error(six.text_type(e))
  405. return value
  406. def prepare_query_value(self, op, value):
  407. return self.to_mongo(value)
  408. def validate(self, value):
  409. try:
  410. ObjectId(six.text_type(value))
  411. except Exception:
  412. self.error('Invalid Object ID')
  413. class GeoJsonBaseField(BaseField):
  414. """A geo json field storing a geojson style object.
  415. .. versionadded:: 0.8
  416. """
  417. _geo_index = pymongo.GEOSPHERE
  418. _type = 'GeoBase'
  419. def __init__(self, auto_index=True, *args, **kwargs):
  420. """
  421. :param bool auto_index: Automatically create a '2dsphere' index.\
  422. Defaults to `True`.
  423. """
  424. self._name = '%sField' % self._type
  425. if not auto_index:
  426. self._geo_index = False
  427. super(GeoJsonBaseField, self).__init__(*args, **kwargs)
  428. def validate(self, value):
  429. """Validate the GeoJson object based on its type."""
  430. if isinstance(value, dict):
  431. if set(value.keys()) == {'type', 'coordinates'}:
  432. if value['type'] != self._type:
  433. self.error('%s type must be "%s"' %
  434. (self._name, self._type))
  435. return self.validate(value['coordinates'])
  436. else:
  437. self.error('%s can only accept a valid GeoJson dictionary'
  438. ' or lists of (x, y)' % self._name)
  439. return
  440. elif not isinstance(value, (list, tuple)):
  441. self.error('%s can only accept lists of [x, y]' % self._name)
  442. return
  443. validate = getattr(self, '_validate_%s' % self._type.lower())
  444. error = validate(value)
  445. if error:
  446. self.error(error)
  447. def _validate_polygon(self, value, top_level=True):
  448. if not isinstance(value, (list, tuple)):
  449. return 'Polygons must contain list of linestrings'
  450. # Quick and dirty validator
  451. try:
  452. value[0][0][0]
  453. except (TypeError, IndexError):
  454. return 'Invalid Polygon must contain at least one valid linestring'
  455. errors = []
  456. for val in value:
  457. error = self._validate_linestring(val, False)
  458. if not error and val[0] != val[-1]:
  459. error = 'LineStrings must start and end at the same point'
  460. if error and error not in errors:
  461. errors.append(error)
  462. if errors:
  463. if top_level:
  464. return 'Invalid Polygon:\n%s' % ', '.join(errors)
  465. else:
  466. return '%s' % ', '.join(errors)
  467. def _validate_linestring(self, value, top_level=True):
  468. """Validate a linestring."""
  469. if not isinstance(value, (list, tuple)):
  470. return 'LineStrings must contain list of coordinate pairs'
  471. # Quick and dirty validator
  472. try:
  473. value[0][0]
  474. except (TypeError, IndexError):
  475. return 'Invalid LineString must contain at least one valid point'
  476. errors = []
  477. for val in value:
  478. error = self._validate_point(val)
  479. if error and error not in errors:
  480. errors.append(error)
  481. if errors:
  482. if top_level:
  483. return 'Invalid LineString:\n%s' % ', '.join(errors)
  484. else:
  485. return '%s' % ', '.join(errors)
  486. def _validate_point(self, value):
  487. """Validate each set of coords"""
  488. if not isinstance(value, (list, tuple)):
  489. return 'Points must be a list of coordinate pairs'
  490. elif not len(value) == 2:
  491. return 'Value (%s) must be a two-dimensional point' % repr(value)
  492. elif (not isinstance(value[0], (float, int)) or
  493. not isinstance(value[1], (float, int))):
  494. return 'Both values (%s) in point must be float or int' % repr(value)
  495. def _validate_multipoint(self, value):
  496. if not isinstance(value, (list, tuple)):
  497. return 'MultiPoint must be a list of Point'
  498. # Quick and dirty validator
  499. try:
  500. value[0][0]
  501. except (TypeError, IndexError):
  502. return 'Invalid MultiPoint must contain at least one valid point'
  503. errors = []
  504. for point in value:
  505. error = self._validate_point(point)
  506. if error and error not in errors:
  507. errors.append(error)
  508. if errors:
  509. return '%s' % ', '.join(errors)
  510. def _validate_multilinestring(self, value, top_level=True):
  511. if not isinstance(value, (list, tuple)):
  512. return 'MultiLineString must be a list of LineString'
  513. # Quick and dirty validator
  514. try:
  515. value[0][0][0]
  516. except (TypeError, IndexError):
  517. return 'Invalid MultiLineString must contain at least one valid linestring'
  518. errors = []
  519. for linestring in value:
  520. error = self._validate_linestring(linestring, False)
  521. if error and error not in errors:
  522. errors.append(error)
  523. if errors:
  524. if top_level:
  525. return 'Invalid MultiLineString:\n%s' % ', '.join(errors)
  526. else:
  527. return '%s' % ', '.join(errors)
  528. def _validate_multipolygon(self, value):
  529. if not isinstance(value, (list, tuple)):
  530. return 'MultiPolygon must be a list of Polygon'
  531. # Quick and dirty validator
  532. try:
  533. value[0][0][0][0]
  534. except (TypeError, IndexError):
  535. return 'Invalid MultiPolygon must contain at least one valid Polygon'
  536. errors = []
  537. for polygon in value:
  538. error = self._validate_polygon(polygon, False)
  539. if error and error not in errors:
  540. errors.append(error)
  541. if errors:
  542. return 'Invalid MultiPolygon:\n%s' % ', '.join(errors)
  543. def to_mongo(self, value):
  544. if isinstance(value, dict):
  545. return value
  546. return SON([('type', self._type), ('coordinates', value)])