utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from __future__ import unicode_literals
  2. import datetime
  3. import decimal
  4. import hashlib
  5. import logging
  6. from time import time
  7. from django.conf import settings
  8. from django.utils.encoding import force_bytes
  9. from django.utils.timezone import utc
  10. logger = logging.getLogger('django.db.backends')
  11. class CursorWrapper(object):
  12. def __init__(self, cursor, db):
  13. self.cursor = cursor
  14. self.db = db
  15. WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
  16. def __getattr__(self, attr):
  17. cursor_attr = getattr(self.cursor, attr)
  18. if attr in CursorWrapper.WRAP_ERROR_ATTRS:
  19. return self.db.wrap_database_errors(cursor_attr)
  20. else:
  21. return cursor_attr
  22. def __iter__(self):
  23. return iter(self.cursor)
  24. def __enter__(self):
  25. return self
  26. def __exit__(self, type, value, traceback):
  27. # Ticket #17671 - Close instead of passing thru to avoid backend
  28. # specific behavior. Catch errors liberally because errors in cleanup
  29. # code aren't useful.
  30. try:
  31. self.close()
  32. except self.db.Database.Error:
  33. pass
  34. # The following methods cannot be implemented in __getattr__, because the
  35. # code must run when the method is invoked, not just when it is accessed.
  36. def callproc(self, procname, params=None):
  37. self.db.validate_no_broken_transaction()
  38. self.db.set_dirty()
  39. with self.db.wrap_database_errors:
  40. if params is None:
  41. return self.cursor.callproc(procname)
  42. else:
  43. return self.cursor.callproc(procname, params)
  44. def execute(self, sql, params=None):
  45. self.db.validate_no_broken_transaction()
  46. self.db.set_dirty()
  47. with self.db.wrap_database_errors:
  48. if params is None:
  49. return self.cursor.execute(sql)
  50. else:
  51. return self.cursor.execute(sql, params)
  52. def executemany(self, sql, param_list):
  53. self.db.validate_no_broken_transaction()
  54. self.db.set_dirty()
  55. with self.db.wrap_database_errors:
  56. return self.cursor.executemany(sql, param_list)
  57. class CursorDebugWrapper(CursorWrapper):
  58. # XXX callproc isn't instrumented at this time.
  59. def execute(self, sql, params=None):
  60. start = time()
  61. try:
  62. return super(CursorDebugWrapper, self).execute(sql, params)
  63. finally:
  64. stop = time()
  65. duration = stop - start
  66. sql = self.db.ops.last_executed_query(self.cursor, sql, params)
  67. self.db.queries.append({
  68. 'sql': sql,
  69. 'time': "%.3f" % duration,
  70. })
  71. logger.debug('(%.3f) %s; args=%s' % (duration, sql, params),
  72. extra={'duration': duration, 'sql': sql, 'params': params}
  73. )
  74. def executemany(self, sql, param_list):
  75. start = time()
  76. try:
  77. return super(CursorDebugWrapper, self).executemany(sql, param_list)
  78. finally:
  79. stop = time()
  80. duration = stop - start
  81. try:
  82. times = len(param_list)
  83. except TypeError: # param_list could be an iterator
  84. times = '?'
  85. self.db.queries.append({
  86. 'sql': '%s times: %s' % (times, sql),
  87. 'time': "%.3f" % duration,
  88. })
  89. logger.debug('(%.3f) %s; args=%s' % (duration, sql, param_list),
  90. extra={'duration': duration, 'sql': sql, 'params': param_list}
  91. )
  92. ###############################################
  93. # Converters from database (string) to Python #
  94. ###############################################
  95. def typecast_date(s):
  96. return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null
  97. def typecast_time(s): # does NOT store time zone information
  98. if not s:
  99. return None
  100. hour, minutes, seconds = s.split(':')
  101. if '.' in seconds: # check whether seconds have a fractional part
  102. seconds, microseconds = seconds.split('.')
  103. else:
  104. microseconds = '0'
  105. return datetime.time(int(hour), int(minutes), int(seconds), int(float('.' + microseconds) * 1000000))
  106. def typecast_timestamp(s): # does NOT store time zone information
  107. # "2005-07-29 15:48:00.590358-05"
  108. # "2005-07-29 09:56:00-05"
  109. if not s:
  110. return None
  111. if ' ' not in s:
  112. return typecast_date(s)
  113. d, t = s.split()
  114. # Extract timezone information, if it exists. Currently we just throw
  115. # it away, but in the future we may make use of it.
  116. if '-' in t:
  117. t, tz = t.split('-', 1)
  118. tz = '-' + tz
  119. elif '+' in t:
  120. t, tz = t.split('+', 1)
  121. tz = '+' + tz
  122. else:
  123. tz = ''
  124. dates = d.split('-')
  125. times = t.split(':')
  126. seconds = times[2]
  127. if '.' in seconds: # check whether seconds have a fractional part
  128. seconds, microseconds = seconds.split('.')
  129. else:
  130. microseconds = '0'
  131. tzinfo = utc if settings.USE_TZ else None
  132. return datetime.datetime(int(dates[0]), int(dates[1]), int(dates[2]),
  133. int(times[0]), int(times[1]), int(seconds),
  134. int((microseconds + '000000')[:6]), tzinfo)
  135. def typecast_decimal(s):
  136. if s is None or s == '':
  137. return None
  138. return decimal.Decimal(s)
  139. ###############################################
  140. # Converters from Python to database (string) #
  141. ###############################################
  142. def rev_typecast_decimal(d):
  143. if d is None:
  144. return None
  145. return str(d)
  146. def truncate_name(name, length=None, hash_len=4):
  147. """Shortens a string to a repeatable mangled version with the given length.
  148. """
  149. if length is None or len(name) <= length:
  150. return name
  151. hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
  152. return '%s%s' % (name[:length - hash_len], hsh)
  153. def format_number(value, max_digits, decimal_places):
  154. """
  155. Formats a number into a string with the requisite number of digits and
  156. decimal places.
  157. """
  158. if isinstance(value, decimal.Decimal):
  159. context = decimal.getcontext().copy()
  160. context.prec = max_digits
  161. return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
  162. else:
  163. return "%.*f" % (decimal_places, value)