connection.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040
  1. from __future__ import with_statement
  2. from distutils.version import StrictVersion
  3. from itertools import chain
  4. from select import select
  5. import os
  6. import socket
  7. import sys
  8. import threading
  9. import warnings
  10. try:
  11. import ssl
  12. ssl_available = True
  13. except ImportError:
  14. ssl_available = False
  15. from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
  16. BytesIO, nativestr, basestring, iteritems,
  17. LifoQueue, Empty, Full, urlparse, parse_qs,
  18. unquote)
  19. from redis.exceptions import (
  20. RedisError,
  21. ConnectionError,
  22. TimeoutError,
  23. BusyLoadingError,
  24. ResponseError,
  25. InvalidResponse,
  26. AuthenticationError,
  27. NoScriptError,
  28. ExecAbortError,
  29. ReadOnlyError
  30. )
  31. from redis.utils import HIREDIS_AVAILABLE
  32. if HIREDIS_AVAILABLE:
  33. import hiredis
  34. hiredis_version = StrictVersion(hiredis.__version__)
  35. HIREDIS_SUPPORTS_CALLABLE_ERRORS = \
  36. hiredis_version >= StrictVersion('0.1.3')
  37. HIREDIS_SUPPORTS_BYTE_BUFFER = \
  38. hiredis_version >= StrictVersion('0.1.4')
  39. if not HIREDIS_SUPPORTS_BYTE_BUFFER:
  40. msg = ("redis-py works best with hiredis >= 0.1.4. You're running "
  41. "hiredis %s. Please consider upgrading." % hiredis.__version__)
  42. warnings.warn(msg)
  43. HIREDIS_USE_BYTE_BUFFER = True
  44. # only use byte buffer if hiredis supports it and the Python version
  45. # is >= 2.7
  46. if not HIREDIS_SUPPORTS_BYTE_BUFFER or (
  47. sys.version_info[0] == 2 and sys.version_info[1] < 7):
  48. HIREDIS_USE_BYTE_BUFFER = False
  49. SYM_STAR = b('*')
  50. SYM_DOLLAR = b('$')
  51. SYM_CRLF = b('\r\n')
  52. SYM_EMPTY = b('')
  53. SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
  54. class Token(object):
  55. """
  56. Literal strings in Redis commands, such as the command names and any
  57. hard-coded arguments are wrapped in this class so we know not to apply
  58. and encoding rules on them.
  59. """
  60. def __init__(self, value):
  61. if isinstance(value, Token):
  62. value = value.value
  63. self.value = value
  64. def __repr__(self):
  65. return self.value
  66. def __str__(self):
  67. return self.value
  68. class BaseParser(object):
  69. EXCEPTION_CLASSES = {
  70. 'ERR': {
  71. 'max number of clients reached': ConnectionError
  72. },
  73. 'EXECABORT': ExecAbortError,
  74. 'LOADING': BusyLoadingError,
  75. 'NOSCRIPT': NoScriptError,
  76. 'READONLY': ReadOnlyError,
  77. }
  78. def parse_error(self, response):
  79. "Parse an error response"
  80. error_code = response.split(' ')[0]
  81. if error_code in self.EXCEPTION_CLASSES:
  82. response = response[len(error_code) + 1:]
  83. exception_class = self.EXCEPTION_CLASSES[error_code]
  84. if isinstance(exception_class, dict):
  85. exception_class = exception_class.get(response, ResponseError)
  86. return exception_class(response)
  87. return ResponseError(response)
  88. class SocketBuffer(object):
  89. def __init__(self, socket, socket_read_size):
  90. self._sock = socket
  91. self.socket_read_size = socket_read_size
  92. self._buffer = BytesIO()
  93. # number of bytes written to the buffer from the socket
  94. self.bytes_written = 0
  95. # number of bytes read from the buffer
  96. self.bytes_read = 0
  97. @property
  98. def length(self):
  99. return self.bytes_written - self.bytes_read
  100. def _read_from_socket(self, length=None):
  101. socket_read_size = self.socket_read_size
  102. buf = self._buffer
  103. buf.seek(self.bytes_written)
  104. marker = 0
  105. try:
  106. while True:
  107. data = self._sock.recv(socket_read_size)
  108. # an empty string indicates the server shutdown the socket
  109. if isinstance(data, bytes) and len(data) == 0:
  110. raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
  111. buf.write(data)
  112. data_length = len(data)
  113. self.bytes_written += data_length
  114. marker += data_length
  115. if length is not None and length > marker:
  116. continue
  117. break
  118. except socket.timeout:
  119. raise TimeoutError("Timeout reading from socket")
  120. except socket.error:
  121. e = sys.exc_info()[1]
  122. raise ConnectionError("Error while reading from socket: %s" %
  123. (e.args,))
  124. def read(self, length):
  125. length = length + 2 # make sure to read the \r\n terminator
  126. # make sure we've read enough data from the socket
  127. if length > self.length:
  128. self._read_from_socket(length - self.length)
  129. self._buffer.seek(self.bytes_read)
  130. data = self._buffer.read(length)
  131. self.bytes_read += len(data)
  132. # purge the buffer when we've consumed it all so it doesn't
  133. # grow forever
  134. if self.bytes_read == self.bytes_written:
  135. self.purge()
  136. return data[:-2]
  137. def readline(self):
  138. buf = self._buffer
  139. buf.seek(self.bytes_read)
  140. data = buf.readline()
  141. while not data.endswith(SYM_CRLF):
  142. # there's more data in the socket that we need
  143. self._read_from_socket()
  144. buf.seek(self.bytes_read)
  145. data = buf.readline()
  146. self.bytes_read += len(data)
  147. # purge the buffer when we've consumed it all so it doesn't
  148. # grow forever
  149. if self.bytes_read == self.bytes_written:
  150. self.purge()
  151. return data[:-2]
  152. def purge(self):
  153. self._buffer.seek(0)
  154. self._buffer.truncate()
  155. self.bytes_written = 0
  156. self.bytes_read = 0
  157. def close(self):
  158. try:
  159. self.purge()
  160. self._buffer.close()
  161. except:
  162. # issue #633 suggests the purge/close somehow raised a
  163. # BadFileDescriptor error. Perhaps the client ran out of
  164. # memory or something else? It's probably OK to ignore
  165. # any error being raised from purge/close since we're
  166. # removing the reference to the instance below.
  167. pass
  168. self._buffer = None
  169. self._sock = None
  170. class PythonParser(BaseParser):
  171. "Plain Python parsing class"
  172. encoding = None
  173. def __init__(self, socket_read_size):
  174. self.socket_read_size = socket_read_size
  175. self._sock = None
  176. self._buffer = None
  177. def __del__(self):
  178. try:
  179. self.on_disconnect()
  180. except Exception:
  181. pass
  182. def on_connect(self, connection):
  183. "Called when the socket connects"
  184. self._sock = connection._sock
  185. self._buffer = SocketBuffer(self._sock, self.socket_read_size)
  186. if connection.decode_responses:
  187. self.encoding = connection.encoding
  188. def on_disconnect(self):
  189. "Called when the socket disconnects"
  190. if self._sock is not None:
  191. self._sock.close()
  192. self._sock = None
  193. if self._buffer is not None:
  194. self._buffer.close()
  195. self._buffer = None
  196. self.encoding = None
  197. def can_read(self):
  198. return self._buffer and bool(self._buffer.length)
  199. def read_response(self):
  200. response = self._buffer.readline()
  201. if not response:
  202. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  203. byte, response = byte_to_chr(response[0]), response[1:]
  204. if byte not in ('-', '+', ':', '$', '*'):
  205. raise InvalidResponse("Protocol Error: %s, %s" %
  206. (str(byte), str(response)))
  207. # server returned an error
  208. if byte == '-':
  209. response = nativestr(response)
  210. error = self.parse_error(response)
  211. # if the error is a ConnectionError, raise immediately so the user
  212. # is notified
  213. if isinstance(error, ConnectionError):
  214. raise error
  215. # otherwise, we're dealing with a ResponseError that might belong
  216. # inside a pipeline response. the connection's read_response()
  217. # and/or the pipeline's execute() will raise this error if
  218. # necessary, so just return the exception instance here.
  219. return error
  220. # single value
  221. elif byte == '+':
  222. pass
  223. # int value
  224. elif byte == ':':
  225. response = long(response)
  226. # bulk response
  227. elif byte == '$':
  228. length = int(response)
  229. if length == -1:
  230. return None
  231. response = self._buffer.read(length)
  232. # multi-bulk response
  233. elif byte == '*':
  234. length = int(response)
  235. if length == -1:
  236. return None
  237. response = [self.read_response() for i in xrange(length)]
  238. if isinstance(response, bytes) and self.encoding:
  239. response = response.decode(self.encoding)
  240. return response
  241. class HiredisParser(BaseParser):
  242. "Parser class for connections using Hiredis"
  243. def __init__(self, socket_read_size):
  244. if not HIREDIS_AVAILABLE:
  245. raise RedisError("Hiredis is not installed")
  246. self.socket_read_size = socket_read_size
  247. if HIREDIS_USE_BYTE_BUFFER:
  248. self._buffer = bytearray(socket_read_size)
  249. def __del__(self):
  250. try:
  251. self.on_disconnect()
  252. except Exception:
  253. pass
  254. def on_connect(self, connection):
  255. self._sock = connection._sock
  256. kwargs = {
  257. 'protocolError': InvalidResponse,
  258. 'replyError': self.parse_error,
  259. }
  260. # hiredis < 0.1.3 doesn't support functions that create exceptions
  261. if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
  262. kwargs['replyError'] = ResponseError
  263. if connection.decode_responses:
  264. kwargs['encoding'] = connection.encoding
  265. self._reader = hiredis.Reader(**kwargs)
  266. self._next_response = False
  267. def on_disconnect(self):
  268. self._sock = None
  269. self._reader = None
  270. self._next_response = False
  271. def can_read(self):
  272. if not self._reader:
  273. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  274. if self._next_response is False:
  275. self._next_response = self._reader.gets()
  276. return self._next_response is not False
  277. def read_response(self):
  278. if not self._reader:
  279. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  280. # _next_response might be cached from a can_read() call
  281. if self._next_response is not False:
  282. response = self._next_response
  283. self._next_response = False
  284. return response
  285. response = self._reader.gets()
  286. socket_read_size = self.socket_read_size
  287. while response is False:
  288. try:
  289. if HIREDIS_USE_BYTE_BUFFER:
  290. bufflen = self._sock.recv_into(self._buffer)
  291. if bufflen == 0:
  292. raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
  293. else:
  294. buffer = self._sock.recv(socket_read_size)
  295. # an empty string indicates the server shutdown the socket
  296. if not isinstance(buffer, bytes) or len(buffer) == 0:
  297. raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
  298. except socket.timeout:
  299. raise TimeoutError("Timeout reading from socket")
  300. except socket.error:
  301. e = sys.exc_info()[1]
  302. raise ConnectionError("Error while reading from socket: %s" %
  303. (e.args,))
  304. if HIREDIS_USE_BYTE_BUFFER:
  305. self._reader.feed(self._buffer, 0, bufflen)
  306. else:
  307. self._reader.feed(buffer)
  308. response = self._reader.gets()
  309. # if an older version of hiredis is installed, we need to attempt
  310. # to convert ResponseErrors to their appropriate types.
  311. if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
  312. if isinstance(response, ResponseError):
  313. response = self.parse_error(response.args[0])
  314. elif isinstance(response, list) and response and \
  315. isinstance(response[0], ResponseError):
  316. response[0] = self.parse_error(response[0].args[0])
  317. # if the response is a ConnectionError or the response is a list and
  318. # the first item is a ConnectionError, raise it as something bad
  319. # happened
  320. if isinstance(response, ConnectionError):
  321. raise response
  322. elif isinstance(response, list) and response and \
  323. isinstance(response[0], ConnectionError):
  324. raise response[0]
  325. return response
  326. if HIREDIS_AVAILABLE:
  327. DefaultParser = HiredisParser
  328. else:
  329. DefaultParser = PythonParser
  330. class Connection(object):
  331. "Manages TCP communication to and from a Redis server"
  332. description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>"
  333. def __init__(self, host='localhost', port=6379, db=0, password=None,
  334. socket_timeout=None, socket_connect_timeout=None,
  335. socket_keepalive=False, socket_keepalive_options=None,
  336. retry_on_timeout=False, encoding='utf-8',
  337. encoding_errors='strict', decode_responses=False,
  338. parser_class=DefaultParser, socket_read_size=65536):
  339. self.pid = os.getpid()
  340. self.host = host
  341. self.port = int(port)
  342. self.db = db
  343. self.password = password
  344. self.socket_timeout = socket_timeout
  345. self.socket_connect_timeout = socket_connect_timeout or socket_timeout
  346. self.socket_keepalive = socket_keepalive
  347. self.socket_keepalive_options = socket_keepalive_options or {}
  348. self.retry_on_timeout = retry_on_timeout
  349. self.encoding = encoding
  350. self.encoding_errors = encoding_errors
  351. self.decode_responses = decode_responses
  352. self._sock = None
  353. self._parser = parser_class(socket_read_size=socket_read_size)
  354. self._description_args = {
  355. 'host': self.host,
  356. 'port': self.port,
  357. 'db': self.db,
  358. }
  359. self._connect_callbacks = []
  360. def __repr__(self):
  361. return self.description_format % self._description_args
  362. def __del__(self):
  363. try:
  364. self.disconnect()
  365. except Exception:
  366. pass
  367. def register_connect_callback(self, callback):
  368. self._connect_callbacks.append(callback)
  369. def clear_connect_callbacks(self):
  370. self._connect_callbacks = []
  371. def connect(self):
  372. "Connects to the Redis server if not already connected"
  373. if self._sock:
  374. return
  375. try:
  376. sock = self._connect()
  377. except socket.error:
  378. e = sys.exc_info()[1]
  379. raise ConnectionError(self._error_message(e))
  380. self._sock = sock
  381. try:
  382. self.on_connect()
  383. except RedisError:
  384. # clean up after any error in on_connect
  385. self.disconnect()
  386. raise
  387. # run any user callbacks. right now the only internal callback
  388. # is for pubsub channel/pattern resubscription
  389. for callback in self._connect_callbacks:
  390. callback(self)
  391. def _connect(self):
  392. "Create a TCP socket connection"
  393. # we want to mimic what socket.create_connection does to support
  394. # ipv4/ipv6, but we want to set options prior to calling
  395. # socket.connect()
  396. err = None
  397. for res in socket.getaddrinfo(self.host, self.port, 0,
  398. socket.SOCK_STREAM):
  399. family, socktype, proto, canonname, socket_address = res
  400. sock = None
  401. try:
  402. sock = socket.socket(family, socktype, proto)
  403. # TCP_NODELAY
  404. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  405. # TCP_KEEPALIVE
  406. if self.socket_keepalive:
  407. sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
  408. for k, v in iteritems(self.socket_keepalive_options):
  409. sock.setsockopt(socket.SOL_TCP, k, v)
  410. # set the socket_connect_timeout before we connect
  411. sock.settimeout(self.socket_connect_timeout)
  412. # connect
  413. sock.connect(socket_address)
  414. # set the socket_timeout now that we're connected
  415. sock.settimeout(self.socket_timeout)
  416. return sock
  417. except socket.error as _:
  418. err = _
  419. if sock is not None:
  420. sock.close()
  421. if err is not None:
  422. raise err
  423. raise socket.error("socket.getaddrinfo returned an empty list")
  424. def _error_message(self, exception):
  425. # args for socket.error can either be (errno, "message")
  426. # or just "message"
  427. if len(exception.args) == 1:
  428. return "Error connecting to %s:%s. %s." % \
  429. (self.host, self.port, exception.args[0])
  430. else:
  431. return "Error %s connecting to %s:%s. %s." % \
  432. (exception.args[0], self.host, self.port, exception.args[1])
  433. def on_connect(self):
  434. "Initialize the connection, authenticate and select a database"
  435. self._parser.on_connect(self)
  436. # if a password is specified, authenticate
  437. if self.password:
  438. self.send_command('AUTH', self.password)
  439. if nativestr(self.read_response()) != 'OK':
  440. raise AuthenticationError('Invalid Password')
  441. # if a database is specified, switch to it
  442. if self.db:
  443. self.send_command('SELECT', self.db)
  444. if nativestr(self.read_response()) != 'OK':
  445. raise ConnectionError('Invalid Database')
  446. def disconnect(self):
  447. "Disconnects from the Redis server"
  448. self._parser.on_disconnect()
  449. if self._sock is None:
  450. return
  451. try:
  452. self._sock.shutdown(socket.SHUT_RDWR)
  453. self._sock.close()
  454. except socket.error:
  455. pass
  456. self._sock = None
  457. def send_packed_command(self, command):
  458. "Send an already packed command to the Redis server"
  459. if not self._sock:
  460. self.connect()
  461. try:
  462. if isinstance(command, str):
  463. command = [command]
  464. for item in command:
  465. self._sock.sendall(item)
  466. except socket.timeout:
  467. self.disconnect()
  468. raise TimeoutError("Timeout writing to socket")
  469. except socket.error:
  470. e = sys.exc_info()[1]
  471. self.disconnect()
  472. if len(e.args) == 1:
  473. errno, errmsg = 'UNKNOWN', e.args[0]
  474. else:
  475. errno = e.args[0]
  476. errmsg = e.args[1]
  477. raise ConnectionError("Error %s while writing to socket. %s." %
  478. (errno, errmsg))
  479. except:
  480. self.disconnect()
  481. raise
  482. def send_command(self, *args):
  483. "Pack and send a command to the Redis server"
  484. self.send_packed_command(self.pack_command(*args))
  485. def can_read(self, timeout=0):
  486. "Poll the socket to see if there's data that can be read."
  487. sock = self._sock
  488. if not sock:
  489. self.connect()
  490. sock = self._sock
  491. return self._parser.can_read() or \
  492. bool(select([sock], [], [], timeout)[0])
  493. def read_response(self):
  494. "Read the response from a previously sent command"
  495. try:
  496. response = self._parser.read_response()
  497. except:
  498. self.disconnect()
  499. raise
  500. if isinstance(response, ResponseError):
  501. raise response
  502. return response
  503. def encode(self, value):
  504. "Return a bytestring representation of the value"
  505. if isinstance(value, Token):
  506. return b(value.value)
  507. elif isinstance(value, bytes):
  508. return value
  509. elif isinstance(value, (int, long)):
  510. value = b(str(value))
  511. elif isinstance(value, float):
  512. value = b(repr(value))
  513. elif not isinstance(value, basestring):
  514. value = unicode(value)
  515. if isinstance(value, unicode):
  516. value = value.encode(self.encoding, self.encoding_errors)
  517. return value
  518. def pack_command(self, *args):
  519. "Pack a series of arguments into the Redis protocol"
  520. output = []
  521. # the client might have included 1 or more literal arguments in
  522. # the command name, e.g., 'CONFIG GET'. The Redis server expects these
  523. # arguments to be sent separately, so split the first argument
  524. # manually. All of these arguements get wrapped in the Token class
  525. # to prevent them from being encoded.
  526. command = args[0]
  527. if ' ' in command:
  528. args = tuple([Token(s) for s in command.split(' ')]) + args[1:]
  529. else:
  530. args = (Token(command),) + args[1:]
  531. buff = SYM_EMPTY.join(
  532. (SYM_STAR, b(str(len(args))), SYM_CRLF))
  533. for arg in imap(self.encode, args):
  534. # to avoid large string mallocs, chunk the command into the
  535. # output list if we're sending large values
  536. if len(buff) > 6000 or len(arg) > 6000:
  537. buff = SYM_EMPTY.join(
  538. (buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF))
  539. output.append(buff)
  540. output.append(arg)
  541. buff = SYM_CRLF
  542. else:
  543. buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))),
  544. SYM_CRLF, arg, SYM_CRLF))
  545. output.append(buff)
  546. return output
  547. def pack_commands(self, commands):
  548. "Pack multiple commands into the Redis protocol"
  549. output = []
  550. pieces = []
  551. buffer_length = 0
  552. for cmd in commands:
  553. for chunk in self.pack_command(*cmd):
  554. pieces.append(chunk)
  555. buffer_length += len(chunk)
  556. if buffer_length > 6000:
  557. output.append(SYM_EMPTY.join(pieces))
  558. buffer_length = 0
  559. pieces = []
  560. if pieces:
  561. output.append(SYM_EMPTY.join(pieces))
  562. return output
  563. class SSLConnection(Connection):
  564. description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>"
  565. def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
  566. ssl_ca_certs=None, **kwargs):
  567. if not ssl_available:
  568. raise RedisError("Python wasn't built with SSL support")
  569. super(SSLConnection, self).__init__(**kwargs)
  570. self.keyfile = ssl_keyfile
  571. self.certfile = ssl_certfile
  572. if ssl_cert_reqs is None:
  573. ssl_cert_reqs = ssl.CERT_NONE
  574. elif isinstance(ssl_cert_reqs, basestring):
  575. CERT_REQS = {
  576. 'none': ssl.CERT_NONE,
  577. 'optional': ssl.CERT_OPTIONAL,
  578. 'required': ssl.CERT_REQUIRED
  579. }
  580. if ssl_cert_reqs not in CERT_REQS:
  581. raise RedisError(
  582. "Invalid SSL Certificate Requirements Flag: %s" %
  583. ssl_cert_reqs)
  584. ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
  585. self.cert_reqs = ssl_cert_reqs
  586. self.ca_certs = ssl_ca_certs
  587. def _connect(self):
  588. "Wrap the socket with SSL support"
  589. sock = super(SSLConnection, self)._connect()
  590. sock = ssl.wrap_socket(sock,
  591. cert_reqs=self.cert_reqs,
  592. keyfile=self.keyfile,
  593. certfile=self.certfile,
  594. ca_certs=self.ca_certs)
  595. return sock
  596. class UnixDomainSocketConnection(Connection):
  597. description_format = "UnixDomainSocketConnection<path=%(path)s,db=%(db)s>"
  598. def __init__(self, path='', db=0, password=None,
  599. socket_timeout=None, encoding='utf-8',
  600. encoding_errors='strict', decode_responses=False,
  601. retry_on_timeout=False,
  602. parser_class=DefaultParser, socket_read_size=65536):
  603. self.pid = os.getpid()
  604. self.path = path
  605. self.db = db
  606. self.password = password
  607. self.socket_timeout = socket_timeout
  608. self.retry_on_timeout = retry_on_timeout
  609. self.encoding = encoding
  610. self.encoding_errors = encoding_errors
  611. self.decode_responses = decode_responses
  612. self._sock = None
  613. self._parser = parser_class(socket_read_size=socket_read_size)
  614. self._description_args = {
  615. 'path': self.path,
  616. 'db': self.db,
  617. }
  618. self._connect_callbacks = []
  619. def _connect(self):
  620. "Create a Unix domain socket connection"
  621. sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  622. sock.settimeout(self.socket_timeout)
  623. sock.connect(self.path)
  624. return sock
  625. def _error_message(self, exception):
  626. # args for socket.error can either be (errno, "message")
  627. # or just "message"
  628. if len(exception.args) == 1:
  629. return "Error connecting to unix socket: %s. %s." % \
  630. (self.path, exception.args[0])
  631. else:
  632. return "Error %s connecting to unix socket: %s. %s." % \
  633. (exception.args[0], self.path, exception.args[1])
  634. class ConnectionPool(object):
  635. "Generic connection pool"
  636. @classmethod
  637. def from_url(cls, url, db=None, decode_components=False, **kwargs):
  638. """
  639. Return a connection pool configured from the given URL.
  640. For example::
  641. redis://[:password]@localhost:6379/0
  642. rediss://[:password]@localhost:6379/0
  643. unix://[:password]@/path/to/socket.sock?db=0
  644. Three URL schemes are supported:
  645. redis:// creates a normal TCP socket connection
  646. rediss:// creates a SSL wrapped TCP socket connection
  647. unix:// creates a Unix Domain Socket connection
  648. There are several ways to specify a database number. The parse function
  649. will return the first specified option:
  650. 1. A ``db`` querystring option, e.g. redis://localhost?db=0
  651. 2. If using the redis:// scheme, the path argument of the url, e.g.
  652. redis://localhost/0
  653. 3. The ``db`` argument to this function.
  654. If none of these options are specified, db=0 is used.
  655. The ``decode_components`` argument allows this function to work with
  656. percent-encoded URLs. If this argument is set to ``True`` all ``%xx``
  657. escapes will be replaced by their single-character equivalents after
  658. the URL has been parsed. This only applies to the ``hostname``,
  659. ``path``, and ``password`` components.
  660. Any additional querystring arguments and keyword arguments will be
  661. passed along to the ConnectionPool class's initializer. In the case
  662. of conflicting arguments, querystring arguments always win.
  663. """
  664. url_string = url
  665. url = urlparse(url)
  666. qs = ''
  667. # in python2.6, custom URL schemes don't recognize querystring values
  668. # they're left as part of the url.path.
  669. if '?' in url.path and not url.query:
  670. # chop the querystring including the ? off the end of the url
  671. # and reparse it.
  672. qs = url.path.split('?', 1)[1]
  673. url = urlparse(url_string[:-(len(qs) + 1)])
  674. else:
  675. qs = url.query
  676. url_options = {}
  677. for name, value in iteritems(parse_qs(qs)):
  678. if value and len(value) > 0:
  679. url_options[name] = value[0]
  680. if decode_components:
  681. password = unquote(url.password) if url.password else None
  682. path = unquote(url.path) if url.path else None
  683. hostname = unquote(url.hostname) if url.hostname else None
  684. else:
  685. password = url.password
  686. path = url.path
  687. hostname = url.hostname
  688. # We only support redis:// and unix:// schemes.
  689. if url.scheme == 'unix':
  690. url_options.update({
  691. 'password': password,
  692. 'path': path,
  693. 'connection_class': UnixDomainSocketConnection,
  694. })
  695. else:
  696. url_options.update({
  697. 'host': hostname,
  698. 'port': int(url.port or 6379),
  699. 'password': password,
  700. })
  701. # If there's a path argument, use it as the db argument if a
  702. # querystring value wasn't specified
  703. if 'db' not in url_options and path:
  704. try:
  705. url_options['db'] = int(path.replace('/', ''))
  706. except (AttributeError, ValueError):
  707. pass
  708. if url.scheme == 'rediss':
  709. url_options['connection_class'] = SSLConnection
  710. # last shot at the db value
  711. url_options['db'] = int(url_options.get('db', db or 0))
  712. # update the arguments from the URL values
  713. kwargs.update(url_options)
  714. # backwards compatability
  715. if 'charset' in kwargs:
  716. warnings.warn(DeprecationWarning(
  717. '"charset" is deprecated. Use "encoding" instead'))
  718. kwargs['encoding'] = kwargs.pop('charset')
  719. if 'errors' in kwargs:
  720. warnings.warn(DeprecationWarning(
  721. '"errors" is deprecated. Use "encoding_errors" instead'))
  722. kwargs['encoding_errors'] = kwargs.pop('errors')
  723. return cls(**kwargs)
  724. def __init__(self, connection_class=Connection, max_connections=None,
  725. **connection_kwargs):
  726. """
  727. Create a connection pool. If max_connections is set, then this
  728. object raises redis.ConnectionError when the pool's limit is reached.
  729. By default, TCP connections are created connection_class is specified.
  730. Use redis.UnixDomainSocketConnection for unix sockets.
  731. Any additional keyword arguments are passed to the constructor of
  732. connection_class.
  733. """
  734. max_connections = max_connections or 2 ** 31
  735. if not isinstance(max_connections, (int, long)) or max_connections < 0:
  736. raise ValueError('"max_connections" must be a positive integer')
  737. self.connection_class = connection_class
  738. self.connection_kwargs = connection_kwargs
  739. self.max_connections = max_connections
  740. self.reset()
  741. def __repr__(self):
  742. return "%s<%s>" % (
  743. type(self).__name__,
  744. self.connection_class.description_format % self.connection_kwargs,
  745. )
  746. def reset(self):
  747. self.pid = os.getpid()
  748. self._created_connections = 0
  749. self._available_connections = []
  750. self._in_use_connections = set()
  751. self._check_lock = threading.Lock()
  752. def _checkpid(self):
  753. if self.pid != os.getpid():
  754. with self._check_lock:
  755. if self.pid == os.getpid():
  756. # another thread already did the work while we waited
  757. # on the lock.
  758. return
  759. self.disconnect()
  760. self.reset()
  761. def get_connection(self, command_name, *keys, **options):
  762. "Get a connection from the pool"
  763. self._checkpid()
  764. try:
  765. connection = self._available_connections.pop()
  766. except IndexError:
  767. connection = self.make_connection()
  768. self._in_use_connections.add(connection)
  769. return connection
  770. def make_connection(self):
  771. "Create a new connection"
  772. if self._created_connections >= self.max_connections:
  773. raise ConnectionError("Too many connections")
  774. self._created_connections += 1
  775. return self.connection_class(**self.connection_kwargs)
  776. def release(self, connection):
  777. "Releases the connection back to the pool"
  778. self._checkpid()
  779. if connection.pid != self.pid:
  780. return
  781. self._in_use_connections.remove(connection)
  782. self._available_connections.append(connection)
  783. def disconnect(self):
  784. "Disconnects all connections in the pool"
  785. all_conns = chain(self._available_connections,
  786. self._in_use_connections)
  787. for connection in all_conns:
  788. connection.disconnect()
  789. class BlockingConnectionPool(ConnectionPool):
  790. """
  791. Thread-safe blocking connection pool::
  792. >>> from redis.client import Redis
  793. >>> client = Redis(connection_pool=BlockingConnectionPool())
  794. It performs the same function as the default
  795. ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
  796. it maintains a pool of reusable connections that can be shared by
  797. multiple redis clients (safely across threads if required).
  798. The difference is that, in the event that a client tries to get a
  799. connection from the pool when all of connections are in use, rather than
  800. raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
  801. ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
  802. makes the client wait ("blocks") for a specified number of seconds until
  803. a connection becomes available.
  804. Use ``max_connections`` to increase / decrease the pool size::
  805. >>> pool = BlockingConnectionPool(max_connections=10)
  806. Use ``timeout`` to tell it either how many seconds to wait for a connection
  807. to become available, or to block forever:
  808. # Block forever.
  809. >>> pool = BlockingConnectionPool(timeout=None)
  810. # Raise a ``ConnectionError`` after five seconds if a connection is
  811. # not available.
  812. >>> pool = BlockingConnectionPool(timeout=5)
  813. """
  814. def __init__(self, max_connections=50, timeout=20,
  815. connection_class=Connection, queue_class=LifoQueue,
  816. **connection_kwargs):
  817. self.queue_class = queue_class
  818. self.timeout = timeout
  819. super(BlockingConnectionPool, self).__init__(
  820. connection_class=connection_class,
  821. max_connections=max_connections,
  822. **connection_kwargs)
  823. def reset(self):
  824. self.pid = os.getpid()
  825. self._check_lock = threading.Lock()
  826. # Create and fill up a thread safe queue with ``None`` values.
  827. self.pool = self.queue_class(self.max_connections)
  828. while True:
  829. try:
  830. self.pool.put_nowait(None)
  831. except Full:
  832. break
  833. # Keep a list of actual connection instances so that we can
  834. # disconnect them later.
  835. self._connections = []
  836. def make_connection(self):
  837. "Make a fresh connection."
  838. connection = self.connection_class(**self.connection_kwargs)
  839. self._connections.append(connection)
  840. return connection
  841. def get_connection(self, command_name, *keys, **options):
  842. """
  843. Get a connection, blocking for ``self.timeout`` until a connection
  844. is available from the pool.
  845. If the connection returned is ``None`` then creates a new connection.
  846. Because we use a last-in first-out queue, the existing connections
  847. (having been returned to the pool after the initial ``None`` values
  848. were added) will be returned before ``None`` values. This means we only
  849. create new connections when we need to, i.e.: the actual number of
  850. connections will only increase in response to demand.
  851. """
  852. # Make sure we haven't changed process.
  853. self._checkpid()
  854. # Try and get a connection from the pool. If one isn't available within
  855. # self.timeout then raise a ``ConnectionError``.
  856. connection = None
  857. try:
  858. connection = self.pool.get(block=True, timeout=self.timeout)
  859. except Empty:
  860. # Note that this is not caught by the redis client and will be
  861. # raised unless handled by application code. If you want never to
  862. raise ConnectionError("No connection available.")
  863. # If the ``connection`` is actually ``None`` then that's a cue to make
  864. # a new connection to add to the pool.
  865. if connection is None:
  866. connection = self.make_connection()
  867. return connection
  868. def release(self, connection):
  869. "Releases the connection back to the pool."
  870. # Make sure we haven't changed process.
  871. self._checkpid()
  872. if connection.pid != self.pid:
  873. return
  874. # Put the connection back into the pool.
  875. try:
  876. self.pool.put_nowait(connection)
  877. except Full:
  878. # perhaps the pool has been reset() after a fork? regardless,
  879. # we don't want this connection
  880. pass
  881. def disconnect(self):
  882. "Disconnects all connections in the pool."
  883. for connection in self._connections:
  884. connection.disconnect()