sessions.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from bson import json_util
  2. from django.conf import settings
  3. from django.contrib.sessions.backends.base import SessionBase, CreateError
  4. from django.core.exceptions import SuspiciousOperation
  5. try:
  6. from django.utils.encoding import force_unicode
  7. except ImportError:
  8. from django.utils.encoding import force_text as force_unicode
  9. from mongoengine.document import Document
  10. from mongoengine import fields
  11. from mongoengine.queryset import OperationError
  12. from mongoengine.connection import DEFAULT_CONNECTION_NAME
  13. from .utils import datetime_now
  14. MONGOENGINE_SESSION_DB_ALIAS = getattr(
  15. settings, 'MONGOENGINE_SESSION_DB_ALIAS',
  16. DEFAULT_CONNECTION_NAME)
  17. # a setting for the name of the collection used to store sessions
  18. MONGOENGINE_SESSION_COLLECTION = getattr(
  19. settings, 'MONGOENGINE_SESSION_COLLECTION',
  20. 'django_session')
  21. # a setting for whether session data is stored encoded or not
  22. MONGOENGINE_SESSION_DATA_ENCODE = getattr(
  23. settings, 'MONGOENGINE_SESSION_DATA_ENCODE',
  24. True)
  25. class MongoSession(Document):
  26. session_key = fields.StringField(primary_key=True, max_length=40)
  27. session_data = fields.StringField() if MONGOENGINE_SESSION_DATA_ENCODE \
  28. else fields.DictField()
  29. expire_date = fields.DateTimeField()
  30. meta = {
  31. 'collection': MONGOENGINE_SESSION_COLLECTION,
  32. 'db_alias': MONGOENGINE_SESSION_DB_ALIAS,
  33. 'allow_inheritance': False,
  34. 'indexes': [
  35. {
  36. 'fields': ['expire_date'],
  37. 'expireAfterSeconds': 0
  38. }
  39. ]
  40. }
  41. def get_decoded(self):
  42. return SessionStore().decode(self.session_data)
  43. class SessionStore(SessionBase):
  44. """A MongoEngine-based session store for Django.
  45. """
  46. def _get_session(self, *args, **kwargs):
  47. sess = super(SessionStore, self)._get_session(*args, **kwargs)
  48. if sess.get('_auth_user_id', None):
  49. sess['_auth_user_id'] = str(sess.get('_auth_user_id'))
  50. return sess
  51. def load(self):
  52. try:
  53. s = MongoSession.objects(session_key=self.session_key,
  54. expire_date__gt=datetime_now)[0]
  55. if MONGOENGINE_SESSION_DATA_ENCODE:
  56. return self.decode(force_unicode(s.session_data))
  57. else:
  58. return s.session_data
  59. except (IndexError, SuspiciousOperation):
  60. self.create()
  61. return {}
  62. def exists(self, session_key):
  63. return bool(MongoSession.objects(session_key=session_key).first())
  64. def create(self):
  65. while True:
  66. self._session_key = self._get_new_session_key()
  67. try:
  68. self.save(must_create=True)
  69. except CreateError:
  70. continue
  71. self.modified = True
  72. self._session_cache = {}
  73. return
  74. def save(self, must_create=False):
  75. if self.session_key is None:
  76. self._session_key = self._get_new_session_key()
  77. s = MongoSession(session_key=self.session_key)
  78. if MONGOENGINE_SESSION_DATA_ENCODE:
  79. s.session_data = self.encode(self._get_session(no_load=must_create))
  80. else:
  81. s.session_data = self._get_session(no_load=must_create)
  82. s.expire_date = self.get_expiry_date()
  83. try:
  84. s.save(force_insert=must_create)
  85. except OperationError:
  86. if must_create:
  87. raise CreateError
  88. raise
  89. def delete(self, session_key=None):
  90. if session_key is None:
  91. if self.session_key is None:
  92. return
  93. session_key = self.session_key
  94. MongoSession.objects(session_key=session_key).delete()
  95. class BSONSerializer(object):
  96. """
  97. Serializer that can handle BSON types (eg ObjectId).
  98. """
  99. def dumps(self, obj):
  100. return json_util.dumps(obj, separators=(',', ':')).encode('ascii')
  101. def loads(self, data):
  102. return json_util.loads(data.decode('ascii'))