factory.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #-*- coding: ISO-8859-1 -*-
  2. # pysqlite2/test/factory.py: tests for the various factories in pysqlite
  3. #
  4. # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
  5. #
  6. # This file is part of pysqlite.
  7. #
  8. # This software is provided 'as-is', without any express or implied
  9. # warranty. In no event will the authors be held liable for any damages
  10. # arising from the use of this software.
  11. #
  12. # Permission is granted to anyone to use this software for any purpose,
  13. # including commercial applications, and to alter it and redistribute it
  14. # freely, subject to the following restrictions:
  15. #
  16. # 1. The origin of this software must not be misrepresented; you must not
  17. # claim that you wrote the original software. If you use this software
  18. # in a product, an acknowledgment in the product documentation would be
  19. # appreciated but is not required.
  20. # 2. Altered source versions must be plainly marked as such, and must not be
  21. # misrepresented as being the original software.
  22. # 3. This notice may not be removed or altered from any source distribution.
  23. import unittest
  24. import pysqlcipher.dbapi2 as sqlite
  25. class MyConnection(sqlite.Connection):
  26. def __init__(self, *args, **kwargs):
  27. sqlite.Connection.__init__(self, *args, **kwargs)
  28. def dict_factory(cursor, row):
  29. d = {}
  30. for idx, col in enumerate(cursor.description):
  31. d[col[0]] = row[idx]
  32. return d
  33. class MyCursor(sqlite.Cursor):
  34. def __init__(self, *args, **kwargs):
  35. sqlite.Cursor.__init__(self, *args, **kwargs)
  36. self.row_factory = dict_factory
  37. class ConnectionFactoryTests(unittest.TestCase):
  38. def setUp(self):
  39. self.con = sqlite.connect(":memory:", factory=MyConnection)
  40. def tearDown(self):
  41. self.con.close()
  42. def CheckIsInstance(self):
  43. self.assertTrue(isinstance(self.con,
  44. MyConnection),
  45. "connection is not instance of MyConnection")
  46. class CursorFactoryTests(unittest.TestCase):
  47. def setUp(self):
  48. self.con = sqlite.connect(":memory:")
  49. def tearDown(self):
  50. self.con.close()
  51. def CheckIsInstance(self):
  52. cur = self.con.cursor(factory=MyCursor)
  53. self.assertTrue(isinstance(cur,
  54. MyCursor),
  55. "cursor is not instance of MyCursor")
  56. class RowFactoryTestsBackwardsCompat(unittest.TestCase):
  57. def setUp(self):
  58. self.con = sqlite.connect(":memory:")
  59. def CheckIsProducedByFactory(self):
  60. cur = self.con.cursor(factory=MyCursor)
  61. cur.execute("select 4+5 as foo")
  62. row = cur.fetchone()
  63. self.assertTrue(isinstance(row,
  64. dict),
  65. "row is not instance of dict")
  66. cur.close()
  67. def tearDown(self):
  68. self.con.close()
  69. class RowFactoryTests(unittest.TestCase):
  70. def setUp(self):
  71. self.con = sqlite.connect(":memory:")
  72. def CheckCustomFactory(self):
  73. self.con.row_factory = lambda cur, row: list(row)
  74. row = self.con.execute("select 1, 2").fetchone()
  75. self.assertTrue(isinstance(row,
  76. list),
  77. "row is not instance of list")
  78. def CheckSqliteRowIndex(self):
  79. self.con.row_factory = sqlite.Row
  80. row = self.con.execute("select 1 as a, 2 as b").fetchone()
  81. self.assertTrue(isinstance(row,
  82. sqlite.Row),
  83. "row is not instance of sqlite.Row")
  84. col1, col2 = row["a"], row["b"]
  85. self.assertTrue(col1 == 1, "by name: wrong result for column 'a'")
  86. self.assertTrue(col2 == 2, "by name: wrong result for column 'a'")
  87. col1, col2 = row["A"], row["B"]
  88. self.assertTrue(col1 == 1, "by name: wrong result for column 'A'")
  89. self.assertTrue(col2 == 2, "by name: wrong result for column 'B'")
  90. col1, col2 = row[0], row[1]
  91. self.assertTrue(col1 == 1, "by index: wrong result for column 0")
  92. self.assertTrue(col2 == 2, "by index: wrong result for column 1")
  93. def CheckSqliteRowIter(self):
  94. """Checks if the row object is iterable"""
  95. self.con.row_factory = sqlite.Row
  96. row = self.con.execute("select 1 as a, 2 as b").fetchone()
  97. for col in row:
  98. pass
  99. def CheckSqliteRowAsTuple(self):
  100. """Checks if the row object can be converted to a tuple"""
  101. self.con.row_factory = sqlite.Row
  102. row = self.con.execute("select 1 as a, 2 as b").fetchone()
  103. t = tuple(row)
  104. def CheckSqliteRowAsDict(self):
  105. """Checks if the row object can be correctly converted to a dictionary"""
  106. self.con.row_factory = sqlite.Row
  107. row = self.con.execute("select 1 as a, 2 as b").fetchone()
  108. d = dict(row)
  109. self.assertEqual(d["a"], row["a"])
  110. self.assertEqual(d["b"], row["b"])
  111. def CheckSqliteRowHashCmp(self):
  112. """Checks if the row object compares and hashes correctly"""
  113. self.con.row_factory = sqlite.Row
  114. row_1 = self.con.execute("select 1 as a, 2 as b").fetchone()
  115. row_2 = self.con.execute("select 1 as a, 2 as b").fetchone()
  116. row_3 = self.con.execute("select 1 as a, 3 as b").fetchone()
  117. self.assertTrue(row_1 == row_1)
  118. self.assertTrue(row_1 == row_2)
  119. self.assertTrue(row_2 != row_3)
  120. self.assertFalse(row_1 != row_1)
  121. self.assertFalse(row_1 != row_2)
  122. self.assertFalse(row_2 == row_3)
  123. self.assertEqual(row_1, row_2)
  124. self.assertEqual(hash(row_1), hash(row_2))
  125. self.assertNotEqual(row_1, row_3)
  126. self.assertNotEqual(hash(row_1), hash(row_3))
  127. def tearDown(self):
  128. self.con.close()
  129. class TextFactoryTests(unittest.TestCase):
  130. def setUp(self):
  131. self.con = sqlite.connect(":memory:")
  132. def CheckUnicode(self):
  133. austria = unicode("Österreich", "latin1")
  134. row = self.con.execute("select ?", (austria,)).fetchone()
  135. self.assertTrue(type(row[0]) == unicode, "type of row[0] must be unicode")
  136. def CheckString(self):
  137. self.con.text_factory = str
  138. austria = unicode("Österreich", "latin1")
  139. row = self.con.execute("select ?", (austria,)).fetchone()
  140. self.assertTrue(type(row[0]) == str, "type of row[0] must be str")
  141. self.assertTrue(row[0] == austria.encode("utf-8"), "column must equal original data in UTF-8")
  142. def CheckCustom(self):
  143. self.con.text_factory = lambda x: unicode(x, "utf-8", "ignore")
  144. austria = unicode("Österreich", "latin1")
  145. row = self.con.execute("select ?", (austria.encode("latin1"),)).fetchone()
  146. self.assertTrue(type(row[0]) == unicode, "type of row[0] must be unicode")
  147. self.assertTrue(row[0].endswith(u"reich"), "column must contain original data")
  148. def CheckOptimizedUnicode(self):
  149. self.con.text_factory = sqlite.OptimizedUnicode
  150. austria = unicode("Österreich", "latin1")
  151. germany = unicode("Deutchland")
  152. a_row = self.con.execute("select ?", (austria,)).fetchone()
  153. d_row = self.con.execute("select ?", (germany,)).fetchone()
  154. self.assertTrue(type(a_row[0]) == unicode, "type of non-ASCII row must be unicode")
  155. self.assertTrue(type(d_row[0]) == str, "type of ASCII-only row must be str")
  156. def tearDown(self):
  157. self.con.close()
  158. def suite():
  159. connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check")
  160. cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check")
  161. row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check")
  162. row_suite = unittest.makeSuite(RowFactoryTests, "Check")
  163. text_suite = unittest.makeSuite(TextFactoryTests, "Check")
  164. return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite))
  165. def test():
  166. runner = unittest.TextTestRunner()
  167. runner.run(suite())
  168. if __name__ == "__main__":
  169. test()