sreg.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. """Simple registration request and response parsing and object representation
  2. This module contains objects representing simple registration requests
  3. and responses that can be used with both OpenID relying parties and
  4. OpenID providers.
  5. 1. The relying party creates a request object and adds it to the
  6. C{L{AuthRequest<openid.consumer.consumer.AuthRequest>}} object
  7. before making the C{checkid_} request to the OpenID provider::
  8. auth_request.addExtension(SRegRequest(required=['email']))
  9. 2. The OpenID provider extracts the simple registration request from
  10. the OpenID request using C{L{SRegRequest.fromOpenIDRequest}},
  11. gets the user's approval and data, creates a C{L{SRegResponse}}
  12. object and adds it to the C{id_res} response::
  13. sreg_req = SRegRequest.fromOpenIDRequest(checkid_request)
  14. # [ get the user's approval and data, informing the user that
  15. # the fields in sreg_response were requested ]
  16. sreg_resp = SRegResponse.extractResponse(sreg_req, user_data)
  17. sreg_resp.toMessage(openid_response.fields)
  18. 3. The relying party uses C{L{SRegResponse.fromSuccessResponse}} to
  19. extract the data from the OpenID response::
  20. sreg_resp = SRegResponse.fromSuccessResponse(success_response)
  21. @since: 2.0
  22. @var sreg_data_fields: The names of the data fields that are listed in
  23. the sreg spec, and a description of them in English
  24. @var sreg_uri: The preferred URI to use for the simple registration
  25. namespace and XRD Type value
  26. """
  27. from openid.message import registerNamespaceAlias, \
  28. NamespaceAliasRegistrationError
  29. from openid.extension import Extension
  30. from openid import oidutil
  31. try:
  32. basestring #pylint:disable-msg=W0104
  33. except NameError:
  34. # For Python 2.2
  35. basestring = (str, unicode) #pylint:disable-msg=W0622
  36. __all__ = [
  37. 'SRegRequest',
  38. 'SRegResponse',
  39. 'data_fields',
  40. 'ns_uri',
  41. 'ns_uri_1_0',
  42. 'ns_uri_1_1',
  43. 'supportsSReg',
  44. ]
  45. # The data fields that are listed in the sreg spec
  46. data_fields = {
  47. 'fullname':'Full Name',
  48. 'nickname':'Nickname',
  49. 'dob':'Date of Birth',
  50. 'email':'E-mail Address',
  51. 'gender':'Gender',
  52. 'postcode':'Postal Code',
  53. 'country':'Country',
  54. 'language':'Language',
  55. 'timezone':'Time Zone',
  56. }
  57. def checkFieldName(field_name):
  58. """Check to see that the given value is a valid simple
  59. registration data field name.
  60. @raise ValueError: if the field name is not a valid simple
  61. registration data field name
  62. """
  63. if field_name not in data_fields:
  64. raise ValueError('%r is not a defined simple registration field' %
  65. (field_name,))
  66. # URI used in the wild for Yadis documents advertising simple
  67. # registration support
  68. ns_uri_1_0 = 'http://openid.net/sreg/1.0'
  69. # URI in the draft specification for simple registration 1.1
  70. # <http://openid.net/specs/openid-simple-registration-extension-1_1-01.html>
  71. ns_uri_1_1 = 'http://openid.net/extensions/sreg/1.1'
  72. # This attribute will always hold the preferred URI to use when adding
  73. # sreg support to an XRDS file or in an OpenID namespace declaration.
  74. ns_uri = ns_uri_1_1
  75. try:
  76. registerNamespaceAlias(ns_uri_1_1, 'sreg')
  77. except NamespaceAliasRegistrationError, e:
  78. oidutil.log('registerNamespaceAlias(%r, %r) failed: %s' % (ns_uri_1_1,
  79. 'sreg', str(e),))
  80. def supportsSReg(endpoint):
  81. """Does the given endpoint advertise support for simple
  82. registration?
  83. @param endpoint: The endpoint object as returned by OpenID discovery
  84. @type endpoint: openid.consumer.discover.OpenIDEndpoint
  85. @returns: Whether an sreg type was advertised by the endpoint
  86. @rtype: bool
  87. """
  88. return (endpoint.usesExtension(ns_uri_1_1) or
  89. endpoint.usesExtension(ns_uri_1_0))
  90. class SRegNamespaceError(ValueError):
  91. """The simple registration namespace was not found and could not
  92. be created using the expected name (there's another extension
  93. using the name 'sreg')
  94. This is not I{illegal}, for OpenID 2, although it probably
  95. indicates a problem, since it's not expected that other extensions
  96. will re-use the alias that is in use for OpenID 1.
  97. If this is an OpenID 1 request, then there is no recourse. This
  98. should not happen unless some code has modified the namespaces for
  99. the message that is being processed.
  100. """
  101. def getSRegNS(message):
  102. """Extract the simple registration namespace URI from the given
  103. OpenID message. Handles OpenID 1 and 2, as well as both sreg
  104. namespace URIs found in the wild, as well as missing namespace
  105. definitions (for OpenID 1)
  106. @param message: The OpenID message from which to parse simple
  107. registration fields. This may be a request or response message.
  108. @type message: C{L{openid.message.Message}}
  109. @returns: the sreg namespace URI for the supplied message. The
  110. message may be modified to define a simple registration
  111. namespace.
  112. @rtype: C{str}
  113. @raise ValueError: when using OpenID 1 if the message defines
  114. the 'sreg' alias to be something other than a simple
  115. registration type.
  116. """
  117. # See if there exists an alias for one of the two defined simple
  118. # registration types.
  119. for sreg_ns_uri in [ns_uri_1_1, ns_uri_1_0]:
  120. alias = message.namespaces.getAlias(sreg_ns_uri)
  121. if alias is not None:
  122. break
  123. else:
  124. # There is no alias for either of the types, so try to add
  125. # one. We default to using the modern value (1.1)
  126. sreg_ns_uri = ns_uri_1_1
  127. try:
  128. message.namespaces.addAlias(ns_uri_1_1, 'sreg')
  129. except KeyError, why:
  130. # An alias for the string 'sreg' already exists, but it's
  131. # defined for something other than simple registration
  132. raise SRegNamespaceError(why[0])
  133. # we know that sreg_ns_uri defined, because it's defined in the
  134. # else clause of the loop as well, so disable the warning
  135. return sreg_ns_uri #pylint:disable-msg=W0631
  136. class SRegRequest(Extension):
  137. """An object to hold the state of a simple registration request.
  138. @ivar required: A list of the required fields in this simple
  139. registration request
  140. @type required: [str]
  141. @ivar optional: A list of the optional fields in this simple
  142. registration request
  143. @type optional: [str]
  144. @ivar policy_url: The policy URL that was provided with the request
  145. @type policy_url: str or NoneType
  146. @group Consumer: requestField, requestFields, getExtensionArgs, addToOpenIDRequest
  147. @group Server: fromOpenIDRequest, parseExtensionArgs
  148. """
  149. ns_alias = 'sreg'
  150. def __init__(self, required=None, optional=None, policy_url=None,
  151. sreg_ns_uri=ns_uri):
  152. """Initialize an empty simple registration request"""
  153. Extension.__init__(self)
  154. self.required = []
  155. self.optional = []
  156. self.policy_url = policy_url
  157. self.ns_uri = sreg_ns_uri
  158. if required:
  159. self.requestFields(required, required=True, strict=True)
  160. if optional:
  161. self.requestFields(optional, required=False, strict=True)
  162. # Assign getSRegNS to a static method so that it can be
  163. # overridden for testing.
  164. _getSRegNS = staticmethod(getSRegNS)
  165. def fromOpenIDRequest(cls, request):
  166. """Create a simple registration request that contains the
  167. fields that were requested in the OpenID request with the
  168. given arguments
  169. @param request: The OpenID request
  170. @type request: openid.server.CheckIDRequest
  171. @returns: The newly created simple registration request
  172. @rtype: C{L{SRegRequest}}
  173. """
  174. self = cls()
  175. # Since we're going to mess with namespace URI mapping, don't
  176. # mutate the object that was passed in.
  177. message = request.message.copy()
  178. self.ns_uri = self._getSRegNS(message)
  179. args = message.getArgs(self.ns_uri)
  180. self.parseExtensionArgs(args)
  181. return self
  182. fromOpenIDRequest = classmethod(fromOpenIDRequest)
  183. def parseExtensionArgs(self, args, strict=False):
  184. """Parse the unqualified simple registration request
  185. parameters and add them to this object.
  186. This method is essentially the inverse of
  187. C{L{getExtensionArgs}}. This method restores the serialized simple
  188. registration request fields.
  189. If you are extracting arguments from a standard OpenID
  190. checkid_* request, you probably want to use C{L{fromOpenIDRequest}},
  191. which will extract the sreg namespace and arguments from the
  192. OpenID request. This method is intended for cases where the
  193. OpenID server needs more control over how the arguments are
  194. parsed than that method provides.
  195. >>> args = message.getArgs(ns_uri)
  196. >>> request.parseExtensionArgs(args)
  197. @param args: The unqualified simple registration arguments
  198. @type args: {str:str}
  199. @param strict: Whether requests with fields that are not
  200. defined in the simple registration specification should be
  201. tolerated (and ignored)
  202. @type strict: bool
  203. @returns: None; updates this object
  204. """
  205. for list_name in ['required', 'optional']:
  206. required = (list_name == 'required')
  207. items = args.get(list_name)
  208. if items:
  209. for field_name in items.split(','):
  210. try:
  211. self.requestField(field_name, required, strict)
  212. except ValueError:
  213. if strict:
  214. raise
  215. self.policy_url = args.get('policy_url')
  216. def allRequestedFields(self):
  217. """A list of all of the simple registration fields that were
  218. requested, whether they were required or optional.
  219. @rtype: [str]
  220. """
  221. return self.required + self.optional
  222. def wereFieldsRequested(self):
  223. """Have any simple registration fields been requested?
  224. @rtype: bool
  225. """
  226. return bool(self.allRequestedFields())
  227. def __contains__(self, field_name):
  228. """Was this field in the request?"""
  229. return (field_name in self.required or
  230. field_name in self.optional)
  231. def requestField(self, field_name, required=False, strict=False):
  232. """Request the specified field from the OpenID user
  233. @param field_name: the unqualified simple registration field name
  234. @type field_name: str
  235. @param required: whether the given field should be presented
  236. to the user as being a required to successfully complete
  237. the request
  238. @param strict: whether to raise an exception when a field is
  239. added to a request more than once
  240. @raise ValueError: when the field requested is not a simple
  241. registration field or strict is set and the field was
  242. requested more than once
  243. """
  244. checkFieldName(field_name)
  245. if strict:
  246. if field_name in self.required or field_name in self.optional:
  247. raise ValueError('That field has already been requested')
  248. else:
  249. if field_name in self.required:
  250. return
  251. if field_name in self.optional:
  252. if required:
  253. self.optional.remove(field_name)
  254. else:
  255. return
  256. if required:
  257. self.required.append(field_name)
  258. else:
  259. self.optional.append(field_name)
  260. def requestFields(self, field_names, required=False, strict=False):
  261. """Add the given list of fields to the request
  262. @param field_names: The simple registration data fields to request
  263. @type field_names: [str]
  264. @param required: Whether these values should be presented to
  265. the user as required
  266. @param strict: whether to raise an exception when a field is
  267. added to a request more than once
  268. @raise ValueError: when a field requested is not a simple
  269. registration field or strict is set and a field was
  270. requested more than once
  271. """
  272. if isinstance(field_names, basestring):
  273. raise TypeError('Fields should be passed as a list of '
  274. 'strings (not %r)' % (type(field_names),))
  275. for field_name in field_names:
  276. self.requestField(field_name, required, strict=strict)
  277. def getExtensionArgs(self):
  278. """Get a dictionary of unqualified simple registration
  279. arguments representing this request.
  280. This method is essentially the inverse of
  281. C{L{parseExtensionArgs}}. This method serializes the simple
  282. registration request fields.
  283. @rtype: {str:str}
  284. """
  285. args = {}
  286. if self.required:
  287. args['required'] = ','.join(self.required)
  288. if self.optional:
  289. args['optional'] = ','.join(self.optional)
  290. if self.policy_url:
  291. args['policy_url'] = self.policy_url
  292. return args
  293. class SRegResponse(Extension):
  294. """Represents the data returned in a simple registration response
  295. inside of an OpenID C{id_res} response. This object will be
  296. created by the OpenID server, added to the C{id_res} response
  297. object, and then extracted from the C{id_res} message by the
  298. Consumer.
  299. @ivar data: The simple registration data, keyed by the unqualified
  300. simple registration name of the field (i.e. nickname is keyed
  301. by C{'nickname'})
  302. @ivar ns_uri: The URI under which the simple registration data was
  303. stored in the response message.
  304. @group Server: extractResponse
  305. @group Consumer: fromSuccessResponse
  306. @group Read-only dictionary interface: keys, iterkeys, items, iteritems,
  307. __iter__, get, __getitem__, keys, has_key
  308. """
  309. ns_alias = 'sreg'
  310. def __init__(self, data=None, sreg_ns_uri=ns_uri):
  311. Extension.__init__(self)
  312. if data is None:
  313. self.data = {}
  314. else:
  315. self.data = data
  316. self.ns_uri = sreg_ns_uri
  317. def extractResponse(cls, request, data):
  318. """Take a C{L{SRegRequest}} and a dictionary of simple
  319. registration values and create a C{L{SRegResponse}}
  320. object containing that data.
  321. @param request: The simple registration request object
  322. @type request: SRegRequest
  323. @param data: The simple registration data for this
  324. response, as a dictionary from unqualified simple
  325. registration field name to string (unicode) value. For
  326. instance, the nickname should be stored under the key
  327. 'nickname'.
  328. @type data: {str:str}
  329. @returns: a simple registration response object
  330. @rtype: SRegResponse
  331. """
  332. self = cls()
  333. self.ns_uri = request.ns_uri
  334. for field in request.allRequestedFields():
  335. value = data.get(field)
  336. if value is not None:
  337. self.data[field] = value
  338. return self
  339. extractResponse = classmethod(extractResponse)
  340. # Assign getSRegArgs to a static method so that it can be
  341. # overridden for testing
  342. _getSRegNS = staticmethod(getSRegNS)
  343. def fromSuccessResponse(cls, success_response, signed_only=True):
  344. """Create a C{L{SRegResponse}} object from a successful OpenID
  345. library response
  346. (C{L{openid.consumer.consumer.SuccessResponse}}) response
  347. message
  348. @param success_response: A SuccessResponse from consumer.complete()
  349. @type success_response: C{L{openid.consumer.consumer.SuccessResponse}}
  350. @param signed_only: Whether to process only data that was
  351. signed in the id_res message from the server.
  352. @type signed_only: bool
  353. @rtype: SRegResponse
  354. @returns: A simple registration response containing the data
  355. that was supplied with the C{id_res} response.
  356. """
  357. self = cls()
  358. self.ns_uri = self._getSRegNS(success_response.message)
  359. if signed_only:
  360. args = success_response.getSignedNS(self.ns_uri)
  361. else:
  362. args = success_response.message.getArgs(self.ns_uri)
  363. if not args:
  364. return None
  365. for field_name in data_fields:
  366. if field_name in args:
  367. self.data[field_name] = args[field_name]
  368. return self
  369. fromSuccessResponse = classmethod(fromSuccessResponse)
  370. def getExtensionArgs(self):
  371. """Get the fields to put in the simple registration namespace
  372. when adding them to an id_res message.
  373. @see: openid.extension
  374. """
  375. return self.data
  376. # Read-only dictionary interface
  377. def get(self, field_name, default=None):
  378. """Like dict.get, except that it checks that the field name is
  379. defined by the simple registration specification"""
  380. checkFieldName(field_name)
  381. return self.data.get(field_name, default)
  382. def items(self):
  383. """All of the data values in this simple registration response
  384. """
  385. return self.data.items()
  386. def iteritems(self):
  387. return self.data.iteritems()
  388. def keys(self):
  389. return self.data.keys()
  390. def iterkeys(self):
  391. return self.data.iterkeys()
  392. def has_key(self, key):
  393. return key in self
  394. def __contains__(self, field_name):
  395. checkFieldName(field_name)
  396. return field_name in self.data
  397. def __iter__(self):
  398. return iter(self.data)
  399. def __getitem__(self, field_name):
  400. checkFieldName(field_name)
  401. return self.data[field_name]
  402. def __nonzero__(self):
  403. return bool(self.data)