Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-44633: Fix parameter substitution of the union type with wrong types. #27218

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
bpo-44633: Fix parameter substitution of the union type with wrong ty…
…pes.

A TypeError is now raised instead of returning NotImplemented.
  • Loading branch information
serhiy-storchaka committed Jul 17, 2021
commit 2e278301f86bde39a288c04a314eca36c735fea2
6 changes: 6 additions & 0 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
@@ -755,6 +755,12 @@ def test_union_parameter_chaining(self):
self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T])
self.assertEqual((list[T] | list[S])[int, int], list[int])

def test_union_parameter_substitution_errors(self):
T = typing.TypeVar("T")
x = int | T
with self.assertRaises(TypeError):
x[42]

def test_or_type_operator_with_forward(self):
T = typing.TypeVar('T')
ForwardAfter = T | 'Forward'
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Parameter substitution of the union type with wrong types raises now
``TypeError`` instead of returning ``NotImplemented``.
59 changes: 36 additions & 23 deletions Objects/unionobject.c
Original file line number Diff line number Diff line change
@@ -5,6 +5,9 @@
#include "structmember.h"


static PyObject *make_union(PyObject *);


typedef struct {
PyObject_HEAD
PyObject *args;
@@ -337,13 +340,25 @@ is_unionable(PyObject *obj)
}

PyObject *
_Py_union_type_or(PyObject* self, PyObject* param)
_Py_union_type_or(PyObject* self, PyObject* other)
{
PyObject *tuple = PyTuple_Pack(2, self, param);
int r = is_unionable(self);
if (r > 0) {
r = is_unionable(other);
}
if (r < 0) {
return NULL;
}
if (!r) {
Py_RETURN_NOTIMPLEMENTED;
}

PyObject *tuple = PyTuple_Pack(2, self, other);
if (tuple == NULL) {
return NULL;
}
PyObject *new_union = _Py_Union(tuple);

PyObject *new_union = make_union(tuple);
Py_DECREF(tuple);
return new_union;
}
@@ -471,7 +486,22 @@ union_getitem(PyObject *self, PyObject *item)
return NULL;
}

PyObject *res = _Py_Union(newargs);
// Check arguments are unionable.
Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
int is_arg_unionable = is_unionable(arg);
if (is_arg_unionable <= 0) {
Py_DECREF(newargs);
if (is_arg_unionable == 0) {
PyErr_Format(PyExc_TypeError,
"Each union arg must be a type, got %.100R", arg);
}
return NULL;
}
}

PyObject *res = make_union(newargs);

Py_DECREF(newargs);
return res;
@@ -527,30 +557,13 @@ PyTypeObject _Py_UnionType = {
.tp_getset = union_properties,
};

PyObject *
_Py_Union(PyObject *args)
static PyObject *
make_union(PyObject *args)
{
assert(PyTuple_CheckExact(args));

unionobject* result = NULL;

// Check arguments are unionable.
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
PyObject *arg = PyTuple_GET_ITEM(args, iarg);
if (arg == NULL) {
return NULL;
}
int is_arg_unionable = is_unionable(arg);
if (is_arg_unionable < 0) {
return NULL;
}
if (!is_arg_unionable) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
}

args = dedup_and_flatten_args(args);
if (args == NULL) {
return NULL;