Skip to content

Commit 1b474e2

Browse files
approx: use exact comparison for bool (#13013)
Fixes #9353 (cherry picked from commit a16e8ea) Co-authored-by: Jakob van Santen <[email protected]>
1 parent b541721 commit 1b474e2

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

changelog/9353.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:func:`pytest.approx` now uses strict equality when given booleans.

src/_pytest/python_api.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -262,19 +262,22 @@ def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
262262
):
263263
if approx_value != other_value:
264264
if approx_value.expected is not None and other_value is not None:
265-
max_abs_diff = max(
266-
max_abs_diff, abs(approx_value.expected - other_value)
267-
)
268-
if approx_value.expected == 0.0:
269-
max_rel_diff = math.inf
270-
else:
271-
max_rel_diff = max(
272-
max_rel_diff,
273-
abs(
274-
(approx_value.expected - other_value)
275-
/ approx_value.expected
276-
),
265+
try:
266+
max_abs_diff = max(
267+
max_abs_diff, abs(approx_value.expected - other_value)
277268
)
269+
if approx_value.expected == 0.0:
270+
max_rel_diff = math.inf
271+
else:
272+
max_rel_diff = max(
273+
max_rel_diff,
274+
abs(
275+
(approx_value.expected - other_value)
276+
/ approx_value.expected
277+
),
278+
)
279+
except ZeroDivisionError:
280+
pass
278281
different_ids.append(approx_key)
279282

280283
message_data = [
@@ -398,8 +401,10 @@ def __repr__(self) -> str:
398401
# Don't show a tolerance for values that aren't compared using
399402
# tolerances, i.e. non-numerics and infinities. Need to call abs to
400403
# handle complex numbers, e.g. (inf + 1j).
401-
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
402-
abs(self.expected)
404+
if (
405+
isinstance(self.expected, bool)
406+
or (not isinstance(self.expected, (Complex, Decimal)))
407+
or math.isinf(abs(self.expected) or isinstance(self.expected, bool))
403408
):
404409
return str(self.expected)
405410

@@ -427,14 +432,17 @@ def __eq__(self, actual) -> bool:
427432
# numpy<1.13. See #3748.
428433
return all(self.__eq__(a) for a in asarray.flat)
429434

430-
# Short-circuit exact equality.
431-
if actual == self.expected:
435+
# Short-circuit exact equality, except for bool
436+
if isinstance(self.expected, bool) and not isinstance(actual, bool):
437+
return False
438+
elif actual == self.expected:
432439
return True
433440

434441
# If either type is non-numeric, fall back to strict equality.
435442
# NB: we need Complex, rather than just Number, to ensure that __abs__,
436-
# __sub__, and __float__ are defined.
437-
if not (
443+
# __sub__, and __float__ are defined. Also, consider bool to be
444+
# nonnumeric, even though it has the required arithmetic.
445+
if isinstance(self.expected, bool) or not (
438446
isinstance(self.expected, (Complex, Decimal))
439447
and isinstance(actual, (Complex, Decimal))
440448
):

testing/python/approx.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,25 @@ def do_assert(lhs, rhs, expected_message, verbosity_level=0):
9090
return do_assert
9191

9292

93-
SOME_FLOAT = r"[+-]?([0-9]*[.])?[0-9]+\s*"
93+
SOME_FLOAT = r"[+-]?((?:([0-9]*[.])?[0-9]+(e-?[0-9]+)?)|inf|nan)\s*"
9494
SOME_INT = r"[0-9]+\s*"
9595

9696

9797
class TestApprox:
9898
def test_error_messages_native_dtypes(self, assert_approx_raises_regex):
99+
# Treat bool exactly.
100+
assert_approx_raises_regex(
101+
{"a": 1.0, "b": True},
102+
{"a": 1.0, "b": False},
103+
[
104+
"",
105+
" comparison failed. Mismatched elements: 1 / 2:",
106+
f" Max absolute difference: {SOME_FLOAT}",
107+
f" Max relative difference: {SOME_FLOAT}",
108+
r" Index\s+\| Obtained\s+\| Expected",
109+
r".*(True|False)\s+",
110+
],
111+
)
99112
assert_approx_raises_regex(
100113
2.0,
101114
1.0,
@@ -590,6 +603,13 @@ def test_complex(self):
590603
assert approx(x, rel=5e-6, abs=0) == a
591604
assert approx(x, rel=5e-7, abs=0) != a
592605

606+
def test_expecting_bool(self) -> None:
607+
assert True == approx(True) # noqa: E712
608+
assert False == approx(False) # noqa: E712
609+
assert True != approx(False) # noqa: E712
610+
assert True != approx(False, abs=2) # noqa: E712
611+
assert 1 != approx(True)
612+
593613
def test_list(self):
594614
actual = [1 + 1e-7, 2 + 1e-8]
595615
expected = [1, 2]
@@ -655,6 +675,7 @@ def test_dict_wrong_len(self):
655675
def test_dict_nonnumeric(self):
656676
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
657677
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
678+
assert {"a": 1.0, "b": True} != pytest.approx({"a": 1.0, "b": False}, abs=2)
658679

659680
def test_dict_vs_other(self):
660681
assert 1 != approx({"a": 0})

0 commit comments

Comments
 (0)