Skip to content

Commit 804a973

Browse files
gh-103365: [Enum] STRICT boundary corrections (GH-103494)
STRICT boundary: - fix bitwise operations - make default for Flag (cherry picked from commit 2194071) Co-authored-by: Ethan Furman <[email protected]>
1 parent e643412 commit 804a973

File tree

4 files changed

+82
-38
lines changed

4 files changed

+82
-38
lines changed

Doc/library/enum.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,8 @@ Data Types
692692

693693
.. attribute:: STRICT
694694

695-
Out-of-range values cause a :exc:`ValueError` to be raised::
695+
Out-of-range values cause a :exc:`ValueError` to be raised. This is the
696+
default for :class:`Flag`::
696697

697698
>>> from enum import Flag, STRICT, auto
698699
>>> class StrictFlag(Flag, boundary=STRICT):
@@ -709,7 +710,7 @@ Data Types
709710
.. attribute:: CONFORM
710711

711712
Out-of-range values have invalid values removed, leaving a valid *Flag*
712-
value. This is the default for :class:`Flag`::
713+
value::
713714

714715
>>> from enum import Flag, CONFORM, auto
715716
>>> class ConformFlag(Flag, boundary=CONFORM):

Lib/enum.py

+39-28
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,13 @@ def __set_name__(self, enum_class, member_name):
273273
enum_member.__objclass__ = enum_class
274274
enum_member.__init__(*args)
275275
enum_member._sort_order_ = len(enum_class._member_names_)
276+
277+
if Flag is not None and issubclass(enum_class, Flag):
278+
enum_class._flag_mask_ |= value
279+
if _is_single_bit(value):
280+
enum_class._singles_mask_ |= value
281+
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
282+
276283
# If another member with the same value was already defined, the
277284
# new member becomes an alias to the existing one.
278285
try:
@@ -525,12 +532,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
525532
classdict['_use_args_'] = use_args
526533
#
527534
# convert future enum members into temporary _proto_members
528-
# and record integer values in case this will be a Flag
529-
flag_mask = 0
530535
for name in member_names:
531536
value = classdict[name]
532-
if isinstance(value, int):
533-
flag_mask |= value
534537
classdict[name] = _proto_member(value)
535538
#
536539
# house-keeping structures
@@ -547,8 +550,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
547550
boundary
548551
or getattr(first_enum, '_boundary_', None)
549552
)
550-
classdict['_flag_mask_'] = flag_mask
551-
classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
553+
classdict['_flag_mask_'] = 0
554+
classdict['_singles_mask_'] = 0
555+
classdict['_all_bits_'] = 0
552556
classdict['_inverted_'] = None
553557
try:
554558
exc = None
@@ -637,21 +641,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
637641
):
638642
delattr(enum_class, '_boundary_')
639643
delattr(enum_class, '_flag_mask_')
644+
delattr(enum_class, '_singles_mask_')
640645
delattr(enum_class, '_all_bits_')
641646
delattr(enum_class, '_inverted_')
642647
elif Flag is not None and issubclass(enum_class, Flag):
643-
# ensure _all_bits_ is correct and there are no missing flags
644-
single_bit_total = 0
645-
multi_bit_total = 0
646-
for flag in enum_class._member_map_.values():
647-
flag_value = flag._value_
648-
if _is_single_bit(flag_value):
649-
single_bit_total |= flag_value
650-
else:
651-
# multi-bit flags are considered aliases
652-
multi_bit_total |= flag_value
653-
enum_class._flag_mask_ = single_bit_total
654-
#
655648
# set correct __iter__
656649
member_list = [m._value_ for m in enum_class]
657650
if member_list != sorted(member_list):
@@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
13031296
class FlagBoundary(StrEnum):
13041297
"""
13051298
control how out of range values are handled
1306-
"strict" -> error is raised
1307-
"conform" -> extra bits are discarded [default for Flag]
1299+
"strict" -> error is raised [default for Flag]
1300+
"conform" -> extra bits are discarded
13081301
"eject" -> lose flag status
13091302
"keep" -> keep flag status and all bits [default for IntFlag]
13101303
"""
@@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
13151308
STRICT, CONFORM, EJECT, KEEP = FlagBoundary
13161309

13171310

