Skip to content

Commit

Permalink
fix(search): fix default entities for aggregation filters
Browse files Browse the repository at this point in the history
* Co-located aggregation functions in AggregationQueryBuilder
* Default list of entity specs for aggregation should include query by default field (at least 1)
* Improved error message for field Display Name collisions
  • Loading branch information
david-leifker committed Feb 1, 2024
1 parent d52818d commit 8d56168
Show file tree
Hide file tree
Showing 14 changed files with 510 additions and 421 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import com.linkedin.metadata.search.FilterValueArray;
import com.linkedin.metadata.search.ScrollResult;
import com.linkedin.metadata.search.SearchResult;
import com.linkedin.metadata.search.elasticsearch.query.request.AggregationQueryBuilder;
import com.linkedin.metadata.search.elasticsearch.query.request.AutocompleteRequestHandler;
import com.linkedin.metadata.search.elasticsearch.query.request.SearchAfterWrapper;
import com.linkedin.metadata.search.elasticsearch.query.request.SearchRequestHandler;
import com.linkedin.metadata.search.utils.QueryUtils;
import com.linkedin.metadata.utils.elasticsearch.IndexConvention;
import com.linkedin.metadata.utils.metrics.MetricUtils;
import io.opentelemetry.extension.annotations.WithSpan;
Expand Down Expand Up @@ -317,7 +319,7 @@ public Map<String, Long> aggregateByValue(
int limit) {
List<EntitySpec> entitySpecs;
if (entityNames == null || entityNames.isEmpty()) {
entitySpecs = new ArrayList<>(entityRegistry.getEntitySpecs().values());
entitySpecs = QueryUtils.getQueryByDefaultEntitySpecs(entityRegistry);
} else {
entitySpecs =
entityNames.stream().map(entityRegistry::getEntitySpec).collect(Collectors.toList());
Expand All @@ -341,7 +343,7 @@ public Map<String, Long> aggregateByValue(
MetricUtils.timer(this.getClass(), "aggregateByValue_search").time()) {
final SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
// extract results, validated against document model as well
return SearchRequestHandler.extractAggregationsFromResponse(searchResponse, field);
return AggregationQueryBuilder.extractAggregationsFromResponse(searchResponse, field);
} catch (Exception e) {
log.error("Aggregation query failed", e);
throw new ESQueryException("Aggregation query failed:", e);
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.linkedin.metadata.recommendation.candidatesource;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.testng.Assert.assertEquals;
Expand All @@ -11,6 +12,7 @@
import com.linkedin.common.urn.CorpuserUrn;
import com.linkedin.common.urn.TestEntityUrn;
import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.filter.Criterion;
import com.linkedin.metadata.recommendation.RecommendationContent;
import com.linkedin.metadata.recommendation.RecommendationParams;
Expand All @@ -29,6 +31,7 @@

public class EntitySearchAggregationCandidateSourceTest {
private EntitySearchService _entitySearchService = Mockito.mock(EntitySearchService.class);
private EntityRegistry entityRegistry = Mockito.mock(EntityRegistry.class);
private EntitySearchAggregationSource _valueBasedCandidateSource;
private EntitySearchAggregationSource _urnBasedCandidateSource;

Expand All @@ -45,7 +48,7 @@ public void setup() {

private EntitySearchAggregationSource buildCandidateSource(
String identifier, boolean isValueUrn) {
return new EntitySearchAggregationSource(_entitySearchService) {
return new EntitySearchAggregationSource(_entitySearchService, entityRegistry) {
@Override
protected String getSearchFieldName() {
return identifier;
Expand Down Expand Up @@ -98,8 +101,7 @@ public void testWhenSearchServiceReturnsEmpty() {
@Test
public void testWhenSearchServiceReturnsValueResults() {
// One result
Mockito.when(
_entitySearchService.aggregateByValue(eq(null), eq("testValue"), eq(null), anyInt()))
Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testValue"), eq(null), anyInt()))
.thenReturn(ImmutableMap.of("value1", 1L));
List<RecommendationContent> candidates =
_valueBasedCandidateSource.getRecommendations(USER, CONTEXT);
Expand All @@ -120,8 +122,7 @@ public void testWhenSearchServiceReturnsValueResults() {
assertTrue(_valueBasedCandidateSource.getRecommendationModule(USER, CONTEXT).isPresent());

// Multiple result
Mockito.when(
_entitySearchService.aggregateByValue(eq(null), eq("testValue"), eq(null), anyInt()))
Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testValue"), eq(null), anyInt()))
.thenReturn(ImmutableMap.of("value1", 1L, "value2", 2L, "value3", 3L));
candidates = _valueBasedCandidateSource.getRecommendations(USER, CONTEXT);
assertEquals(candidates.size(), 2);
Expand Down Expand Up @@ -160,7 +161,7 @@ public void testWhenSearchServiceReturnsUrnResults() {
Urn testUrn1 = new TestEntityUrn("testUrn1", "testUrn1", "testUrn1");
Urn testUrn2 = new TestEntityUrn("testUrn2", "testUrn2", "testUrn2");
Urn testUrn3 = new TestEntityUrn("testUrn3", "testUrn3", "testUrn3");
Mockito.when(_entitySearchService.aggregateByValue(eq(null), eq("testUrn"), eq(null), anyInt()))
Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testUrn"), eq(null), anyInt()))
.thenReturn(ImmutableMap.of(testUrn1.toString(), 1L));
List<RecommendationContent> candidates =
_urnBasedCandidateSource.getRecommendations(USER, CONTEXT);
Expand All @@ -181,7 +182,7 @@ public void testWhenSearchServiceReturnsUrnResults() {
assertTrue(_urnBasedCandidateSource.getRecommendationModule(USER, CONTEXT).isPresent());

// Multiple result
Mockito.when(_entitySearchService.aggregateByValue(eq(null), eq("testUrn"), eq(null), anyInt()))
Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testUrn"), eq(null), anyInt()))
.thenReturn(
ImmutableMap.of(
testUrn1.toString(), 1L, testUrn2.toString(), 2L, testUrn3.toString(), 3L));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package com.linkedin.metadata.search.query.request;

import static com.linkedin.metadata.utils.SearchUtil.*;
import static org.mockito.Mockito.mock;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.linkedin.metadata.config.search.SearchConfiguration;
import com.linkedin.metadata.models.EntitySpec;
import com.linkedin.metadata.models.annotation.SearchableAnnotation;
import com.linkedin.metadata.search.elasticsearch.query.request.AggregationQueryBuilder;
import java.util.Collections;
Expand Down Expand Up @@ -42,7 +45,8 @@ public void testGetDefaultAggregationsHasFields() {
config.setMaxTermBucketSize(25);

AggregationQueryBuilder builder =
new AggregationQueryBuilder(config, ImmutableList.of(annotation));
new AggregationQueryBuilder(
config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation)));

List<AggregationBuilder> aggs = builder.getAggregations();

Expand Down Expand Up @@ -73,7 +77,8 @@ public void testGetDefaultAggregationsFields() {
config.setMaxTermBucketSize(25);

AggregationQueryBuilder builder =
new AggregationQueryBuilder(config, ImmutableList.of(annotation));
new AggregationQueryBuilder(
config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation)));

List<AggregationBuilder> aggs = builder.getAggregations();

Expand Down Expand Up @@ -120,7 +125,9 @@ public void testGetSpecificAggregationsHasFields() {
config.setMaxTermBucketSize(25);

AggregationQueryBuilder builder =
new AggregationQueryBuilder(config, ImmutableList.of(annotation1, annotation2));
new AggregationQueryBuilder(
config,
ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation1, annotation2)));

// Case 1: Ask for fields that should exist.
List<AggregationBuilder> aggs =
Expand All @@ -139,7 +146,9 @@ public void testAggregateOverStructuredProperty() {
SearchConfiguration config = new SearchConfiguration();
config.setMaxTermBucketSize(25);

AggregationQueryBuilder builder = new AggregationQueryBuilder(config, List.of());
AggregationQueryBuilder builder =
new AggregationQueryBuilder(
config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of()));

List<AggregationBuilder> aggs =
builder.getAggregations(List.of("structuredProperties.ab.fgh.ten"));
Expand Down Expand Up @@ -202,7 +211,9 @@ public void testAggregateOverFieldsAndStructProp() {
config.setMaxTermBucketSize(25);

AggregationQueryBuilder builder =
new AggregationQueryBuilder(config, ImmutableList.of(annotation1, annotation2));
new AggregationQueryBuilder(
config,
ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation1, annotation2)));

// Aggregate over fields and structured properties
List<AggregationBuilder> aggs =
Expand Down Expand Up @@ -252,7 +263,8 @@ public void testMissingAggregation() {
config.setMaxTermBucketSize(25);

AggregationQueryBuilder builder =
new AggregationQueryBuilder(config, ImmutableList.of(annotation));
new AggregationQueryBuilder(
config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation)));

List<AggregationBuilder> aggs = builder.getAggregations();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.gms.factory.recommendation.candidatesource;

import com.linkedin.gms.factory.search.EntitySearchServiceFactory;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.recommendation.candidatesource.DomainsCandidateSource;
import com.linkedin.metadata.search.EntitySearchService;
import javax.annotation.Nonnull;
Expand All @@ -20,7 +21,7 @@ public class DomainsCandidateSourceFactory {

@Bean(name = "domainsCandidateSource")
@Nonnull
protected DomainsCandidateSource getInstance() {
return new DomainsCandidateSource(entitySearchService);
protected DomainsCandidateSource getInstance(final EntityRegistry entityRegistry) {
return new DomainsCandidateSource(entitySearchService, entityRegistry);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.gms.factory.recommendation.candidatesource;

import com.linkedin.gms.factory.search.EntitySearchServiceFactory;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.recommendation.candidatesource.TopTagsSource;
import com.linkedin.metadata.search.EntitySearchService;
import javax.annotation.Nonnull;
Expand All @@ -20,7 +21,7 @@ public class TopTagsCandidateSourceFactory {

@Bean(name = "topTagsCandidateSource")
@Nonnull
protected TopTagsSource getInstance() {
return new TopTagsSource(entitySearchService);
protected TopTagsSource getInstance(final EntityService<?> entityService) {
return new TopTagsSource(entitySearchService, entityService);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.gms.factory.recommendation.candidatesource;

import com.linkedin.gms.factory.search.EntitySearchServiceFactory;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.recommendation.candidatesource.TopTermsSource;
import com.linkedin.metadata.search.EntitySearchService;
import javax.annotation.Nonnull;
Expand All @@ -20,7 +21,7 @@ public class TopTermsCandidateSourceFactory {

@Bean(name = "topTermsCandidateSource")
@Nonnull
protected TopTermsSource getInstance() {
return new TopTermsSource(entitySearchService);
protected TopTermsSource getInstance(final EntityService<?> entityService) {
return new TopTermsSource(entitySearchService, entityService);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.metadata.recommendation.candidatesource;

import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.recommendation.RecommendationRenderType;
import com.linkedin.metadata.recommendation.RecommendationRequestContext;
import com.linkedin.metadata.recommendation.ScenarioType;
Expand All @@ -13,8 +14,9 @@ public class DomainsCandidateSource extends EntitySearchAggregationSource {

private static final String DOMAINS = "domains";

public DomainsCandidateSource(EntitySearchService entitySearchService) {
super(entitySearchService);
public DomainsCandidateSource(
EntitySearchService entitySearchService, EntityRegistry entityRegistry) {
super(entitySearchService, entityRegistry);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.google.common.collect.ImmutableList;
import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.models.EntitySpec;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.filter.Criterion;
import com.linkedin.metadata.query.filter.CriterionArray;
import com.linkedin.metadata.recommendation.ContentParams;
Expand All @@ -10,6 +12,7 @@
import com.linkedin.metadata.recommendation.RecommendationRequestContext;
import com.linkedin.metadata.recommendation.SearchParams;
import com.linkedin.metadata.search.EntitySearchService;
import com.linkedin.metadata.search.utils.QueryUtils;
import io.opentelemetry.extension.annotations.WithSpan;
import java.net.URISyntaxException;
import java.util.Collections;
Expand All @@ -35,7 +38,8 @@
@Slf4j
@RequiredArgsConstructor
public abstract class EntitySearchAggregationSource implements RecommendationSource {
private final EntitySearchService _entitySearchService;
private final EntitySearchService entitySearchService;
private final EntityRegistry entityRegistry;

/** Field to aggregate on */
protected abstract String getSearchFieldName();
Expand Down Expand Up @@ -69,8 +73,8 @@ protected <T> boolean isValidCandidate(T candidate) {
public List<RecommendationContent> getRecommendations(
@Nonnull Urn userUrn, @Nullable RecommendationRequestContext requestContext) {
Map<String, Long> aggregationResult =
_entitySearchService.aggregateByValue(
getEntityNames(), getSearchFieldName(), null, getMaxContent());
entitySearchService.aggregateByValue(
getEntityNames(entityRegistry), getSearchFieldName(), null, getMaxContent());

if (aggregationResult.isEmpty()) {
return Collections.emptyList();
Expand Down Expand Up @@ -110,9 +114,11 @@ public List<RecommendationContent> getRecommendations(
.collect(Collectors.toList());
}

protected List<String> getEntityNames() {
protected List<String> getEntityNames(EntityRegistry entityRegistry) {
// By default, no list is applied which means searching across entities.
return null;
return QueryUtils.getQueryByDefaultEntitySpecs(entityRegistry).stream()
.map(EntitySpec::getName)
.collect(Collectors.toList());
}

// Get top K entries with the most count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ public class TopPlatformsSource extends EntitySearchAggregationSource {
Constants.CONTAINER_ENTITY_NAME,
Constants.NOTEBOOK_ENTITY_NAME);

private final EntityService _entityService;
private final EntityService<?> _entityService;
private static final String PLATFORM = "platform";

public TopPlatformsSource(EntityService entityService, EntitySearchService entitySearchService) {
super(entitySearchService);
public TopPlatformsSource(
EntityService<?> entityService, EntitySearchService entitySearchService) {
super(entitySearchService, entityService.getEntityRegistry());
_entityService = entityService;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.metadata.recommendation.candidatesource;

import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.recommendation.RecommendationRenderType;
import com.linkedin.metadata.recommendation.RecommendationRequestContext;
import com.linkedin.metadata.recommendation.ScenarioType;
Expand All @@ -13,8 +14,8 @@ public class TopTagsSource extends EntitySearchAggregationSource {

private static final String TAGS = "tags";

public TopTagsSource(EntitySearchService entitySearchService) {
super(entitySearchService);
public TopTagsSource(EntitySearchService entitySearchService, EntityService<?> entityService) {
super(entitySearchService, entityService.getEntityRegistry());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.metadata.recommendation.candidatesource;

import com.linkedin.common.urn.Urn;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.recommendation.RecommendationRenderType;
import com.linkedin.metadata.recommendation.RecommendationRequestContext;
import com.linkedin.metadata.recommendation.ScenarioType;
Expand All @@ -13,8 +14,8 @@ public class TopTermsSource extends EntitySearchAggregationSource {

private static final String TERMS = "glossaryTerms";

public TopTermsSource(EntitySearchService entitySearchService) {
super(entitySearchService);
public TopTermsSource(EntitySearchService entitySearchService, EntityService<?> entityService) {
super(entitySearchService, entityService.getEntityRegistry());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import com.linkedin.data.template.RecordTemplate;
import com.linkedin.data.template.StringArray;
import com.linkedin.metadata.aspect.AspectVersion;
import com.linkedin.metadata.models.EntitySpec;
import com.linkedin.metadata.models.SearchableFieldSpec;
import com.linkedin.metadata.models.annotation.SearchableAnnotation;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.filter.Condition;
import com.linkedin.metadata.query.filter.ConjunctiveCriterion;
import com.linkedin.metadata.query.filter.ConjunctiveCriterionArray;
Expand All @@ -15,6 +19,7 @@
import com.linkedin.metadata.query.filter.Filter;
import com.linkedin.metadata.query.filter.RelationshipDirection;
import com.linkedin.metadata.query.filter.RelationshipFilter;
import com.linkedin.util.Pair;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -174,4 +179,20 @@ public static Filter getFilterFromCriteria(List<Criterion> criteria) {
new ConjunctiveCriterionArray(
new ConjunctiveCriterion().setAnd(new CriterionArray(criteria))));
}

public static List<EntitySpec> getQueryByDefaultEntitySpecs(EntityRegistry entityRegistry) {
return entityRegistry.getEntitySpecs().values().stream()
.map(
spec ->
Pair.of(
spec,
spec.getSearchableFieldSpecs().stream()
.map(SearchableFieldSpec::getSearchableAnnotation)
.collect(Collectors.toList())))
.filter(
specPair ->
specPair.getSecond().stream().anyMatch(SearchableAnnotation::isQueryByDefault))
.map(Pair::getFirst)
.collect(Collectors.toList());
}
}

0 comments on commit 8d56168

Please sign in to comment.