base.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # -*- coding: utf-8 -*-
  2. """
  3. AipBase
  4. """
  5. import hmac
  6. import json
  7. import hashlib
  8. import datetime
  9. import base64
  10. import time
  11. import sys
  12. import requests
  13. requests.packages.urllib3.disable_warnings()
  14. if sys.version_info.major == 2:
  15. from urllib import urlencode
  16. from urllib import quote
  17. from urlparse import urlparse
  18. else:
  19. from urllib.parse import urlencode
  20. from urllib.parse import quote
  21. from urllib.parse import urlparse
  22. class AipBase(object):
  23. """
  24. AipBase
  25. """
  26. __accessTokenUrl = 'https://aip.baidubce.com/oauth/2.0/token'
  27. __reportUrl = 'https://aip.baidubce.com/rpc/2.0/feedback/v1/report'
  28. __scope = 'brain_all_scope'
  29. def __init__(self, appId, apiKey, secretKey):
  30. """
  31. AipBase(appId, apiKey, secretKey)
  32. """
  33. self._appId = appId.strip()
  34. self._apiKey = apiKey.strip()
  35. self._secretKey = secretKey.strip()
  36. self._authObj = {}
  37. self._isCloudUser = None
  38. self.__client = requests
  39. self.__connectTimeout = 60.0
  40. self.__socketTimeout = 60.0
  41. self._proxies = {}
  42. self.__version = '2_2_18'
  43. def getVersion(self):
  44. """
  45. version
  46. """
  47. return self.__version
  48. def setConnectionTimeoutInMillis(self, ms):
  49. """
  50. setConnectionTimeoutInMillis
  51. """
  52. self.__connectTimeout = ms / 1000.0
  53. def setSocketTimeoutInMillis(self, ms):
  54. """
  55. setSocketTimeoutInMillis
  56. """
  57. self.__socketTimeout = ms / 1000.0
  58. def setProxies(self, proxies):
  59. """
  60. proxies
  61. """
  62. self._proxies = proxies
  63. def _request(self, url, data, headers=None):
  64. """
  65. self._request('', {})
  66. """
  67. try:
  68. result = self._validate(url, data)
  69. if result != True:
  70. return result
  71. authObj = self._auth()
  72. params = self._getParams(authObj)
  73. data = self._proccessRequest(url, params, data, headers)
  74. headers = self._getAuthHeaders('POST', url, params, headers)
  75. response = self.__client.post(url, data=data, params=params,
  76. headers=headers, verify=False, timeout=(
  77. self.__connectTimeout,
  78. self.__socketTimeout,
  79. ), proxies=self._proxies
  80. )
  81. obj = self._proccessResult(response.content)
  82. if not self._isCloudUser and obj.get('error_code', '') == 110:
  83. authObj = self._auth(True)
  84. params = self._getParams(authObj)
  85. response = self.__client.post(url, data=data, params=params,
  86. headers=headers, verify=False, timeout=(
  87. self.__connectTimeout,
  88. self.__socketTimeout,
  89. ), proxies=self._proxies
  90. )
  91. obj = self._proccessResult(response.content)
  92. except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as e:
  93. return {
  94. 'error_code': 'SDK108',
  95. 'error_msg': 'connection or read data timeout',
  96. }
  97. return obj
  98. def _validate(self, url, data):
  99. """
  100. validate
  101. """
  102. return True
  103. def _proccessRequest(self, url, params, data, headers):
  104. """
  105. 参数处理
  106. """
  107. params['aipSdk'] = 'python'
  108. params['aipVersion'] = self.__version
  109. return data
  110. def _proccessResult(self, content):
  111. """
  112. formate result
  113. """
  114. if sys.version_info.major == 2:
  115. return json.loads(content) or {}
  116. else:
  117. return json.loads(content.decode()) or {}
  118. def _auth(self, refresh=False):
  119. """
  120. api access auth
  121. """
  122. if self._isCloudUser:
  123. return self._authObj
  124. #未过期
  125. if not refresh:
  126. tm = self._authObj.get('time', 0) + int(self._authObj.get('expires_in', 0)) - 30
  127. if tm > int(time.time()):
  128. return self._authObj
  129. obj = self.__client.get(self.__accessTokenUrl, verify=False, params={
  130. 'grant_type': 'client_credentials',
  131. 'client_id': self._apiKey,
  132. 'client_secret': self._secretKey,
  133. }, timeout=(
  134. self.__connectTimeout,
  135. self.__socketTimeout,
  136. ), proxies=self._proxies).json()
  137. self._isCloudUser = not self._isPermission(obj)
  138. obj['time'] = int(time.time())
  139. self._authObj = obj
  140. return obj
  141. def _isPermission(self, authObj):
  142. """
  143. check whether permission
  144. """
  145. scopes = authObj.get('scope', '')
  146. return self.__scope in scopes.split(' ')
  147. def _getParams(self, authObj):
  148. """
  149. api request http url params
  150. """
  151. params = {}
  152. if self._isCloudUser == False:
  153. params['access_token'] = authObj['access_token']
  154. return params
  155. def _getAuthHeaders(self, method, url, params=None, headers=None):
  156. """
  157. api request http headers
  158. """
  159. headers = headers or {}
  160. params = params or {}
  161. if self._isCloudUser == False:
  162. return headers
  163. urlResult = urlparse(url)
  164. for kv in urlResult.query.strip().split('&'):
  165. if kv:
  166. k, v = kv.split('=')
  167. params[k] = v
  168. # UTC timestamp
  169. timestamp = datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')
  170. headers['Host'] = urlResult.hostname
  171. headers['x-bce-date'] = timestamp
  172. version, expire = '1', '1800'
  173. # 1 Generate SigningKey
  174. val = "bce-auth-v%s/%s/%s/%s" % (version, self._apiKey, timestamp, expire)
  175. signingKey = hmac.new(self._secretKey.encode('utf-8'), val.encode('utf-8'),
  176. hashlib.sha256
  177. ).hexdigest()
  178. # 2 Generate CanonicalRequest
  179. # 2.1 Genrate CanonicalURI
  180. canonicalUri = quote(urlResult.path)
  181. # 2.2 Generate CanonicalURI: not used here
  182. # 2.3 Generate CanonicalHeaders: only include host here
  183. canonicalHeaders = []
  184. for header, val in headers.items():
  185. canonicalHeaders.append(
  186. '%s:%s' % (
  187. quote(header.strip(), '').lower(),
  188. quote(val.strip(), '')
  189. )
  190. )
  191. canonicalHeaders = '\n'.join(sorted(canonicalHeaders))
  192. # 2.4 Generate CanonicalRequest
  193. canonicalRequest = '%s\n%s\n%s\n%s' % (
  194. method.upper(),
  195. canonicalUri,
  196. '&'.join(sorted(urlencode(params).split('&'))),
  197. canonicalHeaders
  198. )
  199. # 3 Generate Final Signature
  200. signature = hmac.new(signingKey.encode('utf-8'), canonicalRequest.encode('utf-8'),
  201. hashlib.sha256
  202. ).hexdigest()
  203. headers['authorization'] = 'bce-auth-v%s/%s/%s/%s/%s/%s' % (
  204. version,
  205. self._apiKey,
  206. timestamp,
  207. expire,
  208. ';'.join(headers.keys()).lower(),
  209. signature
  210. )
  211. return headers
  212. def report(self, feedback):
  213. """
  214. 数据反馈
  215. """
  216. data = {}
  217. data['feedback'] = feedback
  218. return self._request(self.__reportUrl, data)
  219. def post(self, url, data, headers=None):
  220. """
  221. self.post('', {})
  222. """
  223. return self._request(url, data, headers)