Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for federation data fetcher correct support namespaces #2172

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand All @@ -29,6 +30,7 @@
import graphql.schema.GraphQLNamedSchemaElement;
import graphql.schema.GraphQLNonNull;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import io.smallrye.graphql.spi.config.Config;

Expand All @@ -37,6 +39,7 @@ public class FederationDataFetcher implements DataFetcher<CompletableFuture<List
public static final String TYPENAME = "__typename";
private final GraphQLObjectType queryType;
private final GraphQLCodeRegistry codeRegistry;
private final HashMap<TypeAndArgumentNames, TypeFieldWrapper> cache = new HashMap<>();

public FederationDataFetcher(GraphQLObjectType queryType, GraphQLCodeRegistry codeRegistry) {
this.queryType = queryType;
Expand All @@ -55,17 +58,13 @@ public CompletableFuture<List<Object>> get(DataFetchingEnvironment environment)
var repsWithPositionPerType = representations.stream().collect(Collectors.groupingBy(r -> r.typeAndArgumentNames));
//then we search for the field definition to resolve the objects
var fieldDefinitions = repsWithPositionPerType.keySet().stream()
.collect(Collectors.toMap(Function.identity(), typeAndArgumentNames -> {
var batchDefinition = findBatchFieldDefinition(typeAndArgumentNames);
if (batchDefinition == null) {
return findFieldDefinition(typeAndArgumentNames);
} else {
return batchDefinition;
}
}));
.collect(Collectors.toMap(Function.identity(), typeAndArgumentNames -> cache.computeIfAbsent(
typeAndArgumentNames, type -> Objects.requireNonNullElseGet(
findBatchFieldDefinition(type),
() -> findFieldDefinition(type)))));
return sequence(repsWithPositionPerType.entrySet().stream().map(e -> {
var fieldDefinition = fieldDefinitions.get(e.getKey());
if (getGraphqlTypeFromField(fieldDefinition) instanceof GraphQLList) {
if (getGraphqlTypeFromField(fieldDefinition.getField()) instanceof GraphQLList) {
//use batch loader if available
return executeList(fieldDefinition, environment, e.getValue());
} else {
Expand All @@ -79,35 +78,67 @@ public CompletableFuture<List<Object>> get(DataFetchingEnvironment environment)
.sorted(Comparator.comparingInt(r -> r.position)).map(r -> r.Result).collect(Collectors.toList()));

}
Map<TypeAndArgumentNames, GraphQLFieldDefinition> cache = new HashMap<>();
return sequence(representations.stream()
.map(rep -> fetchEntities(environment, rep,
cache.computeIfAbsent(rep.typeAndArgumentNames, this::findFieldDefinition)))
.collect(Collectors.toList())).thenApply(l -> l.stream().map(r -> r.Result).collect(Collectors.toList()));
}

private GraphQLFieldDefinition findBatchFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
private TypeFieldWrapper findRecursiveFieldDefinition(TypeAndArgumentNames typeAndArgumentNames,
GraphQLFieldDefinition field, BiFunction<GraphQLFieldDefinition, String, Boolean> matchesReturnType) {
if (field.getType() instanceof GraphQLObjectType) {
for (GraphQLSchemaElement child : field.getType().getChildren()) {
if (child instanceof GraphQLFieldDefinition) {
GraphQLFieldDefinition definition = (GraphQLFieldDefinition) child;
if (matchesReturnType.apply(definition, typeAndArgumentNames.type)
&& matchesArguments(typeAndArgumentNames, definition)) {
return new TypeFieldWrapper((GraphQLObjectType) field.getType(), definition);
} else if (definition.getType() instanceof GraphQLObjectType) {
return findRecursiveFieldDefinition(typeAndArgumentNames, definition, matchesReturnType);
}
}
}
}
return null;
}

private TypeFieldWrapper findBatchFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
for (GraphQLFieldDefinition field : queryType.getFields()) {
if (matchesReturnTypeList(field, typeAndArgumentNames.type) && matchesArguments(typeAndArgumentNames, field)) {
return field;
return new TypeFieldWrapper(queryType, field);
}
}
for (GraphQLFieldDefinition field : queryType.getFields()) {
TypeFieldWrapper typeFieldWrapper = findRecursiveFieldDefinition(typeAndArgumentNames, field,
this::matchesReturnTypeList);
if (typeFieldWrapper != null) {
return typeFieldWrapper;
}
}
return null;
}

private GraphQLFieldDefinition findFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
private TypeFieldWrapper findFieldDefinition(TypeAndArgumentNames typeAndArgumentNames) {
for (GraphQLFieldDefinition field : queryType.getFields()) {
if (matchesReturnType(field, typeAndArgumentNames.type) && matchesArguments(typeAndArgumentNames, field)) {
return field;
return new TypeFieldWrapper(queryType, field);
}
}
for (GraphQLFieldDefinition field : queryType.getFields()) {
TypeFieldWrapper typeFieldWrapper = findRecursiveFieldDefinition(typeAndArgumentNames, field,
this::matchesReturnType);
if (typeFieldWrapper != null) {
return typeFieldWrapper;
}
}

throw new RuntimeException(
"no query found for " + typeAndArgumentNames.type + " by " + typeAndArgumentNames.argumentNames);
}

private CompletableFuture<ResultObject> fetchEntities(DataFetchingEnvironment env, Representation representation,
GraphQLFieldDefinition field) {
return execute(field, env, representation);
TypeFieldWrapper wrapper) {
return execute(wrapper, env, representation);
}

private boolean matchesReturnType(GraphQLFieldDefinition field, String typename) {
Expand Down Expand Up @@ -140,9 +171,9 @@ private boolean matchesArguments(TypeAndArgumentNames typeAndArgumentNames, Grap
return argumentNames.equals(typeAndArgumentNames.argumentNames);
}

private CompletableFuture<List<ResultObject>> executeList(GraphQLFieldDefinition field, DataFetchingEnvironment env,
private CompletableFuture<List<ResultObject>> executeList(TypeFieldWrapper wrapper, DataFetchingEnvironment env,
List<Representation> representations) {
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(queryType, field);
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(wrapper.getType(), wrapper.getField());
Map<String, List<Object>> arguments = new HashMap<>();
representations.forEach(r -> {
r.arguments.forEach((argumentName, argumentValue) -> {
Expand Down Expand Up @@ -183,7 +214,7 @@ public <T> T getArgumentOrDefault(String name, T defaultValue) {
resultList = (List<Object>) results;
} else {
throw new IllegalStateException(
"Result of batchDataFetcher for Field " + field.getName() + " needs to be a list"
"Result of batchDataFetcher for Field " + wrapper.getField().getName() + " needs to be a list"
+ results.toString());
}

Expand All @@ -197,13 +228,13 @@ public <T> T getArgumentOrDefault(String name, T defaultValue) {
.collect(Collectors.toList());
});
} catch (Exception e) {
throw new RuntimeException("can't fetch data from " + field, e);
throw new RuntimeException("can't fetch data from " + wrapper.getField(), e);
}
}

private CompletableFuture<ResultObject> execute(GraphQLFieldDefinition field, DataFetchingEnvironment env,
private CompletableFuture<ResultObject> execute(TypeFieldWrapper wrapper, DataFetchingEnvironment env,
Representation representation) {
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(queryType, field);
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(wrapper.getType(), wrapper.getField());
DataFetchingEnvironment argsEnv = new DelegatingDataFetchingEnvironment(env) {
@Override
public Map<String, Object> getArguments() {
Expand All @@ -230,7 +261,7 @@ public <T> T getArgumentOrDefault(String name, T defaultValue) {
return Async.toCompletableFuture(dataFetcher.get(argsEnv))
.thenApply(o -> new ResultObject(o, representation.position));
} catch (Exception e) {
throw new RuntimeException("can't fetch data from " + field, e);
throw new RuntimeException("can't fetch data from " + wrapper.getField(), e);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.smallrye.graphql.bootstrap;

import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLObjectType;

class TypeFieldWrapper {
private final GraphQLObjectType type;
private final GraphQLFieldDefinition field;

public TypeFieldWrapper(GraphQLObjectType type, GraphQLFieldDefinition field) {
this.type = type;
this.field = field;
}

public GraphQLObjectType getType() {
return type;
}

public GraphQLFieldDefinition getField() {
return field;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package io.smallrye.graphql.execution;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.io.IOException;
import java.io.InputStream;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import jakarta.json.JsonObject;
import jakarta.json.JsonString;
import jakarta.json.JsonValue;

import org.jboss.jandex.IndexView;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import graphql.schema.GraphQLSchema;
import io.smallrye.graphql.api.Directive;
import io.smallrye.graphql.api.federation.External;
import io.smallrye.graphql.api.federation.Key;
import io.smallrye.graphql.bootstrap.Bootstrap;
import io.smallrye.graphql.schema.SchemaBuilder;
import io.smallrye.graphql.schema.model.Schema;
import io.smallrye.graphql.spi.config.Config;
import io.smallrye.graphql.test.namespace.NamedNamespaceModel;
import io.smallrye.graphql.test.namespace.NamedNamespaceTestApi;
import io.smallrye.graphql.test.namespace.NamedNamespaceWIthGroupingKeyModel;
import io.smallrye.graphql.test.namespace.NamedNamespaceWithGroupingKeyTestApi;
import io.smallrye.graphql.test.namespace.SourceNamespaceModel;
import io.smallrye.graphql.test.namespace.SourceNamespaceTestApi;
import io.smallrye.graphql.test.namespace.UnamedModel;
import io.smallrye.graphql.test.namespace.UnnamedTestApi;

/**
* Test for Federated namespaces
*/
public class FederatedNamespaceTest {
private static final TestConfig config = (TestConfig) Config.get();
private static ExecutionService executionService;

@AfterAll
static void afterAll() {
config.reset();
config.federationEnabled = false;
System.setProperty("smallrye.graphql.federation.enabled", "false");
}

@BeforeAll
static void beforeAll() {
config.federationEnabled = true;
System.setProperty("smallrye.graphql.federation.enabled", "true");

IndexView index = buildIndex(Directive.class, Key.class, External.class, Key.Keys.class,
NamedNamespaceModel.class, NamedNamespaceTestApi.class,
NamedNamespaceWIthGroupingKeyModel.class, NamedNamespaceWithGroupingKeyTestApi.class,
SourceNamespaceModel.class, SourceNamespaceTestApi.class,
SourceNamespaceTestApi.First.class, SourceNamespaceTestApi.Second.class,
UnamedModel.class, UnnamedTestApi.class);

GraphQLSchema graphQLSchema = createGraphQLSchema(index);
Schema schema = SchemaBuilder.build(index);
executionService = new ExecutionService(graphQLSchema, schema);
}

private static IndexView buildIndex(Class<?>... classes) {
org.jboss.jandex.Indexer indexer = new org.jboss.jandex.Indexer();
Stream.of(classes).forEach(cls -> index(indexer, cls));
return indexer.complete();
}

private static InputStream getResourceStream(Class<?> type) {
String name = type.getName().replace(".", "/") + ".class";
return Thread.currentThread().getContextClassLoader().getResourceAsStream(name);
}

private static void index(org.jboss.jandex.Indexer indexer, Class<?> cls) {
try {
indexer.index(getResourceStream(cls));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static GraphQLSchema createGraphQLSchema(IndexView index) {
Schema schema = SchemaBuilder.build(index);
assertNotNull(schema, "Schema should not be null");
GraphQLSchema graphQLSchema = Bootstrap.bootstrap(schema, true);
assertNotNull(graphQLSchema, "GraphQLSchema should not be null");
return graphQLSchema;
}

private static JsonObject executeAndGetResult(String graphQL) {
JsonObjectResponseWriter jsonObjectResponseWriter = new JsonObjectResponseWriter(graphQL);
jsonObjectResponseWriter.logInput();
executionService.executeSync(jsonObjectResponseWriter.getInput(), jsonObjectResponseWriter);
jsonObjectResponseWriter.logOutput();
return jsonObjectResponseWriter.getOutput();
}

private void test(String type, String id) {
JsonObject jsonObject = executeAndGetResult(TEST_QUERY.apply(type, id));
assertNotNull(jsonObject);

JsonValue jsonValue = jsonObject.getJsonObject("data")
.getJsonArray("_entities")
.getJsonObject(0)
.get("value");
String value = ((JsonString) jsonValue).getString();
assertEquals(value, id);
}

@Test
public void findEntityWithoutNamespace() {
test(UnamedModel.class.getSimpleName(), "unnamed_id");
}

@Test
public void findEntityWithNameNamespace() {
test(NamedNamespaceModel.class.getSimpleName(), "named_id");
}

@Test
public void findEntityWithSourceNamespace() {
test(SourceNamespaceModel.class.getSimpleName(), "source_id");
}

@Test
public void findEntityWithWithGroupedKeyAndNamespace() {
String id = "grouped_key";

JsonObject jsonObject = executeAndGetResult(GROUPED_KEY_QUERY.apply(
NamedNamespaceWIthGroupingKeyModel.class.getSimpleName(),
id));
assertNotNull(jsonObject);

JsonValue jsonValue = jsonObject.getJsonObject("data")
.getJsonArray("_entities")
.getJsonObject(0)
.get("value");
String value = ((JsonString) jsonValue).getString();
assertEquals(value, id);

jsonValue = jsonObject.getJsonObject("data")
.getJsonArray("_entities")
.getJsonObject(0)
.get("anotherId");
String anotherId = ((JsonString) jsonValue).getString();
assertEquals(anotherId, "otherKey_" + id);
}

private static final BiFunction<String, String, String> GROUPED_KEY_QUERY = (type, id) -> "query {\n" +
"_entities(\n" +
" representations: { id: \"" + id + "\", anotherId : \"otherKey_" + id + "\", __typename: \"" + type + "\" }\n" +
") {\n" +
" __typename\n" +
" ... on " + type + " {\n" +
" id\n" +
" anotherId\n" +
" value\n" +
" }\n" +
" }\n" +
"}";

private static final BiFunction<String, String, String> TEST_QUERY = (type, id) -> "query {\n" +
"_entities(\n" +
" representations: { id: \"" + id + "\", __typename: \"" + type + "\" }\n" +
") {\n" +
" __typename\n" +
" ... on " + type + " {\n" +
" id\n" +
" value\n" +
" }\n" +
" }\n" +
"}";
}
Loading
Loading