Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.9] bpo-46615: Don't crash when set operations mutate the sets (GH-31120) #31312

Merged
merged 1 commit into from
Feb 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions Lib/test/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,192 @@ def __eq__(self, o):
s = {0}
s.update(other)


class TestOperationsMutating:
"""Regression test for bpo-46615"""

constructor1 = None
constructor2 = None

def make_sets_of_bad_objects(self):
class Bad:
def __eq__(self, other):
if not enabled:
return False
if randrange(20) == 0:
set1.clear()
if randrange(20) == 0:
set2.clear()
return bool(randrange(2))
def __hash__(self):
return randrange(2)
# Don't behave poorly during construction.
enabled = False
set1 = self.constructor1(Bad() for _ in range(randrange(50)))
set2 = self.constructor2(Bad() for _ in range(randrange(50)))
# Now start behaving poorly
enabled = True
return set1, set2

def check_set_op_does_not_crash(self, function):
for _ in range(100):
set1, set2 = self.make_sets_of_bad_objects()
try:
function(set1, set2)
except RuntimeError as e:
# Just make sure we don't crash here.
self.assertIn("changed size during iteration", str(e))


class TestBinaryOpsMutating(TestOperationsMutating):

def test_eq_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a == b)

def test_ne_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a != b)

def test_lt_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a < b)

def test_le_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a <= b)

def test_gt_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a > b)

def test_ge_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a >= b)

def test_and_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a & b)

def test_or_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a | b)

def test_sub_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a - b)

def test_xor_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a ^ b)

def test_iadd_with_mutation(self):
def f(a, b):
a &= b
self.check_set_op_does_not_crash(f)

def test_ior_with_mutation(self):
def f(a, b):
a |= b
self.check_set_op_does_not_crash(f)

def test_isub_with_mutation(self):
def f(a, b):
a -= b
self.check_set_op_does_not_crash(f)

def test_ixor_with_mutation(self):
def f(a, b):
a ^= b
self.check_set_op_does_not_crash(f)

def test_iteration_with_mutation(self):
def f1(a, b):
for x in a:
pass
for y in b:
pass
def f2(a, b):
for y in b:
pass
for x in a:
pass
def f3(a, b):
for x, y in zip(a, b):
pass
self.check_set_op_does_not_crash(f1)
self.check_set_op_does_not_crash(f2)
self.check_set_op_does_not_crash(f3)


class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = set
constructor2 = set

class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass

class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = set
constructor2 = SetSubclass

class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = set


class TestMethodsMutating(TestOperationsMutating):

def test_issubset_with_mutation(self):
self.check_set_op_does_not_crash(set.issubset)

def test_issuperset_with_mutation(self):
self.check_set_op_does_not_crash(set.issuperset)

def test_intersection_with_mutation(self):
self.check_set_op_does_not_crash(set.intersection)

def test_union_with_mutation(self):
self.check_set_op_does_not_crash(set.union)

def test_difference_with_mutation(self):
self.check_set_op_does_not_crash(set.difference)

def test_symmetric_difference_with_mutation(self):
self.check_set_op_does_not_crash(set.symmetric_difference)

def test_isdisjoint_with_mutation(self):
self.check_set_op_does_not_crash(set.isdisjoint)

def test_difference_update_with_mutation(self):
self.check_set_op_does_not_crash(set.difference_update)

def test_intersection_update_with_mutation(self):
self.check_set_op_does_not_crash(set.intersection_update)

def test_symmetric_difference_update_with_mutation(self):
self.check_set_op_does_not_crash(set.symmetric_difference_update)

def test_update_with_mutation(self):
self.check_set_op_does_not_crash(set.update)


class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = set

class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass

class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = SetSubclass

class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
constructor1 = SetSubclass
constructor2 = set

class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = dict.fromkeys

class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
constructor1 = set
constructor2 = list


# Application tests (based on David Eppstein's graph recipes ====================================

def powerset(U):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When iterating over sets internally in ``setobject.c``, acquire strong references to the resulting items from the set. This prevents crashes in corner-cases of various set operations where the set gets mutated.
47 changes: 39 additions & 8 deletions Objects/setobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -1207,17 +1207,21 @@ set_intersection(PySetObject *so, PyObject *other)
while (set_next((PySetObject *)other, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = set_contains_entry(so, key, hash);
if (rv < 0) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
if (rv) {
if (set_add_entry(result, key, hash)) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
return (PyObject *)result;
}
Expand Down Expand Up @@ -1357,11 +1361,16 @@ set_isdisjoint(PySetObject *so, PyObject *other)
other = tmp;
}
while (set_next((PySetObject *)other, &pos, &entry)) {
rv = set_contains_entry(so, entry->key, entry->hash);
if (rv < 0)
PyObject *key = entry->key;
Py_INCREF(key);
rv = set_contains_entry(so, key, entry->hash);
Py_DECREF(key);
if (rv < 0) {
return NULL;
if (rv)
}
if (rv) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
Expand Down Expand Up @@ -1420,11 +1429,16 @@ set_difference_update_internal(PySetObject *so, PyObject *other)
Py_INCREF(other);
}

while (set_next((PySetObject *)other, &pos, &entry))
if (set_discard_entry(so, entry->key, entry->hash) < 0) {
while (set_next((PySetObject *)other, &pos, &entry)) {
PyObject *key = entry->key;
Py_INCREF(key);
if (set_discard_entry(so, key, entry->hash) < 0) {
Py_DECREF(other);
Py_DECREF(key);
return -1;
}
Py_DECREF(key);
}

Py_DECREF(other);
} else {
Expand Down Expand Up @@ -1515,17 +1529,21 @@ set_difference(PySetObject *so, PyObject *other)
while (set_next(so, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = _PyDict_Contains(other, key, hash);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: cherry-picker only struggled because this is _PyDict_Contains_KnownHash() in main.

if (rv < 0) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
if (!rv) {
if (set_add_entry((PySetObject *)result, key, hash)) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
return result;
}
Expand All @@ -1534,17 +1552,21 @@ set_difference(PySetObject *so, PyObject *other)
while (set_next(so, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = set_contains_entry((PySetObject *)other, key, hash);
if (rv < 0) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
if (!rv) {
if (set_add_entry((PySetObject *)result, key, hash)) {
Py_DECREF(result);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
return result;
}
Expand Down Expand Up @@ -1641,17 +1663,21 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
while (set_next(otherset, &pos, &entry)) {
key = entry->key;
hash = entry->hash;
Py_INCREF(key);
rv = set_discard_entry(so, key, hash);
if (rv < 0) {
Py_DECREF(otherset);
Py_DECREF(key);
return NULL;
}
if (rv == DISCARD_NOTFOUND) {
if (set_add_entry(so, key, hash)) {
Py_DECREF(otherset);
Py_DECREF(key);
return NULL;
}
}
Py_DECREF(key);
}
Py_DECREF(otherset);
Py_RETURN_NONE;
Expand Down Expand Up @@ -1726,11 +1752,16 @@ set_issubset(PySetObject *so, PyObject *other)
Py_RETURN_FALSE;

while (set_next(so, &pos, &entry)) {
rv = set_contains_entry((PySetObject *)other, entry->key, entry->hash);
if (rv < 0)
PyObject *key = entry->key;
Py_INCREF(key);
rv = set_contains_entry((PySetObject *)other, key, entry->hash);
Py_DECREF(key);
if (rv < 0) {
return NULL;
if (!rv)
}
if (!rv) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
Expand Down