Skip to content

Commit 41da78a

Browse files
committed
Deep copy schema with directive with arg of custom type (#210)
1 parent 41058d7 commit 41da78a

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/graphql/type/schema.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
GraphQLAbstractType,
2222
GraphQLCompositeType,
2323
GraphQLField,
24+
GraphQLInputType,
2425
GraphQLInterfaceType,
2526
GraphQLNamedType,
2627
GraphQLObjectType,
@@ -293,6 +294,8 @@ def __deepcopy__(self, memo_: dict) -> GraphQLSchema:
293294
directive if is_specified_directive(directive) else copy(directive)
294295
for directive in self.directives
295296
]
297+
for directive in directives:
298+
remap_directive(directive, type_map)
296299
return self.__class__(
297300
self.query_type and cast(GraphQLObjectType, type_map[self.query_type.name]),
298301
self.mutation_type
@@ -458,11 +461,7 @@ def remapped_type(type_: GraphQLType, type_map: TypeMap) -> GraphQLType:
458461

459462
def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None:
460463
"""Change all references in the given named type to use this type map."""
461-
if is_union_type(type_):
462-
type_.types = [
463-
type_map.get(member_type.name, member_type) for member_type in type_.types
464-
]
465-
elif is_object_type(type_) or is_interface_type(type_):
464+
if is_object_type(type_) or is_interface_type(type_):
466465
type_.interfaces = [
467466
type_map.get(interface_type.name, interface_type)
468467
for interface_type in type_.interfaces
@@ -477,9 +476,22 @@ def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None:
477476
arg.type = remapped_type(arg.type, type_map)
478477
args[arg_name] = arg
479478
fields[field_name] = field
479+
elif is_union_type(type_):
480+
type_.types = [
481+
type_map.get(member_type.name, member_type) for member_type in type_.types
482+
]
480483
elif is_input_object_type(type_):
481484
fields = type_.fields
482485
for field_name, field in fields.items():
483486
field = copy(field) # noqa: PLW2901
484487
field.type = remapped_type(field.type, type_map)
485488
fields[field_name] = field
489+
490+
491+
def remap_directive(directive: GraphQLDirective, type_map: TypeMap) -> None:
492+
"""Change all references in the given directive to use this type map."""
493+
args = directive.args
494+
for arg_name, arg in args.items():
495+
arg = copy(arg) # noqa: PLW2901
496+
arg.type = cast(GraphQLInputType, remapped_type(arg.type, type_map))
497+
args[arg_name] = arg

tests/utilities/test_build_ast_schema.py

+19
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,25 @@ def can_deep_copy_schema():
12221222
# check that printing the copied schema gives the same SDL
12231223
assert print_schema(copied) == sdl
12241224

1225+
def can_deep_copy_schema_with_directive_using_args_of_custom_type():
1226+
sdl = dedent("""
1227+
directive @someDirective(someArg: SomeEnum) on FIELD_DEFINITION
1228+
1229+
enum SomeEnum {
1230+
ONE
1231+
TWO
1232+
}
1233+
1234+
type Query {
1235+
someField: String @someDirective(someArg: ONE)
1236+
}
1237+
""")
1238+
schema = build_schema(sdl)
1239+
copied = deepcopy(schema)
1240+
# custom directives on field definitions cannot be reproduced
1241+
expected_sdl = sdl.replace(" @someDirective(someArg: ONE)", "")
1242+
assert print_schema(copied) == expected_sdl
1243+
12251244
def can_pickle_and_unpickle_star_wars_schema():
12261245
# create a schema from the star wars SDL
12271246
schema = build_schema(sdl, assume_valid_sdl=True)

0 commit comments

Comments
 (0)