ares.pyx 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright (c) 2011-2012 Denis Bilenko. See LICENSE for details.
  2. cimport cares
  3. import sys
  4. from python cimport *
  5. from _socket import gaierror
  6. __all__ = ['channel']
  7. cdef object string_types
  8. cdef object text_type
  9. if sys.version_info[0] >= 3:
  10. string_types = str,
  11. text_type = str
  12. else:
  13. string_types = __builtins__.basestring,
  14. text_type = __builtins__.unicode
  15. TIMEOUT = 1
  16. DEF EV_READ = 1
  17. DEF EV_WRITE = 2
  18. cdef extern from "dnshelper.c":
  19. int AF_INET
  20. int AF_INET6
  21. struct hostent:
  22. char* h_name
  23. int h_addrtype
  24. struct sockaddr_t "sockaddr":
  25. pass
  26. struct ares_channeldata:
  27. pass
  28. object parse_h_name(hostent*)
  29. object parse_h_aliases(hostent*)
  30. object parse_h_addr_list(hostent*)
  31. void* create_object_from_hostent(void*)
  32. # this imports _socket lazily
  33. object PyUnicode_FromString(char*)
  34. int PyTuple_Check(object)
  35. int PyArg_ParseTuple(object, char*, ...) except 0
  36. struct sockaddr_in6:
  37. pass
  38. int gevent_make_sockaddr(char* hostp, int port, int flowinfo, int scope_id, sockaddr_in6* sa6)
  39. void* malloc(int)
  40. void free(void*)
  41. void memset(void*, int, int)
  42. ARES_SUCCESS = cares.ARES_SUCCESS
  43. ARES_ENODATA = cares.ARES_ENODATA
  44. ARES_EFORMERR = cares.ARES_EFORMERR
  45. ARES_ESERVFAIL = cares.ARES_ESERVFAIL
  46. ARES_ENOTFOUND = cares.ARES_ENOTFOUND
  47. ARES_ENOTIMP = cares.ARES_ENOTIMP
  48. ARES_EREFUSED = cares.ARES_EREFUSED
  49. ARES_EBADQUERY = cares.ARES_EBADQUERY
  50. ARES_EBADNAME = cares.ARES_EBADNAME
  51. ARES_EBADFAMILY = cares.ARES_EBADFAMILY
  52. ARES_EBADRESP = cares.ARES_EBADRESP
  53. ARES_ECONNREFUSED = cares.ARES_ECONNREFUSED
  54. ARES_ETIMEOUT = cares.ARES_ETIMEOUT
  55. ARES_EOF = cares.ARES_EOF
  56. ARES_EFILE = cares.ARES_EFILE
  57. ARES_ENOMEM = cares.ARES_ENOMEM
  58. ARES_EDESTRUCTION = cares.ARES_EDESTRUCTION
  59. ARES_EBADSTR = cares.ARES_EBADSTR
  60. ARES_EBADFLAGS = cares.ARES_EBADFLAGS
  61. ARES_ENONAME = cares.ARES_ENONAME
  62. ARES_EBADHINTS = cares.ARES_EBADHINTS
  63. ARES_ENOTINITIALIZED = cares.ARES_ENOTINITIALIZED
  64. ARES_ELOADIPHLPAPI = cares.ARES_ELOADIPHLPAPI
  65. ARES_EADDRGETNETWORKPARAMS = cares.ARES_EADDRGETNETWORKPARAMS
  66. ARES_ECANCELLED = cares.ARES_ECANCELLED
  67. ARES_FLAG_USEVC = cares.ARES_FLAG_USEVC
  68. ARES_FLAG_PRIMARY = cares.ARES_FLAG_PRIMARY
  69. ARES_FLAG_IGNTC = cares.ARES_FLAG_IGNTC
  70. ARES_FLAG_NORECURSE = cares.ARES_FLAG_NORECURSE
  71. ARES_FLAG_STAYOPEN = cares.ARES_FLAG_STAYOPEN
  72. ARES_FLAG_NOSEARCH = cares.ARES_FLAG_NOSEARCH
  73. ARES_FLAG_NOALIASES = cares.ARES_FLAG_NOALIASES
  74. ARES_FLAG_NOCHECKRESP = cares.ARES_FLAG_NOCHECKRESP
  75. _ares_errors = dict([
  76. (cares.ARES_SUCCESS, 'ARES_SUCCESS'),
  77. (cares.ARES_ENODATA, 'ARES_ENODATA'),
  78. (cares.ARES_EFORMERR, 'ARES_EFORMERR'),
  79. (cares.ARES_ESERVFAIL, 'ARES_ESERVFAIL'),
  80. (cares.ARES_ENOTFOUND, 'ARES_ENOTFOUND'),
  81. (cares.ARES_ENOTIMP, 'ARES_ENOTIMP'),
  82. (cares.ARES_EREFUSED, 'ARES_EREFUSED'),
  83. (cares.ARES_EBADQUERY, 'ARES_EBADQUERY'),
  84. (cares.ARES_EBADNAME, 'ARES_EBADNAME'),
  85. (cares.ARES_EBADFAMILY, 'ARES_EBADFAMILY'),
  86. (cares.ARES_EBADRESP, 'ARES_EBADRESP'),
  87. (cares.ARES_ECONNREFUSED, 'ARES_ECONNREFUSED'),
  88. (cares.ARES_ETIMEOUT, 'ARES_ETIMEOUT'),
  89. (cares.ARES_EOF, 'ARES_EOF'),
  90. (cares.ARES_EFILE, 'ARES_EFILE'),
  91. (cares.ARES_ENOMEM, 'ARES_ENOMEM'),
  92. (cares.ARES_EDESTRUCTION, 'ARES_EDESTRUCTION'),
  93. (cares.ARES_EBADSTR, 'ARES_EBADSTR'),
  94. (cares.ARES_EBADFLAGS, 'ARES_EBADFLAGS'),
  95. (cares.ARES_ENONAME, 'ARES_ENONAME'),
  96. (cares.ARES_EBADHINTS, 'ARES_EBADHINTS'),
  97. (cares.ARES_ENOTINITIALIZED, 'ARES_ENOTINITIALIZED'),
  98. (cares.ARES_ELOADIPHLPAPI, 'ARES_ELOADIPHLPAPI'),
  99. (cares.ARES_EADDRGETNETWORKPARAMS, 'ARES_EADDRGETNETWORKPARAMS'),
  100. (cares.ARES_ECANCELLED, 'ARES_ECANCELLED')])
  101. # maps c-ares flag to _socket module flag
  102. _cares_flag_map = None
  103. cdef _prepare_cares_flag_map():
  104. global _cares_flag_map
  105. import _socket
  106. _cares_flag_map = [
  107. (getattr(_socket, 'NI_NUMERICHOST', 1), cares.ARES_NI_NUMERICHOST),
  108. (getattr(_socket, 'NI_NUMERICSERV', 2), cares.ARES_NI_NUMERICSERV),
  109. (getattr(_socket, 'NI_NOFQDN', 4), cares.ARES_NI_NOFQDN),
  110. (getattr(_socket, 'NI_NAMEREQD', 8), cares.ARES_NI_NAMEREQD),
  111. (getattr(_socket, 'NI_DGRAM', 16), cares.ARES_NI_DGRAM)]
  112. cpdef _convert_cares_flags(int flags, int default=cares.ARES_NI_LOOKUPHOST|cares.ARES_NI_LOOKUPSERVICE):
  113. if _cares_flag_map is None:
  114. _prepare_cares_flag_map()
  115. for socket_flag, cares_flag in _cares_flag_map:
  116. if socket_flag & flags:
  117. default |= cares_flag
  118. flags &= ~socket_flag
  119. if not flags:
  120. return default
  121. raise gaierror(-1, "Bad value for ai_flags: 0x%x" % flags)
  122. cpdef strerror(code):
  123. return '%s: %s' % (_ares_errors.get(code) or code, cares.ares_strerror(code))
  124. class InvalidIP(ValueError):
  125. pass
  126. cdef void gevent_sock_state_callback(void *data, int s, int read, int write):
  127. if not data:
  128. return
  129. cdef channel ch = <channel>data
  130. ch._sock_state_callback(s, read, write)
  131. cdef class result:
  132. cdef public object value
  133. cdef public object exception
  134. def __init__(self, object value=None, object exception=None):
  135. self.value = value
  136. self.exception = exception
  137. def __repr__(self):
  138. if self.exception is None:
  139. return '%s(%r)' % (self.__class__.__name__, self.value)
  140. elif self.value is None:
  141. return '%s(exception=%r)' % (self.__class__.__name__, self.exception)
  142. else:
  143. return '%s(value=%r, exception=%r)' % (self.__class__.__name__, self.value, self.exception)
  144. # add repr_recursive precaution
  145. def successful(self):
  146. return self.exception is None
  147. def get(self):
  148. if self.exception is not None:
  149. raise self.exception
  150. return self.value
  151. class ares_host_result(tuple):
  152. def __new__(cls, family, iterable):
  153. cdef object self = tuple.__new__(cls, iterable)
  154. self.family = family
  155. return self
  156. def __getnewargs__(self):
  157. return (self.family, tuple(self))
  158. cdef void gevent_ares_host_callback(void *arg, int status, int timeouts, hostent* host):
  159. cdef channel channel
  160. cdef object callback
  161. channel, callback = <tuple>arg
  162. Py_DECREF(<PyObjectPtr>arg)
  163. cdef object host_result
  164. try:
  165. if status or not host:
  166. callback(result(None, gaierror(status, strerror(status))))
  167. else:
  168. try:
  169. host_result = ares_host_result(host.h_addrtype, (parse_h_name(host), parse_h_aliases(host), parse_h_addr_list(host)))
  170. except:
  171. callback(result(None, sys.exc_info()[1]))
  172. else:
  173. callback(result(host_result))
  174. except:
  175. channel.loop.handle_error(callback, *sys.exc_info())
  176. cdef void gevent_ares_nameinfo_callback(void *arg, int status, int timeouts, char *c_node, char *c_service):
  177. cdef channel channel
  178. cdef object callback
  179. channel, callback = <tuple>arg
  180. Py_DECREF(<PyObjectPtr>arg)
  181. cdef object node
  182. cdef object service
  183. try:
  184. if status:
  185. callback(result(None, gaierror(status, strerror(status))))
  186. else:
  187. if c_node:
  188. node = PyUnicode_FromString(c_node)
  189. else:
  190. node = None
  191. if c_service:
  192. service = PyUnicode_FromString(c_service)
  193. else:
  194. service = None
  195. callback(result((node, service)))
  196. except:
  197. channel.loop.handle_error(callback, *sys.exc_info())
  198. cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresChannel_Type]:
  199. cdef public object loop
  200. cdef ares_channeldata* channel
  201. cdef public dict _watchers
  202. cdef public object _timer
  203. def __init__(self, object loop, flags=None, timeout=None, tries=None, ndots=None,
  204. udp_port=None, tcp_port=None, servers=None):
  205. cdef ares_channeldata* channel = NULL
  206. cdef cares.ares_options options
  207. memset(&options, 0, sizeof(cares.ares_options))
  208. cdef int optmask = cares.ARES_OPT_SOCK_STATE_CB
  209. options.sock_state_cb = <void*>gevent_sock_state_callback
  210. options.sock_state_cb_data = <void*>self
  211. if flags is not None:
  212. options.flags = int(flags)
  213. optmask |= cares.ARES_OPT_FLAGS
  214. if timeout is not None:
  215. options.timeout = int(float(timeout) * 1000)
  216. optmask |= cares.ARES_OPT_TIMEOUTMS
  217. if tries is not None:
  218. options.tries = int(tries)
  219. optmask |= cares.ARES_OPT_TRIES
  220. if ndots is not None:
  221. options.ndots = int(ndots)
  222. optmask |= cares.ARES_OPT_NDOTS
  223. if udp_port is not None:
  224. options.udp_port = int(udp_port)
  225. optmask |= cares.ARES_OPT_UDP_PORT
  226. if tcp_port is not None:
  227. options.tcp_port = int(tcp_port)
  228. optmask |= cares.ARES_OPT_TCP_PORT
  229. cdef int result = cares.ares_library_init(cares.ARES_LIB_INIT_ALL) # ARES_LIB_INIT_WIN32 -DUSE_WINSOCK?
  230. if result:
  231. raise gaierror(result, strerror(result))
  232. result = cares.ares_init_options(&channel, &options, optmask)
  233. if result:
  234. raise gaierror(result, strerror(result))
  235. self._timer = loop.timer(TIMEOUT, TIMEOUT)
  236. self._watchers = {}
  237. self.channel = channel
  238. try:
  239. if servers is not None:
  240. self.set_servers(servers)
  241. self.loop = loop
  242. except:
  243. self.destroy()
  244. raise
  245. def __repr__(self):
  246. args = (self.__class__.__name__, id(self), self._timer, len(self._watchers))
  247. return '<%s at 0x%x _timer=%r _watchers[%s]>' % args
  248. def destroy(self):
  249. if self.channel:
  250. # XXX ares_library_cleanup?
  251. cares.ares_destroy(self.channel)
  252. self.channel = NULL
  253. self._watchers.clear()
  254. self._timer.stop()
  255. self.loop = None
  256. def __dealloc__(self):
  257. if self.channel:
  258. # XXX ares_library_cleanup?
  259. cares.ares_destroy(self.channel)
  260. self.channel = NULL
  261. def set_servers(self, servers=None):
  262. if not self.channel:
  263. raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed')
  264. if not servers:
  265. servers = []
  266. if isinstance(servers, string_types):
  267. servers = servers.split(',')
  268. cdef int length = len(servers)
  269. cdef int result, index
  270. cdef char* string
  271. cdef cares.ares_addr_node* c_servers
  272. if length <= 0:
  273. result = cares.ares_set_servers(self.channel, NULL)
  274. else:
  275. c_servers = <cares.ares_addr_node*>malloc(sizeof(cares.ares_addr_node) * length)
  276. if not c_servers:
  277. raise MemoryError
  278. try:
  279. index = 0
  280. for server in servers:
  281. if isinstance(server, unicode):
  282. server = server.encode('ascii')
  283. string = <char*?>server
  284. if cares.ares_inet_pton(AF_INET, string, &c_servers[index].addr) > 0:
  285. c_servers[index].family = AF_INET
  286. elif cares.ares_inet_pton(AF_INET6, string, &c_servers[index].addr) > 0:
  287. c_servers[index].family = AF_INET6
  288. else:
  289. raise InvalidIP(repr(string))
  290. c_servers[index].next = &c_servers[index] + 1
  291. index += 1
  292. if index >= length:
  293. break
  294. c_servers[length - 1].next = NULL
  295. index = cares.ares_set_servers(self.channel, c_servers)
  296. if index:
  297. raise ValueError(strerror(index))
  298. finally:
  299. free(c_servers)
  300. # this crashes c-ares
  301. #def cancel(self):
  302. # cares.ares_cancel(self.channel)
  303. cdef _sock_state_callback(self, int socket, int read, int write):
  304. if not self.channel:
  305. return
  306. cdef object watcher = self._watchers.get(socket)
  307. cdef int events = 0
  308. if read:
  309. events |= EV_READ
  310. if write:
  311. events |= EV_WRITE
  312. if watcher is None:
  313. if not events:
  314. return
  315. watcher = self.loop.io(socket, events)
  316. self._watchers[socket] = watcher
  317. elif events:
  318. if watcher.events == events:
  319. return
  320. watcher.stop()
  321. watcher.events = events
  322. else:
  323. watcher.stop()
  324. self._watchers.pop(socket, None)
  325. if not self._watchers:
  326. self._timer.stop()
  327. return
  328. watcher.start(self._process_fd, watcher, pass_events=True)
  329. self._timer.again(self._on_timer)
  330. def _on_timer(self):
  331. cares.ares_process_fd(self.channel, cares.ARES_SOCKET_BAD, cares.ARES_SOCKET_BAD)
  332. def _process_fd(self, int events, object watcher):
  333. if not self.channel:
  334. return
  335. cdef int read_fd = watcher.fd
  336. cdef int write_fd = read_fd
  337. if not (events & EV_READ):
  338. read_fd = cares.ARES_SOCKET_BAD
  339. if not (events & EV_WRITE):
  340. write_fd = cares.ARES_SOCKET_BAD
  341. cares.ares_process_fd(self.channel, read_fd, write_fd)
  342. def gethostbyname(self, object callback, char* name, int family=AF_INET):
  343. if not self.channel:
  344. raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed')
  345. # note that for file lookups still AF_INET can be returned for AF_INET6 request
  346. cdef object arg = (self, callback)
  347. Py_INCREF(<PyObjectPtr>arg)
  348. cares.ares_gethostbyname(self.channel, name, family, <void*>gevent_ares_host_callback, <void*>arg)
  349. def gethostbyaddr(self, object callback, char* addr):
  350. if not self.channel:
  351. raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed')
  352. # will guess the family
  353. cdef char addr_packed[16]
  354. cdef int family
  355. cdef int length
  356. if cares.ares_inet_pton(AF_INET, addr, addr_packed) > 0:
  357. family = AF_INET
  358. length = 4
  359. elif cares.ares_inet_pton(AF_INET6, addr, addr_packed) > 0:
  360. family = AF_INET6
  361. length = 16
  362. else:
  363. raise InvalidIP(repr(addr))
  364. cdef object arg = (self, callback)
  365. Py_INCREF(<PyObjectPtr>arg)
  366. cares.ares_gethostbyaddr(self.channel, addr_packed, length, family, <void*>gevent_ares_host_callback, <void*>arg)
  367. cpdef _getnameinfo(self, object callback, tuple sockaddr, int flags):
  368. if not self.channel:
  369. raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed')
  370. cdef char* hostp = NULL
  371. cdef int port = 0
  372. cdef int flowinfo = 0
  373. cdef int scope_id = 0
  374. cdef sockaddr_in6 sa6
  375. if not PyTuple_Check(sockaddr):
  376. raise TypeError('expected a tuple, got %r' % (sockaddr, ))
  377. PyArg_ParseTuple(sockaddr, "si|ii", &hostp, &port, &flowinfo, &scope_id)
  378. if port < 0 or port > 65535:
  379. raise gaierror(-8, 'Invalid value for port: %r' % port)
  380. cdef int length = gevent_make_sockaddr(hostp, port, flowinfo, scope_id, &sa6)
  381. if length <= 0:
  382. raise InvalidIP(repr(hostp))
  383. cdef object arg = (self, callback)
  384. Py_INCREF(<PyObjectPtr>arg)
  385. cdef sockaddr_t* x = <sockaddr_t*>&sa6
  386. cares.ares_getnameinfo(self.channel, x, length, flags, <void*>gevent_ares_nameinfo_callback, <void*>arg)
  387. def getnameinfo(self, object callback, tuple sockaddr, int flags):
  388. try:
  389. flags = _convert_cares_flags(flags)
  390. except gaierror:
  391. # The stdlib just ignores bad flags
  392. flags = 0
  393. return self._getnameinfo(callback, sockaddr, flags)