"""Tests for traitlets.traitlets.""" # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. # # Adapted from enthought.traits, Copyright (c) Enthought, Inc., # also under the terms of the Modified BSD License. import pickle import re from unittest import TestCase import pytest from traitlets import ( All, Any, BaseDescriptor, Bool, Bytes, Callable, CBytes, CFloat, CInt, CLong, Complex, CRegExp, CUnicode, Dict, DottedObjectName, Enum, Float, ForwardDeclaredInstance, ForwardDeclaredType, HasDescriptors, HasTraits, Instance, Int, Integer, List, Long, MetaHasTraits, ObjectName, Set, TCPAddress, This, TraitError, TraitType, Tuple, Type, Undefined, Unicode, Union, default, directional_link, link, observe, observe_compat, traitlets, validate, ) from traitlets.utils import cast_unicode from ._warnings import expected_warnings def change_dict(*ordered_values): change_names = ("name", "old", "new", "owner", "type") return dict(zip(change_names, ordered_values)) # ----------------------------------------------------------------------------- # Helper classes for testing # ----------------------------------------------------------------------------- class HasTraitsStub(HasTraits): def notify_change(self, change): self._notify_name = change["name"] self._notify_old = change["old"] self._notify_new = change["new"] self._notify_type = change["type"] # ----------------------------------------------------------------------------- # Test classes # ----------------------------------------------------------------------------- class TestTraitType(TestCase): def test_get_undefined(self): class A(HasTraits): a = TraitType a = A() assert a.a is Undefined def test_set(self): class A(HasTraitsStub): a = TraitType a = A() a.a = 10 self.assertEqual(a.a, 10) self.assertEqual(a._notify_name, "a") self.assertEqual(a._notify_old, Undefined) self.assertEqual(a._notify_new, 10) def test_validate(self): class MyTT(TraitType): def validate(self, inst, value): return -1 class A(HasTraitsStub): tt = MyTT a = A() a.tt = 10 self.assertEqual(a.tt, -1) def test_default_validate(self): class MyIntTT(TraitType): def validate(self, obj, value): if isinstance(value, int): return value self.error(obj, value) class A(HasTraits): tt = MyIntTT(10) a = A() self.assertEqual(a.tt, 10) # Defaults are validated when the HasTraits is instantiated class B(HasTraits): tt = MyIntTT("bad default") self.assertRaises(TraitError, getattr, B(), "tt") def test_info(self): class A(HasTraits): tt = TraitType a = A() self.assertEqual(A.tt.info(), "any value") def test_error(self): class A(HasTraits): tt = TraitType() a = A() self.assertRaises(TraitError, A.tt.error, a, 10) def test_deprecated_dynamic_initializer(self): class A(HasTraits): x = Int(10) def _x_default(self): return 11 class B(A): x = Int(20) class C(A): def _x_default(self): return 21 a = A() self.assertEqual(a._trait_values, {}) self.assertEqual(a.x, 11) self.assertEqual(a._trait_values, {"x": 11}) b = B() self.assertEqual(b.x, 20) self.assertEqual(b._trait_values, {"x": 20}) c = C() self.assertEqual(c._trait_values, {}) self.assertEqual(c.x, 21) self.assertEqual(c._trait_values, {"x": 21}) # Ensure that the base class remains unmolested when the _default # initializer gets overridden in a subclass. a = A() c = C() self.assertEqual(a._trait_values, {}) self.assertEqual(a.x, 11) self.assertEqual(a._trait_values, {"x": 11}) def test_deprecated_method_warnings(self): with expected_warnings([]): class ShouldntWarn(HasTraits): x = Integer() @default("x") def _x_default(self): return 10 @validate("x") def _x_validate(self, proposal): return proposal.value @observe("x") def _x_changed(self, change): pass obj = ShouldntWarn() obj.x = 5 assert obj.x == 5 with expected_warnings(["@validate", "@observe"]) as w: class ShouldWarn(HasTraits): x = Integer() def _x_default(self): return 10 def _x_validate(self, value, _): return value def _x_changed(self): pass obj = ShouldWarn() obj.x = 5 assert obj.x == 5 def test_dynamic_initializer(self): class A(HasTraits): x = Int(10) @default("x") def _default_x(self): return 11 class B(A): x = Int(20) class C(A): @default("x") def _default_x(self): return 21 a = A() self.assertEqual(a._trait_values, {}) self.assertEqual(a.x, 11) self.assertEqual(a._trait_values, {"x": 11}) b = B() self.assertEqual(b.x, 20) self.assertEqual(b._trait_values, {"x": 20}) c = C() self.assertEqual(c._trait_values, {}) self.assertEqual(c.x, 21) self.assertEqual(c._trait_values, {"x": 21}) # Ensure that the base class remains unmolested when the _default # initializer gets overridden in a subclass. a = A() c = C() self.assertEqual(a._trait_values, {}) self.assertEqual(a.x, 11) self.assertEqual(a._trait_values, {"x": 11}) def test_tag_metadata(self): class MyIntTT(TraitType): metadata = {"a": 1, "b": 2} a = MyIntTT(10).tag(b=3, c=4) self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4}) def test_metadata_localized_instance(self): class MyIntTT(TraitType): metadata = {"a": 1, "b": 2} a = MyIntTT(10) b = MyIntTT(10) a.metadata["c"] = 3 # make sure that changing a's metadata didn't change b's metadata self.assertNotIn("c", b.metadata) def test_union_metadata(self): class Foo(HasTraits): bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a") foo = Foo() # At this point, no value has been set for bar, so value-specific # is not set. self.assertEqual(foo.trait_metadata("bar", "ta"), None) self.assertEqual(foo.trait_metadata("bar", "ti"), "a") foo.bar = {} self.assertEqual(foo.trait_metadata("bar", "ta"), 2) self.assertEqual(foo.trait_metadata("bar", "ti"), "b") foo.bar = 1 self.assertEqual(foo.trait_metadata("bar", "ta"), 1) self.assertEqual(foo.trait_metadata("bar", "ti"), "a") def test_union_default_value(self): class Foo(HasTraits): bar = Union([Dict(), Int()], default_value=1) foo = Foo() self.assertEqual(foo.bar, 1) def test_union_validation_priority(self): class Foo(HasTraits): bar = Union([CInt(), Unicode()]) foo = Foo() foo.bar = "1" # validation in order of the TraitTypes given self.assertEqual(foo.bar, 1) def test_union_trait_default_value(self): class Foo(HasTraits): bar = Union([Dict(), Int()]) self.assertEqual(Foo().bar, {}) def test_deprecated_metadata_access(self): class MyIntTT(TraitType): metadata = {"a": 1, "b": 2} a = MyIntTT(10) with expected_warnings(["use the instance .metadata dictionary directly"] * 2): a.set_metadata("key", "value") v = a.get_metadata("key") self.assertEqual(v, "value") with expected_warnings(["use the instance .help string directly"] * 2): a.set_metadata("help", "some help") v = a.get_metadata("help") self.assertEqual(v, "some help") def test_trait_types_deprecated(self): with expected_warnings(["Traits should be given as instances"]): class C(HasTraits): t = Int def test_trait_types_list_deprecated(self): with expected_warnings(["Traits should be given as instances"]): class C(HasTraits): t = List(Int) def test_trait_types_tuple_deprecated(self): with expected_warnings(["Traits should be given as instances"]): class C(HasTraits): t = Tuple(Int) def test_trait_types_dict_deprecated(self): with expected_warnings(["Traits should be given as instances"]): class C(HasTraits): t = Dict(Int) class TestHasDescriptorsMeta(TestCase): def test_metaclass(self): self.assertEqual(type(HasTraits), MetaHasTraits) class A(HasTraits): a = Int() a = A() self.assertEqual(type(a.__class__), MetaHasTraits) self.assertEqual(a.a, 0) a.a = 10 self.assertEqual(a.a, 10) class B(HasTraits): b = Int() b = B() self.assertEqual(b.b, 0) b.b = 10 self.assertEqual(b.b, 10) class C(HasTraits): c = Int(30) c = C() self.assertEqual(c.c, 30) c.c = 10 self.assertEqual(c.c, 10) def test_this_class(self): class A(HasTraits): t = This() tt = This() class B(A): tt = This() ttt = This() self.assertEqual(A.t.this_class, A) self.assertEqual(B.t.this_class, A) self.assertEqual(B.tt.this_class, B) self.assertEqual(B.ttt.this_class, B) class TestHasDescriptors(TestCase): def test_setup_instance(self): class FooDescriptor(BaseDescriptor): def instance_init(self, inst): foo = inst.foo # instance should have the attr class HasFooDescriptors(HasDescriptors): fd = FooDescriptor() def setup_instance(self, *args, **kwargs): self.foo = kwargs.get("foo", None) super().setup_instance(*args, **kwargs) hfd = HasFooDescriptors(foo="bar") class TestHasTraitsNotify(TestCase): def setUp(self): self._notify1 = [] self._notify2 = [] def notify1(self, name, old, new): self._notify1.append((name, old, new)) def notify2(self, name, old, new): self._notify2.append((name, old, new)) def test_notify_all(self): class A(HasTraits): a = Int() b = Float() a = A() a.on_trait_change(self.notify1) a.a = 0 self.assertEqual(len(self._notify1), 0) a.b = 0.0 self.assertEqual(len(self._notify1), 0) a.a = 10 self.assertTrue(("a", 0, 10) in self._notify1) a.b = 10.0 self.assertTrue(("b", 0.0, 10.0) in self._notify1) self.assertRaises(TraitError, setattr, a, "a", "bad string") self.assertRaises(TraitError, setattr, a, "b", "bad string") self._notify1 = [] a.on_trait_change(self.notify1, remove=True) a.a = 20 a.b = 20.0 self.assertEqual(len(self._notify1), 0) def test_notify_one(self): class A(HasTraits): a = Int() b = Float() a = A() a.on_trait_change(self.notify1, "a") a.a = 0 self.assertEqual(len(self._notify1), 0) a.a = 10 self.assertTrue(("a", 0, 10) in self._notify1) self.assertRaises(TraitError, setattr, a, "a", "bad string") def test_subclass(self): class A(HasTraits): a = Int() class B(A): b = Float() b = B() self.assertEqual(b.a, 0) self.assertEqual(b.b, 0.0) b.a = 100 b.b = 100.0 self.assertEqual(b.a, 100) self.assertEqual(b.b, 100.0) def test_notify_subclass(self): class A(HasTraits): a = Int() class B(A): b = Float() b = B() b.on_trait_change(self.notify1, "a") b.on_trait_change(self.notify2, "b") b.a = 0 b.b = 0.0 self.assertEqual(len(self._notify1), 0) self.assertEqual(len(self._notify2), 0) b.a = 10 b.b = 10.0 self.assertTrue(("a", 0, 10) in self._notify1) self.assertTrue(("b", 0.0, 10.0) in self._notify2) def test_static_notify(self): class A(HasTraits): a = Int() _notify1 = [] def _a_changed(self, name, old, new): self._notify1.append((name, old, new)) a = A() a.a = 0 # This is broken!!! self.assertEqual(len(a._notify1), 0) a.a = 10 self.assertTrue(("a", 0, 10) in a._notify1) class B(A): b = Float() _notify2 = [] def _b_changed(self, name, old, new): self._notify2.append((name, old, new)) b = B() b.a = 10 b.b = 10.0 self.assertTrue(("a", 0, 10) in b._notify1) self.assertTrue(("b", 0.0, 10.0) in b._notify2) def test_notify_args(self): def callback0(): self.cb = () def callback1(name): self.cb = (name,) def callback2(name, new): self.cb = (name, new) def callback3(name, old, new): self.cb = (name, old, new) def callback4(name, old, new, obj): self.cb = (name, old, new, obj) class A(HasTraits): a = Int() a = A() a.on_trait_change(callback0, "a") a.a = 10 self.assertEqual(self.cb, ()) a.on_trait_change(callback0, "a", remove=True) a.on_trait_change(callback1, "a") a.a = 100 self.assertEqual(self.cb, ("a",)) a.on_trait_change(callback1, "a", remove=True) a.on_trait_change(callback2, "a") a.a = 1000 self.assertEqual(self.cb, ("a", 1000)) a.on_trait_change(callback2, "a", remove=True) a.on_trait_change(callback3, "a") a.a = 10000 self.assertEqual(self.cb, ("a", 1000, 10000)) a.on_trait_change(callback3, "a", remove=True) a.on_trait_change(callback4, "a") a.a = 100000 self.assertEqual(self.cb, ("a", 10000, 100000, a)) self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) a.on_trait_change(callback4, "a", remove=True) self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) def test_notify_only_once(self): class A(HasTraits): listen_to = ["a"] a = Int(0) b = 0 def __init__(self, **kwargs): super().__init__(**kwargs) self.on_trait_change(self.listener1, ["a"]) def listener1(self, name, old, new): self.b += 1 class B(A): c = 0 d = 0 def __init__(self, **kwargs): super().__init__(**kwargs) self.on_trait_change(self.listener2) def listener2(self, name, old, new): self.c += 1 def _a_changed(self, name, old, new): self.d += 1 b = B() b.a += 1 self.assertEqual(b.b, b.c) self.assertEqual(b.b, b.d) b.a += 1 self.assertEqual(b.b, b.c) self.assertEqual(b.b, b.d) class TestObserveDecorator(TestCase): def setUp(self): self._notify1 = [] self._notify2 = [] def notify1(self, change): self._notify1.append(change) def notify2(self, change): self._notify2.append(change) def test_notify_all(self): class A(HasTraits): a = Int() b = Float() a = A() a.observe(self.notify1) a.a = 0 self.assertEqual(len(self._notify1), 0) a.b = 0.0 self.assertEqual(len(self._notify1), 0) a.a = 10 change = change_dict("a", 0, 10, a, "change") self.assertTrue(change in self._notify1) a.b = 10.0 change = change_dict("b", 0.0, 10.0, a, "change") self.assertTrue(change in self._notify1) self.assertRaises(TraitError, setattr, a, "a", "bad string") self.assertRaises(TraitError, setattr, a, "b", "bad string") self._notify1 = [] a.unobserve(self.notify1) a.a = 20 a.b = 20.0 self.assertEqual(len(self._notify1), 0) def test_notify_one(self): class A(HasTraits): a = Int() b = Float() a = A() a.observe(self.notify1, "a") a.a = 0 self.assertEqual(len(self._notify1), 0) a.a = 10 change = change_dict("a", 0, 10, a, "change") self.assertTrue(change in self._notify1) self.assertRaises(TraitError, setattr, a, "a", "bad string") def test_subclass(self): class A(HasTraits): a = Int() class B(A): b = Float() b = B() self.assertEqual(b.a, 0) self.assertEqual(b.b, 0.0) b.a = 100 b.b = 100.0 self.assertEqual(b.a, 100) self.assertEqual(b.b, 100.0) def test_notify_subclass(self): class A(HasTraits): a = Int() class B(A): b = Float() b = B() b.observe(self.notify1, "a") b.observe(self.notify2, "b") b.a = 0 b.b = 0.0 self.assertEqual(len(self._notify1), 0) self.assertEqual(len(self._notify2), 0) b.a = 10 b.b = 10.0 change = change_dict("a", 0, 10, b, "change") self.assertTrue(change in self._notify1) change = change_dict("b", 0.0, 10.0, b, "change") self.assertTrue(change in self._notify2) def test_static_notify(self): class A(HasTraits): a = Int() b = Int() _notify1 = [] _notify_any = [] @observe("a") def _a_changed(self, change): self._notify1.append(change) @observe(All) def _any_changed(self, change): self._notify_any.append(change) a = A() a.a = 0 self.assertEqual(len(a._notify1), 0) a.a = 10 change = change_dict("a", 0, 10, a, "change") self.assertTrue(change in a._notify1) a.b = 1 self.assertEqual(len(a._notify_any), 2) change = change_dict("b", 0, 1, a, "change") self.assertTrue(change in a._notify_any) class B(A): b = Float() _notify2 = [] @observe("b") def _b_changed(self, change): self._notify2.append(change) b = B() b.a = 10 b.b = 10.0 change = change_dict("a", 0, 10, b, "change") self.assertTrue(change in b._notify1) change = change_dict("b", 0.0, 10.0, b, "change") self.assertTrue(change in b._notify2) def test_notify_args(self): def callback0(): self.cb = () def callback1(change): self.cb = change class A(HasTraits): a = Int() a = A() a.on_trait_change(callback0, "a") a.a = 10 self.assertEqual(self.cb, ()) a.unobserve(callback0, "a") a.observe(callback1, "a") a.a = 100 change = change_dict("a", 10, 100, a, "change") self.assertEqual(self.cb, change) self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) a.unobserve(callback1, "a") self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) def test_notify_only_once(self): class A(HasTraits): listen_to = ["a"] a = Int(0) b = 0 def __init__(self, **kwargs): super().__init__(**kwargs) self.observe(self.listener1, ["a"]) def listener1(self, change): self.b += 1 class B(A): c = 0 d = 0 def __init__(self, **kwargs): super().__init__(**kwargs) self.observe(self.listener2) def listener2(self, change): self.c += 1 @observe("a") def _a_changed(self, change): self.d += 1 b = B() b.a += 1 self.assertEqual(b.b, b.c) self.assertEqual(b.b, b.d) b.a += 1 self.assertEqual(b.b, b.c) self.assertEqual(b.b, b.d) class TestHasTraits(TestCase): def test_trait_names(self): class A(HasTraits): i = Int() f = Float() a = A() self.assertEqual(sorted(a.trait_names()), ["f", "i"]) self.assertEqual(sorted(A.class_trait_names()), ["f", "i"]) self.assertTrue(a.has_trait("f")) self.assertFalse(a.has_trait("g")) def test_trait_has_value(self): class A(HasTraits): i = Int() f = Float() a = A() self.assertFalse(a.trait_has_value("f")) self.assertFalse(a.trait_has_value("g")) a.i = 1 a.f self.assertTrue(a.trait_has_value("i")) self.assertTrue(a.trait_has_value("f")) def test_trait_metadata_deprecated(self): with expected_warnings([r"metadata should be set using the \.tag\(\) method"]): class A(HasTraits): i = Int(config_key="MY_VALUE") a = A() self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") def test_trait_metadata(self): class A(HasTraits): i = Int().tag(config_key="MY_VALUE") a = A() self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") def test_trait_metadata_default(self): class A(HasTraits): i = Int() a = A() self.assertEqual(a.trait_metadata("i", "config_key"), None) self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default") def test_traits(self): class A(HasTraits): i = Int() f = Float() a = A() self.assertEqual(a.traits(), dict(i=A.i, f=A.f)) self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f)) def test_traits_metadata(self): class A(HasTraits): i = Int().tag(config_key="VALUE1", other_thing="VALUE2") f = Float().tag(config_key="VALUE3", other_thing="VALUE2") j = Int(0) a = A() self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) traits = a.traits(config_key="VALUE1", other_thing="VALUE2") self.assertEqual(traits, dict(i=A.i)) # This passes, but it shouldn't because I am replicating a bug in # traits. traits = a.traits(config_key=lambda v: True) self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) def test_traits_metadata_deprecated(self): with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2): class A(HasTraits): i = Int(config_key="VALUE1", other_thing="VALUE2") f = Float(config_key="VALUE3", other_thing="VALUE2") j = Int(0) a = A() self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) traits = a.traits(config_key="VALUE1", other_thing="VALUE2") self.assertEqual(traits, dict(i=A.i)) # This passes, but it shouldn't because I am replicating a bug in # traits. traits = a.traits(config_key=lambda v: True) self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) def test_init(self): class A(HasTraits): i = Int() x = Float() a = A(i=1, x=10.0) self.assertEqual(a.i, 1) self.assertEqual(a.x, 10.0) def test_positional_args(self): class A(HasTraits): i = Int(0) def __init__(self, i): super().__init__() self.i = i a = A(5) self.assertEqual(a.i, 5) # should raise TypeError if no positional arg given self.assertRaises(TypeError, A) # ----------------------------------------------------------------------------- # Tests for specific trait types # ----------------------------------------------------------------------------- class TestType(TestCase): def test_default(self): class B: pass class A(HasTraits): klass = Type(allow_none=True) a = A() self.assertEqual(a.klass, object) a.klass = B self.assertEqual(a.klass, B) self.assertRaises(TraitError, setattr, a, "klass", 10) def test_default_options(self): class B: pass class C(B): pass class A(HasTraits): # Different possible combinations of options for default_value # and klass. default_value=None is only valid with allow_none=True. k1 = Type() k2 = Type(None, allow_none=True) k3 = Type(B) k4 = Type(klass=B) k5 = Type(default_value=None, klass=B, allow_none=True) k6 = Type(default_value=C, klass=B) self.assertIs(A.k1.default_value, object) self.assertIs(A.k1.klass, object) self.assertIs(A.k2.default_value, None) self.assertIs(A.k2.klass, object) self.assertIs(A.k3.default_value, B) self.assertIs(A.k3.klass, B) self.assertIs(A.k4.default_value, B) self.assertIs(A.k4.klass, B) self.assertIs(A.k5.default_value, None) self.assertIs(A.k5.klass, B) self.assertIs(A.k6.default_value, C) self.assertIs(A.k6.klass, B) a = A() self.assertIs(a.k1, object) self.assertIs(a.k2, None) self.assertIs(a.k3, B) self.assertIs(a.k4, B) self.assertIs(a.k5, None) self.assertIs(a.k6, C) def test_value(self): class B: pass class C: pass class A(HasTraits): klass = Type(B) a = A() self.assertEqual(a.klass, B) self.assertRaises(TraitError, setattr, a, "klass", C) self.assertRaises(TraitError, setattr, a, "klass", object) a.klass = B def test_allow_none(self): class B: pass class C(B): pass class A(HasTraits): klass = Type(B) a = A() self.assertEqual(a.klass, B) self.assertRaises(TraitError, setattr, a, "klass", None) a.klass = C self.assertEqual(a.klass, C) def test_validate_klass(self): class A(HasTraits): klass = Type("no strings allowed") self.assertRaises(ImportError, A) class A(HasTraits): klass = Type("rub.adub.Duck") self.assertRaises(ImportError, A) def test_validate_default(self): class B: pass class A(HasTraits): klass = Type("bad default", B) self.assertRaises(ImportError, A) class C(HasTraits): klass = Type(None, B) self.assertRaises(TraitError, getattr, C(), "klass") def test_str_klass(self): class A(HasTraits): klass = Type("traitlets.config.Config") from traitlets.config import Config a = A() a.klass = Config self.assertEqual(a.klass, Config) self.assertRaises(TraitError, setattr, a, "klass", 10) def test_set_str_klass(self): class A(HasTraits): klass = Type() a = A(klass="traitlets.config.Config") from traitlets.config import Config self.assertEqual(a.klass, Config) class TestInstance(TestCase): def test_basic(self): class Foo: pass class Bar(Foo): pass class Bah: pass class A(HasTraits): inst = Instance(Foo, allow_none=True) a = A() self.assertTrue(a.inst is None) a.inst = Foo() self.assertTrue(isinstance(a.inst, Foo)) a.inst = Bar() self.assertTrue(isinstance(a.inst, Foo)) self.assertRaises(TraitError, setattr, a, "inst", Foo) self.assertRaises(TraitError, setattr, a, "inst", Bar) self.assertRaises(TraitError, setattr, a, "inst", Bah()) def test_default_klass(self): class Foo: pass class Bar(Foo): pass class Bah: pass class FooInstance(Instance): klass = Foo class A(HasTraits): inst = FooInstance(allow_none=True) a = A() self.assertTrue(a.inst is None) a.inst = Foo() self.assertTrue(isinstance(a.inst, Foo)) a.inst = Bar() self.assertTrue(isinstance(a.inst, Foo)) self.assertRaises(TraitError, setattr, a, "inst", Foo) self.assertRaises(TraitError, setattr, a, "inst", Bar) self.assertRaises(TraitError, setattr, a, "inst", Bah()) def test_unique_default_value(self): class Foo: pass class A(HasTraits): inst = Instance(Foo, (), {}) a = A() b = A() self.assertTrue(a.inst is not b.inst) def test_args_kw(self): class Foo: def __init__(self, c): self.c = c class Bar: pass class Bah: def __init__(self, c, d): self.c = c self.d = d class A(HasTraits): inst = Instance(Foo, (10,)) a = A() self.assertEqual(a.inst.c, 10) class B(HasTraits): inst = Instance(Bah, args=(10,), kw=dict(d=20)) b = B() self.assertEqual(b.inst.c, 10) self.assertEqual(b.inst.d, 20) class C(HasTraits): inst = Instance(Foo, allow_none=True) c = C() self.assertTrue(c.inst is None) def test_bad_default(self): class Foo: pass class A(HasTraits): inst = Instance(Foo) a = A() with self.assertRaises(TraitError): a.inst def test_instance(self): class Foo: pass def inner(): class A(HasTraits): inst = Instance(Foo()) self.assertRaises(TraitError, inner) class TestThis(TestCase): def test_this_class(self): class Foo(HasTraits): this = This() f = Foo() self.assertEqual(f.this, None) g = Foo() f.this = g self.assertEqual(f.this, g) self.assertRaises(TraitError, setattr, f, "this", 10) def test_this_inst(self): class Foo(HasTraits): this = This() f = Foo() f.this = Foo() self.assertTrue(isinstance(f.this, Foo)) def test_subclass(self): class Foo(HasTraits): t = This() class Bar(Foo): pass f = Foo() b = Bar() f.t = b b.t = f self.assertEqual(f.t, b) self.assertEqual(b.t, f) def test_subclass_override(self): class Foo(HasTraits): t = This() class Bar(Foo): t = This() f = Foo() b = Bar() f.t = b self.assertEqual(f.t, b) self.assertRaises(TraitError, setattr, b, "t", f) def test_this_in_container(self): class Tree(HasTraits): value = Unicode() leaves = List(This()) tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")]) with self.assertRaises(TraitError): tree.leaves = [1, 2] class TraitTestBase(TestCase): """A best testing class for basic trait types.""" def assign(self, value): self.obj.value = value def coerce(self, value): return value def test_good_values(self): if hasattr(self, "_good_values"): for value in self._good_values: self.assign(value) self.assertEqual(self.obj.value, self.coerce(value)) def test_bad_values(self): if hasattr(self, "_bad_values"): for value in self._bad_values: try: self.assertRaises(TraitError, self.assign, value) except AssertionError: assert False, value def test_default_value(self): if hasattr(self, "_default_value"): self.assertEqual(self._default_value, self.obj.value) def test_allow_none(self): if ( hasattr(self, "_bad_values") and hasattr(self, "_good_values") and None in self._bad_values ): trait = self.obj.traits()["value"] try: trait.allow_none = True self._bad_values.remove(None) # skip coerce. Allow None casts None to None. self.assign(None) self.assertEqual(self.obj.value, None) self.test_good_values() self.test_bad_values() finally: # tear down trait.allow_none = False self._bad_values.append(None) def tearDown(self): # restore default value after tests, if set if hasattr(self, "_default_value"): self.obj.value = self._default_value class AnyTrait(HasTraits): value = Any() class AnyTraitTest(TraitTestBase): obj = AnyTrait() _default_value = None _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j] _bad_values = [] class UnionTrait(HasTraits): value = Union([Type(), Bool()]) class UnionTraitTest(TraitTestBase): obj = UnionTrait(value="traitlets.config.Config") _good_values = [int, float, True] _bad_values = [[], (0,), 1j] class CallableTrait(HasTraits): value = Callable() class CallableTraitTest(TraitTestBase): obj = CallableTrait(value=lambda x: type(x)) _good_values = [int, sorted, lambda x: print(x)] _bad_values = [[], 1, ""] class OrTrait(HasTraits): value = Bool() | Unicode() class OrTraitTest(TraitTestBase): obj = OrTrait() _good_values = [True, False, "ten"] _bad_values = [[], (0,), 1j] class IntTrait(HasTraits): value = Int(99, min=-100) class TestInt(TraitTestBase): obj = IntTrait() _default_value = 99 _good_values = [10, -10] _bad_values = [ "ten", [10], {"ten": 10}, (10,), None, 1j, 10.1, -10.1, "10L", "-10L", "10.1", "-10.1", "10", "-10", -200, ] class CIntTrait(HasTraits): value = CInt("5") class TestCInt(TraitTestBase): obj = CIntTrait() _default_value = 5 _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] def coerce(self, n): return int(n) class MinBoundCIntTrait(HasTraits): value = CInt("5", min=3) class TestMinBoundCInt(TestCInt): obj = MinBoundCIntTrait() _default_value = 5 _good_values = [3, 3.0, "3"] _bad_values = [2.6, 2, -3, -3.0] class LongTrait(HasTraits): value = Long(99) class TestLong(TraitTestBase): obj = LongTrait() _default_value = 99 _good_values = [10, -10] _bad_values = [ "ten", [10], {"ten": 10}, (10,), None, 1j, 10.1, -10.1, "10", "-10", "10L", "-10L", "10.1", "-10.1", ] class MinBoundLongTrait(HasTraits): value = Long(99, min=5) class TestMinBoundLong(TraitTestBase): obj = MinBoundLongTrait() _default_value = 99 _good_values = [5, 10] _bad_values = [4, -10] class MaxBoundLongTrait(HasTraits): value = Long(5, max=10) class TestMaxBoundLong(TraitTestBase): obj = MaxBoundLongTrait() _default_value = 5 _good_values = [10, -2] _bad_values = [11, 20] class CLongTrait(HasTraits): value = CLong("5") class TestCLong(TraitTestBase): obj = CLongTrait() _default_value = 5 _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] def coerce(self, n): return int(n) class MaxBoundCLongTrait(HasTraits): value = CLong("5", max=10) class TestMaxBoundCLong(TestCLong): obj = MaxBoundCLongTrait() _default_value = 5 _good_values = [10, "10", 10.3] _bad_values = [11.0, "11"] class IntegerTrait(HasTraits): value = Integer(1) class TestInteger(TestLong): obj = IntegerTrait() _default_value = 1 def coerce(self, n): return int(n) class MinBoundIntegerTrait(HasTraits): value = Integer(5, min=3) class TestMinBoundInteger(TraitTestBase): obj = MinBoundIntegerTrait() _default_value = 5 _good_values = 3, 20 _bad_values = [2, -10] class MaxBoundIntegerTrait(HasTraits): value = Integer(1, max=3) class TestMaxBoundInteger(TraitTestBase): obj = MaxBoundIntegerTrait() _default_value = 1 _good_values = 3, -2 _bad_values = [4, 10] class FloatTrait(HasTraits): value = Float(99.0, max=200.0) class TestFloat(TraitTestBase): obj = FloatTrait() _default_value = 99.0 _good_values = [10, -10, 10.1, -10.1] _bad_values = [ "ten", [10], {"ten": 10}, (10,), None, 1j, "10", "-10", "10L", "-10L", "10.1", "-10.1", 201.0, ] class CFloatTrait(HasTraits): value = CFloat("99.0", max=200.0) class TestCFloat(TraitTestBase): obj = CFloatTrait() _default_value = 99.0 _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"] _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"] def coerce(self, v): return float(v) class ComplexTrait(HasTraits): value = Complex(99.0 - 99.0j) class TestComplex(TraitTestBase): obj = ComplexTrait() _default_value = 99.0 - 99.0j _good_values = [ 10, -10, 10.1, -10.1, 10j, 10 + 10j, 10 - 10j, 10.1j, 10.1 + 10.1j, 10.1 - 10.1j, ] _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] class BytesTrait(HasTraits): value = Bytes(b"string") class TestBytes(TraitTestBase): obj = BytesTrait() _default_value = b"string" _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"] _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"] class UnicodeTrait(HasTraits): value = Unicode("unicode") class TestUnicode(TraitTestBase): obj = UnicodeTrait() _default_value = "unicode" _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"] _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None] def coerce(self, v): return cast_unicode(v) class ObjectNameTrait(HasTraits): value = ObjectName("abc") class TestObjectName(TraitTestBase): obj = ObjectNameTrait() _default_value = "abc" _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"] _bad_values = [ 1, "", "€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]", None, object(), object, ] _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131). class DottedObjectNameTrait(HasTraits): value = DottedObjectName("a.b") class TestDottedObjectName(TraitTestBase): obj = DottedObjectNameTrait() _default_value = "a.b" _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"] _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] _good_values.append("t.þ") class TCPAddressTrait(HasTraits): value = TCPAddress() class TestTCPAddress(TraitTestBase): obj = TCPAddressTrait() _default_value = ("127.0.0.1", 0) _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)] _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None] class ListTrait(HasTraits): value = List(Int()) class TestList(TraitTestBase): obj = ListTrait() _default_value = [] _good_values = [[], [1], list(range(10)), (1, 2)] _bad_values = [10, [1, "a"], "a"] def coerce(self, value): if value is not None: value = list(value) return value class Foo: pass class NoneInstanceListTrait(HasTraits): value = List(Instance(Foo)) class TestNoneInstanceList(TraitTestBase): obj = NoneInstanceListTrait() _default_value = [] _good_values = [[Foo(), Foo()], []] _bad_values = [[None], [Foo(), None]] class InstanceListTrait(HasTraits): value = List(Instance(__name__ + ".Foo")) class TestInstanceList(TraitTestBase): obj = InstanceListTrait() def test_klass(self): """Test that the instance klass is properly assigned.""" self.assertIs(self.obj.traits()["value"]._trait.klass, Foo) _default_value = [] _good_values = [[Foo(), Foo()], []] _bad_values = [ [ "1", 2, ], "1", [Foo], None, ] class UnionListTrait(HasTraits): value = List(Int() | Bool()) class TestUnionListTrait(TraitTestBase): obj = UnionListTrait() _default_value = [] _good_values = [[True, 1], [False, True]] _bad_values = [[1, "True"], False] class LenListTrait(HasTraits): value = List(Int(), [0], minlen=1, maxlen=2) class TestLenList(TraitTestBase): obj = LenListTrait() _default_value = [0] _good_values = [[1], [1, 2], (1, 2)] _bad_values = [10, [1, "a"], "a", [], list(range(3))] def coerce(self, value): if value is not None: value = list(value) return value class TupleTrait(HasTraits): value = Tuple(Int(allow_none=True), default_value=(1,)) class TestTupleTrait(TraitTestBase): obj = TupleTrait() _default_value = (1,) _good_values = [(1,), (0,), [1]] _bad_values = [10, (1, 2), ("a"), (), None] def coerce(self, value): if value is not None: value = tuple(value) return value def test_invalid_args(self): self.assertRaises(TypeError, Tuple, 5) self.assertRaises(TypeError, Tuple, default_value="hello") t = Tuple(Int(), CBytes(), default_value=(1, 5)) class LooseTupleTrait(HasTraits): value = Tuple((1, 2, 3)) class TestLooseTupleTrait(TraitTestBase): obj = LooseTupleTrait() _default_value = (1, 2, 3) _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()] _bad_values = [10, "hello", {}, None] def coerce(self, value): if value is not None: value = tuple(value) return value def test_invalid_args(self): self.assertRaises(TypeError, Tuple, 5) self.assertRaises(TypeError, Tuple, default_value="hello") t = Tuple(Int(), CBytes(), default_value=(1, 5)) class MultiTupleTrait(HasTraits): value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"]) class TestMultiTuple(TraitTestBase): obj = MultiTupleTrait() _default_value = (99, b"bottles") _good_values = [(1, b"a"), (2, b"b")] _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a")) @pytest.mark.parametrize( "Trait", ( List, Tuple, Set, Dict, Integer, Unicode, ), ) def test_allow_none_default_value(Trait): class C(HasTraits): t = Trait(default_value=None, allow_none=True) # test default value c = C() assert c.t is None # and in constructor c = C(t=None) assert c.t is None @pytest.mark.parametrize( "Trait, default_value", ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), ) def test_default_value(Trait, default_value): class C(HasTraits): t = Trait() # test default value c = C() assert type(c.t) is type(default_value) assert c.t == default_value @pytest.mark.parametrize( "Trait, default_value", ((List, []), (Tuple, ()), (Set, set())), ) def test_subclass_default_value(Trait, default_value): """Test deprecated default_value=None behavior for Container subclass traits""" class SubclassTrait(Trait): def __init__(self, default_value=None): super().__init__(default_value=default_value) class C(HasTraits): t = SubclassTrait() # test default value c = C() assert type(c.t) is type(default_value) assert c.t == default_value class CRegExpTrait(HasTraits): value = CRegExp(r"") class TestCRegExp(TraitTestBase): def coerce(self, value): return re.compile(value) obj = CRegExpTrait() _default_value = re.compile(r"") _good_values = [r"\d+", re.compile(r"\d+")] _bad_values = ["(", None, ()] class DictTrait(HasTraits): value = Dict() def test_dict_assignment(): d = {} c = DictTrait() c.value = d d["a"] = 5 assert d == c.value assert c.value is d class UniformlyValueValidatedDictTrait(HasTraits): value = Dict(trait=Unicode(), default_value={"foo": "1"}) class TestInstanceUniformlyValueValidatedDict(TraitTestBase): obj = UniformlyValueValidatedDictTrait() _default_value = {"foo": "1"} _good_values = [{"foo": "0", "bar": "1"}] _bad_values = [{"foo": 0, "bar": "1"}] class NonuniformlyValueValidatedDictTrait(HasTraits): value = Dict(traits={"foo": Int()}, default_value={"foo": 1}) class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase): obj = NonuniformlyValueValidatedDictTrait() _default_value = {"foo": 1} _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}] _bad_values = [{"foo": "0", "bar": "1"}] class KeyValidatedDictTrait(HasTraits): value = Dict(key_trait=Unicode(), default_value={"foo": "1"}) class TestInstanceKeyValidatedDict(TraitTestBase): obj = KeyValidatedDictTrait() _default_value = {"foo": "1"} _good_values = [{"foo": "0", "bar": "1"}] _bad_values = [{"foo": "0", 0: "1"}] class FullyValidatedDictTrait(HasTraits): value = Dict( trait=Unicode(), key_trait=Unicode(), traits={"foo": Int()}, default_value={"foo": 1} ) class TestInstanceFullyValidatedDict(TraitTestBase): obj = FullyValidatedDictTrait() _default_value = {"foo": 1} _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}] _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}] def test_dict_default_value(): """Check that the `{}` default value of the Dict traitlet constructor is actually copied.""" class Foo(HasTraits): d1 = Dict() d2 = Dict() foo = Foo() assert foo.d1 == {} assert foo.d2 == {} assert foo.d1 is not foo.d2 class TestValidationHook(TestCase): def test_parity_trait(self): """Verify that the early validation hook is effective""" class Parity(HasTraits): value = Int(0) parity = Enum(["odd", "even"], default_value="even") @validate("value") def _value_validate(self, proposal): value = proposal["value"] if self.parity == "even" and value % 2: raise TraitError("Expected an even number") if self.parity == "odd" and (value % 2 == 0): raise TraitError("Expected an odd number") return value u = Parity() u.parity = "odd" u.value = 1 # OK with self.assertRaises(TraitError): u.value = 2 # Trait Error u.parity = "even" u.value = 2 # OK def test_multiple_validate(self): """Verify that we can register the same validator to multiple names""" class OddEven(HasTraits): odd = Int(1) even = Int(0) @validate("odd", "even") def check_valid(self, proposal): if proposal["trait"].name == "odd" and not proposal["value"] % 2: raise TraitError("odd should be odd") if proposal["trait"].name == "even" and proposal["value"] % 2: raise TraitError("even should be even") u = OddEven() u.odd = 3 # OK with self.assertRaises(TraitError): u.odd = 2 # Trait Error u.even = 2 # OK with self.assertRaises(TraitError): u.even = 3 # Trait Error class TestLink(TestCase): def test_connect_same(self): """Verify two traitlets of the same type can be linked together using link.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() a = A(value=9) b = A(value=8) # Conenct the two classes. c = link((a, "value"), (b, "value")) # Make sure the values are the same at the point of linking. self.assertEqual(a.value, b.value) # Change one of the values to make sure they stay in sync. a.value = 5 self.assertEqual(a.value, b.value) b.value = 6 self.assertEqual(a.value, b.value) def test_link_different(self): """Verify two traitlets of different types can be linked together using link.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() class B(HasTraits): count = Int() a = A(value=9) b = B(count=8) # Conenct the two classes. c = link((a, "value"), (b, "count")) # Make sure the values are the same at the point of linking. self.assertEqual(a.value, b.count) # Change one of the values to make sure they stay in sync. a.value = 5 self.assertEqual(a.value, b.count) b.count = 4 self.assertEqual(a.value, b.count) def test_unlink_link(self): """Verify two linked traitlets can be unlinked and relinked.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() a = A(value=9) b = A(value=8) # Connect the two classes. c = link((a, "value"), (b, "value")) a.value = 4 c.unlink() # Change one of the values to make sure they don't stay in sync. a.value = 5 self.assertNotEqual(a.value, b.value) c.link() self.assertEqual(a.value, b.value) a.value += 1 self.assertEqual(a.value, b.value) def test_callbacks(self): """Verify two linked traitlets have their callbacks called once.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() class B(HasTraits): count = Int() a = A(value=9) b = B(count=8) # Register callbacks that count. callback_count = [] def a_callback(name, old, new): callback_count.append("a") a.on_trait_change(a_callback, "value") def b_callback(name, old, new): callback_count.append("b") b.on_trait_change(b_callback, "count") # Connect the two classes. c = link((a, "value"), (b, "count")) # Make sure b's count was set to a's value once. self.assertEqual("".join(callback_count), "b") del callback_count[:] # Make sure a's value was set to b's count once. b.count = 5 self.assertEqual("".join(callback_count), "ba") del callback_count[:] # Make sure b's count was set to a's value once. a.value = 4 self.assertEqual("".join(callback_count), "ab") del callback_count[:] def test_tranform(self): """Test transform link.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() a = A(value=9) b = A(value=8) # Conenct the two classes. c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0))) # Make sure the values are correct at the point of linking. self.assertEqual(b.value, 2 * a.value) # Change one the value of the source and check that it modifies the target. a.value = 5 self.assertEqual(b.value, 10) # Change one the value of the target and check that it modifies the # source. b.value = 6 self.assertEqual(a.value, 3) def test_link_broken_at_source(self): class MyClass(HasTraits): i = Int() j = Int() @observe("j") def another_update(self, change): self.i = change.new * 2 mc = MyClass() l = link((mc, "i"), (mc, "j")) # noqa self.assertRaises(TraitError, setattr, mc, "i", 2) def test_link_broken_at_target(self): class MyClass(HasTraits): i = Int() j = Int() @observe("i") def another_update(self, change): self.j = change.new * 2 mc = MyClass() l = link((mc, "i"), (mc, "j")) # noqa self.assertRaises(TraitError, setattr, mc, "j", 2) class TestDirectionalLink(TestCase): def test_connect_same(self): """Verify two traitlets of the same type can be linked together using directional_link.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() a = A(value=9) b = A(value=8) # Conenct the two classes. c = directional_link((a, "value"), (b, "value")) # Make sure the values are the same at the point of linking. self.assertEqual(a.value, b.value) # Change one the value of the source and check that it synchronizes the target. a.value = 5 self.assertEqual(b.value, 5) # Change one the value of the target and check that it has no impact on the source b.value = 6 self.assertEqual(a.value, 5) def test_tranform(self): """Test transform link.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() a = A(value=9) b = A(value=8) # Conenct the two classes. c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x) # Make sure the values are correct at the point of linking. self.assertEqual(b.value, 2 * a.value) # Change one the value of the source and check that it modifies the target. a.value = 5 self.assertEqual(b.value, 10) # Change one the value of the target and check that it has no impact on the source b.value = 6 self.assertEqual(a.value, 5) def test_link_different(self): """Verify two traitlets of different types can be linked together using link.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() class B(HasTraits): count = Int() a = A(value=9) b = B(count=8) # Conenct the two classes. c = directional_link((a, "value"), (b, "count")) # Make sure the values are the same at the point of linking. self.assertEqual(a.value, b.count) # Change one the value of the source and check that it synchronizes the target. a.value = 5 self.assertEqual(b.count, 5) # Change one the value of the target and check that it has no impact on the source b.value = 6 self.assertEqual(a.value, 5) def test_unlink_link(self): """Verify two linked traitlets can be unlinked and relinked.""" # Create two simple classes with Int traitlets. class A(HasTraits): value = Int() a = A(value=9) b = A(value=8) # Connect the two classes. c = directional_link((a, "value"), (b, "value")) a.value = 4 c.unlink() # Change one of the values to make sure they don't stay in sync. a.value = 5 self.assertNotEqual(a.value, b.value) c.link() self.assertEqual(a.value, b.value) a.value += 1 self.assertEqual(a.value, b.value) class Pickleable(HasTraits): i = Int() @observe("i") def _i_changed(self, change): pass @validate("i") def _i_validate(self, commit): return commit["value"] j = Int() def __init__(self): with self.hold_trait_notifications(): self.i = 1 self.on_trait_change(self._i_changed, "i") def test_pickle_hastraits(): c = Pickleable() for protocol in range(pickle.HIGHEST_PROTOCOL + 1): p = pickle.dumps(c, protocol) c2 = pickle.loads(p) assert c2.i == c.i assert c2.j == c.j c.i = 5 for protocol in range(pickle.HIGHEST_PROTOCOL + 1): p = pickle.dumps(c, protocol) c2 = pickle.loads(p) assert c2.i == c.i assert c2.j == c.j def test_hold_trait_notifications(): changes = [] class Test(HasTraits): a = Integer(0) b = Integer(0) def _a_changed(self, name, old, new): changes.append((old, new)) def _b_validate(self, value, trait): if value != 0: raise TraitError("Only 0 is a valid value") return value # Test context manager and nesting t = Test() with t.hold_trait_notifications(): with t.hold_trait_notifications(): t.a = 1 assert t.a == 1 assert changes == [] t.a = 2 assert t.a == 2 with t.hold_trait_notifications(): t.a = 3 assert t.a == 3 assert changes == [] t.a = 4 assert t.a == 4 assert changes == [] t.a = 4 assert t.a == 4 assert changes == [] assert changes == [(0, 4)] # Test roll-back try: with t.hold_trait_notifications(): t.b = 1 # raises a Trait error except Exception: pass assert t.b == 0 class RollBack(HasTraits): bar = Int() def _bar_validate(self, value, trait): if value: raise TraitError("foobar") return value class TestRollback(TestCase): def test_roll_back(self): def assign_rollback(): RollBack(bar=1) self.assertRaises(TraitError, assign_rollback) class CacheModification(HasTraits): foo = Int() bar = Int() def _bar_validate(self, value, trait): self.foo = value return value def _foo_validate(self, value, trait): self.bar = value return value def test_cache_modification(): CacheModification(foo=1) CacheModification(bar=1) class OrderTraits(HasTraits): notified = Dict() a = Unicode() b = Unicode() c = Unicode() d = Unicode() e = Unicode() f = Unicode() g = Unicode() h = Unicode() i = Unicode() j = Unicode() k = Unicode() l = Unicode() # noqa def _notify(self, name, old, new): """check the value of all traits when each trait change is triggered This verifies that the values are not sensitive to dict ordering when loaded from kwargs """ # check the value of the other traits # when a given trait change notification fires self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"} def __init__(self, **kwargs): self.on_trait_change(self._notify) super().__init__(**kwargs) def test_notification_order(): d = {c: c for c in "abcdefghijkl"} obj = OrderTraits() assert obj.notified == {} obj = OrderTraits(**d) notifications = {c: d for c in "abcdefghijkl"} assert obj.notified == notifications ### # Traits for Forward Declaration Tests ### class ForwardDeclaredInstanceTrait(HasTraits): value = ForwardDeclaredInstance("ForwardDeclaredBar", allow_none=True) class ForwardDeclaredTypeTrait(HasTraits): value = ForwardDeclaredType("ForwardDeclaredBar", allow_none=True) class ForwardDeclaredInstanceListTrait(HasTraits): value = List(ForwardDeclaredInstance("ForwardDeclaredBar")) class ForwardDeclaredTypeListTrait(HasTraits): value = List(ForwardDeclaredType("ForwardDeclaredBar")) ### # End Traits for Forward Declaration Tests ### ### # Classes for Forward Declaration Tests ### class ForwardDeclaredBar: pass class ForwardDeclaredBarSub(ForwardDeclaredBar): pass ### # End Classes for Forward Declaration Tests ### ### # Forward Declaration Tests ### class TestForwardDeclaredInstanceTrait(TraitTestBase): obj = ForwardDeclaredInstanceTrait() _default_value = None _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()] _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub] class TestForwardDeclaredTypeTrait(TraitTestBase): obj = ForwardDeclaredTypeTrait() _default_value = None _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub] _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()] class TestForwardDeclaredInstanceList(TraitTestBase): obj = ForwardDeclaredInstanceListTrait() def test_klass(self): """Test that the instance klass is properly assigned.""" self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) _default_value = [] _good_values = [ [ForwardDeclaredBar(), ForwardDeclaredBarSub()], [], ] _bad_values = [ ForwardDeclaredBar(), [ForwardDeclaredBar(), 3, None], "1", # Note that this is the type, not an instance. [ForwardDeclaredBar], [None], None, ] class TestForwardDeclaredTypeList(TraitTestBase): obj = ForwardDeclaredTypeListTrait() def test_klass(self): """Test that the instance klass is properly assigned.""" self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) _default_value = [] _good_values = [ [ForwardDeclaredBar, ForwardDeclaredBarSub], [], ] _bad_values = [ ForwardDeclaredBar, [ForwardDeclaredBar, 3], "1", # Note that this is an instance, not the type. [ForwardDeclaredBar()], [None], None, ] ### # End Forward Declaration Tests ### class TestDynamicTraits(TestCase): def setUp(self): self._notify1 = [] def notify1(self, name, old, new): self._notify1.append((name, old, new)) def test_notify_all(self): class A(HasTraits): pass a = A() self.assertTrue(not hasattr(a, "x")) self.assertTrue(not hasattr(a, "y")) # Dynamically add trait x. a.add_traits(x=Int()) self.assertTrue(hasattr(a, "x")) self.assertTrue(isinstance(a, (A,))) # Dynamically add trait y. a.add_traits(y=Float()) self.assertTrue(hasattr(a, "y")) self.assertTrue(isinstance(a, (A,))) self.assertEqual(a.__class__.__name__, A.__name__) # Create a new instance and verify that x and y # aren't defined. b = A() self.assertTrue(not hasattr(b, "x")) self.assertTrue(not hasattr(b, "y")) # Verify that notification works like normal. a.on_trait_change(self.notify1) a.x = 0 self.assertEqual(len(self._notify1), 0) a.y = 0.0 self.assertEqual(len(self._notify1), 0) a.x = 10 self.assertTrue(("x", 0, 10) in self._notify1) a.y = 10.0 self.assertTrue(("y", 0.0, 10.0) in self._notify1) self.assertRaises(TraitError, setattr, a, "x", "bad string") self.assertRaises(TraitError, setattr, a, "y", "bad string") self._notify1 = [] a.on_trait_change(self.notify1, remove=True) a.x = 20 a.y = 20.0 self.assertEqual(len(self._notify1), 0) def test_enum_no_default(): class C(HasTraits): t = Enum(["a", "b"]) c = C() c.t = "a" assert c.t == "a" c = C() with pytest.raises(TraitError): t = c.t c = C(t="b") assert c.t == "b" def test_default_value_repr(): class C(HasTraits): t = Type("traitlets.HasTraits") t2 = Type(HasTraits) n = Integer(0) lis = List() d = Dict() assert C.t.default_value_repr() == "'traitlets.HasTraits'" assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'" assert C.n.default_value_repr() == "0" assert C.lis.default_value_repr() == "[]" assert C.d.default_value_repr() == "{}" class TransitionalClass(HasTraits): d = Any() @default("d") def _d_default(self): return TransitionalClass parent_super = False calls_super = Integer(0) @default("calls_super") def _calls_super_default(self): return -1 @observe("calls_super") @observe_compat def _calls_super_changed(self, change): self.parent_super = change parent_override = False overrides = Integer(0) @observe("overrides") @observe_compat def _overrides_changed(self, change): self.parent_override = change class SubClass(TransitionalClass): def _d_default(self): return SubClass subclass_super = False def _calls_super_changed(self, name, old, new): self.subclass_super = True super()._calls_super_changed(name, old, new) subclass_override = False def _overrides_changed(self, name, old, new): self.subclass_override = True def test_subclass_compat(): obj = SubClass() obj.calls_super = 5 assert obj.parent_super assert obj.subclass_super obj.overrides = 5 assert obj.subclass_override assert not obj.parent_override assert obj.d is SubClass class DefinesHandler(HasTraits): parent_called = False trait = Integer() @observe("trait") def handler(self, change): self.parent_called = True class OverridesHandler(DefinesHandler): child_called = False @observe("trait") def handler(self, change): self.child_called = True def test_subclass_override_observer(): obj = OverridesHandler() obj.trait = 5 assert obj.child_called assert not obj.parent_called class DoesntRegisterHandler(DefinesHandler): child_called = False def handler(self, change): self.child_called = True def test_subclass_override_not_registered(): """Subclass that overrides observer and doesn't re-register unregisters both""" obj = DoesntRegisterHandler() obj.trait = 5 assert not obj.child_called assert not obj.parent_called class AddsHandler(DefinesHandler): child_called = False @observe("trait") def child_handler(self, change): self.child_called = True def test_subclass_add_observer(): obj = AddsHandler() obj.trait = 5 assert obj.child_called assert obj.parent_called def test_observe_iterables(): class C(HasTraits): i = Integer() s = Unicode() c = C() recorded = {} def record(change): recorded["change"] = change # observe with names=set c.observe(record, names={"i", "s"}) c.i = 5 assert recorded["change"].name == "i" assert recorded["change"].new == 5 c.s = "hi" assert recorded["change"].name == "s" assert recorded["change"].new == "hi" # observe with names=custom container with iter, contains class MyContainer: def __init__(self, container): self.container = container def __iter__(self): return iter(self.container) def __contains__(self, key): return key in self.container c.observe(record, names=MyContainer({"i", "s"})) c.i = 10 assert recorded["change"].name == "i" assert recorded["change"].new == 10 c.s = "ok" assert recorded["change"].name == "s" assert recorded["change"].new == "ok" def test_super_args(): class SuperRecorder: def __init__(self, *args, **kwargs): self.super_args = args self.super_kwargs = kwargs class SuperHasTraits(HasTraits, SuperRecorder): i = Integer() obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x") assert obj.i == 5 assert not hasattr(obj, "b") assert not hasattr(obj, "c") assert obj.super_args == ("a1", "a2") assert obj.super_kwargs == {"b": 10, "c": "x"} def test_super_bad_args(): class SuperHasTraits(HasTraits): a = Integer() w = ["Passing unrecognized arguments"] with expected_warnings(w): obj = SuperHasTraits(a=1, b=2) assert obj.a == 1 assert not hasattr(obj, "b") def test_default_mro(): """Verify that default values follow mro""" class Base(HasTraits): trait = Unicode("base") attr = "base" class A(Base): pass class B(Base): trait = Unicode("B") attr = "B" class AB(A, B): pass class BA(B, A): pass assert A().trait == "base" assert A().attr == "base" assert BA().trait == "B" assert BA().attr == "B" assert AB().trait == "B" assert AB().attr == "B" def test_cls_self_argument(): class X(HasTraits): def __init__(__self, cls, self): pass x = X(cls=None, self=None) def test_override_default(): class C(HasTraits): a = Unicode("hard default") def _a_default(self): return "default method" C._a_default = lambda self: "overridden" c = C() assert c.a == "overridden" def test_override_default_decorator(): class C(HasTraits): a = Unicode("hard default") @default("a") def _a_default(self): return "default method" C._a_default = lambda self: "overridden" c = C() assert c.a == "overridden" def test_override_default_instance(): class C(HasTraits): a = Unicode("hard default") @default("a") def _a_default(self): return "default method" c = C() c._a_default = lambda self: "overridden" assert c.a == "overridden" def test_copy_HasTraits(): from copy import copy class C(HasTraits): a = Int() c = C(a=1) assert c.a == 1 cc = copy(c) cc.a = 2 assert cc.a == 2 assert c.a == 1 def _from_string_test(traittype, s, expected): """Run a test of trait.from_string""" if isinstance(traittype, TraitType): trait = traittype else: trait = traittype(allow_none=True) if isinstance(s, list): cast = trait.from_string_list else: cast = trait.from_string if type(expected) is type and issubclass(expected, Exception): with pytest.raises(expected): value = cast(s) trait.validate(None, value) else: value = cast(s) assert value == expected @pytest.mark.parametrize( "s, expected", [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], ) def test_unicode_from_string(s, expected): _from_string_test(Unicode, s, expected) @pytest.mark.parametrize( "s, expected", [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], ) def test_cunicode_from_string(s, expected): _from_string_test(CUnicode, s, expected) @pytest.mark.parametrize( "s, expected", [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], ) def test_bytes_from_string(s, expected): _from_string_test(Bytes, s, expected) @pytest.mark.parametrize( "s, expected", [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], ) def test_cbytes_from_string(s, expected): _from_string_test(CBytes, s, expected) @pytest.mark.parametrize( "s, expected", [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)], ) def test_int_from_string(s, expected): _from_string_test(Integer, s, expected) @pytest.mark.parametrize( "s, expected", [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)], ) def test_float_from_string(s, expected): _from_string_test(Float, s, expected) @pytest.mark.parametrize( "s, expected", [ ("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("1+2j", 1 + 2j), ("None", None), ], ) def test_complex_from_string(s, expected): _from_string_test(Complex, s, expected) @pytest.mark.parametrize( "s, expected", [ ("true", True), ("TRUE", True), ("1", True), ("0", False), ("False", False), ("false", False), ("1.0", ValueError), ("None", None), ], ) def test_bool_from_string(s, expected): _from_string_test(Bool, s, expected) @pytest.mark.parametrize( "s, expected", [ ("{}", {}), ("1", TraitError), ("{1: 2}", {1: 2}), ('{"key": "value"}', {"key": "value"}), ("x", TraitError), ("None", None), ], ) def test_dict_from_string(s, expected): _from_string_test(Dict, s, expected) @pytest.mark.parametrize( "s, expected", [ ("[]", []), ('[1, 2, "x"]', [1, 2, "x"]), (["1", "x"], ["1", "x"]), (["None"], None), ], ) def test_list_from_string(s, expected): _from_string_test(List, s, expected) @pytest.mark.parametrize( "s, expected, value_trait", [ (["1", "2", "3"], [1, 2, 3], Integer()), (["x"], ValueError, Integer()), (["1", "x"], ["1", "x"], Unicode()), (["None"], [None], Unicode(allow_none=True)), (["None"], ["None"], Unicode(allow_none=False)), ], ) def test_list_items_from_string(s, expected, value_trait): _from_string_test(List(value_trait), s, expected) @pytest.mark.parametrize( "s, expected", [ ("[]", set()), ('[1, 2, "x"]', {1, 2, "x"}), ('{1, 2, "x"}', {1, 2, "x"}), (["1", "x"], {"1", "x"}), (["None"], None), ], ) def test_set_from_string(s, expected): _from_string_test(Set, s, expected) @pytest.mark.parametrize( "s, expected, value_trait", [ (["1", "2", "3"], {1, 2, 3}, Integer()), (["x"], ValueError, Integer()), (["1", "x"], {"1", "x"}, Unicode()), (["None"], {None}, Unicode(allow_none=True)), ], ) def test_set_items_from_string(s, expected, value_trait): _from_string_test(Set(value_trait), s, expected) @pytest.mark.parametrize( "s, expected", [ ("[]", ()), ("()", ()), ('[1, 2, "x"]', (1, 2, "x")), ('(1, 2, "x")', (1, 2, "x")), (["1", "x"], ("1", "x")), (["None"], None), ], ) def test_tuple_from_string(s, expected): _from_string_test(Tuple, s, expected) @pytest.mark.parametrize( "s, expected, value_traits", [ (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]), (["x"], ValueError, [Integer()]), (["1", "x"], ("1", "x"), [Unicode()]), (["None"], ("None",), [Unicode(allow_none=False)]), (["None"], (None,), [Unicode(allow_none=True)]), ], ) def test_tuple_items_from_string(s, expected, value_traits): _from_string_test(Tuple(*value_traits), s, expected) @pytest.mark.parametrize( "s, expected", [ ("x", "x"), ("mod.submod", "mod.submod"), ("not an identifier", TraitError), ("1", "1"), ("None", None), ], ) def test_object_from_string(s, expected): _from_string_test(DottedObjectName, s, expected) @pytest.mark.parametrize( "s, expected", [ ("127.0.0.1:8000", ("127.0.0.1", 8000)), ("host.tld:80", ("host.tld", 80)), ("host:notaport", ValueError), ("127.0.0.1", ValueError), ("None", None), ], ) def test_tcp_from_string(s, expected): _from_string_test(TCPAddress, s, expected) def test_all_attribute(): """Verify all trait types are added to `traitlets.__all__`""" names = dir(traitlets) for name in names: value = getattr(traitlets, name) if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType): if name not in traitlets.__all__: raise ValueError(f"{name} not in __all__") for name in traitlets.__all__: if name not in names: raise ValueError(f"{name} should be removed from __all__")