test_django.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from __future__ import absolute_import
  2. import os
  3. from contextlib import contextmanager
  4. from mock import Mock, patch
  5. from celery.fixups.django import (
  6. _maybe_close_fd,
  7. fixup,
  8. DjangoFixup,
  9. )
  10. from celery.tests.case import AppCase, patch_many, patch_modules, mask_modules
  11. class test_DjangoFixup(AppCase):
  12. def test_fixup(self):
  13. with patch('celery.fixups.django.DjangoFixup') as Fixup:
  14. with patch.dict(os.environ, DJANGO_SETTINGS_MODULE=''):
  15. fixup(self.app)
  16. self.assertFalse(Fixup.called)
  17. with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'):
  18. with mask_modules('django'):
  19. with self.assertWarnsRegex(UserWarning, 'but Django is'):
  20. fixup(self.app)
  21. self.assertFalse(Fixup.called)
  22. with patch_modules('django'):
  23. fixup(self.app)
  24. self.assertTrue(Fixup.called)
  25. @contextmanager
  26. def fixup_context(self, app):
  27. with patch('celery.fixups.django.import_module') as import_module:
  28. with patch('celery.fixups.django.symbol_by_name') as symbyname:
  29. f = DjangoFixup(app)
  30. yield f, import_module, symbyname
  31. def test_maybe_close_fd(self):
  32. with patch('os.close'):
  33. _maybe_close_fd(Mock())
  34. _maybe_close_fd(object())
  35. def test_init(self):
  36. with self.fixup_context(self.app) as (f, importmod, sym):
  37. self.assertTrue(f)
  38. def se(name):
  39. if name == 'django.utils.timezone:now':
  40. raise ImportError()
  41. return Mock()
  42. sym.side_effect = se
  43. self.assertTrue(DjangoFixup(self.app)._now)
  44. def se2(name):
  45. if name == 'django.db:close_old_connections':
  46. raise ImportError()
  47. return Mock()
  48. sym.side_effect = se2
  49. self.assertIsNone(DjangoFixup(self.app)._close_old_connections)
  50. def test_install(self):
  51. self.app.conf = {'CELERY_DB_REUSE_MAX': None}
  52. self.app.loader = Mock()
  53. with self.fixup_context(self.app) as (f, _, _):
  54. with patch_many('os.getcwd', 'sys.path',
  55. 'celery.fixups.django.signals') as (cw, p, sigs):
  56. cw.return_value = '/opt/vandelay'
  57. f.install()
  58. sigs.beat_embedded_init.connect.assert_called_with(
  59. f.close_database,
  60. )
  61. sigs.worker_ready.connect.assert_called_with(f.on_worker_ready)
  62. sigs.task_prerun.connect.assert_called_with(f.on_task_prerun)
  63. sigs.task_postrun.connect.assert_called_with(f.on_task_postrun)
  64. sigs.worker_init.connect.assert_called_with(f.on_worker_init)
  65. sigs.worker_process_init.connect.assert_called_with(
  66. f.on_worker_process_init,
  67. )
  68. self.assertEqual(self.app.loader.now, f.now)
  69. self.assertEqual(self.app.loader.mail_admins, f.mail_admins)
  70. p.append.assert_called_with('/opt/vandelay')
  71. def test_now(self):
  72. with self.fixup_context(self.app) as (f, _, _):
  73. self.assertTrue(f.now(utc=True))
  74. self.assertFalse(f._now.called)
  75. self.assertTrue(f.now(utc=False))
  76. self.assertTrue(f._now.called)
  77. def test_mail_admins(self):
  78. with self.fixup_context(self.app) as (f, _, _):
  79. f.mail_admins('sub', 'body', True)
  80. f._mail_admins.assert_called_with(
  81. 'sub', 'body', fail_silently=True,
  82. )
  83. def test_on_worker_init(self):
  84. with self.fixup_context(self.app) as (f, _, _):
  85. f.close_database = Mock()
  86. f.close_cache = Mock()
  87. f.on_worker_init()
  88. f.close_database.assert_called_with()
  89. f.close_cache.assert_called_with()
  90. def test_on_worker_process_init(self):
  91. with self.fixup_context(self.app) as (f, _, _):
  92. with patch('celery.fixups.django._maybe_close_fd') as mcf:
  93. _all = f._db.connections.all = Mock()
  94. conns = _all.return_value = [
  95. Mock(), Mock(),
  96. ]
  97. conns[0].connection = None
  98. with patch.object(f, 'close_cache'):
  99. with patch.object(f, '_close_database'):
  100. f.on_worker_process_init()
  101. mcf.assert_called_with(conns[1].connection)
  102. f.close_cache.assert_called_with()
  103. f._close_database.assert_called_with()
  104. mcf.reset_mock()
  105. _all.side_effect = AttributeError()
  106. f.on_worker_process_init()
  107. mcf.assert_called_with(f._db.connection.connection)
  108. f._db.connection = None
  109. f.on_worker_process_init()
  110. def test_on_task_prerun(self):
  111. task = Mock()
  112. with self.fixup_context(self.app) as (f, _, _):
  113. task.request.is_eager = False
  114. with patch.object(f, 'close_database'):
  115. f.on_task_prerun(task)
  116. f.close_database.assert_called_with()
  117. task.request.is_eager = True
  118. with patch.object(f, 'close_database'):
  119. f.on_task_prerun(task)
  120. self.assertFalse(f.close_database.called)
  121. def test_on_task_postrun(self):
  122. task = Mock()
  123. with self.fixup_context(self.app) as (f, _, _):
  124. with patch.object(f, 'close_cache'):
  125. task.request.is_eager = False
  126. with patch.object(f, 'close_database'):
  127. f.on_task_postrun(task)
  128. self.assertTrue(f.close_database.called)
  129. self.assertTrue(f.close_cache.called)
  130. # when a task is eager, do not close connections
  131. with patch.object(f, 'close_cache'):
  132. task.request.is_eager = True
  133. with patch.object(f, 'close_database'):
  134. f.on_task_postrun(task)
  135. self.assertFalse(f.close_database.called)
  136. self.assertFalse(f.close_cache.called)
  137. def test_close_database(self):
  138. with self.fixup_context(self.app) as (f, _, _):
  139. f._close_old_connections = Mock()
  140. f.close_database()
  141. f._close_old_connections.assert_called_with()
  142. f._close_old_connections = None
  143. with patch.object(f, '_close_database') as _close:
  144. f.db_reuse_max = None
  145. f.close_database()
  146. _close.assert_called_with()
  147. _close.reset_mock()
  148. f.db_reuse_max = 10
  149. f._db_recycles = 3
  150. f.close_database()
  151. self.assertFalse(_close.called)
  152. self.assertEqual(f._db_recycles, 4)
  153. _close.reset_mock()
  154. f._db_recycles = 20
  155. f.close_database()
  156. _close.assert_called_with()
  157. self.assertEqual(f._db_recycles, 1)
  158. def test__close_database(self):
  159. with self.fixup_context(self.app) as (f, _, _):
  160. conns = f._db.connections = [Mock(), Mock(), Mock()]
  161. conns[1].close.side_effect = KeyError('already closed')
  162. f.database_errors = (KeyError, )
  163. f._close_database()
  164. conns[0].close.assert_called_with()
  165. conns[1].close.assert_called_with()
  166. conns[2].close.assert_called_with()
  167. conns[1].close.side_effect = KeyError('omg')
  168. with self.assertRaises(KeyError):
  169. f._close_database()
  170. class Object(object):
  171. pass
  172. o = Object()
  173. o.close_connection = Mock()
  174. f._db = o
  175. f._close_database()
  176. o.close_connection.assert_called_with()
  177. def test_close_cache(self):
  178. with self.fixup_context(self.app) as (f, _, _):
  179. f.close_cache()
  180. f._cache.cache.close.assert_called_with()
  181. f._cache.cache.close.side_effect = TypeError()
  182. f.close_cache()
  183. def test_on_worker_ready(self):
  184. with self.fixup_context(self.app) as (f, _, _):
  185. f._settings.DEBUG = False
  186. f.on_worker_ready()
  187. with self.assertWarnsRegex(UserWarning, r'leads to a memory leak'):
  188. f._settings.DEBUG = True
  189. f.on_worker_ready()
  190. def test_mysql_errors(self):
  191. with patch_modules('MySQLdb'):
  192. import MySQLdb as mod
  193. mod.DatabaseError = Mock()
  194. mod.InterfaceError = Mock()
  195. mod.OperationalError = Mock()
  196. with self.fixup_context(self.app) as (f, _, _):
  197. self.assertIn(mod.DatabaseError, f.database_errors)
  198. self.assertIn(mod.InterfaceError, f.database_errors)
  199. self.assertIn(mod.OperationalError, f.database_errors)
  200. with mask_modules('MySQLdb'):
  201. with self.fixup_context(self.app):
  202. pass
  203. def test_pg_errors(self):
  204. with patch_modules('psycopg2'):
  205. import psycopg2 as mod
  206. mod.DatabaseError = Mock()
  207. mod.InterfaceError = Mock()
  208. mod.OperationalError = Mock()
  209. with self.fixup_context(self.app) as (f, _, _):
  210. self.assertIn(mod.DatabaseError, f.database_errors)
  211. self.assertIn(mod.InterfaceError, f.database_errors)
  212. self.assertIn(mod.OperationalError, f.database_errors)
  213. with mask_modules('psycopg2'):
  214. with self.fixup_context(self.app):
  215. pass
  216. def test_sqlite_errors(self):
  217. with patch_modules('sqlite3'):
  218. import sqlite3 as mod
  219. mod.DatabaseError = Mock()
  220. mod.InterfaceError = Mock()
  221. mod.OperationalError = Mock()
  222. with self.fixup_context(self.app) as (f, _, _):
  223. self.assertIn(mod.DatabaseError, f.database_errors)
  224. self.assertIn(mod.InterfaceError, f.database_errors)
  225. self.assertIn(mod.OperationalError, f.database_errors)
  226. with mask_modules('sqlite3'):
  227. with self.fixup_context(self.app):
  228. pass
  229. def test_oracle_errors(self):
  230. with patch_modules('cx_Oracle'):
  231. import cx_Oracle as mod
  232. mod.DatabaseError = Mock()
  233. mod.InterfaceError = Mock()
  234. mod.OperationalError = Mock()
  235. with self.fixup_context(self.app) as (f, _, _):
  236. self.assertIn(mod.DatabaseError, f.database_errors)
  237. self.assertIn(mod.InterfaceError, f.database_errors)
  238. self.assertIn(mod.OperationalError, f.database_errors)
  239. with mask_modules('cx_Oracle'):
  240. with self.fixup_context(self.app):
  241. pass