import datetime import logging import sys import time import unittest from huey import RedisHuey from huey.api import Huey from huey.consumer import Consumer from huey.registry import registry from huey.storage import BaseStorage from huey.storage import RedisStorage def b(s): if sys.version_info[0] == 3: return s.encode('utf-8') return s class DummyHuey(Huey): def get_storage(self, **kwargs): return BaseStorage() class BrokenRedisStorage(RedisStorage): def dequeue(self): raise ValueError('broken redis dequeue') broken_redis_storage = BrokenRedisStorage() class BrokenHuey(Huey): def get_storage(self): return broken_redis_storage dummy_huey = DummyHuey() test_huey = RedisHuey('testing', blocking=False, read_timeout=0.1) # Logger used by the consumer. logger = logging.getLogger('huey.consumer') logger.addHandler(logging.NullHandler()) # Create a log handler that will track messages generated by the consumer. class CaptureLogs(logging.Handler): def __init__(self, *args, **kwargs): self.messages = [] logging.Handler.__init__(self, *args, **kwargs) def emit(self, record): self.messages.append(record.getMessage()) def __enter__(self): logger.addHandler(self) logger.setLevel(logging.INFO) return self def __exit__(self, exc_type, exc_val, exc_tb): logger.removeHandler(self) class BaseTestCase(unittest.TestCase): pass class HueyTestCase(BaseTestCase): def setUp(self): self.huey = self.get_huey() self.consumer = self.get_consumer(workers=2, scheduler_interval=10) self.events = iter(self.huey.storage) self._periodic_tasks = registry._periodic_tasks registry._periodic_tasks = self.get_periodic_tasks() self._sleep = time.sleep time.sleep = lambda x: None def tearDown(self): if self.consumer is not None: self.consumer.stop() self.huey.flush() registry._periodic_tasks = self._periodic_tasks time.sleep = self._sleep def get_huey(self): return test_huey def get_consumer(self, **kwargs): return Consumer(self.huey, **kwargs) def get_periodic_tasks(self): return [] def assertTaskEvents(self, *states): for (status, task) in states: event_data = next(self.events) self.assertEqual(event_data['status'], status) self.assertEqual(event_data['id'], task.task_id) def assertLogs(self, capture, expected): self.assertEqual(len(capture.messages), len(expected)) for (log, msg) in zip(capture.messages, expected): self.assertTrue(log.startswith(msg)) def worker(self, task, ts=None): worker = self.consumer._create_worker() ts = ts or datetime.datetime.utcnow() worker.handle_task(task, ts) return worker def scheduler(self, ts=None, periodic=False): scheduler = self.consumer._create_scheduler() ts = ts or datetime.datetime.utcnow() if periodic: scheduler._counter = scheduler._q scheduler.loop(ts) return scheduler