Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional feature for aggregate count query support on plural entity object types #488

Merged
merged 11 commits into from
May 27, 2024
Prev Previous commit
Next Next commit
Apply prettier formatting
igdianov committed May 26, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 363c1413426402acf5258ff07bdf4892adec5002
Original file line number Diff line number Diff line change
@@ -115,11 +115,17 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
getFields(aggregateField.getSelectionSet(), "count")
.forEach(countField -> {
getCountOfArgument(countField)
.ifPresentOrElse(argument ->
aggregate.put(getAliasOrName(countField), queryFactory.queryAggregateCount(argument, environment, restrictedKeys))
,
.ifPresentOrElse(
argument ->
aggregate.put(
getAliasOrName(countField),
queryFactory.queryAggregateCount(argument, environment, restrictedKeys)
),
() ->
aggregate.put(getAliasOrName(countField), queryFactory.queryTotalCount(environment, restrictedKeys))
aggregate.put(
getAliasOrName(countField),
queryFactory.queryTotalCount(environment, restrictedKeys)
)
);
});

@@ -132,17 +138,23 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {

var countOfArgumentValue = getCountOfArgument(groupField);

Map.Entry<String, String>[] groupings =
getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
}

var resultList = queryFactory.queryAggregateGroupByCount(getAliasOrName(countField), countOfArgumentValue, environment, restrictedKeys, groupings)
var resultList = queryFactory
.queryAggregateGroupByCount(
getAliasOrName(countField),
countOfArgumentValue,
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
@@ -184,8 +196,7 @@ static Map.Entry<String, String> groupByFieldEntry(Field selectedField) {
static Map.Entry<String, String> countFieldEntry(Field selectedField) {
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());

String value = getCountOfArgument(selectedField)
.orElse(selectedField.getName());
String value = getCountOfArgument(selectedField).orElse(selectedField.getName());

return Map.entry(key, value);
}
@@ -197,7 +208,6 @@ static Optional<String> getCountOfArgument(Field selectedField) {
.map(EnumValue::getName);
}


public int getDefaultMaxResults() {
return defaultMaxResults;
}
Original file line number Diff line number Diff line change
@@ -356,7 +356,11 @@ public Long queryTotalCount(DataFetchingEnvironment environment, Optional<List<O
return 0L;
}

public Long queryAggregateCount(String aggregate, DataFetchingEnvironment environment, Optional<List<Object>> restrictedKeys) {
public Long queryAggregateCount(
String aggregate,
DataFetchingEnvironment environment,
Optional<List<Object>> restrictedKeys
) {
final MergedField queryField = flattenEmbeddedIdArguments(environment.getField());

final DataFetchingEnvironment queryEnvironment = getQueryEnvironment(environment, queryField);
@@ -379,7 +383,13 @@ public Long queryAggregateCount(String aggregate, DataFetchingEnvironment enviro
return 0L;
}

public List<Map> queryAggregateGroupByCount(String alias, Optional<String> countOf, DataFetchingEnvironment environment, Optional<List<Object>> restrictedKeys, Map.Entry<String,String>... groupings) {
public List<Map> queryAggregateGroupByCount(
String alias,
Optional<String> countOf,
DataFetchingEnvironment environment,
Optional<List<Object>> restrictedKeys,
Map.Entry<String, String>... groupings
) {
final MergedField queryField = flattenEmbeddedIdArguments(environment.getField());

final DataFetchingEnvironment queryEnvironment = getQueryEnvironment(environment, queryField);
@@ -451,11 +461,17 @@ protected TypedQuery<Long> getCountQuery(DataFetchingEnvironment environment, Fi
return entityManager.createQuery(query);
}

protected TypedQuery<Long> getAggregateCountQuery(DataFetchingEnvironment environment, Field field, String aggregate, List<Object> keys, String... groupings) {
protected TypedQuery<Long> getAggregateCountQuery(
DataFetchingEnvironment environment,
Field field,
String aggregate,
List<Object> keys,
String... groupings
) {
CriteriaBuilder cb = entityManager.getCriteriaBuilder();
CriteriaQuery<Long> query = cb.createQuery(Long.class);
Root<?> root = query.from(entityType);
Join<?,?> join = root.join(aggregate);
Join<?, ?> join = root.join(aggregate);

DataFetchingEnvironment queryEnvironment = DataFetchingEnvironmentBuilder
.newDataFetchingEnvironment(environment)
@@ -482,7 +498,14 @@ protected TypedQuery<Long> getAggregateCountQuery(DataFetchingEnvironment enviro
return entityManager.createQuery(query);
}

protected TypedQuery<Map> getAggregateGroupByCountQuery(DataFetchingEnvironment environment, Field field, String alias, Optional<String> countOfJoin, List<Object> keys, Map.Entry<String,String>... groupBy) {
protected TypedQuery<Map> getAggregateGroupByCountQuery(
DataFetchingEnvironment environment,
Field field,
String alias,
Optional<String> countOfJoin,
List<Object> keys,
Map.Entry<String, String>... groupBy
) {
final CriteriaBuilder cb = entityManager.getCriteriaBuilder();
final CriteriaQuery<Map> query = cb.createQuery(Map.class);
final Root<?> root = query.from(entityType);
@@ -494,20 +517,17 @@ protected TypedQuery<Map> getAggregateGroupByCountQuery(DataFetchingEnvironment

final List<Selection<?>> selections = new ArrayList<>();

Stream.of(groupBy)
.map(group -> root.get(group.getValue()).alias(group.getKey()))
.forEach(selections::add);
Stream.of(groupBy).map(group -> root.get(group.getValue()).alias(group.getKey())).forEach(selections::add);

final Expression<?>[] groupings = Stream
.of(groupBy)
.map(group -> root.get(group.getValue()))
.map(group -> root.get(group.getValue()))
.toArray(Expression[]::new);

countOfJoin
.ifPresentOrElse(
it ->selections.add(cb.count(root.join(it)).alias(alias)),
() -> selections.add(cb.count(root).alias(alias))
);
countOfJoin.ifPresentOrElse(
it -> selections.add(cb.count(root.join(it)).alias(alias)),
() -> selections.add(cb.count(root).alias(alias))
);

query.multiselect(selections).groupBy(groupings);

Original file line number Diff line number Diff line change
@@ -449,58 +449,63 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT
.map(name -> newEnumValueDefinition().name(name).build())
.toList();

if (entityType.getAttributes()
.stream()
.anyMatch(Attribute::isAssociation)) {
countFieldDefinition
.argument(newArgument()
if (entityType.getAttributes().stream().anyMatch(Attribute::isAssociation)) {
countFieldDefinition.argument(
newArgument()
.name("of")
.type(newEnum()
.name(aggregateObjectTypeName.concat("CountOfAssociationsEnum"))
.values(associationEnumValueDefinitions)
.build()));
.type(
newEnum()
.name(aggregateObjectTypeName.concat("CountOfAssociationsEnum"))
.values(associationEnumValueDefinitions)
.build()
)
);
}


var groupFieldDefinition = newFieldDefinition()
.name("group")
.dataFetcher(aggregateDataFetcher)
.type(new GraphQLList(newObject()
.name(aggregateObjectTypeName.concat("GroupBy"))
.field(newFieldDefinition()
.name("by")
.dataFetcher(aggregateDataFetcher)
.argument(newArgument()
.name("field")
.type(newEnum()
.name(aggregateObjectTypeName.concat("GroupByFieldsEnum"))
.values(fieldsEnumValueDefinitions)
.build()))
.type(GraphQLString))
.field(newFieldDefinition()
.name("count")
.type(GraphQLInt))
.build()));

if (entityType.getAttributes()
.stream()
.anyMatch(Attribute::isAssociation)) {
groupFieldDefinition
.argument(newArgument()
.type(
new GraphQLList(
newObject()
.name(aggregateObjectTypeName.concat("GroupBy"))
.field(
newFieldDefinition()
.name("by")
.dataFetcher(aggregateDataFetcher)
.argument(
newArgument()
.name("field")
.type(
newEnum()
.name(aggregateObjectTypeName.concat("GroupByFieldsEnum"))
.values(fieldsEnumValueDefinitions)
.build()
)
)
.type(GraphQLString)
)
.field(newFieldDefinition().name("count").type(GraphQLInt))
.build()
)
);

if (entityType.getAttributes().stream().anyMatch(Attribute::isAssociation)) {
groupFieldDefinition.argument(
newArgument()
.name("of")
.type(newEnum()
.name(aggregateObjectTypeName.concat("GroupOfAssociationsEnum"))
.values(associationEnumValueDefinitions)
.build()));
.type(
newEnum()
.name(aggregateObjectTypeName.concat("GroupOfAssociationsEnum"))
.values(associationEnumValueDefinitions)
.build()
)
);
}

aggregateObjectType
.field(countFieldDefinition)
.field(groupFieldDefinition);
aggregateObjectType.field(countFieldDefinition).field(groupFieldDefinition);

var aggregateFieldDefinition = newFieldDefinition()
.name("aggregate")
.type(aggregateObjectType);
var aggregateFieldDefinition = newFieldDefinition().name("aggregate").type(aggregateObjectType);

return aggregateFieldDefinition.build();
}
@@ -1211,8 +1216,7 @@ private GraphQLFieldDefinition getObjectField(Attribute attribute, EntityType ba
DataFetcher dataFetcher = PropertyDataFetcher.fetching(attribute.getName());

// Only add the orderBy argument for basic attribute types
if (isBasic(attribute) && isNotIgnoredOrder(attribute)
) {
if (isBasic(attribute) && isNotIgnoredOrder(attribute)) {
arguments.add(
GraphQLArgument
.newArgument()
@@ -1235,9 +1239,7 @@ private GraphQLFieldDefinition getObjectField(Attribute attribute, EntityType ba
// to-one end could be optional
arguments.add(optionalArgument(singularAttribute.isOptional()));

GraphQLObjectType entityObjectType = newObject()
.name(resolveEntityObjectTypeName(baseEntity))
.build();
GraphQLObjectType entityObjectType = newObject().name(resolveEntityObjectTypeName(baseEntity)).build();

GraphQLJpaQueryFactory graphQLJpaQueryFactory = GraphQLJpaQueryFactory
.builder()
@@ -1273,9 +1275,7 @@ else if (isPlural(attribute)) {
// make it configurable via builder api
arguments.add(optionalArgument(toManyDefaultOptional));

GraphQLObjectType entityObjectType = newObject()
.name(resolveEntityObjectTypeName(baseEntity))
.build();
GraphQLObjectType entityObjectType = newObject().name(resolveEntityObjectTypeName(baseEntity)).build();

GraphQLJpaQueryFactory graphQLJpaQueryFactory = GraphQLJpaQueryFactory
.builder()
@@ -1420,8 +1420,10 @@ protected final boolean isEmbeddable(Attribute<?, ?> attribute) {
}

protected final boolean isBasic(Attribute<?, ?> attribute) {
return attribute instanceof SingularAttribute &&
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC;
return (
attribute instanceof SingularAttribute &&
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.BASIC
);
}

protected final boolean isElementCollection(Attribute<?, ?> attribute) {
@@ -1446,17 +1448,21 @@ protected final boolean isToOne(Attribute<?, ?> attribute) {
);
}

private boolean isPlural(Attribute<?,?> attribute) {
return attribute instanceof PluralAttribute &&
private boolean isPlural(Attribute<?, ?> attribute) {
return (
attribute instanceof PluralAttribute &&
(
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.ONE_TO_MANY ||
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.MANY_TO_MANY
);
attribute.getPersistentAttributeType() == Attribute.PersistentAttributeType.MANY_TO_MANY
)
);
}

private boolean isSingular(Attribute<?,?> attribute) {
return attribute instanceof SingularAttribute &&
attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC;
private boolean isSingular(Attribute<?, ?> attribute) {
return (
attribute instanceof SingularAttribute &&
attribute.getPersistentAttributeType() != Attribute.PersistentAttributeType.BASIC
);
}

protected final boolean isValidInput(Attribute<?, ?> attribute) {
Original file line number Diff line number Diff line change
@@ -139,10 +139,7 @@ public static final Optional<Field> getSelectionField(Field field, String fieldN
public static Optional<Argument> findArgument(Field selectedField, String name) {
return Optional
.ofNullable(selectedField.getArguments())
.flatMap(arguments -> arguments
.stream()
.filter(argument -> name.equals(argument.getName()))
.findFirst());
.flatMap(arguments -> arguments.stream().filter(argument -> name.equals(argument.getName())).findFirst());
}

public static List<Field> getFields(SelectionSet selections, String fieldName) {
Loading