resumable.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. # -*- coding: utf-8 -*-
  2. """
  3. oss2.resumable
  4. ~~~~~~~~~~~~~~
  5. 该模块包含了断点续传相关的函数和类。
  6. """
  7. import os
  8. from . import utils
  9. from .utils import b64encode_as_string, b64decode_from_string
  10. from . import iterators
  11. from . import exceptions
  12. from . import defaults
  13. from . import http
  14. from . import models
  15. from .crypto_bucket import CryptoBucket
  16. from . import Bucket
  17. from .iterators import PartIterator
  18. from .models import PartInfo
  19. from .compat import json, stringify, to_unicode, to_string
  20. from .task_queue import TaskQueue
  21. from .headers import *
  22. import functools
  23. import threading
  24. import random
  25. import string
  26. import logging
  27. logger = logging.getLogger(__name__)
  28. def resumable_upload(bucket, key, filename,
  29. store=None,
  30. headers=None,
  31. multipart_threshold=None,
  32. part_size=None,
  33. progress_callback=None,
  34. num_threads=None,
  35. params=None):
  36. """断点上传本地文件。
  37. 实现中采用分片上传方式上传本地文件,缺省的并发数是 `oss2.defaults.multipart_num_threads` ,并且在
  38. 本地磁盘保存已经上传的分片信息。如果因为某种原因上传被中断,下次上传同样的文件,即源文件和目标文件路径都
  39. 一样,就只会上传缺失的分片。
  40. 缺省条件下,该函数会在用户 `HOME` 目录下保存断点续传的信息。当待上传的本地文件没有发生变化,
  41. 且目标文件名没有变化时,会根据本地保存的信息,从断点开始上传。
  42. 使用该函数应注意如下细节:
  43. #. 如果使用CryptoBucket,函数会退化为普通上传
  44. :param bucket: :class:`Bucket <oss2.Bucket>` 或者 ::class:`CryptoBucket <oss2.CryptoBucket>` 对象
  45. :param key: 上传到用户空间的文件名
  46. :param filename: 待上传本地文件名
  47. :param store: 用来保存断点信息的持久存储,参见 :class:`ResumableStore` 的接口。如不指定,则使用 `ResumableStore` 。
  48. :param headers: HTTP头部
  49. # 调用外部函数put_object 或 init_multipart_upload传递完整headers
  50. # 调用外部函数uplpad_part目前只传递OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT
  51. # 调用外部函数complete_multipart_upload目前只传递OSS_REQUEST_PAYER, OSS_OBJECT_ACL
  52. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  53. :param multipart_threshold: 文件长度大于该值时,则用分片上传。
  54. :param part_size: 指定分片上传的每个分片的大小。如不指定,则自动计算。
  55. :param progress_callback: 上传进度回调函数。参见 :ref:`progress_callback` 。
  56. :param num_threads: 并发上传的线程数,如不指定则使用 `oss2.defaults.multipart_num_threads` 。
  57. :param params: HTTP请求参数
  58. # 只有'sequential'这个参数才会被传递到外部函数init_multipart_upload中。
  59. # 其他参数视为无效参数不会往外部函数传递。
  60. :type params: dict
  61. """
  62. logger.debug("Start to resumable upload, bucket: {0}, key: {1}, filename: {2}, headers: {3}, "
  63. "multipart_threshold: {4}, part_size: {5}, num_threads: {6}".format(bucket.bucket_name, to_string(key),
  64. filename, headers, multipart_threshold,
  65. part_size, num_threads))
  66. size = os.path.getsize(filename)
  67. multipart_threshold = defaults.get(multipart_threshold, defaults.multipart_threshold)
  68. logger.debug("The size of file to upload is: {0}, multipart_threshold: {1}".format(size, multipart_threshold))
  69. if size >= multipart_threshold:
  70. uploader = _ResumableUploader(bucket, key, filename, size, store,
  71. part_size=part_size,
  72. headers=headers,
  73. progress_callback=progress_callback,
  74. num_threads=num_threads,
  75. params=params)
  76. result = uploader.upload()
  77. else:
  78. with open(to_unicode(filename), 'rb') as f:
  79. result = bucket.put_object(key, f, headers=headers, progress_callback=progress_callback)
  80. return result
  81. def resumable_download(bucket, key, filename,
  82. multiget_threshold=None,
  83. part_size=None,
  84. progress_callback=None,
  85. num_threads=None,
  86. store=None,
  87. params=None,
  88. headers=None):
  89. """断点下载。
  90. 实现的方法是:
  91. #. 在本地创建一个临时文件,文件名由原始文件名加上一个随机的后缀组成;
  92. #. 通过指定请求的 `Range` 头按照范围并发读取OSS文件,并写入到临时文件里对应的位置;
  93. #. 全部完成之后,把临时文件重命名为目标文件 (即 `filename` )
  94. 在上述过程中,断点信息,即已经完成的范围,会保存在磁盘上。因为某种原因下载中断,后续如果下载
  95. 同样的文件,也就是源文件和目标文件一样,就会先读取断点信息,然后只下载缺失的部分。
  96. 缺省设置下,断点信息保存在 `HOME` 目录的一个子目录下。可以通过 `store` 参数更改保存位置。
  97. 使用该函数应注意如下细节:
  98. #. 对同样的源文件、目标文件,避免多个程序(线程)同时调用该函数。因为断点信息会在磁盘上互相覆盖,或临时文件名会冲突。
  99. #. 避免使用太小的范围(分片),即 `part_size` 不宜过小,建议大于或等于 `oss2.defaults.multiget_part_size` 。
  100. #. 如果目标文件已经存在,那么该函数会覆盖此文件。
  101. #. 如果使用CryptoBucket,函数会退化为普通下载
  102. :param bucket: :class:`Bucket <oss2.Bucket>` 或者 ::class:`CryptoBucket <oss2.CryptoBucket>` 对象
  103. :param str key: 待下载的远程文件名。
  104. :param str filename: 本地的目标文件名。
  105. :param int multiget_threshold: 文件长度大于该值时,则使用断点下载。
  106. :param int part_size: 指定期望的分片大小,即每个请求获得的字节数,实际的分片大小可能有所不同。
  107. :param progress_callback: 下载进度回调函数。参见 :ref:`progress_callback` 。
  108. :param num_threads: 并发下载的线程数,如不指定则使用 `oss2.defaults.multiget_num_threads` 。
  109. :param store: 用来保存断点信息的持久存储,可以指定断点信息所在的目录。
  110. :type store: `ResumableDownloadStore`
  111. :param dict params: 指定下载参数,可以传入versionId下载指定版本文件
  112. :param headers: HTTP头部,
  113. # 调用外部函数head_object目前只传递OSS_REQUEST_PAYER
  114. # 调用外部函数get_object_to_file, get_object目前需要向下传递的值有OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT
  115. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  116. :raises: 如果OSS文件不存在,则抛出 :class:`NotFound <oss2.exceptions.NotFound>` ;也有可能抛出其他因下载文件而产生的异常。
  117. """
  118. logger.debug("Start to resumable download, bucket: {0}, key: {1}, filename: {2}, multiget_threshold: {3}, "
  119. "part_size: {4}, num_threads: {5}".format(bucket.bucket_name, to_string(key), filename,
  120. multiget_threshold, part_size, num_threads))
  121. multiget_threshold = defaults.get(multiget_threshold, defaults.multiget_threshold)
  122. valid_headers = _populate_valid_headers(headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  123. result = bucket.head_object(key, params=params, headers=valid_headers)
  124. logger.debug("The size of object to download is: {0}, multiget_threshold: {1}".format(result.content_length,
  125. multiget_threshold))
  126. if result.content_length >= multiget_threshold:
  127. downloader = _ResumableDownloader(bucket, key, filename, _ObjectInfo.make(result), part_size=part_size,
  128. progress_callback=progress_callback, num_threads=num_threads, store=store,
  129. params=params, headers=valid_headers)
  130. downloader.download(result.server_crc)
  131. else:
  132. bucket.get_object_to_file(key, filename, progress_callback=progress_callback, params=params,
  133. headers=valid_headers)
  134. _MAX_MULTIGET_PART_COUNT = 100
  135. def determine_part_size(total_size,
  136. preferred_size=None):
  137. """确定分片上传是分片的大小。
  138. :param int total_size: 总共需要上传的长度
  139. :param int preferred_size: 用户期望的分片大小。如果不指定则采用defaults.part_size
  140. :return: 分片大小
  141. """
  142. if not preferred_size:
  143. preferred_size = defaults.part_size
  144. return _determine_part_size_internal(total_size, preferred_size, defaults.max_part_count)
  145. def _determine_part_size_internal(total_size, preferred_size, max_count):
  146. if total_size < preferred_size:
  147. return total_size
  148. while preferred_size * max_count < total_size or preferred_size < defaults.min_part_size:
  149. preferred_size = preferred_size * 2
  150. return preferred_size
  151. def _split_to_parts(total_size, part_size):
  152. parts = []
  153. num_parts = utils.how_many(total_size, part_size)
  154. for i in range(num_parts):
  155. if i == num_parts - 1:
  156. start = i * part_size
  157. end = total_size
  158. else:
  159. start = i * part_size
  160. end = part_size + start
  161. parts.append(_PartToProcess(i + 1, start, end))
  162. return parts
  163. def _populate_valid_headers(headers=None, valid_keys=None):
  164. """构建只包含有效keys的http header
  165. :param headers: 需要过滤的header
  166. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  167. :param valid_keys: 有效的关键key列表
  168. :type valid_keys: list
  169. :return: 只包含有效keys的http header, type: oss2.CaseInsensitiveDict
  170. """
  171. if headers is None or valid_keys is None:
  172. return None
  173. headers = http.CaseInsensitiveDict(headers)
  174. valid_headers = http.CaseInsensitiveDict()
  175. for key in valid_keys:
  176. if headers.get(key) is not None:
  177. valid_headers[key] = headers[key]
  178. if len(valid_headers) == 0:
  179. valid_headers = None
  180. return valid_headers
  181. def _filter_invalid_headers(headers=None, invalid_keys=None):
  182. """过滤无效keys的http header
  183. :param headers: 需要过滤的header
  184. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  185. :param invalid_keys: 无效的关键key列表
  186. :type invalid_keys: list
  187. :return: 过滤无效header之后的http headers, type: oss2.CaseInsensitiveDict
  188. """
  189. if headers is None or invalid_keys is None:
  190. return None
  191. headers = http.CaseInsensitiveDict(headers)
  192. valid_headers = headers.copy()
  193. for key in invalid_keys:
  194. if valid_headers.get(key) is not None:
  195. valid_headers.pop(key)
  196. if len(valid_headers) == 0:
  197. valid_headers = None
  198. return valid_headers
  199. def _populate_valid_params(params=None, valid_keys=None):
  200. """构建只包含有效keys的params
  201. :param params: 需要过滤的params
  202. :type params: dict
  203. :param valid_keys: 有效的关键key列表
  204. :type valid_keys: list
  205. :return: 只包含有效keys的params
  206. """
  207. if params is None or valid_keys is None:
  208. return None
  209. valid_params = dict()
  210. for key in valid_keys:
  211. if params.get(key) is not None:
  212. valid_params[key] = params[key]
  213. if len(valid_params) == 0:
  214. valid_params = None
  215. return valid_params
  216. class _ResumableOperation(object):
  217. def __init__(self, bucket, key, filename, size, store,
  218. progress_callback=None, versionid=None):
  219. self.bucket = bucket
  220. self.key = to_string(key)
  221. self.filename = filename
  222. self.size = size
  223. self._abspath = os.path.abspath(filename)
  224. self.__store = store
  225. if versionid is None:
  226. self.__record_key = self.__store.make_store_key(bucket.bucket_name, self.key, self._abspath)
  227. else:
  228. self.__record_key = self.__store.make_store_key(bucket.bucket_name, self.key, self._abspath, versionid)
  229. logger.debug("Init _ResumableOperation, record_key: {0}".format(self.__record_key))
  230. # protect self.__progress_callback
  231. self.__plock = threading.Lock()
  232. self.__progress_callback = progress_callback
  233. def _del_record(self):
  234. self.__store.delete(self.__record_key)
  235. def _put_record(self, record):
  236. self.__store.put(self.__record_key, record)
  237. def _get_record(self):
  238. return self.__store.get(self.__record_key)
  239. def _report_progress(self, consumed_size):
  240. if self.__progress_callback:
  241. with self.__plock:
  242. self.__progress_callback(consumed_size, self.size)
  243. class _ObjectInfo(object):
  244. def __init__(self):
  245. self.size = None
  246. self.etag = None
  247. self.mtime = None
  248. @staticmethod
  249. def make(head_object_result):
  250. objectInfo = _ObjectInfo()
  251. objectInfo.size = head_object_result.content_length
  252. objectInfo.etag = head_object_result.etag
  253. objectInfo.mtime = head_object_result.last_modified
  254. return objectInfo
  255. class _ResumableDownloader(_ResumableOperation):
  256. def __init__(self, bucket, key, filename, objectInfo,
  257. part_size=None,
  258. store=None,
  259. progress_callback=None,
  260. num_threads=None,
  261. params=None,
  262. headers=None):
  263. versionid = None
  264. if params is not None and params.get('versionId') is not None:
  265. versionid = params.get('versionId')
  266. super(_ResumableDownloader, self).__init__(bucket, key, filename, objectInfo.size,
  267. store or ResumableDownloadStore(),
  268. progress_callback=progress_callback,
  269. versionid=versionid)
  270. self.objectInfo = objectInfo
  271. self.__op = 'ResumableDownload'
  272. self.__part_size = defaults.get(part_size, defaults.multiget_part_size)
  273. self.__part_size = _determine_part_size_internal(self.size, self.__part_size, _MAX_MULTIGET_PART_COUNT)
  274. self.__tmp_file = None
  275. self.__num_threads = defaults.get(num_threads, defaults.multiget_num_threads)
  276. self.__finished_parts = None
  277. self.__finished_size = None
  278. self.__params = params
  279. self.__headers = headers
  280. # protect record
  281. self.__lock = threading.Lock()
  282. self.__record = None
  283. logger.debug("Init _ResumableDownloader, bucket: {0}, key: {1}, part_size: {2}, num_thread: {3}".format(
  284. bucket.bucket_name, to_string(key), self.__part_size, self.__num_threads))
  285. def download(self, server_crc = None):
  286. self.__load_record()
  287. parts_to_download = self.__get_parts_to_download()
  288. logger.debug("Parts need to download: {0}".format(parts_to_download))
  289. # create tmp file if it is does not exist
  290. open(self.__tmp_file, 'a').close()
  291. q = TaskQueue(functools.partial(self.__producer, parts_to_download=parts_to_download),
  292. [self.__consumer] * self.__num_threads)
  293. q.run()
  294. if self.bucket.enable_crc:
  295. parts = sorted(self.__finished_parts, key=lambda p: p.part_number)
  296. object_crc = utils.calc_obj_crc_from_parts(parts)
  297. utils.check_crc('resume download', object_crc, server_crc, None)
  298. utils.force_rename(self.__tmp_file, self.filename)
  299. self._report_progress(self.size)
  300. self._del_record()
  301. def __producer(self, q, parts_to_download=None):
  302. for part in parts_to_download:
  303. q.put(part)
  304. def __consumer(self, q):
  305. while q.ok():
  306. part = q.get()
  307. if part is None:
  308. break
  309. self.__download_part(part)
  310. def __download_part(self, part):
  311. self._report_progress(self.__finished_size)
  312. with open(self.__tmp_file, 'rb+') as f:
  313. f.seek(part.start, os.SEEK_SET)
  314. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  315. if headers is None:
  316. headers = http.CaseInsensitiveDict()
  317. headers[IF_MATCH] = self.objectInfo.etag
  318. headers[IF_UNMODIFIED_SINCE] = utils.http_date(self.objectInfo.mtime)
  319. result = self.bucket.get_object(self.key, byte_range=(part.start, part.end - 1), headers=headers, params=self.__params)
  320. utils.copyfileobj_and_verify(result, f, part.end - part.start, request_id=result.request_id)
  321. part.part_crc = result.client_crc
  322. logger.debug("down part success, add part info to record, part_number: {0}, start: {1}, end: {2}".format(
  323. part.part_number, part.start, part.end))
  324. self.__finish_part(part)
  325. def __load_record(self):
  326. record = self._get_record()
  327. logger.debug("Load record return {0}".format(record))
  328. if record and not self.__is_record_sane(record):
  329. logger.warn("The content of record is invalid, delete the record")
  330. self._del_record()
  331. record = None
  332. if record and not os.path.exists(self.filename + record['tmp_suffix']):
  333. logger.warn("Temp file: {0} does not exist, delete the record".format(
  334. self.filename + record['tmp_suffix']))
  335. self._del_record()
  336. record = None
  337. if record and self.__is_remote_changed(record):
  338. logger.warn("Object: {0} has been overwritten,delete the record and tmp file".format(self.key))
  339. utils.silently_remove(self.filename + record['tmp_suffix'])
  340. self._del_record()
  341. record = None
  342. if not record:
  343. record = {'op_type': self.__op, 'bucket': self.bucket.bucket_name, 'key': self.key,
  344. 'size': self.objectInfo.size, 'mtime': self.objectInfo.mtime, 'etag': self.objectInfo.etag,
  345. 'part_size': self.__part_size, 'file_path': self._abspath, 'tmp_suffix': self.__gen_tmp_suffix(),
  346. 'parts': []}
  347. logger.debug('Add new record, bucket: {0}, key: {1}, part_size: {2}'.format(
  348. self.bucket.bucket_name, self.key, self.__part_size))
  349. self._put_record(record)
  350. self.__tmp_file = self.filename + record['tmp_suffix']
  351. self.__part_size = record['part_size']
  352. self.__finished_parts = list(
  353. _PartToProcess(p['part_number'], p['start'], p['end'], p['part_crc']) for p in record['parts'])
  354. self.__finished_size = sum(p.size for p in self.__finished_parts)
  355. self.__record = record
  356. def __get_parts_to_download(self):
  357. assert self.__record
  358. all_set = set(_split_to_parts(self.size, self.__part_size))
  359. finished_set = set(self.__finished_parts)
  360. return sorted(list(all_set - finished_set), key=lambda p: p.part_number)
  361. def __is_record_sane(self, record):
  362. try:
  363. if record['op_type'] != self.__op:
  364. logger.error('op_type invalid, op_type in record:{0} is invalid'.format(record['op_type']))
  365. return False
  366. for key in ('etag', 'tmp_suffix', 'file_path', 'bucket', 'key'):
  367. if not isinstance(record[key], str):
  368. logger.error('{0} is not a string: {1}'.format(key, record[key]))
  369. return False
  370. for key in ('part_size', 'size', 'mtime'):
  371. if not isinstance(record[key], int):
  372. logger.error('{0} is not an integer: {1}'.format(key, record[key]))
  373. return False
  374. if not isinstance(record['parts'], list):
  375. logger.error('{0} is not a list: {1}'.format(key, record[key]))
  376. return False
  377. except KeyError as e:
  378. logger.error('Key not found: {0}'.format(e.args))
  379. return False
  380. return True
  381. def __is_remote_changed(self, record):
  382. return (record['mtime'] != self.objectInfo.mtime or
  383. record['size'] != self.objectInfo.size or
  384. record['etag'] != self.objectInfo.etag)
  385. def __finish_part(self, part):
  386. with self.__lock:
  387. self.__finished_parts.append(part)
  388. self.__finished_size += part.size
  389. self.__record['parts'].append({'part_number': part.part_number,
  390. 'start': part.start,
  391. 'end': part.end,
  392. 'part_crc': part.part_crc})
  393. self._put_record(self.__record)
  394. def __gen_tmp_suffix(self):
  395. return '.tmp-' + ''.join(random.choice(string.ascii_lowercase) for i in range(12))
  396. class _ResumableUploader(_ResumableOperation):
  397. """以断点续传方式上传文件。
  398. :param bucket: :class:`Bucket <oss2.Bucket>` 对象
  399. :param key: 文件名
  400. :param filename: 待上传的文件名
  401. :param size: 文件总长度
  402. :param store: 用来保存进度的持久化存储
  403. :param headers: 传给 `init_multipart_upload` 的HTTP头部
  404. :param part_size: 分片大小。优先使用用户提供的值。如果用户没有指定,那么对于新上传,计算出一个合理值;对于老的上传,采用第一个
  405. 分片的大小。
  406. :param progress_callback: 上传进度回调函数。参见 :ref:`progress_callback` 。
  407. """
  408. def __init__(self, bucket, key, filename, size,
  409. store=None,
  410. headers=None,
  411. part_size=None,
  412. progress_callback=None,
  413. num_threads=None,
  414. params=None):
  415. super(_ResumableUploader, self).__init__(bucket, key, filename, size,
  416. store or ResumableStore(),
  417. progress_callback=progress_callback)
  418. self.__op = 'ResumableUpload'
  419. self.__headers = headers
  420. self.__part_size = defaults.get(part_size, defaults.part_size)
  421. self.__mtime = os.path.getmtime(filename)
  422. self.__num_threads = defaults.get(num_threads, defaults.multipart_num_threads)
  423. self.__upload_id = None
  424. self.__params = params
  425. # protect below fields
  426. self.__lock = threading.Lock()
  427. self.__record = None
  428. self.__finished_size = 0
  429. self.__finished_parts = None
  430. self.__encryption = False
  431. self.__record_upload_context = False
  432. self.__upload_context = None
  433. if isinstance(self.bucket, CryptoBucket):
  434. self.__encryption = True
  435. self.__record_upload_context = True
  436. logger.debug("Init _ResumableUploader, bucket: {0}, key: {1}, part_size: {2}, num_thread: {3}".format(
  437. bucket.bucket_name, to_string(key), self.__part_size, self.__num_threads))
  438. def upload(self):
  439. self.__load_record()
  440. parts_to_upload = self.__get_parts_to_upload(self.__finished_parts)
  441. parts_to_upload = sorted(parts_to_upload, key=lambda p: p.part_number)
  442. logger.debug("Parts need to upload: {0}".format(parts_to_upload))
  443. q = TaskQueue(functools.partial(self.__producer, parts_to_upload=parts_to_upload),
  444. [self.__consumer] * self.__num_threads)
  445. q.run()
  446. self._report_progress(self.size)
  447. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_OBJECT_ACL])
  448. result = self.bucket.complete_multipart_upload(self.key, self.__upload_id, self.__finished_parts, headers=headers)
  449. self._del_record()
  450. return result
  451. def __producer(self, q, parts_to_upload=None):
  452. for part in parts_to_upload:
  453. q.put(part)
  454. def __consumer(self, q):
  455. while True:
  456. part = q.get()
  457. if part is None:
  458. break
  459. self.__upload_part(part)
  460. def __upload_part(self, part):
  461. with open(to_unicode(self.filename), 'rb') as f:
  462. self._report_progress(self.__finished_size)
  463. f.seek(part.start, os.SEEK_SET)
  464. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  465. if self.__encryption:
  466. result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
  467. utils.SizedFileAdapter(f, part.size), headers=headers,
  468. upload_context=self.__upload_context)
  469. else:
  470. result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
  471. utils.SizedFileAdapter(f, part.size), headers=headers)
  472. logger.debug("Upload part success, add part info to record, part_number: {0}, etag: {1}, size: {2}".format(
  473. part.part_number, result.etag, part.size))
  474. self.__finish_part(PartInfo(part.part_number, result.etag, size=part.size, part_crc=result.crc))
  475. def __finish_part(self, part_info):
  476. with self.__lock:
  477. self.__finished_parts.append(part_info)
  478. self.__finished_size += part_info.size
  479. def __load_record(self):
  480. record = self._get_record()
  481. logger.debug("Load record return {0}".format(record))
  482. if record and not self.__is_record_sane(record):
  483. logger.warn("The content of record is invalid, delete the record")
  484. self._del_record()
  485. record = None
  486. if record and self.__file_changed(record):
  487. logger.warn("File: {0} has been changed, delete the record".format(self.filename))
  488. self._del_record()
  489. record = None
  490. if record and not self.__upload_exists(record['upload_id']):
  491. logger.warn('Multipart upload: {0} does not exist, delete the record'.format(record['upload_id']))
  492. self._del_record()
  493. record = None
  494. if not record:
  495. params = _populate_valid_params(self.__params, [Bucket.SEQUENTIAL])
  496. part_size = determine_part_size(self.size, self.__part_size)
  497. logger.debug("Upload File size: {0}, User-specify part_size: {1}, Calculated part_size: {2}".format(
  498. self.size, self.__part_size, part_size))
  499. if self.__encryption:
  500. upload_context = models.MultipartUploadCryptoContext(self.size, part_size)
  501. upload_id = self.bucket.init_multipart_upload(self.key, self.__headers, params,
  502. upload_context).upload_id
  503. if self.__record_upload_context:
  504. material = upload_context.content_crypto_material
  505. material_record = {'wrap_alg': material.wrap_alg, 'cek_alg': material.cek_alg,
  506. 'encrypted_key': b64encode_as_string(material.encrypted_key),
  507. 'encrypted_iv': b64encode_as_string(material.encrypted_iv),
  508. 'mat_desc': material.mat_desc}
  509. else:
  510. upload_id = self.bucket.init_multipart_upload(self.key, self.__headers, params).upload_id
  511. record = {'op_type': self.__op, 'upload_id': upload_id, 'file_path': self._abspath, 'size': self.size,
  512. 'mtime': self.__mtime, 'bucket': self.bucket.bucket_name, 'key': self.key, 'part_size': part_size}
  513. if self.__record_upload_context:
  514. record['content_crypto_material'] = material_record
  515. logger.debug('Add new record, bucket: {0}, key: {1}, upload_id: {2}, part_size: {3}'.format(
  516. self.bucket.bucket_name, self.key, upload_id, part_size))
  517. self._put_record(record)
  518. self.__record = record
  519. self.__part_size = self.__record['part_size']
  520. self.__upload_id = self.__record['upload_id']
  521. if self.__record_upload_context:
  522. if 'content_crypto_material' in self.__record:
  523. material_record = self.__record['content_crypto_material']
  524. wrap_alg = material_record['wrap_alg']
  525. cek_alg = material_record['cek_alg']
  526. if cek_alg != self.bucket.crypto_provider.cipher.alg or wrap_alg != self.bucket.crypto_provider.wrap_alg:
  527. err_msg = 'Envelope or data encryption/decryption algorithm is inconsistent'
  528. raise exceptions.InconsistentError(err_msg, self)
  529. content_crypto_material = models.ContentCryptoMaterial(self.bucket.crypto_provider.cipher,
  530. material_record['wrap_alg'],
  531. b64decode_from_string(
  532. material_record['encrypted_key']),
  533. b64decode_from_string(
  534. material_record['encrypted_iv']),
  535. material_record['mat_desc'])
  536. self.__upload_context = models.MultipartUploadCryptoContext(self.size, self.__part_size,
  537. content_crypto_material)
  538. else:
  539. err_msg = 'If record_upload_context flag is true, content_crypto_material must in the the record'
  540. raise exceptions.InconsistentError(err_msg, self)
  541. else:
  542. if 'content_crypto_material' in self.__record:
  543. err_msg = 'content_crypto_material must in the the record, but record_upload_context flat is false'
  544. raise exceptions.InvalidEncryptionRequest(err_msg, self)
  545. self.__finished_parts = self.__get_finished_parts()
  546. self.__finished_size = sum(p.size for p in self.__finished_parts)
  547. def __get_finished_parts(self):
  548. parts = []
  549. valid_headers = _filter_invalid_headers(self.__headers,
  550. [OSS_SERVER_SIDE_ENCRYPTION, OSS_SERVER_SIDE_DATA_ENCRYPTION])
  551. for part in PartIterator(self.bucket, self.key, self.__upload_id, headers=valid_headers):
  552. parts.append(part)
  553. return parts
  554. def __upload_exists(self, upload_id):
  555. try:
  556. valid_headers = _filter_invalid_headers(self.__headers,
  557. [OSS_SERVER_SIDE_ENCRYPTION, OSS_SERVER_SIDE_DATA_ENCRYPTION])
  558. list(iterators.PartIterator(self.bucket, self.key, upload_id, '0', max_parts=1, headers=valid_headers))
  559. except exceptions.NoSuchUpload:
  560. return False
  561. else:
  562. return True
  563. def __file_changed(self, record):
  564. return record['mtime'] != self.__mtime or record['size'] != self.size
  565. def __get_parts_to_upload(self, parts_uploaded):
  566. all_parts = _split_to_parts(self.size, self.__part_size)
  567. if not parts_uploaded:
  568. return all_parts
  569. all_parts_map = dict((p.part_number, p) for p in all_parts)
  570. for uploaded in parts_uploaded:
  571. if uploaded.part_number in all_parts_map:
  572. del all_parts_map[uploaded.part_number]
  573. return all_parts_map.values()
  574. def __is_record_sane(self, record):
  575. try:
  576. if record['op_type'] != self.__op:
  577. logger.error('op_type invalid, op_type in record:{0} is invalid'.format(record['op_type']))
  578. return False
  579. for key in ('upload_id', 'file_path', 'bucket', 'key'):
  580. if not isinstance(record[key], str):
  581. logger.error('Type Error, {0} in record is not a string type: {1}'.format(key, record[key]))
  582. return False
  583. for key in ('size', 'part_size'):
  584. if not isinstance(record[key], int):
  585. logger.error('Type Error, {0} in record is not an integer type: {1}'.format(key, record[key]))
  586. return False
  587. if not isinstance(record['mtime'], int) and not isinstance(record['mtime'], float):
  588. logger.error(
  589. 'Type Error, mtime in record is not a float or an integer type: {0}'.format(record['mtime']))
  590. return False
  591. except KeyError as e:
  592. logger.error('Key not found: {0}'.format(e.args))
  593. return False
  594. return True
  595. _UPLOAD_TEMP_DIR = '.py-oss-upload'
  596. _DOWNLOAD_TEMP_DIR = '.py-oss-download'
  597. class _ResumableStoreBase(object):
  598. def __init__(self, root, dir):
  599. logger.debug("Init ResumableStoreBase, root path: {0}, temp dir: {1}".format(root, dir))
  600. self.dir = os.path.join(root, dir)
  601. if os.path.isdir(self.dir):
  602. return
  603. utils.makedir_p(self.dir)
  604. def get(self, key):
  605. pathname = self.__path(key)
  606. logger.debug('ResumableStoreBase: get key: {0} from file path: {1}'.format(key, pathname))
  607. if not os.path.exists(pathname):
  608. logger.debug("file {0} is not exist".format(pathname))
  609. return None
  610. # json.load()返回的总是unicode,对于Python2,我们将其转换
  611. # 为str。
  612. try:
  613. with open(to_unicode(pathname), 'r') as f:
  614. content = json.load(f)
  615. except ValueError:
  616. os.remove(pathname)
  617. return None
  618. else:
  619. return stringify(content)
  620. def put(self, key, value):
  621. pathname = self.__path(key)
  622. with open(to_unicode(pathname), 'w') as f:
  623. json.dump(value, f)
  624. logger.debug('ResumableStoreBase: put key: {0} to file path: {1}, value: {2}'.format(key, pathname, value))
  625. def delete(self, key):
  626. pathname = self.__path(key)
  627. os.remove(pathname)
  628. logger.debug('ResumableStoreBase: delete key: {0}, file path: {1}'.format(key, pathname))
  629. def __path(self, key):
  630. return os.path.join(self.dir, key)
  631. def _normalize_path(path):
  632. return os.path.normpath(os.path.normcase(path))
  633. class ResumableStore(_ResumableStoreBase):
  634. """保存断点上传断点信息的类。
  635. 每次上传的信息会保存在 `root/dir/` 下面的某个文件里。
  636. :param str root: 父目录,缺省为HOME
  637. :param str dir: 子目录,缺省为 `_UPLOAD_TEMP_DIR`
  638. """
  639. def __init__(self, root=None, dir=None):
  640. super(ResumableStore, self).__init__(root or os.path.expanduser('~'), dir or _UPLOAD_TEMP_DIR)
  641. @staticmethod
  642. def make_store_key(bucket_name, key, filename):
  643. filepath = _normalize_path(filename)
  644. oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
  645. return utils.md5_string(oss_pathname) + '--' + utils.md5_string(filepath)
  646. class ResumableDownloadStore(_ResumableStoreBase):
  647. """保存断点下载断点信息的类。
  648. 每次下载的断点信息会保存在 `root/dir/` 下面的某个文件里。
  649. :param str root: 父目录,缺省为HOME
  650. :param str dir: 子目录,缺省为 `_DOWNLOAD_TEMP_DIR`
  651. """
  652. def __init__(self, root=None, dir=None):
  653. super(ResumableDownloadStore, self).__init__(root or os.path.expanduser('~'), dir or _DOWNLOAD_TEMP_DIR)
  654. @staticmethod
  655. def make_store_key(bucket_name, key, filename, version_id=None):
  656. filepath = _normalize_path(filename)
  657. if version_id is None:
  658. oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
  659. else:
  660. oss_pathname = 'oss://{0}/{1}?versionid={2}'.format(bucket_name, key, version_id)
  661. return utils.md5_string(oss_pathname) + '--' + utils.md5_string(filepath)
  662. def make_upload_store(root=None, dir=None):
  663. return ResumableStore(root=root, dir=dir)
  664. def make_download_store(root=None, dir=None):
  665. return ResumableDownloadStore(root=root, dir=dir)
  666. class _PartToProcess(object):
  667. def __init__(self, part_number, start, end, part_crc=None):
  668. self.part_number = part_number
  669. self.start = start
  670. self.end = end
  671. self.part_crc = part_crc
  672. @property
  673. def size(self):
  674. return self.end - self.start
  675. def __hash__(self):
  676. return hash(self.__key)
  677. def __eq__(self, other):
  678. return self.__key == other.__key
  679. @property
  680. def __key(self):
  681. return self.part_number, self.start, self.end