@@ -1789,6 +1789,30 @@ def Elem(*args):
1789
1789
1790
1790
Union [Elem , str ] # Nor should this
1791
1791
1792
+ def test_union_of_literals (self ):
1793
+ self .assertEqual (Union [Literal [1 ], Literal [2 ]].__args__ ,
1794
+ (Literal [1 ], Literal [2 ]))
1795
+ self .assertEqual (Union [Literal [1 ], Literal [1 ]],
1796
+ Literal [1 ])
1797
+
1798
+ self .assertEqual (Union [Literal [False ], Literal [0 ]].__args__ ,
1799
+ (Literal [False ], Literal [0 ]))
1800
+ self .assertEqual (Union [Literal [True ], Literal [1 ]].__args__ ,
1801
+ (Literal [True ], Literal [1 ]))
1802
+
1803
+ import enum
1804
+ class Ints (enum .IntEnum ):
1805
+ A = 0
1806
+ B = 1
1807
+
1808
+ self .assertEqual (Union [Literal [Ints .A ], Literal [Ints .B ]].__args__ ,
1809
+ (Literal [Ints .A ], Literal [Ints .B ]))
1810
+
1811
+ self .assertEqual (Union [Literal [0 ], Literal [Ints .A ], Literal [False ]].__args__ ,
1812
+ (Literal [0 ], Literal [Ints .A ], Literal [False ]))
1813
+ self .assertEqual (Union [Literal [1 ], Literal [Ints .B ], Literal [True ]].__args__ ,
1814
+ (Literal [1 ], Literal [Ints .B ], Literal [True ]))
1815
+
1792
1816
1793
1817
class TupleTests (BaseTestCase ):
1794
1818
@@ -2156,6 +2180,13 @@ def test_basics(self):
2156
2180
Literal [Literal [1 , 2 ], Literal [4 , 5 ]]
2157
2181
Literal [b"foo" , u"bar" ]
2158
2182
2183
+ def test_enum (self ):
2184
+ import enum
2185
+ class My (enum .Enum ):
2186
+ A = 'A'
2187
+
2188
+ self .assertEqual (Literal [My .A ].__args__ , (My .A ,))
2189
+
2159
2190
def test_illegal_parameters_do_not_raise_runtime_errors (self ):
2160
2191
# Type checkers should reject these types, but we do not
2161
2192
# raise errors at runtime to maintain maximum flexibility.
@@ -2245,6 +2276,20 @@ def test_flatten(self):
2245
2276
self .assertEqual (l , Literal [1 , 2 , 3 ])
2246
2277
self .assertEqual (l .__args__ , (1 , 2 , 3 ))
2247
2278
2279
+ def test_does_not_flatten_enum (self ):
2280
+ import enum
2281
+ class Ints (enum .IntEnum ):
2282
+ A = 1
2283
+ B = 2
2284
+
2285
+ l = Literal [
2286
+ Literal [Ints .A ],
2287
+ Literal [Ints .B ],
2288
+ Literal [1 ],
2289
+ Literal [2 ],
2290
+ ]
2291
+ self .assertEqual (l .__args__ , (Ints .A , Ints .B , 1 , 2 ))
2292
+
2248
2293
2249
2294
XK = TypeVar ('XK' , str , bytes )
2250
2295
XV = TypeVar ('XV' )
0 commit comments