test_traitlets.py 65 KB


  1. # encoding: utf-8
  2. """Tests for traitlets.traitlets."""
  3. # Copyright (c) IPython Development Team.
  4. # Distributed under the terms of the Modified BSD License.
  5. #
  6. # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
  7. # also under the terms of the Modified BSD License.
  8. import pickle
  9. import re
  10. import sys
  11. from ._warnings import expected_warnings
  12. from unittest import TestCase
  13. import pytest
  14. from pytest import mark
  15. from traitlets import (
  16. HasTraits, MetaHasTraits, TraitType, Any, Bool, CBytes, Dict, Enum,
  17. Int, CInt, Long, CLong, Integer, Float, CFloat, Complex, Bytes, Unicode,
  18. TraitError, Union, All, Undefined, Type, This, Instance, TCPAddress,
  19. List, Tuple, ObjectName, DottedObjectName, CRegExp, link, directional_link,
  20. ForwardDeclaredType, ForwardDeclaredInstance, validate, observe, default,
  21. observe_compat, BaseDescriptor, HasDescriptors,
  22. )
  23. import six
  24. def change_dict(*ordered_values):
  25. change_names = ('name', 'old', 'new', 'owner', 'type')
  26. return dict(zip(change_names, ordered_values))
  27. #-----------------------------------------------------------------------------
  28. # Helper classes for testing
  29. #-----------------------------------------------------------------------------
  30. class HasTraitsStub(HasTraits):
  31. def notify_change(self, change):
  32. self._notify_name = change['name']
  33. self._notify_old = change['old']
  34. self._notify_new = change['new']
  35. self._notify_type = change['type']
  36. #-----------------------------------------------------------------------------
  37. # Test classes
  38. #-----------------------------------------------------------------------------
  39. class TestTraitType(TestCase):
  40. def test_get_undefined(self):
  41. class A(HasTraits):
  42. a = TraitType
  43. a = A()
  44. with self.assertRaises(TraitError):
  45. a.a
  46. def test_set(self):
  47. class A(HasTraitsStub):
  48. a = TraitType
  49. a = A()
  50. a.a = 10
  51. self.assertEqual(a.a, 10)
  52. self.assertEqual(a._notify_name, 'a')
  53. self.assertEqual(a._notify_old, Undefined)
  54. self.assertEqual(a._notify_new, 10)
  55. def test_validate(self):
  56. class MyTT(TraitType):
  57. def validate(self, inst, value):
  58. return -1
  59. class A(HasTraitsStub):
  60. tt = MyTT
  61. a = A()
  62. a.tt = 10
  63. self.assertEqual(a.tt, -1)
  64. def test_default_validate(self):
  65. class MyIntTT(TraitType):
  66. def validate(self, obj, value):
  67. if isinstance(value, int):
  68. return value
  69. self.error(obj, value)
  70. class A(HasTraits):
  71. tt = MyIntTT(10)
  72. a = A()
  73. self.assertEqual(a.tt, 10)
  74. # Defaults are validated when the HasTraits is instantiated
  75. class B(HasTraits):
  76. tt = MyIntTT('bad default')
  77. self.assertRaises(TraitError, B)
  78. def test_info(self):
  79. class A(HasTraits):
  80. tt = TraitType
  81. a = A()
  82. self.assertEqual(A.tt.info(), 'any value')
  83. def test_error(self):
  84. class A(HasTraits):
  85. tt = TraitType
  86. a = A()
  87. self.assertRaises(TraitError, A.tt.error, a, 10)
  88. def test_deprecated_dynamic_initializer(self):
  89. class A(HasTraits):
  90. x = Int(10)
  91. def _x_default(self):
  92. return 11
  93. class B(A):
  94. x = Int(20)
  95. class C(A):
  96. def _x_default(self):
  97. return 21
  98. a = A()
  99. self.assertEqual(a._trait_values, {})
  100. self.assertEqual(a.x, 11)
  101. self.assertEqual(a._trait_values, {'x': 11})
  102. b = B()
  103. self.assertEqual(b.x, 20)
  104. self.assertEqual(b._trait_values, {'x': 20})
  105. c = C()
  106. self.assertEqual(c._trait_values, {})
  107. self.assertEqual(c.x, 21)
  108. self.assertEqual(c._trait_values, {'x': 21})
  109. # Ensure that the base class remains unmolested when the _default
  110. # initializer gets overridden in a subclass.
  111. a = A()
  112. c = C()
  113. self.assertEqual(a._trait_values, {})
  114. self.assertEqual(a.x, 11)
  115. self.assertEqual(a._trait_values, {'x': 11})
  116. def test_dynamic_initializer(self):
  117. class A(HasTraits):
  118. x = Int(10)
  119. @default('x')
  120. def _default_x(self):
  121. return 11
  122. class B(A):
  123. x = Int(20)
  124. class C(A):
  125. @default('x')
  126. def _default_x(self):
  127. return 21
  128. a = A()
  129. self.assertEqual(a._trait_values, {})
  130. self.assertEqual(a.x, 11)
  131. self.assertEqual(a._trait_values, {'x': 11})
  132. b = B()
  133. self.assertEqual(b.x, 20)
  134. self.assertEqual(b._trait_values, {'x': 20})
  135. c = C()
  136. self.assertEqual(c._trait_values, {})
  137. self.assertEqual(c.x, 21)
  138. self.assertEqual(c._trait_values, {'x': 21})
  139. # Ensure that the base class remains unmolested when the _default
  140. # initializer gets overridden in a subclass.
  141. a = A()
  142. c = C()
  143. self.assertEqual(a._trait_values, {})
  144. self.assertEqual(a.x, 11)
  145. self.assertEqual(a._trait_values, {'x': 11})
  146. def test_tag_metadata(self):
  147. class MyIntTT(TraitType):
  148. metadata = {'a': 1, 'b': 2}
  149. a = MyIntTT(10).tag(b=3, c=4)
  150. self.assertEqual(a.metadata, {'a': 1, 'b': 3, 'c': 4})
  151. def test_metadata_localized_instance(self):
  152. class MyIntTT(TraitType):
  153. metadata = {'a': 1, 'b': 2}
  154. a = MyIntTT(10)
  155. b = MyIntTT(10)
  156. a.metadata['c'] = 3
  157. # make sure that changing a's metadata didn't change b's metadata
  158. self.assertNotIn('c', b.metadata)
  159. def test_union_metadata(self):
  160. class Foo(HasTraits):
  161. bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti='b')).tag(ti='a')
  162. foo = Foo()
  163. # At this point, no value has been set for bar, so value-specific
  164. # is not set.
  165. self.assertEqual(foo.trait_metadata('bar', 'ta'), None)
  166. self.assertEqual(foo.trait_metadata('bar', 'ti'), 'a')
  167. foo.bar = {}
  168. self.assertEqual(foo.trait_metadata('bar', 'ta'), 2)
  169. self.assertEqual(foo.trait_metadata('bar', 'ti'), 'b')
  170. foo.bar = 1
  171. self.assertEqual(foo.trait_metadata('bar', 'ta'), 1)
  172. self.assertEqual(foo.trait_metadata('bar', 'ti'), 'a')
  173. def test_union_default_value(self):
  174. class Foo(HasTraits):
  175. bar = Union([Dict(), Int()], default_value=1)
  176. foo = Foo()
  177. self.assertEqual(foo.bar, 1)
  178. def test_deprecated_metadata_access(self):
  179. class MyIntTT(TraitType):
  180. metadata = {'a': 1, 'b': 2}
  181. a = MyIntTT(10)
  182. with expected_warnings(["use the instance .metadata dictionary directly"]*2):
  183. a.set_metadata('key', 'value')
  184. v = a.get_metadata('key')
  185. self.assertEqual(v, 'value')
  186. with expected_warnings(["use the instance .help string directly"]*2):
  187. a.set_metadata('help', 'some help')
  188. v = a.get_metadata('help')
  189. self.assertEqual(v, 'some help')
  190. def test_trait_types_deprecated(self):
  191. with expected_warnings(["Traits should be given as instances"]):
  192. class C(HasTraits):
  193. t = Int
  194. def test_trait_types_list_deprecated(self):
  195. with expected_warnings(["Traits should be given as instances"]):
  196. class C(HasTraits):
  197. t = List(Int)
  198. def test_trait_types_tuple_deprecated(self):
  199. with expected_warnings(["Traits should be given as instances"]):
  200. class C(HasTraits):
  201. t = Tuple(Int)
  202. def test_trait_types_dict_deprecated(self):
  203. with expected_warnings(["Traits should be given as instances"]):
  204. class C(HasTraits):
  205. t = Dict(Int)
  206. class TestHasDescriptorsMeta(TestCase):
  207. def test_metaclass(self):
  208. self.assertEqual(type(HasTraits), MetaHasTraits)
  209. class A(HasTraits):
  210. a = Int()
  211. a = A()
  212. self.assertEqual(type(a.__class__), MetaHasTraits)
  213. self.assertEqual(a.a,0)
  214. a.a = 10
  215. self.assertEqual(a.a,10)
  216. class B(HasTraits):
  217. b = Int()
  218. b = B()
  219. self.assertEqual(b.b,0)
  220. b.b = 10
  221. self.assertEqual(b.b,10)
  222. class C(HasTraits):
  223. c = Int(30)
  224. c = C()
  225. self.assertEqual(c.c,30)
  226. c.c = 10
  227. self.assertEqual(c.c,10)
  228. def test_this_class(self):
  229. class A(HasTraits):
  230. t = This()
  231. tt = This()
  232. class B(A):
  233. tt = This()
  234. ttt = This()
  235. self.assertEqual(A.t.this_class, A)
  236. self.assertEqual(B.t.this_class, A)
  237. self.assertEqual(B.tt.this_class, B)
  238. self.assertEqual(B.ttt.this_class, B)
  239. class TestHasDescriptors(TestCase):
  240. def test_setup_instance(self):
  241. class FooDescriptor(BaseDescriptor):
  242. def instance_init(self, inst):
  243. foo = inst.foo # instance should have the attr
  244. class HasFooDescriptors(HasDescriptors):
  245. fd = FooDescriptor()
  246. def setup_instance(self, *args, **kwargs):
  247. self.foo = kwargs.get('foo', None)
  248. super(HasFooDescriptors, self).setup_instance(*args, **kwargs)
  249. hfd = HasFooDescriptors(foo='bar')
  250. class TestHasTraitsNotify(TestCase):
  251. def setUp(self):
  252. self._notify1 = []
  253. self._notify2 = []
  254. def notify1(self, name, old, new):
  255. self._notify1.append((name, old, new))
  256. def notify2(self, name, old, new):
  257. self._notify2.append((name, old, new))
  258. def test_notify_all(self):
  259. class A(HasTraits):
  260. a = Int()
  261. b = Float()
  262. a = A()
  263. a.on_trait_change(self.notify1)
  264. a.a = 0
  265. self.assertEqual(len(self._notify1),0)
  266. a.b = 0.0
  267. self.assertEqual(len(self._notify1),0)
  268. a.a = 10
  269. self.assertTrue(('a',0,10) in self._notify1)
  270. a.b = 10.0
  271. self.assertTrue(('b',0.0,10.0) in self._notify1)
  272. self.assertRaises(TraitError,setattr,a,'a','bad string')
  273. self.assertRaises(TraitError,setattr,a,'b','bad string')
  274. self._notify1 = []
  275. a.on_trait_change(self.notify1,remove=True)
  276. a.a = 20
  277. a.b = 20.0
  278. self.assertEqual(len(self._notify1),0)
  279. def test_notify_one(self):
  280. class A(HasTraits):
  281. a = Int()
  282. b = Float()
  283. a = A()
  284. a.on_trait_change(self.notify1, 'a')
  285. a.a = 0
  286. self.assertEqual(len(self._notify1),0)
  287. a.a = 10
  288. self.assertTrue(('a',0,10) in self._notify1)
  289. self.assertRaises(TraitError,setattr,a,'a','bad string')
  290. def test_subclass(self):
  291. class A(HasTraits):
  292. a = Int()
  293. class B(A):
  294. b = Float()
  295. b = B()
  296. self.assertEqual(b.a,0)
  297. self.assertEqual(b.b,0.0)
  298. b.a = 100
  299. b.b = 100.0
  300. self.assertEqual(b.a,100)
  301. self.assertEqual(b.b,100.0)
  302. def test_notify_subclass(self):
  303. class A(HasTraits):
  304. a = Int()
  305. class B(A):
  306. b = Float()
  307. b = B()
  308. b.on_trait_change(self.notify1, 'a')
  309. b.on_trait_change(self.notify2, 'b')
  310. b.a = 0
  311. b.b = 0.0
  312. self.assertEqual(len(self._notify1),0)
  313. self.assertEqual(len(self._notify2),0)
  314. b.a = 10
  315. b.b = 10.0
  316. self.assertTrue(('a',0,10) in self._notify1)
  317. self.assertTrue(('b',0.0,10.0) in self._notify2)
  318. def test_static_notify(self):
  319. class A(HasTraits):
  320. a = Int()
  321. _notify1 = []
  322. def _a_changed(self, name, old, new):
  323. self._notify1.append((name, old, new))
  324. a = A()
  325. a.a = 0
  326. # This is broken!!!
  327. self.assertEqual(len(a._notify1),0)
  328. a.a = 10
  329. self.assertTrue(('a',0,10) in a._notify1)
  330. class B(A):
  331. b = Float()
  332. _notify2 = []
  333. def _b_changed(self, name, old, new):
  334. self._notify2.append((name, old, new))
  335. b = B()
  336. b.a = 10
  337. b.b = 10.0
  338. self.assertTrue(('a',0,10) in b._notify1)
  339. self.assertTrue(('b',0.0,10.0) in b._notify2)
  340. def test_notify_args(self):
  341. def callback0():
  342. self.cb = ()
  343. def callback1(name):
  344. self.cb = (name,)
  345. def callback2(name, new):
  346. self.cb = (name, new)
  347. def callback3(name, old, new):
  348. self.cb = (name, old, new)
  349. def callback4(name, old, new, obj):
  350. self.cb = (name, old, new, obj)
  351. class A(HasTraits):
  352. a = Int()
  353. a = A()
  354. a.on_trait_change(callback0, 'a')
  355. a.a = 10
  356. self.assertEqual(self.cb,())
  357. a.on_trait_change(callback0, 'a', remove=True)
  358. a.on_trait_change(callback1, 'a')
  359. a.a = 100
  360. self.assertEqual(self.cb,('a',))
  361. a.on_trait_change(callback1, 'a', remove=True)
  362. a.on_trait_change(callback2, 'a')
  363. a.a = 1000
  364. self.assertEqual(self.cb,('a',1000))
  365. a.on_trait_change(callback2, 'a', remove=True)
  366. a.on_trait_change(callback3, 'a')
  367. a.a = 10000
  368. self.assertEqual(self.cb,('a',1000,10000))
  369. a.on_trait_change(callback3, 'a', remove=True)
  370. a.on_trait_change(callback4, 'a')
  371. a.a = 100000
  372. self.assertEqual(self.cb,('a',10000,100000,a))
  373. self.assertEqual(len(a._trait_notifiers['a']['change']), 1)
  374. a.on_trait_change(callback4, 'a', remove=True)
  375. self.assertEqual(len(a._trait_notifiers['a']['change']), 0)
  376. def test_notify_only_once(self):
  377. class A(HasTraits):
  378. listen_to = ['a']
  379. a = Int(0)
  380. b = 0
  381. def __init__(self, **kwargs):
  382. super(A, self).__init__(**kwargs)
  383. self.on_trait_change(self.listener1, ['a'])
  384. def listener1(self, name, old, new):
  385. self.b += 1
  386. class B(A):
  387. c = 0
  388. d = 0
  389. def __init__(self, **kwargs):
  390. super(B, self).__init__(**kwargs)
  391. self.on_trait_change(self.listener2)
  392. def listener2(self, name, old, new):
  393. self.c += 1
  394. def _a_changed(self, name, old, new):
  395. self.d += 1
  396. b = B()
  397. b.a += 1
  398. self.assertEqual(b.b, b.c)
  399. self.assertEqual(b.b, b.d)
  400. b.a += 1
  401. self.assertEqual(b.b, b.c)
  402. self.assertEqual(b.b, b.d)
  403. class TestObserveDecorator(TestCase):
  404. def setUp(self):
  405. self._notify1 = []
  406. self._notify2 = []
  407. def notify1(self, change):
  408. self._notify1.append(change)
  409. def notify2(self, change):
  410. self._notify2.append(change)
  411. def test_notify_all(self):
  412. class A(HasTraits):
  413. a = Int()
  414. b = Float()
  415. a = A()
  416. a.observe(self.notify1)
  417. a.a = 0
  418. self.assertEqual(len(self._notify1),0)
  419. a.b = 0.0
  420. self.assertEqual(len(self._notify1),0)
  421. a.a = 10
  422. change = change_dict('a', 0, 10, a, 'change')
  423. self.assertTrue(change in self._notify1)
  424. a.b = 10.0
  425. change = change_dict('b', 0.0, 10.0, a, 'change')
  426. self.assertTrue(change in self._notify1)
  427. self.assertRaises(TraitError,setattr,a,'a','bad string')
  428. self.assertRaises(TraitError,setattr,a,'b','bad string')
  429. self._notify1 = []
  430. a.unobserve(self.notify1)
  431. a.a = 20
  432. a.b = 20.0
  433. self.assertEqual(len(self._notify1),0)
  434. def test_notify_one(self):
  435. class A(HasTraits):
  436. a = Int()
  437. b = Float()
  438. a = A()
  439. a.observe(self.notify1, 'a')
  440. a.a = 0
  441. self.assertEqual(len(self._notify1),0)
  442. a.a = 10
  443. change = change_dict('a', 0, 10, a, 'change')
  444. self.assertTrue(change in self._notify1)
  445. self.assertRaises(TraitError,setattr,a,'a','bad string')
  446. def test_subclass(self):
  447. class A(HasTraits):
  448. a = Int()
  449. class B(A):
  450. b = Float()
  451. b = B()
  452. self.assertEqual(b.a,0)
  453. self.assertEqual(b.b,0.0)
  454. b.a = 100
  455. b.b = 100.0
  456. self.assertEqual(b.a,100)
  457. self.assertEqual(b.b,100.0)
  458. def test_notify_subclass(self):
  459. class A(HasTraits):
  460. a = Int()
  461. class B(A):
  462. b = Float()
  463. b = B()
  464. b.observe(self.notify1, 'a')
  465. b.observe(self.notify2, 'b')
  466. b.a = 0
  467. b.b = 0.0
  468. self.assertEqual(len(self._notify1),0)
  469. self.assertEqual(len(self._notify2),0)
  470. b.a = 10
  471. b.b = 10.0
  472. change = change_dict('a', 0, 10, b, 'change')
  473. self.assertTrue(change in self._notify1)
  474. change = change_dict('b', 0.0, 10.0, b, 'change')
  475. self.assertTrue(change in self._notify2)
  476. def test_static_notify(self):
  477. class A(HasTraits):
  478. a = Int()
  479. b = Int()
  480. _notify1 = []
  481. _notify_any = []
  482. @observe('a')
  483. def _a_changed(self, change):
  484. self._notify1.append(change)
  485. @observe(All)
  486. def _any_changed(self, change):
  487. self._notify_any.append(change)
  488. a = A()
  489. a.a = 0
  490. self.assertEqual(len(a._notify1),0)
  491. a.a = 10
  492. change = change_dict('a', 0, 10, a, 'change')
  493. self.assertTrue(change in a._notify1)
  494. a.b = 1
  495. self.assertEqual(len(a._notify_any), 2)
  496. change = change_dict('b', 0, 1, a, 'change')
  497. self.assertTrue(change in a._notify_any)
  498. class B(A):
  499. b = Float()
  500. _notify2 = []
  501. @observe('b')
  502. def _b_changed(self, change):
  503. self._notify2.append(change)
  504. b = B()
  505. b.a = 10
  506. b.b = 10.0
  507. change = change_dict('a', 0, 10, b, 'change')
  508. self.assertTrue(change in b._notify1)
  509. change = change_dict('b', 0.0, 10.0, b, 'change')
  510. self.assertTrue(change in b._notify2)
  511. def test_notify_args(self):
  512. def callback0():
  513. self.cb = ()
  514. def callback1(change):
  515. self.cb = change
  516. class A(HasTraits):
  517. a = Int()
  518. a = A()
  519. a.on_trait_change(callback0, 'a')
  520. a.a = 10
  521. self.assertEqual(self.cb,())
  522. a.unobserve(callback0, 'a')
  523. a.observe(callback1, 'a')
  524. a.a = 100
  525. change = change_dict('a', 10, 100, a, 'change')
  526. self.assertEqual(self.cb, change)
  527. self.assertEqual(len(a._trait_notifiers['a']['change']), 1)
  528. a.unobserve(callback1, 'a')
  529. self.assertEqual(len(a._trait_notifiers['a']['change']), 0)
  530. def test_notify_only_once(self):
  531. class A(HasTraits):
  532. listen_to = ['a']
  533. a = Int(0)
  534. b = 0
  535. def __init__(self, **kwargs):
  536. super(A, self).__init__(**kwargs)
  537. self.observe(self.listener1, ['a'])
  538. def listener1(self, change):
  539. self.b += 1
  540. class B(A):
  541. c = 0
  542. d = 0
  543. def __init__(self, **kwargs):
  544. super(B, self).__init__(**kwargs)
  545. self.observe(self.listener2)
  546. def listener2(self, change):
  547. self.c += 1
  548. @observe('a')
  549. def _a_changed(self, change):
  550. self.d += 1
  551. b = B()
  552. b.a += 1
  553. self.assertEqual(b.b, b.c)
  554. self.assertEqual(b.b, b.d)
  555. b.a += 1
  556. self.assertEqual(b.b, b.c)
  557. self.assertEqual(b.b, b.d)
  558. class TestHasTraits(TestCase):
  559. def test_trait_names(self):
  560. class A(HasTraits):
  561. i = Int()
  562. f = Float()
  563. a = A()
  564. self.assertEqual(sorted(a.trait_names()),['f','i'])
  565. self.assertEqual(sorted(A.class_trait_names()),['f','i'])
  566. self.assertTrue(a.has_trait('f'))
  567. self.assertFalse(a.has_trait('g'))
  568. def test_trait_metadata_deprecated(self):
  569. with expected_warnings(['metadata should be set using the \.tag\(\) method']):
  570. class A(HasTraits):
  571. i = Int(config_key='MY_VALUE')
  572. a = A()
  573. self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
  574. def test_trait_metadata(self):
  575. class A(HasTraits):
  576. i = Int().tag(config_key='MY_VALUE')
  577. a = A()
  578. self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
  579. def test_trait_metadata_default(self):
  580. class A(HasTraits):
  581. i = Int()
  582. a = A()
  583. self.assertEqual(a.trait_metadata('i', 'config_key'), None)
  584. self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
  585. def test_traits(self):
  586. class A(HasTraits):
  587. i = Int()
  588. f = Float()
  589. a = A()
  590. self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
  591. self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
  592. def test_traits_metadata(self):
  593. class A(HasTraits):
  594. i = Int().tag(config_key='VALUE1', other_thing='VALUE2')
  595. f = Float().tag(config_key='VALUE3', other_thing='VALUE2')
  596. j = Int(0)
  597. a = A()
  598. self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
  599. traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
  600. self.assertEqual(traits, dict(i=A.i))
  601. # This passes, but it shouldn't because I am replicating a bug in
  602. # traits.
  603. traits = a.traits(config_key=lambda v: True)
  604. self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
  605. def test_traits_metadata_deprecated(self):
  606. with expected_warnings(['metadata should be set using the \.tag\(\) method']*2):
  607. class A(HasTraits):
  608. i = Int(config_key='VALUE1', other_thing='VALUE2')
  609. f = Float(config_key='VALUE3', other_thing='VALUE2')
  610. j = Int(0)
  611. a = A()
  612. self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
  613. traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
  614. self.assertEqual(traits, dict(i=A.i))
  615. # This passes, but it shouldn't because I am replicating a bug in
  616. # traits.
  617. traits = a.traits(config_key=lambda v: True)
  618. self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
  619. def test_init(self):
  620. class A(HasTraits):
  621. i = Int()
  622. x = Float()
  623. a = A(i=1, x=10.0)
  624. self.assertEqual(a.i, 1)
  625. self.assertEqual(a.x, 10.0)
  626. def test_positional_args(self):
  627. class A(HasTraits):
  628. i = Int(0)
  629. def __init__(self, i):
  630. super(A, self).__init__()
  631. self.i = i
  632. a = A(5)
  633. self.assertEqual(a.i, 5)
  634. # should raise TypeError if no positional arg given
  635. self.assertRaises(TypeError, A)
  636. #-----------------------------------------------------------------------------
  637. # Tests for specific trait types
  638. #-----------------------------------------------------------------------------
  639. class TestType(TestCase):
  640. def test_default(self):
  641. class B(object): pass
  642. class A(HasTraits):
  643. klass = Type(allow_none=True)
  644. a = A()
  645. self.assertEqual(a.klass, object)
  646. a.klass = B
  647. self.assertEqual(a.klass, B)
  648. self.assertRaises(TraitError, setattr, a, 'klass', 10)
  649. def test_default_options(self):
  650. class B(object): pass
  651. class C(B): pass
  652. class A(HasTraits):
  653. # Different possible combinations of options for default_value
  654. # and klass. default_value=None is only valid with allow_none=True.
  655. k1 = Type()
  656. k2 = Type(None, allow_none=True)
  657. k3 = Type(B)
  658. k4 = Type(klass=B)
  659. k5 = Type(default_value=None, klass=B, allow_none=True)
  660. k6 = Type(default_value=C, klass=B)
  661. self.assertIs(A.k1.default_value, object)
  662. self.assertIs(A.k1.klass, object)
  663. self.assertIs(A.k2.default_value, None)
  664. self.assertIs(A.k2.klass, object)
  665. self.assertIs(A.k3.default_value, B)
  666. self.assertIs(A.k3.klass, B)
  667. self.assertIs(A.k4.default_value, B)
  668. self.assertIs(A.k4.klass, B)
  669. self.assertIs(A.k5.default_value, None)
  670. self.assertIs(A.k5.klass, B)
  671. self.assertIs(A.k6.default_value, C)
  672. self.assertIs(A.k6.klass, B)
  673. a = A()
  674. self.assertIs(a.k1, object)
  675. self.assertIs(a.k2, None)
  676. self.assertIs(a.k3, B)
  677. self.assertIs(a.k4, B)
  678. self.assertIs(a.k5, None)
  679. self.assertIs(a.k6, C)
  680. def test_value(self):
  681. class B(object): pass
  682. class C(object): pass
  683. class A(HasTraits):
  684. klass = Type(B)
  685. a = A()
  686. self.assertEqual(a.klass, B)
  687. self.assertRaises(TraitError, setattr, a, 'klass', C)
  688. self.assertRaises(TraitError, setattr, a, 'klass', object)
  689. a.klass = B
  690. def test_allow_none(self):
  691. class B(object): pass
  692. class C(B): pass
  693. class A(HasTraits):
  694. klass = Type(B)
  695. a = A()
  696. self.assertEqual(a.klass, B)
  697. self.assertRaises(TraitError, setattr, a, 'klass', None)
  698. a.klass = C
  699. self.assertEqual(a.klass, C)
  700. def test_validate_klass(self):
  701. class A(HasTraits):
  702. klass = Type('no strings allowed')
  703. self.assertRaises(ImportError, A)
  704. class A(HasTraits):
  705. klass = Type('rub.adub.Duck')
  706. self.assertRaises(ImportError, A)
  707. def test_validate_default(self):
  708. class B(object): pass
  709. class A(HasTraits):
  710. klass = Type('bad default', B)
  711. self.assertRaises(ImportError, A)
  712. class C(HasTraits):
  713. klass = Type(None, B)
  714. self.assertRaises(TraitError, C)
  715. def test_str_klass(self):
  716. class A(HasTraits):
  717. klass = Type('ipython_genutils.ipstruct.Struct')
  718. from ipython_genutils.ipstruct import Struct
  719. a = A()
  720. a.klass = Struct
  721. self.assertEqual(a.klass, Struct)
  722. self.assertRaises(TraitError, setattr, a, 'klass', 10)
  723. def test_set_str_klass(self):
  724. class A(HasTraits):
  725. klass = Type()
  726. a = A(klass='ipython_genutils.ipstruct.Struct')
  727. from ipython_genutils.ipstruct import Struct
  728. self.assertEqual(a.klass, Struct)
  729. class TestInstance(TestCase):
  730. def test_basic(self):
  731. class Foo(object): pass
  732. class Bar(Foo): pass
  733. class Bah(object): pass
  734. class A(HasTraits):
  735. inst = Instance(Foo, allow_none=True)
  736. a = A()
  737. self.assertTrue(a.inst is None)
  738. a.inst = Foo()
  739. self.assertTrue(isinstance(a.inst, Foo))
  740. a.inst = Bar()
  741. self.assertTrue(isinstance(a.inst, Foo))
  742. self.assertRaises(TraitError, setattr, a, 'inst', Foo)
  743. self.assertRaises(TraitError, setattr, a, 'inst', Bar)
  744. self.assertRaises(TraitError, setattr, a, 'inst', Bah())
  745. def test_default_klass(self):
  746. class Foo(object): pass
  747. class Bar(Foo): pass
  748. class Bah(object): pass
  749. class FooInstance(Instance):
  750. klass = Foo
  751. class A(HasTraits):
  752. inst = FooInstance(allow_none=True)
  753. a = A()
  754. self.assertTrue(a.inst is None)
  755. a.inst = Foo()
  756. self.assertTrue(isinstance(a.inst, Foo))
  757. a.inst = Bar()
  758. self.assertTrue(isinstance(a.inst, Foo))
  759. self.assertRaises(TraitError, setattr, a, 'inst', Foo)
  760. self.assertRaises(TraitError, setattr, a, 'inst', Bar)
  761. self.assertRaises(TraitError, setattr, a, 'inst', Bah())
  762. def test_unique_default_value(self):
  763. class Foo(object): pass
  764. class A(HasTraits):
  765. inst = Instance(Foo,(),{})
  766. a = A()
  767. b = A()
  768. self.assertTrue(a.inst is not b.inst)
  769. def test_args_kw(self):
  770. class Foo(object):
  771. def __init__(self, c): self.c = c
  772. class Bar(object): pass
  773. class Bah(object):
  774. def __init__(self, c, d):
  775. self.c = c; self.d = d
  776. class A(HasTraits):
  777. inst = Instance(Foo, (10,))
  778. a = A()
  779. self.assertEqual(a.inst.c, 10)
  780. class B(HasTraits):
  781. inst = Instance(Bah, args=(10,), kw=dict(d=20))
  782. b = B()
  783. self.assertEqual(b.inst.c, 10)
  784. self.assertEqual(b.inst.d, 20)
  785. class C(HasTraits):
  786. inst = Instance(Foo, allow_none=True)
  787. c = C()
  788. self.assertTrue(c.inst is None)
  789. def test_bad_default(self):
  790. class Foo(object): pass
  791. class A(HasTraits):
  792. inst = Instance(Foo)
  793. a = A()
  794. with self.assertRaises(TraitError):
  795. a.inst
  796. def test_instance(self):
  797. class Foo(object): pass
  798. def inner():
  799. class A(HasTraits):
  800. inst = Instance(Foo())
  801. self.assertRaises(TraitError, inner)
  802. class TestThis(TestCase):
  803. def test_this_class(self):
  804. class Foo(HasTraits):
  805. this = This()
  806. f = Foo()
  807. self.assertEqual(f.this, None)
  808. g = Foo()
  809. f.this = g
  810. self.assertEqual(f.this, g)
  811. self.assertRaises(TraitError, setattr, f, 'this', 10)
  812. def test_this_inst(self):
  813. class Foo(HasTraits):
  814. this = This()
  815. f = Foo()
  816. f.this = Foo()
  817. self.assertTrue(isinstance(f.this, Foo))
  818. def test_subclass(self):
  819. class Foo(HasTraits):
  820. t = This()
  821. class Bar(Foo):
  822. pass
  823. f = Foo()
  824. b = Bar()
  825. f.t = b
  826. b.t = f
  827. self.assertEqual(f.t, b)
  828. self.assertEqual(b.t, f)
  829. def test_subclass_override(self):
  830. class Foo(HasTraits):
  831. t = This()
  832. class Bar(Foo):
  833. t = This()
  834. f = Foo()
  835. b = Bar()
  836. f.t = b
  837. self.assertEqual(f.t, b)
  838. self.assertRaises(TraitError, setattr, b, 't', f)
  839. def test_this_in_container(self):
  840. class Tree(HasTraits):
  841. value = Unicode()
  842. leaves = List(This())
  843. tree = Tree(
  844. value='foo',
  845. leaves=[Tree(value='bar'), Tree(value='buzz')]
  846. )
  847. with self.assertRaises(TraitError):
  848. tree.leaves = [1, 2]
  849. class TraitTestBase(TestCase):
  850. """A best testing class for basic trait types."""
  851. def assign(self, value):
  852. self.obj.value = value
  853. def coerce(self, value):
  854. return value
  855. def test_good_values(self):
  856. if hasattr(self, '_good_values'):
  857. for value in self._good_values:
  858. self.assign(value)
  859. self.assertEqual(self.obj.value, self.coerce(value))
  860. def test_bad_values(self):
  861. if hasattr(self, '_bad_values'):
  862. for value in self._bad_values:
  863. try:
  864. self.assertRaises(TraitError, self.assign, value)
  865. except AssertionError:
  866. assert False, value
  867. def test_default_value(self):
  868. if hasattr(self, '_default_value'):
  869. self.assertEqual(self._default_value, self.obj.value)
  870. def test_allow_none(self):
  871. if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
  872. None in self._bad_values):
  873. trait=self.obj.traits()['value']
  874. try:
  875. trait.allow_none = True
  876. self._bad_values.remove(None)
  877. #skip coerce. Allow None casts None to None.
  878. self.assign(None)
  879. self.assertEqual(self.obj.value,None)
  880. self.test_good_values()
  881. self.test_bad_values()
  882. finally:
  883. #tear down
  884. trait.allow_none = False
  885. self._bad_values.append(None)
  886. def tearDown(self):
  887. # restore default value after tests, if set
  888. if hasattr(self, '_default_value'):
  889. self.obj.value = self._default_value
  890. class AnyTrait(HasTraits):
  891. value = Any()
  892. class AnyTraitTest(TraitTestBase):
  893. obj = AnyTrait()
  894. _default_value = None
  895. _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
  896. _bad_values = []
  897. class UnionTrait(HasTraits):
  898. value = Union([Type(), Bool()])
  899. class UnionTraitTest(TraitTestBase):
  900. obj = UnionTrait(value='ipython_genutils.ipstruct.Struct')
  901. _good_values = [int, float, True]
  902. _bad_values = [[], (0,), 1j]
  903. class OrTrait(HasTraits):
  904. value = Bool() | Unicode()
  905. class OrTraitTest(TraitTestBase):
  906. obj = OrTrait()
  907. _good_values = [True, False, 'ten']
  908. _bad_values = [[], (0,), 1j]
  909. class IntTrait(HasTraits):
  910. value = Int(99, min=-100)
  911. class TestInt(TraitTestBase):
  912. obj = IntTrait()
  913. _default_value = 99
  914. _good_values = [10, -10]
  915. _bad_values = ['ten', u'ten', [10], {'ten': 10}, (10,), None, 1j,
  916. 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
  917. u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', -200]
  918. if not six.PY3:
  919. _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
  920. class CIntTrait(HasTraits):
  921. value = CInt('5')
  922. class TestCInt(TraitTestBase):
  923. obj = CIntTrait()
  924. _default_value = 5
  925. _good_values = ['10', '-10', u'10', u'-10', 10, 10.0, -10.0, 10.1]
  926. _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
  927. None, 1j, '10.1', u'10.1']
  928. def coerce(self, n):
  929. return int(n)
  930. class MinBoundCIntTrait(HasTraits):
  931. value = CInt('5', min=3)
  932. class TestMinBoundCInt(TestCInt):
  933. obj = MinBoundCIntTrait()
  934. _default_value = 5
  935. _good_values = [3, 3.0, '3']
  936. _bad_values = [2.6, 2, -3, -3.0]
  937. class LongTrait(HasTraits):
  938. value = Long(99 if six.PY3 else long(99))
  939. class TestLong(TraitTestBase):
  940. obj = LongTrait()
  941. _default_value = 99 if six.PY3 else long(99)
  942. _good_values = [10, -10]
  943. _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
  944. None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
  945. '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
  946. u'-10.1']
  947. if not six.PY3:
  948. # maxint undefined on py3, because int == long
  949. _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
  950. _bad_values.extend([[long(10)], (long(10),)])
  951. @mark.skipif(six.PY3, reason="not relevant on py3")
  952. def test_cast_small(self):
  953. """Long casts ints to long"""
  954. self.obj.value = 10
  955. self.assertEqual(type(self.obj.value), long)
  956. class MinBoundLongTrait(HasTraits):
  957. value = Long(99 if six.PY3 else long(99), min=5)
  958. class TestMinBoundLong(TraitTestBase):
  959. obj = MinBoundLongTrait()
  960. _default_value = 99 if six.PY3 else long(99)
  961. _good_values = [5, 10]
  962. _bad_values = [4, -10]
  963. class MaxBoundLongTrait(HasTraits):
  964. value = Long(5 if six.PY3 else long(5), max=10)
  965. class TestMaxBoundLong(TraitTestBase):
  966. obj = MaxBoundLongTrait()
  967. _default_value = 5 if six.PY3 else long(5)
  968. _good_values = [10, -2]
  969. _bad_values = [11, 20]
  970. class CLongTrait(HasTraits):
  971. value = CLong('5')
  972. class TestCLong(TraitTestBase):
  973. obj = CLongTrait()
  974. _default_value = 5 if six.PY3 else long(5)
  975. _good_values = ['10', '-10', u'10', u'-10', 10, 10.0, -10.0, 10.1]
  976. _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
  977. None, 1j, '10.1', u'10.1']
  978. def coerce(self, n):
  979. return int(n) if six.PY3 else long(n)
  980. class MaxBoundCLongTrait(HasTraits):
  981. value = CLong('5', max=10)
  982. class TestMaxBoundCLong(TestCLong):
  983. obj = MaxBoundCLongTrait()
  984. _default_value = 5 if six.PY3 else long(5)
  985. _good_values = [10, '10', 10.3]
  986. _bad_values = [11.0, '11']
  987. class IntegerTrait(HasTraits):
  988. value = Integer(1)
  989. class TestInteger(TestLong):
  990. obj = IntegerTrait()
  991. _default_value = 1
  992. def coerce(self, n):
  993. return int(n)
  994. @mark.skipif(six.PY3, reason="not relevant on py3")
  995. def test_cast_small(self):
  996. """Integer casts small longs to int"""
  997. self.obj.value = long(100)
  998. self.assertEqual(type(self.obj.value), int)
  999. class MinBoundIntegerTrait(HasTraits):
  1000. value = Integer(5, min=3)
  1001. class TestMinBoundInteger(TraitTestBase):
  1002. obj = MinBoundIntegerTrait()
  1003. _default_value = 5
  1004. _good_values = 3, 20
  1005. _bad_values = [2, -10]
  1006. class MaxBoundIntegerTrait(HasTraits):
  1007. value = Integer(1, max=3)
  1008. class TestMaxBoundInteger(TraitTestBase):
  1009. obj = MaxBoundIntegerTrait()
  1010. _default_value = 1
  1011. _good_values = 3, -2
  1012. _bad_values = [4, 10]
  1013. class FloatTrait(HasTraits):
  1014. value = Float(99.0, max=200.0)
  1015. class TestFloat(TraitTestBase):
  1016. obj = FloatTrait()
  1017. _default_value = 99.0
  1018. _good_values = [10, -10, 10.1, -10.1]
  1019. _bad_values = ['ten', u'ten', [10], {'ten': 10}, (10,), None,
  1020. 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
  1021. u'-10', u'10L', u'-10L', u'10.1', u'-10.1', 201.0]
  1022. if not six.PY3:
  1023. _bad_values.extend([long(10), long(-10)])
  1024. class CFloatTrait(HasTraits):
  1025. value = CFloat('99.0', max=200.0)
  1026. class TestCFloat(TraitTestBase):
  1027. obj = CFloatTrait()
  1028. _default_value = 99.0
  1029. _good_values = [10, 10.0, 10.5, '10.0', '10', '-10', '10.0', u'10']
  1030. _bad_values = ['ten', u'ten', [10], {'ten': 10}, (10,), None, 1j,
  1031. 200.1, '200.1']
  1032. def coerce(self, v):
  1033. return float(v)
  1034. class ComplexTrait(HasTraits):
  1035. value = Complex(99.0-99.0j)
  1036. class TestComplex(TraitTestBase):
  1037. obj = ComplexTrait()
  1038. _default_value = 99.0-99.0j
  1039. _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
  1040. 10.1j, 10.1+10.1j, 10.1-10.1j]
  1041. _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
  1042. if not six.PY3:
  1043. _bad_values.extend([long(10), long(-10)])
  1044. class BytesTrait(HasTraits):
  1045. value = Bytes(b'string')
  1046. class TestBytes(TraitTestBase):
  1047. obj = BytesTrait()
  1048. _default_value = b'string'
  1049. _good_values = [b'10', b'-10', b'10L',
  1050. b'-10L', b'10.1', b'-10.1', b'string']
  1051. _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
  1052. ['ten'],{'ten': 10},(10,), None, u'string']
  1053. if not six.PY3:
  1054. _bad_values.extend([long(10), long(-10)])
  1055. class UnicodeTrait(HasTraits):
  1056. value = Unicode(u'unicode')
  1057. class TestUnicode(TraitTestBase):
  1058. obj = UnicodeTrait()
  1059. _default_value = u'unicode'
  1060. _good_values = ['10', '-10', '10L', '-10L', '10.1',
  1061. '-10.1', '', u'', 'string', u'string', u"€"]
  1062. _bad_values = [10, -10, 10.1, -10.1, 1j,
  1063. [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
  1064. if not six.PY3:
  1065. _bad_values.extend([long(10), long(-10)])
  1066. class ObjectNameTrait(HasTraits):
  1067. value = ObjectName("abc")
  1068. class TestObjectName(TraitTestBase):
  1069. obj = ObjectNameTrait()
  1070. _default_value = "abc"
  1071. _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
  1072. _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
  1073. None, object(), object]
  1074. if sys.version_info[0] < 3:
  1075. _bad_values.append(u"þ")
  1076. else:
  1077. _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
  1078. class DottedObjectNameTrait(HasTraits):
  1079. value = DottedObjectName("a.b")
  1080. class TestDottedObjectName(TraitTestBase):
  1081. obj = DottedObjectNameTrait()
  1082. _default_value = "a.b"
  1083. _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
  1084. _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
  1085. if sys.version_info[0] < 3:
  1086. _bad_values.append(u"t.þ")
  1087. else:
  1088. _good_values.append(u"t.þ")
  1089. class TCPAddressTrait(HasTraits):
  1090. value = TCPAddress()
  1091. class TestTCPAddress(TraitTestBase):
  1092. obj = TCPAddressTrait()
  1093. _default_value = ('127.0.0.1',0)
  1094. _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
  1095. _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
  1096. class ListTrait(HasTraits):
  1097. value = List(Int())
  1098. class TestList(TraitTestBase):
  1099. obj = ListTrait()
  1100. _default_value = []
  1101. _good_values = [[], [1], list(range(10)), (1,2)]
  1102. _bad_values = [10, [1,'a'], 'a']
  1103. def coerce(self, value):
  1104. if value is not None:
  1105. value = list(value)
  1106. return value
  1107. class Foo(object):
  1108. pass
  1109. class NoneInstanceListTrait(HasTraits):
  1110. value = List(Instance(Foo))
  1111. class TestNoneInstanceList(TraitTestBase):
  1112. obj = NoneInstanceListTrait()
  1113. _default_value = []
  1114. _good_values = [[Foo(), Foo()], []]
  1115. _bad_values = [[None], [Foo(), None]]
  1116. class InstanceListTrait(HasTraits):
  1117. value = List(Instance(__name__+'.Foo'))
  1118. class TestInstanceList(TraitTestBase):
  1119. obj = InstanceListTrait()
  1120. def test_klass(self):
  1121. """Test that the instance klass is properly assigned."""
  1122. self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
  1123. _default_value = []
  1124. _good_values = [[Foo(), Foo()], []]
  1125. _bad_values = [['1', 2,], '1', [Foo], None]
  1126. class UnionListTrait(HasTraits):
  1127. value = List(Int() | Bool())
  1128. class TestUnionListTrait(HasTraits):
  1129. obj = UnionListTrait()
  1130. _default_value = []
  1131. _good_values = [[True, 1], [False, True]]
  1132. _bad_values = [[1, 'True'], False]
  1133. class LenListTrait(HasTraits):
  1134. value = List(Int(), [0], minlen=1, maxlen=2)
  1135. class TestLenList(TraitTestBase):
  1136. obj = LenListTrait()
  1137. _default_value = [0]
  1138. _good_values = [[1], [1,2], (1,2)]
  1139. _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
  1140. def coerce(self, value):
  1141. if value is not None:
  1142. value = list(value)
  1143. return value
  1144. class TupleTrait(HasTraits):
  1145. value = Tuple(Int(allow_none=True), default_value=(1,))
  1146. class TestTupleTrait(TraitTestBase):
  1147. obj = TupleTrait()
  1148. _default_value = (1,)
  1149. _good_values = [(1,), (0,), [1]]
  1150. _bad_values = [10, (1, 2), ('a'), (), None]
  1151. def coerce(self, value):
  1152. if value is not None:
  1153. value = tuple(value)
  1154. return value
  1155. def test_invalid_args(self):
  1156. self.assertRaises(TypeError, Tuple, 5)
  1157. self.assertRaises(TypeError, Tuple, default_value='hello')
  1158. t = Tuple(Int(), CBytes(), default_value=(1,5))
  1159. class LooseTupleTrait(HasTraits):
  1160. value = Tuple((1,2,3))
  1161. class TestLooseTupleTrait(TraitTestBase):
  1162. obj = LooseTupleTrait()
  1163. _default_value = (1,2,3)
  1164. _good_values = [(1,), [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
  1165. _bad_values = [10, 'hello', {}, None]
  1166. def coerce(self, value):
  1167. if value is not None:
  1168. value = tuple(value)
  1169. return value
  1170. def test_invalid_args(self):
  1171. self.assertRaises(TypeError, Tuple, 5)
  1172. self.assertRaises(TypeError, Tuple, default_value='hello')
  1173. t = Tuple(Int(), CBytes(), default_value=(1,5))
  1174. class MultiTupleTrait(HasTraits):
  1175. value = Tuple(Int(), Bytes(), default_value=[99,b'bottles'])
  1176. class TestMultiTuple(TraitTestBase):
  1177. obj = MultiTupleTrait()
  1178. _default_value = (99,b'bottles')
  1179. _good_values = [(1,b'a'), (2,b'b')]
  1180. _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
  1181. class CRegExpTrait(HasTraits):
  1182. value = CRegExp(r'')
  1183. class TestCRegExp(TraitTestBase):
  1184. def coerce(self, value):
  1185. return re.compile(value)
  1186. obj = CRegExpTrait()
  1187. _default_value = re.compile(r'')
  1188. _good_values = [r'\d+', re.compile(r'\d+')]
  1189. _bad_values = ['(', None, ()]
  1190. class DictTrait(HasTraits):
  1191. value = Dict()
  1192. def test_dict_assignment():
  1193. d = dict()
  1194. c = DictTrait()
  1195. c.value = d
  1196. d['a'] = 5
  1197. assert d == c.value
  1198. assert c.value is d
  1199. class UniformlyValidatedDictTrait(HasTraits):
  1200. value = Dict(trait=Unicode(),
  1201. default_value={'foo': '1'})
  1202. class TestInstanceUniformlyValidatedDict(TraitTestBase):
  1203. obj = UniformlyValidatedDictTrait()
  1204. _default_value = {'foo': '1'}
  1205. _good_values = [{'foo': '0', 'bar': '1'}]
  1206. _bad_values = [{'foo': 0, 'bar': '1'}]
  1207. class KeyValidatedDictTrait(HasTraits):
  1208. value = Dict(traits={'foo': Int()},
  1209. default_value={'foo': 1})
  1210. class TestInstanceKeyValidatedDict(TraitTestBase):
  1211. obj = KeyValidatedDictTrait()
  1212. _default_value = {'foo': 1}
  1213. _good_values = [{'foo': 0, 'bar': '1'}, {'foo': 0, 'bar': 1}]
  1214. _bad_values = [{'foo': '0', 'bar': '1'}]
  1215. class FullyValidatedDictTrait(HasTraits):
  1216. value = Dict(trait=Unicode(),
  1217. traits={'foo': Int()},
  1218. default_value={'foo': 1})
  1219. class TestInstanceFullyValidatedDict(TraitTestBase):
  1220. obj = FullyValidatedDictTrait()
  1221. _default_value = {'foo': 1}
  1222. _good_values = [{'foo': 0, 'bar': '1'}, {'foo': 1, 'bar': '2'}]
  1223. _bad_values = [{'foo': 0, 'bar': 1}, {'foo': '0', 'bar': '1'}]
  1224. def test_dict_default_value():
  1225. """Check that the `{}` default value of the Dict traitlet constructor is
  1226. actually copied."""
  1227. class Foo(HasTraits):
  1228. d1 = Dict()
  1229. d2 = Dict()
  1230. foo = Foo()
  1231. assert foo.d1 == {}
  1232. assert foo.d2 == {}
  1233. assert foo.d1 is not foo.d2
  1234. class TestValidationHook(TestCase):
  1235. def test_parity_trait(self):
  1236. """Verify that the early validation hook is effective"""
  1237. class Parity(HasTraits):
  1238. value = Int(0)
  1239. parity = Enum(['odd', 'even'], default_value='even')
  1240. @validate('value')
  1241. def _value_validate(self, proposal):
  1242. value = proposal['value']
  1243. if self.parity == 'even' and value % 2:
  1244. raise TraitError('Expected an even number')
  1245. if self.parity == 'odd' and (value % 2 == 0):
  1246. raise TraitError('Expected an odd number')
  1247. return value
  1248. u = Parity()
  1249. u.parity = 'odd'
  1250. u.value = 1 # OK
  1251. with self.assertRaises(TraitError):
  1252. u.value = 2 # Trait Error
  1253. u.parity = 'even'
  1254. u.value = 2 # OK
  1255. def test_multiple_validate(self):
  1256. """Verify that we can register the same validator to multiple names"""
  1257. class OddEven(HasTraits):
  1258. odd = Int(1)
  1259. even = Int(0)
  1260. @validate('odd', 'even')
  1261. def check_valid(self, proposal):
  1262. if proposal['trait'].name == 'odd' and not proposal['value'] % 2:
  1263. raise TraitError('odd should be odd')
  1264. if proposal['trait'].name == 'even' and proposal['value'] % 2:
  1265. raise TraitError('even should be even')
  1266. u = OddEven()
  1267. u.odd = 3 # OK
  1268. with self.assertRaises(TraitError):
  1269. u.odd = 2 # Trait Error
  1270. u.even = 2 # OK
  1271. with self.assertRaises(TraitError):
  1272. u.even = 3 # Trait Error
  1273. class TestLink(TestCase):
  1274. def test_connect_same(self):
  1275. """Verify two traitlets of the same type can be linked together using link."""
  1276. # Create two simple classes with Int traitlets.
  1277. class A(HasTraits):
  1278. value = Int()
  1279. a = A(value=9)
  1280. b = A(value=8)
  1281. # Conenct the two classes.
  1282. c = link((a, 'value'), (b, 'value'))
  1283. # Make sure the values are the same at the point of linking.
  1284. self.assertEqual(a.value, b.value)
  1285. # Change one of the values to make sure they stay in sync.
  1286. a.value = 5
  1287. self.assertEqual(a.value, b.value)
  1288. b.value = 6
  1289. self.assertEqual(a.value, b.value)
  1290. def test_link_different(self):
  1291. """Verify two traitlets of different types can be linked together using link."""
  1292. # Create two simple classes with Int traitlets.
  1293. class A(HasTraits):
  1294. value = Int()
  1295. class B(HasTraits):
  1296. count = Int()
  1297. a = A(value=9)
  1298. b = B(count=8)
  1299. # Conenct the two classes.
  1300. c = link((a, 'value'), (b, 'count'))
  1301. # Make sure the values are the same at the point of linking.
  1302. self.assertEqual(a.value, b.count)
  1303. # Change one of the values to make sure they stay in sync.
  1304. a.value = 5
  1305. self.assertEqual(a.value, b.count)
  1306. b.count = 4
  1307. self.assertEqual(a.value, b.count)
  1308. def test_unlink(self):
  1309. """Verify two linked traitlets can be unlinked."""
  1310. # Create two simple classes with Int traitlets.
  1311. class A(HasTraits):
  1312. value = Int()
  1313. a = A(value=9)
  1314. b = A(value=8)
  1315. # Connect the two classes.
  1316. c = link((a, 'value'), (b, 'value'))
  1317. a.value = 4
  1318. c.unlink()
  1319. # Change one of the values to make sure they don't stay in sync.
  1320. a.value = 5
  1321. self.assertNotEqual(a.value, b.value)
  1322. def test_callbacks(self):
  1323. """Verify two linked traitlets have their callbacks called once."""
  1324. # Create two simple classes with Int traitlets.
  1325. class A(HasTraits):
  1326. value = Int()
  1327. class B(HasTraits):
  1328. count = Int()
  1329. a = A(value=9)
  1330. b = B(count=8)
  1331. # Register callbacks that count.
  1332. callback_count = []
  1333. def a_callback(name, old, new):
  1334. callback_count.append('a')
  1335. a.on_trait_change(a_callback, 'value')
  1336. def b_callback(name, old, new):
  1337. callback_count.append('b')
  1338. b.on_trait_change(b_callback, 'count')
  1339. # Connect the two classes.
  1340. c = link((a, 'value'), (b, 'count'))
  1341. # Make sure b's count was set to a's value once.
  1342. self.assertEqual(''.join(callback_count), 'b')
  1343. del callback_count[:]
  1344. # Make sure a's value was set to b's count once.
  1345. b.count = 5
  1346. self.assertEqual(''.join(callback_count), 'ba')
  1347. del callback_count[:]
  1348. # Make sure b's count was set to a's value once.
  1349. a.value = 4
  1350. self.assertEqual(''.join(callback_count), 'ab')
  1351. del callback_count[:]
  1352. class TestDirectionalLink(TestCase):
  1353. def test_connect_same(self):
  1354. """Verify two traitlets of the same type can be linked together using directional_link."""
  1355. # Create two simple classes with Int traitlets.
  1356. class A(HasTraits):
  1357. value = Int()
  1358. a = A(value=9)
  1359. b = A(value=8)
  1360. # Conenct the two classes.
  1361. c = directional_link((a, 'value'), (b, 'value'))
  1362. # Make sure the values are the same at the point of linking.
  1363. self.assertEqual(a.value, b.value)
  1364. # Change one the value of the source and check that it synchronizes the target.
  1365. a.value = 5
  1366. self.assertEqual(b.value, 5)
  1367. # Change one the value of the target and check that it has no impact on the source
  1368. b.value = 6
  1369. self.assertEqual(a.value, 5)
  1370. def test_tranform(self):
  1371. """Test transform link."""
  1372. # Create two simple classes with Int traitlets.
  1373. class A(HasTraits):
  1374. value = Int()
  1375. a = A(value=9)
  1376. b = A(value=8)
  1377. # Conenct the two classes.
  1378. c = directional_link((a, 'value'), (b, 'value'), lambda x: 2 * x)
  1379. # Make sure the values are correct at the point of linking.
  1380. self.assertEqual(b.value, 2 * a.value)
  1381. # Change one the value of the source and check that it modifies the target.
  1382. a.value = 5
  1383. self.assertEqual(b.value, 10)
  1384. # Change one the value of the target and check that it has no impact on the source
  1385. b.value = 6
  1386. self.assertEqual(a.value, 5)
  1387. def test_link_different(self):
  1388. """Verify two traitlets of different types can be linked together using link."""
  1389. # Create two simple classes with Int traitlets.
  1390. class A(HasTraits):
  1391. value = Int()
  1392. class B(HasTraits):
  1393. count = Int()
  1394. a = A(value=9)
  1395. b = B(count=8)
  1396. # Conenct the two classes.
  1397. c = directional_link((a, 'value'), (b, 'count'))
  1398. # Make sure the values are the same at the point of linking.
  1399. self.assertEqual(a.value, b.count)
  1400. # Change one the value of the source and check that it synchronizes the target.
  1401. a.value = 5
  1402. self.assertEqual(b.count, 5)
  1403. # Change one the value of the target and check that it has no impact on the source
  1404. b.value = 6
  1405. self.assertEqual(a.value, 5)
  1406. def test_unlink(self):
  1407. """Verify two linked traitlets can be unlinked."""
  1408. # Create two simple classes with Int traitlets.
  1409. class A(HasTraits):
  1410. value = Int()
  1411. a = A(value=9)
  1412. b = A(value=8)
  1413. # Connect the two classes.
  1414. c = directional_link((a, 'value'), (b, 'value'))
  1415. a.value = 4
  1416. c.unlink()
  1417. # Change one of the values to make sure they don't stay in sync.
  1418. a.value = 5
  1419. self.assertNotEqual(a.value, b.value)
  1420. class Pickleable(HasTraits):
  1421. i = Int()
  1422. @observe('i')
  1423. def _i_changed(self, change): pass
  1424. @validate('i')
  1425. def _i_validate(self, commit):
  1426. return commit['value']
  1427. j = Int()
  1428. def __init__(self):
  1429. with self.hold_trait_notifications():
  1430. self.i = 1
  1431. self.on_trait_change(self._i_changed, 'i')
  1432. def test_pickle_hastraits():
  1433. c = Pickleable()
  1434. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  1435. p = pickle.dumps(c, protocol)
  1436. c2 = pickle.loads(p)
  1437. assert c2.i == c.i
  1438. assert c2.j == c.j
  1439. c.i = 5
  1440. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  1441. p = pickle.dumps(c, protocol)
  1442. c2 = pickle.loads(p)
  1443. assert c2.i == c.i
  1444. assert c2.j == c.j
  1445. def test_hold_trait_notifications():
  1446. changes = []
  1447. class Test(HasTraits):
  1448. a = Integer(0)
  1449. b = Integer(0)
  1450. def _a_changed(self, name, old, new):
  1451. changes.append((old, new))
  1452. def _b_validate(self, value, trait):
  1453. if value != 0:
  1454. raise TraitError('Only 0 is a valid value')
  1455. return value
  1456. # Test context manager and nesting
  1457. t = Test()
  1458. with t.hold_trait_notifications():
  1459. with t.hold_trait_notifications():
  1460. t.a = 1
  1461. assert t.a == 1
  1462. assert changes == []
  1463. t.a = 2
  1464. assert t.a == 2
  1465. with t.hold_trait_notifications():
  1466. t.a = 3
  1467. assert t.a == 3
  1468. assert changes == []
  1469. t.a = 4
  1470. assert t.a == 4
  1471. assert changes == []
  1472. t.a = 4
  1473. assert t.a == 4
  1474. assert changes == []
  1475. assert changes == [(0, 4)]
  1476. # Test roll-back
  1477. try:
  1478. with t.hold_trait_notifications():
  1479. t.b = 1 # raises a Trait error
  1480. except:
  1481. pass
  1482. assert t.b == 0
  1483. class RollBack(HasTraits):
  1484. bar = Int()
  1485. def _bar_validate(self, value, trait):
  1486. if value:
  1487. raise TraitError('foobar')
  1488. return value
  1489. class TestRollback(TestCase):
  1490. def test_roll_back(self):
  1491. def assign_rollback():
  1492. RollBack(bar=1)
  1493. self.assertRaises(TraitError, assign_rollback)
  1494. class CacheModification(HasTraits):
  1495. foo = Int()
  1496. bar = Int()
  1497. def _bar_validate(self, value, trait):
  1498. self.foo = value
  1499. return value
  1500. def _foo_validate(self, value, trait):
  1501. self.bar = value
  1502. return value
  1503. def test_cache_modification():
  1504. CacheModification(foo=1)
  1505. CacheModification(bar=1)
  1506. class OrderTraits(HasTraits):
  1507. notified = Dict()
  1508. a = Unicode()
  1509. b = Unicode()
  1510. c = Unicode()
  1511. d = Unicode()
  1512. e = Unicode()
  1513. f = Unicode()
  1514. g = Unicode()
  1515. h = Unicode()
  1516. i = Unicode()
  1517. j = Unicode()
  1518. k = Unicode()
  1519. l = Unicode()
  1520. def _notify(self, name, old, new):
  1521. """check the value of all traits when each trait change is triggered
  1522. This verifies that the values are not sensitive
  1523. to dict ordering when loaded from kwargs
  1524. """
  1525. # check the value of the other traits
  1526. # when a given trait change notification fires
  1527. self.notified[name] = {
  1528. c: getattr(self, c) for c in 'abcdefghijkl'
  1529. }
  1530. def __init__(self, **kwargs):
  1531. self.on_trait_change(self._notify)
  1532. super(OrderTraits, self).__init__(**kwargs)
  1533. def test_notification_order():
  1534. d = {c:c for c in 'abcdefghijkl'}
  1535. obj = OrderTraits()
  1536. assert obj.notified == {}
  1537. obj = OrderTraits(**d)
  1538. notifications = {
  1539. c: d for c in 'abcdefghijkl'
  1540. }
  1541. assert obj.notified == notifications
  1542. ###
  1543. # Traits for Forward Declaration Tests
  1544. ###
  1545. class ForwardDeclaredInstanceTrait(HasTraits):
  1546. value = ForwardDeclaredInstance('ForwardDeclaredBar', allow_none=True)
  1547. class ForwardDeclaredTypeTrait(HasTraits):
  1548. value = ForwardDeclaredType('ForwardDeclaredBar', allow_none=True)
  1549. class ForwardDeclaredInstanceListTrait(HasTraits):
  1550. value = List(ForwardDeclaredInstance('ForwardDeclaredBar'))
  1551. class ForwardDeclaredTypeListTrait(HasTraits):
  1552. value = List(ForwardDeclaredType('ForwardDeclaredBar'))
  1553. ###
  1554. # End Traits for Forward Declaration Tests
  1555. ###
  1556. ###
  1557. # Classes for Forward Declaration Tests
  1558. ###
  1559. class ForwardDeclaredBar(object):
  1560. pass
  1561. class ForwardDeclaredBarSub(ForwardDeclaredBar):
  1562. pass
  1563. ###
  1564. # End Classes for Forward Declaration Tests
  1565. ###
  1566. ###
  1567. # Forward Declaration Tests
  1568. ###
  1569. class TestForwardDeclaredInstanceTrait(TraitTestBase):
  1570. obj = ForwardDeclaredInstanceTrait()
  1571. _default_value = None
  1572. _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
  1573. _bad_values = ['foo', 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
  1574. class TestForwardDeclaredTypeTrait(TraitTestBase):
  1575. obj = ForwardDeclaredTypeTrait()
  1576. _default_value = None
  1577. _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
  1578. _bad_values = ['foo', 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
  1579. class TestForwardDeclaredInstanceList(TraitTestBase):
  1580. obj = ForwardDeclaredInstanceListTrait()
  1581. def test_klass(self):
  1582. """Test that the instance klass is properly assigned."""
  1583. self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
  1584. _default_value = []
  1585. _good_values = [
  1586. [ForwardDeclaredBar(), ForwardDeclaredBarSub()],
  1587. [],
  1588. ]
  1589. _bad_values = [
  1590. ForwardDeclaredBar(),
  1591. [ForwardDeclaredBar(), 3, None],
  1592. '1',
  1593. # Note that this is the type, not an instance.
  1594. [ForwardDeclaredBar],
  1595. [None],
  1596. None,
  1597. ]
  1598. class TestForwardDeclaredTypeList(TraitTestBase):
  1599. obj = ForwardDeclaredTypeListTrait()
  1600. def test_klass(self):
  1601. """Test that the instance klass is properly assigned."""
  1602. self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
  1603. _default_value = []
  1604. _good_values = [
  1605. [ForwardDeclaredBar, ForwardDeclaredBarSub],
  1606. [],
  1607. ]
  1608. _bad_values = [
  1609. ForwardDeclaredBar,
  1610. [ForwardDeclaredBar, 3],
  1611. '1',
  1612. # Note that this is an instance, not the type.
  1613. [ForwardDeclaredBar()],
  1614. [None],
  1615. None,
  1616. ]
  1617. ###
  1618. # End Forward Declaration Tests
  1619. ###
  1620. class TestDynamicTraits(TestCase):
  1621. def setUp(self):
  1622. self._notify1 = []
  1623. def notify1(self, name, old, new):
  1624. self._notify1.append((name, old, new))
  1625. def test_notify_all(self):
  1626. class A(HasTraits):
  1627. pass
  1628. a = A()
  1629. self.assertTrue(not hasattr(a, 'x'))
  1630. self.assertTrue(not hasattr(a, 'y'))
  1631. # Dynamically add trait x.
  1632. a.add_traits(x=Int())
  1633. self.assertTrue(hasattr(a, 'x'))
  1634. self.assertTrue(isinstance(a, (A, )))
  1635. # Dynamically add trait y.
  1636. a.add_traits(y=Float())
  1637. self.assertTrue(hasattr(a, 'y'))
  1638. self.assertTrue(isinstance(a, (A, )))
  1639. self.assertEqual(a.__class__.__name__, A.__name__)
  1640. # Create a new instance and verify that x and y
  1641. # aren't defined.
  1642. b = A()
  1643. self.assertTrue(not hasattr(b, 'x'))
  1644. self.assertTrue(not hasattr(b, 'y'))
  1645. # Verify that notification works like normal.
  1646. a.on_trait_change(self.notify1)
  1647. a.x = 0
  1648. self.assertEqual(len(self._notify1), 0)
  1649. a.y = 0.0
  1650. self.assertEqual(len(self._notify1), 0)
  1651. a.x = 10
  1652. self.assertTrue(('x', 0, 10) in self._notify1)
  1653. a.y = 10.0
  1654. self.assertTrue(('y', 0.0, 10.0) in self._notify1)
  1655. self.assertRaises(TraitError, setattr, a, 'x', 'bad string')
  1656. self.assertRaises(TraitError, setattr, a, 'y', 'bad string')
  1657. self._notify1 = []
  1658. a.on_trait_change(self.notify1, remove=True)
  1659. a.x = 20
  1660. a.y = 20.0
  1661. self.assertEqual(len(self._notify1), 0)
  1662. def test_enum_no_default():
  1663. class C(HasTraits):
  1664. t = Enum(['a', 'b'])
  1665. c = C()
  1666. c.t = 'a'
  1667. assert c.t == 'a'
  1668. c = C()
  1669. with pytest.raises(TraitError):
  1670. t = c.t
  1671. c = C(t='b')
  1672. assert c.t == 'b'
  1673. def test_default_value_repr():
  1674. class C(HasTraits):
  1675. t = Type('traitlets.HasTraits')
  1676. t2 = Type(HasTraits)
  1677. n = Integer(0)
  1678. lis = List()
  1679. d = Dict()
  1680. assert C.t.default_value_repr() == "'traitlets.HasTraits'"
  1681. assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'"
  1682. assert C.n.default_value_repr() == '0'
  1683. assert C.lis.default_value_repr() == '[]'
  1684. assert C.d.default_value_repr() == '{}'
  1685. class TransitionalClass(HasTraits):
  1686. d = Any()
  1687. @default('d')
  1688. def _d_default(self):
  1689. return TransitionalClass
  1690. parent_super = False
  1691. calls_super = Integer(0)
  1692. @default('calls_super')
  1693. def _calls_super_default(self):
  1694. return -1
  1695. @observe('calls_super')
  1696. @observe_compat
  1697. def _calls_super_changed(self, change):
  1698. self.parent_super = change
  1699. parent_override = False
  1700. overrides = Integer(0)
  1701. @observe('overrides')
  1702. @observe_compat
  1703. def _overrides_changed(self, change):
  1704. self.parent_override = change
  1705. class SubClass(TransitionalClass):
  1706. def _d_default(self):
  1707. return SubClass
  1708. subclass_super = False
  1709. def _calls_super_changed(self, name, old, new):
  1710. self.subclass_super = True
  1711. super(SubClass, self)._calls_super_changed(name, old, new)
  1712. subclass_override = False
  1713. def _overrides_changed(self, name, old, new):
  1714. self.subclass_override = True
  1715. def test_subclass_compat():
  1716. obj = SubClass()
  1717. obj.calls_super = 5
  1718. assert obj.parent_super
  1719. assert obj.subclass_super
  1720. obj.overrides = 5
  1721. assert obj.subclass_override
  1722. assert not obj.parent_override
  1723. assert obj.d is SubClass
  1724. class DefinesHandler(HasTraits):
  1725. parent_called = False
  1726. trait = Integer()
  1727. @observe('trait')
  1728. def handler(self, change):
  1729. self.parent_called = True
  1730. class OverridesHandler(DefinesHandler):
  1731. child_called = False
  1732. @observe('trait')
  1733. def handler(self, change):
  1734. self.child_called = True
  1735. def test_subclass_override_observer():
  1736. obj = OverridesHandler()
  1737. obj.trait = 5
  1738. assert obj.child_called
  1739. assert not obj.parent_called
  1740. class DoesntRegisterHandler(DefinesHandler):
  1741. child_called = False
  1742. def handler(self, change):
  1743. self.child_called = True
  1744. def test_subclass_override_not_registered():
  1745. """Subclass that overrides observer and doesn't re-register unregisters both"""
  1746. obj = DoesntRegisterHandler()
  1747. obj.trait = 5
  1748. assert not obj.child_called
  1749. assert not obj.parent_called
  1750. class AddsHandler(DefinesHandler):
  1751. child_called = False
  1752. @observe('trait')
  1753. def child_handler(self, change):
  1754. self.child_called = True
  1755. def test_subclass_add_observer():
  1756. obj = AddsHandler()
  1757. obj.trait = 5
  1758. assert obj.child_called
  1759. assert obj.parent_called
  1760. def test_observe_iterables():
  1761. class C(HasTraits):
  1762. i = Integer()
  1763. s = Unicode()
  1764. c = C()
  1765. recorded = {}
  1766. def record(change):
  1767. recorded['change'] = change
  1768. # observe with names=set
  1769. c.observe(record, names={'i', 's'})
  1770. c.i = 5
  1771. assert recorded['change'].name == 'i'
  1772. assert recorded['change'].new == 5
  1773. c.s = 'hi'
  1774. assert recorded['change'].name == 's'
  1775. assert recorded['change'].new == 'hi'
  1776. # observe with names=custom container with iter, contains
  1777. class MyContainer(object):
  1778. def __init__(self, container):
  1779. self.container = container
  1780. def __iter__(self):
  1781. return iter(self.container)
  1782. def __contains__(self, key):
  1783. return key in self.container
  1784. c.observe(record, names=MyContainer({'i', 's'}))
  1785. c.i = 10
  1786. assert recorded['change'].name == 'i'
  1787. assert recorded['change'].new == 10
  1788. c.s = 'ok'
  1789. assert recorded['change'].name == 's'
  1790. assert recorded['change'].new == 'ok'
  1791. def test_super_args():
  1792. class SuperRecorder(object):
  1793. def __init__(self, *args, **kwargs):
  1794. self.super_args = args
  1795. self.super_kwargs = kwargs
  1796. class SuperHasTraits(HasTraits, SuperRecorder):
  1797. i = Integer()
  1798. obj = SuperHasTraits('a1', 'a2', b=10, i=5, c='x')
  1799. assert obj.i == 5
  1800. assert not hasattr(obj, 'b')
  1801. assert not hasattr(obj, 'c')
  1802. assert obj.super_args == ('a1' , 'a2')
  1803. assert obj.super_kwargs == {'b': 10 , 'c': 'x'}
  1804. def test_super_bad_args():
  1805. class SuperHasTraits(HasTraits):
  1806. a = Integer()
  1807. if sys.version_info < (3,):
  1808. # Legacy Python, object.__init__ warns itself, instead of raising
  1809. w = ['object.__init__']
  1810. else:
  1811. w = ["Passing unrecoginized arguments"]
  1812. with expected_warnings(w):
  1813. obj = SuperHasTraits(a=1, b=2)
  1814. assert obj.a == 1
  1815. assert not hasattr(obj, 'b')