Skip to content

Commit 473adf4

Browse files
fix(types): handle more discriminated union shapes (#78)
1 parent 520ba3a commit 473adf4

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/contextual/_models.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from ._constants import RAW_RESPONSE_HEADER
6666

6767
if TYPE_CHECKING:
68-
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
68+
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
6969

7070
__all__ = ["BaseModel", "GenericModel"]
7171

@@ -646,15 +646,18 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
646646

647647
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
648648
schema = model.__pydantic_core_schema__
649+
if schema["type"] == "definitions":
650+
schema = schema["schema"]
651+
649652
if schema["type"] != "model":
650653
return None
651654

655+
schema = cast("ModelSchema", schema)
652656
fields_schema = schema["schema"]
653657
if fields_schema["type"] != "model-fields":
654658
return None
655659

656660
fields_schema = cast("ModelFieldsSchema", fields_schema)
657-
658661
field = fields_schema["fields"].get(field_name)
659662
if not field:
660663
return None

tests/test_models.py

+32
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,35 @@ class Model(BaseModel):
854854
m = construct_type(value={"cls": "foo"}, type_=Model)
855855
assert isinstance(m, Model)
856856
assert isinstance(m.cls, str)
857+
858+
859+
def test_discriminated_union_case() -> None:
860+
class A(BaseModel):
861+
type: Literal["a"]
862+
863+
data: bool
864+
865+
class B(BaseModel):
866+
type: Literal["b"]
867+
868+
data: List[Union[A, object]]
869+
870+
class ModelA(BaseModel):
871+
type: Literal["modelA"]
872+
873+
data: int
874+
875+
class ModelB(BaseModel):
876+
type: Literal["modelB"]
877+
878+
required: str
879+
880+
data: Union[A, B]
881+
882+
# when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required`
883+
m = construct_type(
884+
value={"type": "modelB", "data": {"type": "a", "data": True}},
885+
type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]),
886+
)
887+
888+
assert isinstance(m, ModelB)

0 commit comments

Comments
 (0)