huffman.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # -*- coding: utf-8 -*-
  2. """
  3. hpack/huffman_decoder
  4. ~~~~~~~~~~~~~~~~~~~~~
  5. An implementation of a bitwise prefix tree specially built for decoding
  6. Huffman-coded content where we already know the Huffman table.
  7. """
  8. from .compat import to_byte, decode_hex
  9. from .exceptions import HPACKDecodingError
  10. def _pad_binary(bin_str, req_len=8):
  11. """
  12. Given a binary string (returned by bin()), pad it to a full byte length.
  13. """
  14. bin_str = bin_str[2:] # Strip the 0b prefix
  15. return max(0, req_len - len(bin_str)) * '0' + bin_str
  16. def _hex_to_bin_str(hex_string):
  17. """
  18. Given a Python bytestring, returns a string representing those bytes in
  19. unicode form.
  20. """
  21. unpadded_bin_string_list = (bin(to_byte(c)) for c in hex_string)
  22. padded_bin_string_list = map(_pad_binary, unpadded_bin_string_list)
  23. bitwise_message = "".join(padded_bin_string_list)
  24. return bitwise_message
  25. class HuffmanDecoder(object):
  26. """
  27. Decodes a Huffman-coded bytestream according to the Huffman table laid out
  28. in the HPACK specification.
  29. """
  30. class _Node(object):
  31. def __init__(self, data):
  32. self.data = data
  33. self.mapping = {}
  34. def __init__(self, huffman_code_list, huffman_code_list_lengths):
  35. self.root = self._Node(None)
  36. for index, (huffman_code, code_length) in enumerate(zip(huffman_code_list, huffman_code_list_lengths)):
  37. self._insert(huffman_code, code_length, index)
  38. def _insert(self, hex_number, hex_length, letter):
  39. """
  40. Inserts a Huffman code point into the tree.
  41. """
  42. hex_number = _pad_binary(bin(hex_number), hex_length)
  43. cur_node = self.root
  44. for digit in hex_number:
  45. if digit not in cur_node.mapping:
  46. cur_node.mapping[digit] = self._Node(None)
  47. cur_node = cur_node.mapping[digit]
  48. cur_node.data = letter
  49. def decode(self, encoded_string):
  50. """
  51. Decode the given Huffman coded string.
  52. """
  53. number = _hex_to_bin_str(encoded_string)
  54. cur_node = self.root
  55. decoded_message = bytearray()
  56. try:
  57. for digit in number:
  58. cur_node = cur_node.mapping[digit]
  59. if cur_node.data is not None:
  60. # If we get EOS, everything else is padding.
  61. if cur_node.data == 256:
  62. break
  63. decoded_message.append(cur_node.data)
  64. cur_node = self.root
  65. except KeyError:
  66. # We have a Huffman-coded string that doesn't match our trie. This
  67. # is pretty bad: raise a useful exception.
  68. raise HPACKDecodingError("Invalid Huffman-coded string received.")
  69. return bytes(decoded_message)
  70. class HuffmanEncoder(object):
  71. """
  72. Encodes a string according to the Huffman encoding table defined in the
  73. HPACK specification.
  74. """
  75. def __init__(self, huffman_code_list, huffman_code_list_lengths):
  76. self.huffman_code_list = huffman_code_list
  77. self.huffman_code_list_lengths = huffman_code_list_lengths
  78. def encode(self, bytes_to_encode):
  79. """
  80. Given a string of bytes, encodes them according to the HPACK Huffman
  81. specification.
  82. """
  83. # If handed the empty string, just immediately return.
  84. if not bytes_to_encode:
  85. return b''
  86. final_num = 0
  87. final_int_len = 0
  88. # Turn each byte into its huffman code. These codes aren't necessarily
  89. # octet aligned, so keep track of how far through an octet we are. To
  90. # handle this cleanly, just use a single giant integer.
  91. for char in bytes_to_encode:
  92. byte = to_byte(char)
  93. bin_int_len = self.huffman_code_list_lengths[byte]
  94. bin_int = self.huffman_code_list[byte] & (2 ** (bin_int_len + 1) - 1)
  95. final_num <<= bin_int_len
  96. final_num |= bin_int
  97. final_int_len += bin_int_len
  98. # Pad out to an octet with ones.
  99. bits_to_be_padded = (8 - (final_int_len % 8)) % 8
  100. final_num <<= bits_to_be_padded
  101. final_num |= (1 << (bits_to_be_padded)) - 1
  102. # Convert the number to hex and strip off the leading '0x' and the
  103. # trailing 'L', if present.
  104. final_num = hex(final_num)[2:].rstrip('L')
  105. # If this is odd, prepend a zero.
  106. final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num
  107. # This number should have twice as many digits as bytes. If not, we're
  108. # missing some leading zeroes. Work out how many bytes we want and how
  109. # many digits we have, then add the missing zero digits to the front.
  110. total_bytes = (final_int_len + bits_to_be_padded) // 8
  111. expected_digits = total_bytes * 2
  112. if len(final_num) != expected_digits:
  113. missing_digits = expected_digits - len(final_num)
  114. final_num = ('0' * missing_digits) + final_num
  115. return decode_hex(final_num)