sqlalchemy.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """ SQLAlchemy support. """
  2. from __future__ import absolute_import
  3. import datetime
  4. from types import GeneratorType
  5. import decimal
  6. from sqlalchemy import func
  7. # from sqlalchemy.orm.interfaces import MANYTOONE
  8. from sqlalchemy.orm.collections import InstrumentedList
  9. from sqlalchemy.orm.attributes import InstrumentedAttribute
  10. from sqlalchemy.sql.type_api import TypeDecorator
  11. try:
  12. from sqlalchemy.orm.relationships import RelationshipProperty
  13. except ImportError:
  14. from sqlalchemy.orm.properties import RelationshipProperty
  15. from sqlalchemy.types import (
  16. BIGINT, BOOLEAN, BigInteger, Boolean, CHAR, DATE, DATETIME, DECIMAL, Date,
  17. DateTime, FLOAT, Float, INT, INTEGER, Integer, NCHAR, NVARCHAR, NUMERIC,
  18. Numeric, SMALLINT, SmallInteger, String, TEXT, TIME, Text, Time, Unicode,
  19. UnicodeText, VARCHAR, Enum)
  20. from .. import mix_types as t
  21. from ..main import (
  22. SKIP_VALUE, LOGGER, TypeMixer as BaseTypeMixer, GenFactory as BaseFactory,
  23. Mixer as BaseMixer, partial, faker)
  24. class GenFactory(BaseFactory):
  25. """ Map a sqlalchemy classes to simple types. """
  26. types = {
  27. (String, VARCHAR, Unicode, NVARCHAR, NCHAR, CHAR): str,
  28. (Text, UnicodeText, TEXT): t.Text,
  29. (Boolean, BOOLEAN): bool,
  30. (Date, DATE): datetime.date,
  31. (DateTime, DATETIME): datetime.datetime,
  32. (Time, TIME): datetime.time,
  33. (DECIMAL, Numeric, NUMERIC): decimal.Decimal,
  34. (Float, FLOAT): float,
  35. (Integer, INTEGER, INT): int,
  36. (BigInteger, BIGINT): t.BigInteger,
  37. (SmallInteger, SMALLINT): t.SmallInteger,
  38. }
  39. generators = {
  40. Enum: None
  41. }
  42. class TypeMixer(BaseTypeMixer):
  43. """ TypeMixer for SQLAlchemy. """
  44. factory = GenFactory
  45. def __init__(self, cls, **params):
  46. """ Init TypeMixer and save the mapper. """
  47. super(TypeMixer, self).__init__(cls, **params)
  48. self.mapper = self.__scheme._sa_class_manager.mapper
  49. def postprocess(self, target, postprocess_values):
  50. """ Fill postprocess values. """
  51. mixed = []
  52. for name, deffered in postprocess_values:
  53. value = deffered.value
  54. if isinstance(value, GeneratorType):
  55. value = next(value)
  56. if isinstance(value, t.Mix):
  57. mixed.append((name, value))
  58. continue
  59. if isinstance(getattr(target, name), InstrumentedList) and not isinstance(value, list):
  60. value = [value]
  61. setattr(target, name, value)
  62. for name, mix in mixed:
  63. setattr(target, name, mix & target)
  64. if self.__mixer:
  65. target = self.__mixer.postprocess(target)
  66. return target
  67. @staticmethod
  68. def get_default(field):
  69. """ Get default value from field.
  70. :return value: A default value or NO_VALUE
  71. """
  72. column = field.scheme
  73. if isinstance(column, RelationshipProperty):
  74. column = column.local_remote_pairs[0][0]
  75. if column is None:
  76. return SKIP_VALUE
  77. if not column.default:
  78. return SKIP_VALUE
  79. if column.default.is_callable:
  80. return column.default.arg(None)
  81. return getattr(column.default, 'arg', SKIP_VALUE)
  82. def gen_select(self, field_name, select):
  83. """ Select exists value from database.
  84. :param field_name: Name of field for generation.
  85. :return : None or (name, value) for later use
  86. """
  87. if not self.__mixer or not self.__mixer.params.get('session'):
  88. return field_name, SKIP_VALUE
  89. relation = self.mapper.get_property(field_name)
  90. session = self.__mixer.params.get('session')
  91. value = session.query(
  92. relation.mapper.class_
  93. ).filter(*select.choices).order_by(func.random()).first()
  94. return self.get_value(field_name, value)
  95. @staticmethod
  96. def is_unique(field):
  97. """ Return True is field's value should be a unique.
  98. :return bool:
  99. """
  100. scheme = field.scheme
  101. if isinstance(scheme, RelationshipProperty):
  102. scheme = scheme.local_remote_pairs[0][0]
  103. if scheme is None:
  104. return False
  105. return scheme.unique
  106. def is_required(self, field):
  107. """ Return True is field's value should be defined.
  108. :return bool:
  109. """
  110. if field.params:
  111. return True
  112. column = field.scheme
  113. if isinstance(column, RelationshipProperty):
  114. column = column.local_remote_pairs[0][0]
  115. # According to the SQLAlchemy docs, autoincrement "only has an effect for columns which are
  116. # Integer derived (i.e. INT, SMALLINT, BIGINT) [and] Part of the primary key [...]".
  117. autoincrement = column.autoincrement and column.primary_key and \
  118. isinstance(column.type, Integer)
  119. return not (column.nullable or autoincrement)
  120. def get_value(self, field_name, field_value):
  121. """ Get `value` as `field_name`.
  122. :return : None or (name, value) for later use
  123. """
  124. field = self.__fields.get(field_name)
  125. if field and isinstance(field.scheme, RelationshipProperty):
  126. return field_name, t._Deffered(field_value, field.scheme)
  127. return super(TypeMixer, self).get_value(field_name, field_value)
  128. def make_fabric(self, column, field_name=None, fake=False, kwargs=None): # noqa
  129. """ Make values fabric for column.
  130. :param column: SqlAlchemy column
  131. :param field_name: Field name
  132. :param fake: Force fake data
  133. :return function:
  134. """
  135. kwargs = {} if kwargs is None else kwargs
  136. if column is None:
  137. column = getattr(self.__scheme, field_name, None)
  138. if isinstance(column, InstrumentedAttribute):
  139. column = column.prop
  140. if isinstance(column, RelationshipProperty):
  141. Mixer = type(self)
  142. Model = column.mapper.class_
  143. mixer = Mixer( Model, mixer=self.__mixer, fake=self.__fake, factory=self.__factory)
  144. return partial(mixer.blend, **kwargs)
  145. ftype = type(column.type)
  146. # augmented types created with TypeDecorator
  147. # don't directly inherit from the base types
  148. if TypeDecorator in ftype.__bases__:
  149. ftype = ftype.impl
  150. stype = self.__factory.cls_to_simple(ftype)
  151. if stype is str:
  152. fab = super(TypeMixer, self).make_fabric(
  153. stype, field_name=field_name, fake=fake, kwargs=kwargs)
  154. return lambda: fab()[:column.type.length]
  155. if ftype is Enum:
  156. return partial(faker.random_element, column.type.enums)
  157. return super(TypeMixer, self).make_fabric(
  158. stype, field_name=field_name, fake=fake, kwargs=kwargs)
  159. def guard(self, *args, **kwargs):
  160. """ Look objects in database.
  161. :returns: A finded object or False
  162. """
  163. try:
  164. session = self.__mixer.params.get('session')
  165. assert session
  166. except (AttributeError, AssertionError):
  167. raise ValueError('Cannot make request to DB.')
  168. qs = session.query(self.mapper).filter(*args, **kwargs)
  169. count = qs.count()
  170. if count == 1:
  171. return qs.first()
  172. if count:
  173. return qs.all()
  174. return False
  175. def reload(self, obj):
  176. """ Reload object from database. """
  177. try:
  178. session = self.__mixer.params.get('session')
  179. session.expire(obj)
  180. session.refresh(obj)
  181. return obj
  182. except (AttributeError, AssertionError):
  183. raise ValueError('Cannot make request to DB.')
  184. def populate_target(self, values):
  185. target = self.__scheme()
  186. for n, v in values:
  187. if isinstance(getattr(target, n, None), InstrumentedList) and not isinstance(v, list):
  188. v = [v]
  189. setattr(target, n, v)
  190. return target
  191. def __load_fields(self):
  192. """ Prepare SQLALchemyTypeMixer.
  193. Select columns and relations for data generation.
  194. """
  195. mapper = self.__scheme._sa_class_manager.mapper
  196. relations = set()
  197. if hasattr(mapper, 'relationships'):
  198. for rel in mapper.relationships:
  199. fkeys = any(c.foreign_keys for c in rel.local_columns)
  200. if not fkeys:
  201. continue
  202. relations |= rel.local_columns
  203. yield rel.key, t.Field(rel, rel.key)
  204. for key, column in mapper.columns.items():
  205. if column not in relations:
  206. yield key, t.Field(column, key)
  207. class Mixer(BaseMixer):
  208. """ Integration with SQLAlchemy. """
  209. type_mixer_cls = TypeMixer
  210. def __init__(self, session=None, commit=True, **params):
  211. """Initialize the SQLAlchemy Mixer.
  212. :param fake: (True) Generate fake data instead of random data.
  213. :param session: SQLAlchemy session. Using for commits.
  214. :param commit: (True) Commit instance to session after creation.
  215. """
  216. super(Mixer, self).__init__(**params)
  217. self.params['session'] = session
  218. self.params['commit'] = bool(session) and commit
  219. def postprocess(self, target):
  220. """ Save objects in db.
  221. :return value: A generated value
  222. """
  223. if self.params.get('commit'):
  224. session = self.params.get('session')
  225. if not session:
  226. LOGGER.warning("'commit' set true but session not initialized.")
  227. else:
  228. session.add(target)
  229. session.commit()
  230. return target
  231. # Default mixer
  232. mixer = Mixer()
  233. # pylama:ignore=E1120,E0611