roundrobin.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import absolute_import
  2. from .base import Partitioner
  3. class RoundRobinPartitioner(Partitioner):
  4. def __init__(self, partitions=None):
  5. self.partitions_iterable = CachedPartitionCycler(partitions)
  6. if partitions:
  7. self._set_partitions(partitions)
  8. else:
  9. self.partitions = None
  10. def __call__(self, key, all_partitions=None, available_partitions=None):
  11. if available_partitions:
  12. cur_partitions = available_partitions
  13. else:
  14. cur_partitions = all_partitions
  15. if not self.partitions:
  16. self._set_partitions(cur_partitions)
  17. elif cur_partitions != self.partitions_iterable.partitions and cur_partitions is not None:
  18. self._set_partitions(cur_partitions)
  19. return next(self.partitions_iterable)
  20. def _set_partitions(self, available_partitions):
  21. self.partitions = available_partitions
  22. self.partitions_iterable.set_partitions(available_partitions)
  23. def partition(self, key, all_partitions=None, available_partitions=None):
  24. return self.__call__(key, all_partitions, available_partitions)
  25. class CachedPartitionCycler(object):
  26. def __init__(self, partitions=None):
  27. self.partitions = partitions
  28. if partitions:
  29. assert type(partitions) is list
  30. self.cur_pos = None
  31. def __next__(self):
  32. return self.next()
  33. @staticmethod
  34. def _index_available(cur_pos, partitions):
  35. return cur_pos < len(partitions)
  36. def set_partitions(self, partitions):
  37. if self.cur_pos:
  38. if not self._index_available(self.cur_pos, partitions):
  39. self.cur_pos = 0
  40. self.partitions = partitions
  41. return None
  42. self.partitions = partitions
  43. next_item = self.partitions[self.cur_pos]
  44. if next_item in partitions:
  45. self.cur_pos = partitions.index(next_item)
  46. else:
  47. self.cur_pos = 0
  48. return None
  49. self.partitions = partitions
  50. def next(self):
  51. assert self.partitions is not None
  52. if self.cur_pos is None or not self._index_available(self.cur_pos, self.partitions):
  53. self.cur_pos = 1
  54. return self.partitions[0]
  55. cur_item = self.partitions[self.cur_pos]
  56. self.cur_pos += 1
  57. return cur_item