struct.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from __future__ import absolute_import
  2. from io import BytesIO
  3. from .abstract import AbstractType
  4. from .types import Schema
  5. from ..util import WeakMethod
  6. class Struct(AbstractType):
  7. SCHEMA = Schema()
  8. def __init__(self, *args, **kwargs):
  9. if len(args) == len(self.SCHEMA.fields):
  10. for i, name in enumerate(self.SCHEMA.names):
  11. self.__dict__[name] = args[i]
  12. elif len(args) > 0:
  13. raise ValueError('Args must be empty or mirror schema')
  14. else:
  15. for name in self.SCHEMA.names:
  16. self.__dict__[name] = kwargs.pop(name, None)
  17. if kwargs:
  18. raise ValueError('Keyword(s) not in schema %s: %s'
  19. % (list(self.SCHEMA.names),
  20. ', '.join(kwargs.keys())))
  21. # overloading encode() to support both class and instance
  22. # Without WeakMethod() this creates circular ref, which
  23. # causes instances to "leak" to garbage
  24. self.encode = WeakMethod(self._encode_self)
  25. @classmethod
  26. def encode(cls, item): # pylint: disable=E0202
  27. bits = []
  28. for i, field in enumerate(cls.SCHEMA.fields):
  29. bits.append(field.encode(item[i]))
  30. return b''.join(bits)
  31. def _encode_self(self):
  32. return self.SCHEMA.encode(
  33. [self.__dict__[name] for name in self.SCHEMA.names]
  34. )
  35. @classmethod
  36. def decode(cls, data):
  37. if isinstance(data, bytes):
  38. data = BytesIO(data)
  39. return cls(*[field.decode(data) for field in cls.SCHEMA.fields])
  40. def __repr__(self):
  41. key_vals = []
  42. for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields):
  43. key_vals.append('%s=%s' % (name, field.repr(self.__dict__[name])))
  44. return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')'
  45. def __hash__(self):
  46. return hash(self.encode())
  47. def __eq__(self, other):
  48. if self.SCHEMA != other.SCHEMA:
  49. return False
  50. for attr in self.SCHEMA.names:
  51. if self.__dict__[attr] != other.__dict__[attr]:
  52. return False
  53. return True
  54. """
  55. class MetaStruct(type):
  56. def __new__(cls, clsname, bases, dct):
  57. nt = namedtuple(clsname, [name for (name, _) in dct['SCHEMA']])
  58. bases = tuple([Struct, nt] + list(bases))
  59. return super(MetaStruct, cls).__new__(cls, clsname, bases, dct)
  60. """