1318-
class Flag(Enum, boundary=CONFORM):
1311+
class Flag(Enum, boundary=STRICT):
13191312
"""
13201313
Support for flags
13211314
"""
@@ -1393,6 +1386,7 @@ def _missing_(cls, value):
13931386
# - value must not include any skipped flags (e.g. if bit 2 is not
13941387
# defined, then 0d10 is invalid)
13951388
flag_mask = cls._flag_mask_
1389+
singles_mask = cls._singles_mask_
13961390
all_bits = cls._all_bits_
13971391
neg_value = None
13981392
if (
@@ -1424,7 +1418,8 @@ def _missing_(cls, value):
14241418
value = all_bits + 1 + value
14251419
# get members and unknown
14261420
unknown = value & ~flag_mask
1427-
member_value = value & flag_mask
1421+
aliases = value & ~singles_mask
1422+
member_value = value & singles_mask
14281423
if unknown and cls._boundary_ is not KEEP:
14291424
raise ValueError(
14301425
'%s(%r) --> unknown values %r [%s]'
@@ -1438,11 +1433,25 @@ def _missing_(cls, value):
14381433
pseudo_member = cls._member_type_.__new__(cls, value)
14391434
if not hasattr(pseudo_member, '_value_'):
14401435
pseudo_member._value_ = value
1441-
if member_value:
1442-
pseudo_member._name_ = '|'.join([
1443-
m._name_ for m in cls._iter_member_(member_value)
1444-
])
1445-
if unknown:
1436+
if member_value or aliases:
1437+
members = []
1438+
combined_value = 0
1439+
for m in cls._iter_member_(member_value):
1440+
members.append(m)
1441+
combined_value |= m._value_
1442+
if aliases:
1443+
value = member_value | aliases
1444+
for n, pm in cls._member_map_.items():
1445+
if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
1446+
members.append(pm)
1447+
combined_value |= pm._value_
1448+
unknown = value ^ combined_value
1449+
pseudo_member._name_ = '|'.join([m._name_ for m in members])
1450+
if not combined_value:
1451+
pseudo_member._name_ = None
1452+
elif unknown and cls._boundary_ is STRICT:
1453+
raise ValueError('%r: no members with value %r' % (cls, unknown))
1454+
elif unknown:
14461455
pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
14471456
else:
14481457
pseudo_member._name_ = None
@@ -1671,6 +1680,7 @@ def convert_class(cls):
16711680
body['_boundary_'] = boundary or etype._boundary_
16721681
body['_flag_mask_'] = None
16731682
body['_all_bits_'] = None
1683+
body['_singles_mask_'] = None
16741684
body['_inverted_'] = None
16751685
body['__or__'] = Flag.__or__
16761686
body['__xor__'] = Flag.__xor__
@@ -1743,7 +1753,8 @@ def convert_class(cls):
17431753
else:
17441754
multi_bits |= value
17451755
gnv_last_values.append(value)
1746-
enum_class._flag_mask_ = single_bits
1756+
enum_class._flag_mask_ = single_bits | multi_bits
1757+
enum_class._singles_mask_ = single_bits
17471758
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
17481759
# set correct __iter__
17491760
member_list = [m._value_ for m in enum_class]

Lib/test/test_enum.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -2758,6 +2758,8 @@ def __new__(cls, c):
27582758
#
27592759
a = ord('a')
27602760
#
2761+
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
2762+
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
27612763
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
27622764
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
27632765
#
@@ -2772,6 +2774,8 @@ def __new__(cls, c):
27722774
a = ord('a')
27732775
z = 1
27742776
#
2777+
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
2778+
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674)
27752779
self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672)
27762780
self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674)
27772781
#
@@ -2785,6 +2789,8 @@ def __new__(cls, c):
27852789
#
27862790
a = ord('a')
27872791
#
2792+
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
2793+
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
27882794
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
27892795
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
27902796

@@ -2962,18 +2968,18 @@ def test_bool(self):
29622968
self.assertEqual(bool(f.value), bool(f))
29632969

29642970
def test_boundary(self):
2965-
self.assertIs(enum.Flag._boundary_, CONFORM)
2966-
class Iron(Flag, boundary=STRICT):
2971+
self.assertIs(enum.Flag._boundary_, STRICT)
2972+
class Iron(Flag, boundary=CONFORM):
29672973
ONE = 1
29682974
TWO = 2
29692975
EIGHT = 8
2970-
self.assertIs(Iron._boundary_, STRICT)
2976+
self.assertIs(Iron._boundary_, CONFORM)
29712977
#
2972-
class Water(Flag, boundary=CONFORM):
2978+
class Water(Flag, boundary=STRICT):
29732979
ONE = 1
29742980
TWO = 2
29752981
EIGHT = 8
2976-
self.assertIs(Water._boundary_, CONFORM)
2982+
self.assertIs(Water._boundary_, STRICT)
29772983
#
29782984
class Space(Flag, boundary=EJECT):
29792985
ONE = 1
@@ -2986,17 +2992,42 @@ class Bizarre(Flag, boundary=KEEP):
29862992
c = 4
29872993
d = 6
29882994
#
2989-
self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7)
2995+
self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7)
29902996
#
2991-
self.assertIs(Water(7), Water.ONE|Water.TWO)
2992-
self.assertIs(Water(~9), Water.TWO)
2997+
self.assertIs(Iron(7), Iron.ONE|Iron.TWO)
2998+
self.assertIs(Iron(~9), Iron.TWO)
29932999
#
29943000
self.assertEqual(Space(7), 7)
29953001
self.assertTrue(type(Space(7)) is int)
29963002
#
29973003
self.assertEqual(list(Bizarre), [Bizarre.c])
29983004
self.assertIs(Bizarre(3), Bizarre.b)
29993005
self.assertIs(Bizarre(6), Bizarre.d)
3006+
#
3007+
class SkipFlag(enum.Flag):
3008+
A = 1
3009+
B = 2
3010+
C = 4 | B
3011+
#
3012+
self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C))
3013+
self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42)
3014+
#
3015+
class SkipIntFlag(enum.IntFlag):
3016+
A = 1
3017+
B = 2
3018+
C = 4 | B
3019+
#
3020+
self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C))
3021+
self.assertEqual(SkipIntFlag(42).value, 42)
3022+
#
3023+
class MethodHint(Flag):
3024+
HiddenText = 0x10
3025+
DigitsOnly = 0x01
3026+
LettersOnly = 0x02
3027+
OnlyMask = 0x0f
3028+
#
3029+
self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask')
3030+
30003031

30013032
def test_iter(self):
30023033
Color = self.Color
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Set default Flag boundary to ``STRICT`` and fix bitwise operations.

0 commit comments

Comments
 (0)