base.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import datetime
  2. import logging
  3. import sys
  4. import time
  5. import unittest
  6. from huey import RedisHuey
  7. from huey.api import Huey
  8. from huey.consumer import Consumer
  9. from huey.registry import registry
  10. from huey.storage import BaseStorage
  11. from huey.storage import RedisStorage
  12. def b(s):
  13. if sys.version_info[0] == 3:
  14. return s.encode('utf-8')
  15. return s
  16. class DummyHuey(Huey):
  17. def get_storage(self, **kwargs):
  18. return BaseStorage()
  19. class BrokenRedisStorage(RedisStorage):
  20. def dequeue(self):
  21. raise ValueError('broken redis dequeue')
  22. broken_redis_storage = BrokenRedisStorage()
  23. class BrokenHuey(Huey):
  24. def get_storage(self):
  25. return broken_redis_storage
  26. dummy_huey = DummyHuey()
  27. test_huey = RedisHuey('testing', blocking=False, read_timeout=0.1)
  28. # Logger used by the consumer.
  29. logger = logging.getLogger('huey.consumer')
  30. logger.addHandler(logging.NullHandler())
  31. # Create a log handler that will track messages generated by the consumer.
  32. class CaptureLogs(logging.Handler):
  33. def __init__(self, *args, **kwargs):
  34. self.messages = []
  35. logging.Handler.__init__(self, *args, **kwargs)
  36. def emit(self, record):
  37. self.messages.append(record.getMessage())
  38. def __enter__(self):
  39. logger.addHandler(self)
  40. logger.setLevel(logging.INFO)
  41. return self
  42. def __exit__(self, exc_type, exc_val, exc_tb):
  43. logger.removeHandler(self)
  44. class BaseTestCase(unittest.TestCase):
  45. pass
  46. class HueyTestCase(BaseTestCase):
  47. def setUp(self):
  48. self.huey = self.get_huey()
  49. self.consumer = self.get_consumer(workers=2, scheduler_interval=10)
  50. self.events = iter(self.huey.storage)
  51. self._periodic_tasks = registry._periodic_tasks
  52. registry._periodic_tasks = self.get_periodic_tasks()
  53. self._sleep = time.sleep
  54. time.sleep = lambda x: None
  55. def tearDown(self):
  56. if self.consumer is not None:
  57. self.consumer.stop()
  58. self.huey.flush()
  59. registry._periodic_tasks = self._periodic_tasks
  60. time.sleep = self._sleep
  61. def get_huey(self):
  62. return test_huey
  63. def get_consumer(self, **kwargs):
  64. return Consumer(self.huey, **kwargs)
  65. def get_periodic_tasks(self):
  66. return []
  67. def assertTaskEvents(self, *states):
  68. for (status, task) in states:
  69. event_data = next(self.events)
  70. self.assertEqual(event_data['status'], status)
  71. self.assertEqual(event_data['id'], task.task_id)
  72. def assertLogs(self, capture, expected):
  73. self.assertEqual(len(capture.messages), len(expected))
  74. for (log, msg) in zip(capture.messages, expected):
  75. self.assertTrue(log.startswith(msg))
  76. def worker(self, task, ts=None):
  77. worker = self.consumer._create_worker()
  78. ts = ts or datetime.datetime.utcnow()
  79. worker.handle_task(task, ts)
  80. return worker
  81. def scheduler(self, ts=None, periodic=False):
  82. scheduler = self.consumer._create_scheduler()
  83. ts = ts or datetime.datetime.utcnow()
  84. if periodic:
  85. scheduler._counter = scheduler._q
  86. scheduler.loop(ts)
  87. return scheduler