""" SQLAlchemy support. """ from __future__ import absolute_import import datetime from types import GeneratorType import decimal from sqlalchemy import func # from sqlalchemy.orm.interfaces import MANYTOONE from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql.type_api import TypeDecorator try: from sqlalchemy.orm.relationships import RelationshipProperty except ImportError: from sqlalchemy.orm.properties import RelationshipProperty from sqlalchemy.types import ( BIGINT, BOOLEAN, BigInteger, Boolean, CHAR, DATE, DATETIME, DECIMAL, Date, DateTime, FLOAT, Float, INT, INTEGER, Integer, NCHAR, NVARCHAR, NUMERIC, Numeric, SMALLINT, SmallInteger, String, TEXT, TIME, Text, Time, Unicode, UnicodeText, VARCHAR, Enum) from .. import mix_types as t from ..main import ( SKIP_VALUE, LOGGER, TypeMixer as BaseTypeMixer, GenFactory as BaseFactory, Mixer as BaseMixer, partial, faker) class GenFactory(BaseFactory): """ Map a sqlalchemy classes to simple types. """ types = { (String, VARCHAR, Unicode, NVARCHAR, NCHAR, CHAR): str, (Text, UnicodeText, TEXT): t.Text, (Boolean, BOOLEAN): bool, (Date, DATE): datetime.date, (DateTime, DATETIME): datetime.datetime, (Time, TIME): datetime.time, (DECIMAL, Numeric, NUMERIC): decimal.Decimal, (Float, FLOAT): float, (Integer, INTEGER, INT): int, (BigInteger, BIGINT): t.BigInteger, (SmallInteger, SMALLINT): t.SmallInteger, } generators = { Enum: None } class TypeMixer(BaseTypeMixer): """ TypeMixer for SQLAlchemy. """ factory = GenFactory def __init__(self, cls, **params): """ Init TypeMixer and save the mapper. """ super(TypeMixer, self).__init__(cls, **params) self.mapper = self.__scheme._sa_class_manager.mapper def postprocess(self, target, postprocess_values): """ Fill postprocess values. """ mixed = [] for name, deffered in postprocess_values: value = deffered.value if isinstance(value, GeneratorType): value = next(value) if isinstance(value, t.Mix): mixed.append((name, value)) continue if isinstance(getattr(target, name), InstrumentedList) and not isinstance(value, list): value = [value] setattr(target, name, value) for name, mix in mixed: setattr(target, name, mix & target) if self.__mixer: target = self.__mixer.postprocess(target) return target @staticmethod def get_default(field): """ Get default value from field. :return value: A default value or NO_VALUE """ column = field.scheme if isinstance(column, RelationshipProperty): column = column.local_remote_pairs[0][0] if column is None: return SKIP_VALUE if not column.default: return SKIP_VALUE if column.default.is_callable: return column.default.arg(None) return getattr(column.default, 'arg', SKIP_VALUE) def gen_select(self, field_name, select): """ Select exists value from database. :param field_name: Name of field for generation. :return : None or (name, value) for later use """ if not self.__mixer or not self.__mixer.params.get('session'): return field_name, SKIP_VALUE relation = self.mapper.get_property(field_name) session = self.__mixer.params.get('session') value = session.query( relation.mapper.class_ ).filter(*select.choices).order_by(func.random()).first() return self.get_value(field_name, value) @staticmethod def is_unique(field): """ Return True is field's value should be a unique. :return bool: """ scheme = field.scheme if isinstance(scheme, RelationshipProperty): scheme = scheme.local_remote_pairs[0][0] if scheme is None: return False return scheme.unique def is_required(self, field): """ Return True is field's value should be defined. :return bool: """ if field.params: return True column = field.scheme if isinstance(column, RelationshipProperty): column = column.local_remote_pairs[0][0] # According to the SQLAlchemy docs, autoincrement "only has an effect for columns which are # Integer derived (i.e. INT, SMALLINT, BIGINT) [and] Part of the primary key [...]". autoincrement = column.autoincrement and column.primary_key and \ isinstance(column.type, Integer) return not (column.nullable or autoincrement) def get_value(self, field_name, field_value): """ Get `value` as `field_name`. :return : None or (name, value) for later use """ field = self.__fields.get(field_name) if field and isinstance(field.scheme, RelationshipProperty): return field_name, t._Deffered(field_value, field.scheme) return super(TypeMixer, self).get_value(field_name, field_value) def make_fabric(self, column, field_name=None, fake=False, kwargs=None): # noqa """ Make values fabric for column. :param column: SqlAlchemy column :param field_name: Field name :param fake: Force fake data :return function: """ kwargs = {} if kwargs is None else kwargs if column is None: column = getattr(self.__scheme, field_name, None) if isinstance(column, InstrumentedAttribute): column = column.prop if isinstance(column, RelationshipProperty): Mixer = type(self) Model = column.mapper.class_ mixer = Mixer( Model, mixer=self.__mixer, fake=self.__fake, factory=self.__factory) return partial(mixer.blend, **kwargs) ftype = type(column.type) # augmented types created with TypeDecorator # don't directly inherit from the base types if TypeDecorator in ftype.__bases__: ftype = ftype.impl stype = self.__factory.cls_to_simple(ftype) if stype is str: fab = super(TypeMixer, self).make_fabric( stype, field_name=field_name, fake=fake, kwargs=kwargs) return lambda: fab()[:column.type.length] if ftype is Enum: return partial(faker.random_element, column.type.enums) return super(TypeMixer, self).make_fabric( stype, field_name=field_name, fake=fake, kwargs=kwargs) def guard(self, *args, **kwargs): """ Look objects in database. :returns: A finded object or False """ try: session = self.__mixer.params.get('session') assert session except (AttributeError, AssertionError): raise ValueError('Cannot make request to DB.') qs = session.query(self.mapper).filter(*args, **kwargs) count = qs.count() if count == 1: return qs.first() if count: return qs.all() return False def reload(self, obj): """ Reload object from database. """ try: session = self.__mixer.params.get('session') session.expire(obj) session.refresh(obj) return obj except (AttributeError, AssertionError): raise ValueError('Cannot make request to DB.') def populate_target(self, values): target = self.__scheme() for n, v in values: if isinstance(getattr(target, n, None), InstrumentedList) and not isinstance(v, list): v = [v] setattr(target, n, v) return target def __load_fields(self): """ Prepare SQLALchemyTypeMixer. Select columns and relations for data generation. """ mapper = self.__scheme._sa_class_manager.mapper relations = set() if hasattr(mapper, 'relationships'): for rel in mapper.relationships: fkeys = any(c.foreign_keys for c in rel.local_columns) if not fkeys: continue relations |= rel.local_columns yield rel.key, t.Field(rel, rel.key) for key, column in mapper.columns.items(): if column not in relations: yield key, t.Field(column, key) class Mixer(BaseMixer): """ Integration with SQLAlchemy. """ type_mixer_cls = TypeMixer def __init__(self, session=None, commit=True, **params): """Initialize the SQLAlchemy Mixer. :param fake: (True) Generate fake data instead of random data. :param session: SQLAlchemy session. Using for commits. :param commit: (True) Commit instance to session after creation. """ super(Mixer, self).__init__(**params) self.params['session'] = session self.params['commit'] = bool(session) and commit def postprocess(self, target): """ Save objects in db. :return value: A generated value """ if self.params.get('commit'): session = self.params.get('session') if not session: LOGGER.warning("'commit' set true but session not initialized.") else: session.add(target) session.commit() return target # Default mixer mixer = Mixer() # pylama:ignore=E1120,E0611