Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional feature for aggregate count query support on plural entity object types #488

Merged
merged 11 commits into from
May 27, 2024
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@
<version>3.6.3</version>
<configuration>
<source>${java.version}</source>
<doclint>none</doclint>
</configuration>
<executions>
<execution>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ private static NoSuchElementException noSuchElementException(Class<?> containerC
/**
* Returns a String which capitalizes the first letter of the string.
*/
private static String capitalize(String name) {
public static String capitalize(String name) {
if (name == null || name.length() == 0) {
return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,28 @@
import static com.introproventures.graphql.jpa.query.schema.impl.GraphQLJpaSchemaBuilder.PAGE_TOTAL_PARAM_NAME;
import static com.introproventures.graphql.jpa.query.schema.impl.GraphQLJpaSchemaBuilder.QUERY_SELECT_PARAM_NAME;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.extractPageArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.findArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getAliasOrName;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getFields;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getPageArgument;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.getSelectionField;
import static com.introproventures.graphql.jpa.query.support.GraphQLSupport.searchByFieldName;

import com.introproventures.graphql.jpa.query.schema.JavaScalars;
import graphql.GraphQLException;
import graphql.language.Argument;
import graphql.language.EnumValue;
import graphql.language.Field;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLScalarType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -65,6 +76,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
Optional<Field> pagesSelection = getSelectionField(rootNode, PAGE_PAGES_PARAM_NAME);
Optional<Field> totalSelection = getSelectionField(rootNode, PAGE_TOTAL_PARAM_NAME);
Optional<Field> recordsSelection = searchByFieldName(rootNode, QUERY_SELECT_PARAM_NAME);
Optional<Field> aggregateSelection = getSelectionField(rootNode, "aggregate");

final int firstResult = page.getOffset();
final int maxResults = Integer.min(page.getLimit(), defaultMaxResults); // Limit max results to avoid OoM
Expand Down Expand Up @@ -98,9 +110,155 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
pagedResult.withTotal(total);
}

aggregateSelection.ifPresent(aggregateField -> {
Map<String, Object> aggregate = new LinkedHashMap<>();

getFields(aggregateField.getSelectionSet(), "count")
.forEach(countField -> {
getCountOfArgument(countField)
.ifPresentOrElse(
argument ->
aggregate.put(
getAliasOrName(countField),
queryFactory.queryAggregateCount(argument, environment, restrictedKeys)
),
() ->
aggregate.put(
getAliasOrName(countField),
queryFactory.queryTotalCount(environment, restrictedKeys)
)
);
});

getFields(aggregateField.getSelectionSet(), "group")
.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));

var countOfArgumentValue = getCountOfArgument(countField);

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
}

var resultList = queryFactory
.queryAggregateGroupByCount(
getAliasOrName(countField),
countOfArgumentValue,
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
.of(groupings)
.forEach(group -> {
var value = map.get(group.getKey());

Optional
.ofNullable(value)
.map(Object::getClass)
.map(JavaScalars::of)
.map(GraphQLScalarType::getCoercing)
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
})
)
.toList();

aggregate.put(getAliasOrName(groupField), resultList);
});

aggregateField
.getSelectionSet()
.getSelections()
.stream()
.filter(Field.class::isInstance)
.map(Field.class::cast)
.filter(it -> !Arrays.asList("count", "group").contains(it.getName()))
.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
}

var resultList = queryFactory
.queryAggregateGroupByAssociationCount(
getAliasOrName(countField),
groupField.getName(),
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
.of(groupings)
.forEach(group -> {
var value = map.get(group.getKey());

Optional
.ofNullable(value)
.map(Object::getClass)
.map(JavaScalars::of)
.map(GraphQLScalarType::getCoercing)
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
})
)
.toList();

aggregate.put(getAliasOrName(groupField), resultList);
});

pagedResult.withAggregate(aggregate);
});

return pagedResult.build();
}

static Map.Entry<String, String> groupByFieldEntry(Field selectedField) {
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());

String value = findArgument(selectedField, "field")
.map(Argument::getValue)
.map(EnumValue.class::cast)
.map(EnumValue::getName)
.orElseThrow(() -> new GraphQLException("group by argument is required."));

return Map.entry(key, value);
}

static Map.Entry<String, String> countFieldEntry(Field selectedField) {
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());

String value = getCountOfArgument(selectedField).orElse(selectedField.getName());

return Map.entry(key, value);
}

static Optional<String> getCountOfArgument(Field selectedField) {
return findArgument(selectedField, "of")
.map(Argument::getValue)
.map(EnumValue.class::cast)
.map(EnumValue::getName);
}

public int getDefaultMaxResults() {
return defaultMaxResults;
}
Expand Down
Loading
Loading