123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import datetime
- import logging
- import sys
- import time
- import unittest
- from huey import RedisHuey
- from huey.api import Huey
- from huey.consumer import Consumer
- from huey.registry import registry
- from huey.storage import BaseStorage
- from huey.storage import RedisStorage
- def b(s):
- if sys.version_info[0] == 3:
- return s.encode('utf-8')
- return s
- class DummyHuey(Huey):
- def get_storage(self, **kwargs):
- return BaseStorage()
- class BrokenRedisStorage(RedisStorage):
- def dequeue(self):
- raise ValueError('broken redis dequeue')
- broken_redis_storage = BrokenRedisStorage()
- class BrokenHuey(Huey):
- def get_storage(self):
- return broken_redis_storage
- dummy_huey = DummyHuey()
- test_huey = RedisHuey('testing', blocking=False, read_timeout=0.1)
- # Logger used by the consumer.
- logger = logging.getLogger('huey.consumer')
- logger.addHandler(logging.NullHandler())
- # Create a log handler that will track messages generated by the consumer.
- class CaptureLogs(logging.Handler):
- def __init__(self, *args, **kwargs):
- self.messages = []
- logging.Handler.__init__(self, *args, **kwargs)
- def emit(self, record):
- self.messages.append(record.getMessage())
- def __enter__(self):
- logger.addHandler(self)
- logger.setLevel(logging.INFO)
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- logger.removeHandler(self)
- class BaseTestCase(unittest.TestCase):
- pass
- class HueyTestCase(BaseTestCase):
- def setUp(self):
- self.huey = self.get_huey()
- self.consumer = self.get_consumer(workers=2, scheduler_interval=10)
- self.events = iter(self.huey.storage)
- self._periodic_tasks = registry._periodic_tasks
- registry._periodic_tasks = self.get_periodic_tasks()
- self._sleep = time.sleep
- time.sleep = lambda x: None
- def tearDown(self):
- if self.consumer is not None:
- self.consumer.stop()
- self.huey.flush()
- registry._periodic_tasks = self._periodic_tasks
- time.sleep = self._sleep
- def get_huey(self):
- return test_huey
- def get_consumer(self, **kwargs):
- return Consumer(self.huey, **kwargs)
- def get_periodic_tasks(self):
- return []
- def assertTaskEvents(self, *states):
- for (status, task) in states:
- event_data = next(self.events)
- self.assertEqual(event_data['status'], status)
- self.assertEqual(event_data['id'], task.task_id)
- def assertLogs(self, capture, expected):
- self.assertEqual(len(capture.messages), len(expected))
- for (log, msg) in zip(capture.messages, expected):
- self.assertTrue(log.startswith(msg))
- def worker(self, task, ts=None):
- worker = self.consumer._create_worker()
- ts = ts or datetime.datetime.utcnow()
- worker.handle_task(task, ts)
- return worker
- def scheduler(self, ts=None, periodic=False):
- scheduler = self.consumer._create_scheduler()
- ts = ts or datetime.datetime.utcnow()
- if periodic:
- scheduler._counter = scheduler._q
- scheduler.loop(ts)
- return scheduler
|