compression_support.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright 2018 MongoDB, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. try:
  16. import snappy
  17. _HAVE_SNAPPY = True
  18. except ImportError:
  19. # python-snappy isn't available.
  20. _HAVE_SNAPPY = False
  21. try:
  22. import zlib
  23. _HAVE_ZLIB = True
  24. except ImportError:
  25. # Python built without zlib support.
  26. _HAVE_ZLIB = False
  27. try:
  28. from zstandard import ZstdCompressor, ZstdDecompressor
  29. _HAVE_ZSTD = True
  30. except ImportError:
  31. _HAVE_ZSTD = False
  32. from pymongo.hello_compat import HelloCompat
  33. from pymongo.monitoring import _SENSITIVE_COMMANDS
  34. _SUPPORTED_COMPRESSORS = set(["snappy", "zlib", "zstd"])
  35. _NO_COMPRESSION = set([HelloCompat.CMD, HelloCompat.LEGACY_CMD])
  36. _NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
  37. def validate_compressors(dummy, value):
  38. try:
  39. # `value` is string.
  40. compressors = value.split(",")
  41. except AttributeError:
  42. # `value` is an iterable.
  43. compressors = list(value)
  44. for compressor in compressors[:]:
  45. if compressor not in _SUPPORTED_COMPRESSORS:
  46. compressors.remove(compressor)
  47. warnings.warn("Unsupported compressor: %s" % (compressor,))
  48. elif compressor == "snappy" and not _HAVE_SNAPPY:
  49. compressors.remove(compressor)
  50. warnings.warn(
  51. "Wire protocol compression with snappy is not available. "
  52. "You must install the python-snappy module for snappy support.")
  53. elif compressor == "zlib" and not _HAVE_ZLIB:
  54. compressors.remove(compressor)
  55. warnings.warn(
  56. "Wire protocol compression with zlib is not available. "
  57. "The zlib module is not available.")
  58. elif compressor == "zstd" and not _HAVE_ZSTD:
  59. compressors.remove(compressor)
  60. warnings.warn(
  61. "Wire protocol compression with zstandard is not available. "
  62. "You must install the zstandard module for zstandard support.")
  63. return compressors
  64. def validate_zlib_compression_level(option, value):
  65. try:
  66. level = int(value)
  67. except:
  68. raise TypeError("%s must be an integer, not %r." % (option, value))
  69. if level < -1 or level > 9:
  70. raise ValueError(
  71. "%s must be between -1 and 9, not %d." % (option, level))
  72. return level
  73. class CompressionSettings(object):
  74. def __init__(self, compressors, zlib_compression_level):
  75. self.compressors = compressors
  76. self.zlib_compression_level = zlib_compression_level
  77. def get_compression_context(self, compressors):
  78. if compressors:
  79. chosen = compressors[0]
  80. if chosen == "snappy":
  81. return SnappyContext()
  82. elif chosen == "zlib":
  83. return ZlibContext(self.zlib_compression_level)
  84. elif chosen == "zstd":
  85. return ZstdContext()
  86. def _zlib_no_compress(data):
  87. """Compress data with zlib level 0."""
  88. cobj = zlib.compressobj(0)
  89. return b"".join([cobj.compress(data), cobj.flush()])
  90. class SnappyContext(object):
  91. compressor_id = 1
  92. @staticmethod
  93. def compress(data):
  94. return snappy.compress(data)
  95. class ZlibContext(object):
  96. compressor_id = 2
  97. def __init__(self, level):
  98. # Jython zlib.compress doesn't support -1
  99. if level == -1:
  100. self.compress = zlib.compress
  101. # Jython zlib.compress also doesn't support 0
  102. elif level == 0:
  103. self.compress = _zlib_no_compress
  104. else:
  105. self.compress = lambda data: zlib.compress(data, level)
  106. class ZstdContext(object):
  107. compressor_id = 3
  108. @staticmethod
  109. def compress(data):
  110. # ZstdCompressor is not thread safe.
  111. # TODO: Use a pool?
  112. return ZstdCompressor().compress(data)
  113. def decompress(data, compressor_id):
  114. if compressor_id == SnappyContext.compressor_id:
  115. # python-snappy doesn't support the buffer interface.
  116. # https://github.com/andrix/python-snappy/issues/65
  117. # This only matters when data is a memoryview since
  118. # id(bytes(data)) == id(data) when data is a bytes.
  119. # NOTE: bytes(memoryview) returns the memoryview repr
  120. # in Python 2.7. The right thing to do in 2.7 is call
  121. # memoryview.tobytes(), but we currently only use
  122. # memoryview in Python 3.x.
  123. return snappy.uncompress(bytes(data))
  124. elif compressor_id == ZlibContext.compressor_id:
  125. return zlib.decompress(data)
  126. elif compressor_id == ZstdContext.compressor_id:
  127. # ZstdDecompressor is not thread safe.
  128. # TODO: Use a pool?
  129. return ZstdDecompressor().decompress(data)
  130. else:
  131. raise ValueError("Unknown compressorId %d" % (compressor_id,))