diff --git a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java index 7de2770626ae34..76153a8d2adb3f 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java @@ -22,9 +22,11 @@ import com.linkedin.metadata.search.FilterValueArray; import com.linkedin.metadata.search.ScrollResult; import com.linkedin.metadata.search.SearchResult; +import com.linkedin.metadata.search.elasticsearch.query.request.AggregationQueryBuilder; import com.linkedin.metadata.search.elasticsearch.query.request.AutocompleteRequestHandler; import com.linkedin.metadata.search.elasticsearch.query.request.SearchAfterWrapper; import com.linkedin.metadata.search.elasticsearch.query.request.SearchRequestHandler; +import com.linkedin.metadata.search.utils.QueryUtils; import com.linkedin.metadata.utils.elasticsearch.IndexConvention; import com.linkedin.metadata.utils.metrics.MetricUtils; import io.opentelemetry.extension.annotations.WithSpan; @@ -317,7 +319,7 @@ public Map aggregateByValue( int limit) { List entitySpecs; if (entityNames == null || entityNames.isEmpty()) { - entitySpecs = new ArrayList<>(entityRegistry.getEntitySpecs().values()); + entitySpecs = QueryUtils.getQueryByDefaultEntitySpecs(entityRegistry); } else { entitySpecs = entityNames.stream().map(entityRegistry::getEntitySpec).collect(Collectors.toList()); @@ -341,7 +343,7 @@ public Map aggregateByValue( MetricUtils.timer(this.getClass(), "aggregateByValue_search").time()) { final SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); // extract results, validated against document model as well - return SearchRequestHandler.extractAggregationsFromResponse(searchResponse, field); + return AggregationQueryBuilder.extractAggregationsFromResponse(searchResponse, field); } catch (Exception e) { log.error("Aggregation query failed", e); throw new ESQueryException("Aggregation query failed:", e); diff --git a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/AggregationQueryBuilder.java b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/AggregationQueryBuilder.java index bdc0332b040df9..887d4b22f37e24 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/AggregationQueryBuilder.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/AggregationQueryBuilder.java @@ -1,36 +1,71 @@ package com.linkedin.metadata.search.elasticsearch.query.request; import static com.linkedin.metadata.Constants.*; +import static com.linkedin.metadata.search.utils.ESUtils.toFacetField; import static com.linkedin.metadata.utils.SearchUtil.*; +import com.linkedin.data.template.LongMap; import com.linkedin.metadata.config.search.SearchConfiguration; +import com.linkedin.metadata.models.EntitySpec; import com.linkedin.metadata.models.StructuredPropertyUtils; import com.linkedin.metadata.models.annotation.SearchableAnnotation; +import com.linkedin.metadata.query.filter.ConjunctiveCriterion; +import com.linkedin.metadata.query.filter.ConjunctiveCriterionArray; +import com.linkedin.metadata.query.filter.Criterion; +import com.linkedin.metadata.query.filter.CriterionArray; +import com.linkedin.metadata.query.filter.Filter; +import com.linkedin.metadata.search.AggregationMetadata; +import com.linkedin.metadata.search.FilterValueArray; import com.linkedin.metadata.search.utils.ESUtils; +import com.linkedin.metadata.utils.SearchUtil; +import com.linkedin.util.Pair; +import io.opentelemetry.extension.annotations.WithSpan; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.BinaryOperator; import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang.StringUtils; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.missing.ParsedMissing; +import org.opensearch.search.aggregations.bucket.terms.ParsedTerms; +import org.opensearch.search.aggregations.bucket.terms.Terms; @Slf4j public class AggregationQueryBuilder { + private static final String URN_FILTER = "urn"; - private final SearchConfiguration _configs; - private final Set _defaultFacetFields; - private final Set _allFacetFields; + private final SearchConfiguration configs; + private final Set defaultFacetFields; + private final Set allFacetFields; + private final Map> entitySearchAnnotations; + + private Map filtersToDisplayName; public AggregationQueryBuilder( @Nonnull final SearchConfiguration configs, - @Nonnull final List annotations) { - this._configs = Objects.requireNonNull(configs, "configs must not be null"); - this._defaultFacetFields = getDefaultFacetFields(annotations); - this._allFacetFields = getAllFacetFields(annotations); + @Nonnull Map> entitySearchAnnotations) { + this.configs = Objects.requireNonNull(configs, "configs must not be null"); + this.entitySearchAnnotations = entitySearchAnnotations; + + List annotations = + this.entitySearchAnnotations.values().stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + this.defaultFacetFields = getDefaultFacetFields(annotations); + this.allFacetFields = getAllFacetFields(annotations); } /** Get the set of default aggregations, across all facets. */ @@ -48,7 +83,7 @@ public List getAggregations(@Nullable List facets) { facetsToAggregate = facets.stream().filter(this::isValidAggregate).collect(Collectors.toSet()); } else { - facetsToAggregate = _defaultFacetFields; + facetsToAggregate = defaultFacetFields; } return facetsToAggregate.stream() .map(this::facetToAggregationBuilder) @@ -79,13 +114,13 @@ private boolean isValidAggregate(final String inputFacet) { !facets.isEmpty() && ((facets.size() == 1 && facets.get(0).startsWith(STRUCTURED_PROPERTY_MAPPING_FIELD + ".")) - || _allFacetFields.containsAll(facets)); + || allFacetFields.containsAll(facets)); if (!isValid) { log.warn( String.format( "Requested facet for search filter aggregations that isn't part of the filters. " + "Provided: %s; Available: %s", - inputFacet, _allFacetFields)); + inputFacet, allFacetFields)); } return isValid; } @@ -122,11 +157,11 @@ private AggregationBuilder facetToAggregationBuilder(final String inputFacet) { facet.equalsIgnoreCase(INDEX_VIRTUAL_FIELD) ? AggregationBuilders.terms(inputFacet) .field(getAggregationField("_index")) - .size(_configs.getMaxTermBucketSize()) + .size(configs.getMaxTermBucketSize()) .minDocCount(0) : AggregationBuilders.terms(inputFacet) .field(getAggregationField(facet)) - .size(_configs.getMaxTermBucketSize()); + .size(configs.getMaxTermBucketSize()); } if (lastAggBuilder != null) { aggBuilder = aggBuilder.subAggregation(lastAggBuilder); @@ -173,4 +208,365 @@ List getAllFacetFieldsFromAnnotation(final SearchableAnnotation annotati } return facetsFromAnnotation; } + + private String computeDisplayName(String name) { + if (getFacetToDisplayNames().containsKey(name)) { + return getFacetToDisplayNames().get(name); + } else if (name.contains(AGGREGATION_SEPARATOR_CHAR)) { + return Arrays.stream(name.split(AGGREGATION_SEPARATOR_CHAR)) + .map(i -> getFacetToDisplayNames().get(i)) + .collect(Collectors.joining(AGGREGATION_SEPARATOR_CHAR)); + } + return name; + } + + List extractAggregationMetadata( + @Nonnull SearchResponse searchResponse, @Nullable Filter filter) { + final List aggregationMetadataList = new ArrayList<>(); + if (searchResponse.getAggregations() == null) { + return addFiltersToAggregationMetadata(aggregationMetadataList, filter); + } + for (Map.Entry entry : + searchResponse.getAggregations().getAsMap().entrySet()) { + if (entry.getValue() instanceof ParsedTerms) { + processTermAggregations(entry, aggregationMetadataList); + } + if (entry.getValue() instanceof ParsedMissing) { + processMissingAggregations(entry, aggregationMetadataList); + } + } + return addFiltersToAggregationMetadata(aggregationMetadataList, filter); + } + + private void processTermAggregations( + final Map.Entry entry, + final List aggregationMetadataList) { + final Map oneTermAggResult = + extractTermAggregations( + (ParsedTerms) entry.getValue(), entry.getKey().equals(INDEX_VIRTUAL_FIELD)); + if (oneTermAggResult.isEmpty()) { + return; + } + final AggregationMetadata aggregationMetadata = + new AggregationMetadata() + .setName(entry.getKey()) + .setDisplayName(computeDisplayName(entry.getKey())) + .setAggregations(new LongMap(oneTermAggResult)) + .setFilterValues( + new FilterValueArray( + SearchUtil.convertToFilters(oneTermAggResult, Collections.emptySet()))); + aggregationMetadataList.add(aggregationMetadata); + } + + /** + * Adds nested sub-aggregation values to the aggregated results + * + * @param aggs The aggregations to traverse. Could be null (base case) + * @return A map from names to aggregation count values + */ + @Nonnull + private static Map recursivelyAddNestedSubAggs(@Nullable Aggregations aggs) { + final Map aggResult = new HashMap<>(); + + if (aggs != null) { + for (Map.Entry entry : aggs.getAsMap().entrySet()) { + if (entry.getValue() instanceof ParsedTerms) { + recurseTermsAgg((ParsedTerms) entry.getValue(), aggResult, false); + } else if (entry.getValue() instanceof ParsedMissing) { + recurseMissingAgg((ParsedMissing) entry.getValue(), aggResult); + } else { + throw new UnsupportedOperationException( + "Unsupported aggregation type: " + entry.getValue().getClass().getName()); + } + } + } + return aggResult; + } + + private static void recurseTermsAgg( + ParsedTerms terms, Map aggResult, boolean includeZeroes) { + List bucketList = terms.getBuckets(); + bucketList.forEach(bucket -> processTermBucket(bucket, aggResult, includeZeroes)); + } + + private static void processTermBucket( + Terms.Bucket bucket, Map aggResult, boolean includeZeroes) { + String key = bucket.getKeyAsString(); + // Gets filtered sub aggregation doc count if exist + Map subAggs = recursivelyAddNestedSubAggs(bucket.getAggregations()); + subAggs.forEach( + (entryKey, entryValue) -> + aggResult.put( + String.format("%s%s%s", key, AGGREGATION_SEPARATOR_CHAR, entryKey), entryValue)); + long docCount = bucket.getDocCount(); + if (includeZeroes || docCount > 0) { + aggResult.put(key, docCount); + } + } + + private static void recurseMissingAgg(ParsedMissing missing, Map aggResult) { + Map subAggs = recursivelyAddNestedSubAggs(missing.getAggregations()); + subAggs.forEach( + (key, value) -> + aggResult.put( + String.format("%s%s%s", missing.getName(), AGGREGATION_SEPARATOR_CHAR, key), + value)); + long docCount = missing.getDocCount(); + if (docCount > 0) { + aggResult.put(missing.getName(), docCount); + } + } + + /** + * Extracts term aggregations give a parsed term. + * + * @param terms an abstract parse term, input can be either ParsedStringTerms ParsedLongTerms + * @return a map with aggregation key and corresponding doc counts + */ + @Nonnull + private static Map extractTermAggregations( + @Nonnull ParsedTerms terms, boolean includeZeroes) { + + final Map aggResult = new HashMap<>(); + recurseTermsAgg(terms, aggResult, includeZeroes); + + return aggResult; + } + + /** Injects the missing conjunctive filters into the aggregations list. */ + public List addFiltersToAggregationMetadata( + @Nonnull final List originalMetadata, @Nullable final Filter filter) { + if (filter == null) { + return originalMetadata; + } + if (filter.getOr() != null) { + addOrFiltersToAggregationMetadata(filter.getOr(), originalMetadata); + } else if (filter.getCriteria() != null) { + addCriteriaFiltersToAggregationMetadata(filter.getCriteria(), originalMetadata); + } + return originalMetadata; + } + + void addOrFiltersToAggregationMetadata( + @Nonnull final ConjunctiveCriterionArray or, + @Nonnull final List originalMetadata) { + for (ConjunctiveCriterion conjunction : or) { + // For each item in the conjunction, inject an empty aggregation if necessary + addCriteriaFiltersToAggregationMetadata(conjunction.getAnd(), originalMetadata); + } + } + + private void addCriteriaFiltersToAggregationMetadata( + @Nonnull final CriterionArray criteria, + @Nonnull final List originalMetadata) { + for (Criterion criterion : criteria) { + addCriterionFiltersToAggregationMetadata(criterion, originalMetadata); + } + } + + private void addCriterionFiltersToAggregationMetadata( + @Nonnull final Criterion criterion, + @Nonnull final List aggregationMetadata) { + + // We should never see duplicate aggregation for the same field in aggregation metadata list. + final Map aggregationMetadataMap = + aggregationMetadata.stream() + .collect(Collectors.toMap(AggregationMetadata::getName, agg -> agg)); + + // Map a filter criterion to a facet field (e.g. domains.keyword -> domains) + final String finalFacetField = toFacetField(criterion.getField()); + + if (finalFacetField == null) { + log.warn( + String.format( + "Found invalid filter field for entity search. Invalid or unrecognized facet %s", + criterion.getField())); + return; + } + + // We don't want to add urn filters to the aggregations we return as a sidecar to search + // results. + // They are automatically added by searchAcrossLineage and we dont need them to show up in the + // filter panel. + if (finalFacetField.equals(URN_FILTER)) { + return; + } + + if (aggregationMetadataMap.containsKey(finalFacetField)) { + /* + * If we already have aggregations for the facet field, simply inject any missing values counts into the set. + * If there are no results for a particular facet value, it will NOT be in the original aggregation set returned by + * Elasticsearch. + */ + AggregationMetadata originalAggMetadata = aggregationMetadataMap.get(finalFacetField); + if (criterion.hasValues()) { + criterion + .getValues() + .forEach( + value -> + addMissingAggregationValueToAggregationMetadata(value, originalAggMetadata)); + } else { + addMissingAggregationValueToAggregationMetadata(criterion.getValue(), originalAggMetadata); + } + } else { + /* + * If we do not have ANY aggregation for the facet field, then inject a new aggregation metadata object for the + * facet field. + * If there are no results for a particular facet, it will NOT be in the original aggregation set returned by + * Elasticsearch. + */ + aggregationMetadata.add( + buildAggregationMetadata( + finalFacetField, + getFacetToDisplayNames().getOrDefault(finalFacetField, finalFacetField), + new LongMap( + criterion.getValues().stream().collect(Collectors.toMap(i -> i, i -> 0L))), + new FilterValueArray( + criterion.getValues().stream() + .map(value -> createFilterValue(value, 0L, true)) + .collect(Collectors.toList())))); + } + } + + private void addMissingAggregationValueToAggregationMetadata( + @Nonnull final String value, @Nonnull final AggregationMetadata originalMetadata) { + if (originalMetadata.getAggregations().entrySet().stream() + .noneMatch(entry -> value.equals(entry.getKey())) + || originalMetadata.getFilterValues().stream() + .noneMatch(entry -> entry.getValue().equals(value))) { + // No aggregation found for filtered value -- inject one! + originalMetadata.getAggregations().put(value, 0L); + originalMetadata.getFilterValues().add(createFilterValue(value, 0L, true)); + } + } + + private AggregationMetadata buildAggregationMetadata( + @Nonnull final String facetField, + @Nonnull final String displayName, + @Nonnull final LongMap aggValues, + @Nonnull final FilterValueArray filterValues) { + return new AggregationMetadata() + .setName(facetField) + .setDisplayName(displayName) + .setAggregations(aggValues) + .setFilterValues(filterValues); + } + + private List>> getFacetFieldDisplayNameFromAnnotation( + @Nonnull EntitySpec entitySpec, @Nonnull final SearchableAnnotation annotation) { + final List>> facetsFromAnnotation = new ArrayList<>(); + // Case 1: Default Keyword field + if (annotation.isAddToFilters()) { + facetsFromAnnotation.add( + Pair.of( + annotation.getFieldName(), + Pair.of(entitySpec.getName(), annotation.getFilterName()))); + } + // Case 2: HasX boolean field + if (annotation.isAddHasValuesToFilters() && annotation.getHasValuesFieldName().isPresent()) { + facetsFromAnnotation.add( + Pair.of( + annotation.getHasValuesFieldName().get(), + Pair.of(entitySpec.getName(), annotation.getHasValuesFilterName()))); + } + return facetsFromAnnotation; + } + + @WithSpan + public static Map extractAggregationsFromResponse( + @Nonnull SearchResponse searchResponse, @Nonnull String aggregationName) { + if (searchResponse.getAggregations() == null) { + return Collections.emptyMap(); + } + + Aggregation aggregation = searchResponse.getAggregations().get(aggregationName); + if (aggregation == null) { + return Collections.emptyMap(); + } + if (aggregation instanceof ParsedTerms) { + return extractTermAggregations( + (ParsedTerms) aggregation, aggregationName.equals("_entityType")); + } else if (aggregation instanceof ParsedMissing) { + return Collections.singletonMap( + aggregation.getName(), ((ParsedMissing) aggregation).getDocCount()); + } + throw new UnsupportedOperationException( + "Unsupported aggregation type: " + aggregation.getClass().getName()); + } + + /** + * Only used in aggregation queries, lazy load + * + * @return map of field name to facet display names + */ + private Map getFacetToDisplayNames() { + if (filtersToDisplayName == null) { + // Validate field names + Map>>> validateFieldMap = + entitySearchAnnotations.entrySet().stream() + .flatMap( + entry -> + entry.getValue().stream() + .flatMap( + annotation -> + getFacetFieldDisplayNameFromAnnotation(entry.getKey(), annotation) + .stream())) + .collect(Collectors.groupingBy(Pair::getFirst, Collectors.toSet())); + for (Map.Entry>>> entry : + validateFieldMap.entrySet()) { + if (entry.getValue().stream().map(i -> i.getSecond().getSecond()).distinct().count() > 1) { + Map>> displayNameEntityMap = + entry.getValue().stream() + .map(Pair::getSecond) + .collect(Collectors.groupingBy(Pair::getSecond, Collectors.toSet())); + throw new IllegalStateException( + String.format( + "Facet field collision on field `%s`. Incompatible Display Name across entities. Multiple Display Names detected: %s", + entry.getKey(), displayNameEntityMap)); + } + } + + filtersToDisplayName = + entitySearchAnnotations.entrySet().stream() + .flatMap( + entry -> + entry.getValue().stream() + .flatMap( + annotation -> + getFacetFieldDisplayNameFromAnnotation(entry.getKey(), annotation) + .stream())) + .collect( + Collectors.toMap(Pair::getFirst, p -> p.getSecond().getSecond(), mapMerger())); + filtersToDisplayName.put(INDEX_VIRTUAL_FIELD, "Type"); + } + + return filtersToDisplayName; + } + + private void processMissingAggregations( + final Map.Entry entry, + final List aggregationMetadataList) { + ParsedMissing parsedMissing = (ParsedMissing) entry.getValue(); + Long docCount = parsedMissing.getDocCount(); + LongMap longMap = new LongMap(); + longMap.put(entry.getKey(), docCount); + final AggregationMetadata aggregationMetadata = + new AggregationMetadata() + .setName(entry.getKey()) + .setDisplayName(computeDisplayName(entry.getKey())) + .setAggregations(longMap) + .setFilterValues( + new FilterValueArray(SearchUtil.convertToFilters(longMap, Collections.emptySet()))); + aggregationMetadataList.add(aggregationMetadata); + } + + // If values are not equal, throw error + private BinaryOperator mapMerger() { + return (s1, s2) -> { + if (!StringUtils.equals(s1, s2)) { + throw new IllegalStateException(String.format("Unable to merge values %s and %s", s1, s2)); + } + return s1; + }; + } } diff --git a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/SearchRequestHandler.java b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/SearchRequestHandler.java index 277e15e1334d56..3ac05ed122cd70 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/SearchRequestHandler.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/request/SearchRequestHandler.java @@ -1,7 +1,6 @@ package com.linkedin.metadata.search.elasticsearch.query.request; import static com.linkedin.metadata.search.utils.ESUtils.NAME_SUGGESTION; -import static com.linkedin.metadata.search.utils.ESUtils.toFacetField; import static com.linkedin.metadata.search.utils.SearchUtils.applyDefaultSearchFlags; import static com.linkedin.metadata.utils.SearchUtil.*; @@ -10,22 +9,16 @@ import com.google.common.collect.ImmutableMap; import com.linkedin.common.urn.Urn; import com.linkedin.data.template.DoubleMap; -import com.linkedin.data.template.LongMap; import com.linkedin.metadata.config.search.SearchConfiguration; import com.linkedin.metadata.config.search.custom.CustomSearchConfiguration; import com.linkedin.metadata.models.EntitySpec; import com.linkedin.metadata.models.SearchableFieldSpec; import com.linkedin.metadata.models.annotation.SearchableAnnotation; import com.linkedin.metadata.query.SearchFlags; -import com.linkedin.metadata.query.filter.ConjunctiveCriterion; -import com.linkedin.metadata.query.filter.ConjunctiveCriterionArray; -import com.linkedin.metadata.query.filter.Criterion; -import com.linkedin.metadata.query.filter.CriterionArray; import com.linkedin.metadata.query.filter.Filter; import com.linkedin.metadata.query.filter.SortCriterion; import com.linkedin.metadata.search.AggregationMetadata; import com.linkedin.metadata.search.AggregationMetadataArray; -import com.linkedin.metadata.search.FilterValueArray; import com.linkedin.metadata.search.MatchedField; import com.linkedin.metadata.search.MatchedFieldArray; import com.linkedin.metadata.search.ScrollResult; @@ -37,13 +30,11 @@ import com.linkedin.metadata.search.SearchSuggestionArray; import com.linkedin.metadata.search.features.Features; import com.linkedin.metadata.search.utils.ESUtils; -import com.linkedin.metadata.utils.SearchUtil; import com.linkedin.util.Pair; import io.opentelemetry.extension.annotations.WithSpan; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -51,13 +42,11 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BinaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang.StringUtils; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.text.Text; @@ -66,12 +55,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; -import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.bucket.missing.ParsedMissing; -import org.opensearch.search.aggregations.bucket.terms.ParsedTerms; -import org.opensearch.search.aggregations.bucket.terms.Terms; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightField; @@ -88,11 +72,9 @@ public class SearchRequestHandler { .setSkipHighlighting(false); private static final Map, SearchRequestHandler> REQUEST_HANDLER_BY_ENTITY_NAME = new ConcurrentHashMap<>(); - private static final String URN_FILTER = "urn"; private final List _entitySpecs; private final Set _defaultQueryFieldNames; private final HighlightBuilder _highlights; - private final Map _filtersToDisplayName; private final SearchConfiguration _configs; private final SearchQueryBuilder _searchQueryBuilder; @@ -111,16 +93,16 @@ private SearchRequestHandler( @Nonnull SearchConfiguration configs, @Nullable CustomSearchConfiguration customSearchConfiguration) { _entitySpecs = entitySpecs; - List annotations = getSearchableAnnotations(); + Map> entitySearchAnnotations = + getSearchableAnnotations(); + List annotations = + entitySearchAnnotations.values().stream() + .flatMap(List::stream) + .collect(Collectors.toList()); _defaultQueryFieldNames = getDefaultQueryFieldNames(annotations); - _filtersToDisplayName = - annotations.stream() - .flatMap(annotation -> getFacetFieldDisplayNameFromAnnotation(annotation).stream()) - .collect(Collectors.toMap(Pair::getFirst, Pair::getSecond, mapMerger())); - _filtersToDisplayName.put(INDEX_VIRTUAL_FIELD, "Type"); _highlights = getHighlights(); _searchQueryBuilder = new SearchQueryBuilder(configs, customSearchConfiguration); - _aggregationQueryBuilder = new AggregationQueryBuilder(configs, annotations); + _aggregationQueryBuilder = new AggregationQueryBuilder(configs, entitySearchAnnotations); _configs = configs; searchableFieldTypes = _entitySpecs.stream() @@ -153,12 +135,16 @@ public static SearchRequestHandler getBuilder( k -> new SearchRequestHandler(entitySpecs, configs, customSearchConfiguration)); } - private List getSearchableAnnotations() { + private Map> getSearchableAnnotations() { return _entitySpecs.stream() - .map(EntitySpec::getSearchableFieldSpecs) - .flatMap(List::stream) - .map(SearchableFieldSpec::getSearchableAnnotation) - .collect(Collectors.toList()); + .map( + spec -> + Pair.of( + spec, + spec.getSearchableFieldSpecs().stream() + .map(SearchableFieldSpec::getSearchableAnnotation) + .collect(Collectors.toList()))) + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } @VisibleForTesting @@ -171,16 +157,6 @@ private Set getDefaultQueryFieldNames(List annotat .collect(Collectors.toSet()); } - // If values are not equal, throw error - private BinaryOperator mapMerger() { - return (s1, s2) -> { - if (!StringUtils.equals(s1, s2)) { - throw new IllegalStateException(String.format("Unable to merge values %s and %s", s1, s2)); - } - return s1; - }; - } - public BoolQueryBuilder getFilterQuery(@Nullable Filter filter) { return getFilterQuery(filter, searchableFieldTypes); } @@ -327,42 +303,6 @@ public SearchRequest getFilterRequest( return searchRequest; } - /** - * Returns a {@link SearchRequest} given filters to be applied to search query and sort criterion - * to be applied to search results. - * - *

TODO: Used in batch ingestion from ingestion scheduler - * - * @param filters {@link Filter} list of conditions with fields and values - * @param sortCriterion {@link SortCriterion} to be applied to the search results - * @param sort sort values from last result of previous request - * @param pitId the Point In Time Id of the previous request - * @param keepAlive string representation of time to keep point in time alive - * @param size the number of search hits to return - * @return {@link SearchRequest} that contains the filtered query - */ - @Nonnull - public SearchRequest getFilterRequest( - @Nullable Filter filters, - @Nullable SortCriterion sortCriterion, - @Nullable Object[] sort, - @Nullable String pitId, - @Nonnull String keepAlive, - int size) { - SearchRequest searchRequest = new SearchRequest(); - - BoolQueryBuilder filterQuery = getFilterQuery(filters); - final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(filterQuery); - searchSourceBuilder.size(size); - - ESUtils.setSearchAfter(searchSourceBuilder, sort, pitId, keepAlive); - ESUtils.buildSortOrder(searchSourceBuilder, sortCriterion, _entitySpecs); - searchRequest.source(searchSourceBuilder); - - return searchRequest; - } - /** * Get search request to aggregate and get document counts per field value * @@ -558,7 +498,7 @@ private SearchResultMetadata extractSearchResultMetadata( new SearchResultMetadata().setAggregations(new AggregationMetadataArray()); final List aggregationMetadataList = - extractAggregationMetadata(searchResponse, filter); + _aggregationQueryBuilder.extractAggregationMetadata(searchResponse, filter); searchResultMetadata.setAggregations(new AggregationMetadataArray(aggregationMetadataList)); final List searchSuggestions = extractSearchSuggestions(searchResponse); @@ -588,301 +528,4 @@ private List extractSearchSuggestions(@Nonnull SearchResponse } return searchSuggestions; } - - private String computeDisplayName(String name) { - if (_filtersToDisplayName.containsKey(name)) { - return _filtersToDisplayName.get(name); - } else if (name.contains(AGGREGATION_SEPARATOR_CHAR)) { - return Arrays.stream(name.split(AGGREGATION_SEPARATOR_CHAR)) - .map(_filtersToDisplayName::get) - .collect(Collectors.joining(AGGREGATION_SEPARATOR_CHAR)); - } - return name; - } - - private List extractAggregationMetadata( - @Nonnull SearchResponse searchResponse, @Nullable Filter filter) { - final List aggregationMetadataList = new ArrayList<>(); - if (searchResponse.getAggregations() == null) { - return addFiltersToAggregationMetadata(aggregationMetadataList, filter); - } - for (Map.Entry entry : - searchResponse.getAggregations().getAsMap().entrySet()) { - if (entry.getValue() instanceof ParsedTerms) { - processTermAggregations(entry, aggregationMetadataList); - } - if (entry.getValue() instanceof ParsedMissing) { - processMissingAggregations(entry, aggregationMetadataList); - } - } - return addFiltersToAggregationMetadata(aggregationMetadataList, filter); - } - - private void processTermAggregations( - final Map.Entry entry, - final List aggregationMetadataList) { - final Map oneTermAggResult = - extractTermAggregations( - (ParsedTerms) entry.getValue(), entry.getKey().equals(INDEX_VIRTUAL_FIELD)); - if (oneTermAggResult.isEmpty()) { - return; - } - final AggregationMetadata aggregationMetadata = - new AggregationMetadata() - .setName(entry.getKey()) - .setDisplayName(computeDisplayName(entry.getKey())) - .setAggregations(new LongMap(oneTermAggResult)) - .setFilterValues( - new FilterValueArray( - SearchUtil.convertToFilters(oneTermAggResult, Collections.emptySet()))); - aggregationMetadataList.add(aggregationMetadata); - } - - private void processMissingAggregations( - final Map.Entry entry, - final List aggregationMetadataList) { - ParsedMissing parsedMissing = (ParsedMissing) entry.getValue(); - Long docCount = parsedMissing.getDocCount(); - LongMap longMap = new LongMap(); - longMap.put(entry.getKey(), docCount); - final AggregationMetadata aggregationMetadata = - new AggregationMetadata() - .setName(entry.getKey()) - .setDisplayName(computeDisplayName(entry.getKey())) - .setAggregations(longMap) - .setFilterValues( - new FilterValueArray(SearchUtil.convertToFilters(longMap, Collections.emptySet()))); - aggregationMetadataList.add(aggregationMetadata); - } - - @WithSpan - public static Map extractAggregationsFromResponse( - @Nonnull SearchResponse searchResponse, @Nonnull String aggregationName) { - if (searchResponse.getAggregations() == null) { - return Collections.emptyMap(); - } - - Aggregation aggregation = searchResponse.getAggregations().get(aggregationName); - if (aggregation == null) { - return Collections.emptyMap(); - } - if (aggregation instanceof ParsedTerms) { - return extractTermAggregations( - (ParsedTerms) aggregation, aggregationName.equals("_entityType")); - } else if (aggregation instanceof ParsedMissing) { - return Collections.singletonMap( - aggregation.getName(), ((ParsedMissing) aggregation).getDocCount()); - } - throw new UnsupportedOperationException( - "Unsupported aggregation type: " + aggregation.getClass().getName()); - } - - /** - * Adds nested sub-aggregation values to the aggregated results - * - * @param aggs The aggregations to traverse. Could be null (base case) - * @return A map from names to aggregation count values - */ - @Nonnull - private static Map recursivelyAddNestedSubAggs(@Nullable Aggregations aggs) { - final Map aggResult = new HashMap<>(); - - if (aggs != null) { - for (Map.Entry entry : aggs.getAsMap().entrySet()) { - if (entry.getValue() instanceof ParsedTerms) { - recurseTermsAgg((ParsedTerms) entry.getValue(), aggResult, false); - } else if (entry.getValue() instanceof ParsedMissing) { - recurseMissingAgg((ParsedMissing) entry.getValue(), aggResult); - } else { - throw new UnsupportedOperationException( - "Unsupported aggregation type: " + entry.getValue().getClass().getName()); - } - } - } - return aggResult; - } - - private static void recurseTermsAgg( - ParsedTerms terms, Map aggResult, boolean includeZeroes) { - List bucketList = terms.getBuckets(); - bucketList.forEach(bucket -> processTermBucket(bucket, aggResult, includeZeroes)); - } - - private static void processTermBucket( - Terms.Bucket bucket, Map aggResult, boolean includeZeroes) { - String key = bucket.getKeyAsString(); - // Gets filtered sub aggregation doc count if exist - Map subAggs = recursivelyAddNestedSubAggs(bucket.getAggregations()); - subAggs.forEach( - (entryKey, entryValue) -> - aggResult.put( - String.format("%s%s%s", key, AGGREGATION_SEPARATOR_CHAR, entryKey), entryValue)); - long docCount = bucket.getDocCount(); - if (includeZeroes || docCount > 0) { - aggResult.put(key, docCount); - } - } - - private static void recurseMissingAgg(ParsedMissing missing, Map aggResult) { - Map subAggs = recursivelyAddNestedSubAggs(missing.getAggregations()); - subAggs.forEach( - (key, value) -> - aggResult.put( - String.format("%s%s%s", missing.getName(), AGGREGATION_SEPARATOR_CHAR, key), - value)); - long docCount = missing.getDocCount(); - if (docCount > 0) { - aggResult.put(missing.getName(), docCount); - } - } - - /** - * Extracts term aggregations give a parsed term. - * - * @param terms an abstract parse term, input can be either ParsedStringTerms ParsedLongTerms - * @return a map with aggregation key and corresponding doc counts - */ - @Nonnull - private static Map extractTermAggregations( - @Nonnull ParsedTerms terms, boolean includeZeroes) { - - final Map aggResult = new HashMap<>(); - recurseTermsAgg(terms, aggResult, includeZeroes); - - return aggResult; - } - - /** Injects the missing conjunctive filters into the aggregations list. */ - public List addFiltersToAggregationMetadata( - @Nonnull final List originalMetadata, @Nullable final Filter filter) { - if (filter == null) { - return originalMetadata; - } - if (filter.getOr() != null) { - addOrFiltersToAggregationMetadata(filter.getOr(), originalMetadata); - } else if (filter.getCriteria() != null) { - addCriteriaFiltersToAggregationMetadata(filter.getCriteria(), originalMetadata); - } - return originalMetadata; - } - - void addOrFiltersToAggregationMetadata( - @Nonnull final ConjunctiveCriterionArray or, - @Nonnull final List originalMetadata) { - for (ConjunctiveCriterion conjunction : or) { - // For each item in the conjunction, inject an empty aggregation if necessary - addCriteriaFiltersToAggregationMetadata(conjunction.getAnd(), originalMetadata); - } - } - - private void addCriteriaFiltersToAggregationMetadata( - @Nonnull final CriterionArray criteria, - @Nonnull final List originalMetadata) { - for (Criterion criterion : criteria) { - addCriterionFiltersToAggregationMetadata(criterion, originalMetadata); - } - } - - private void addCriterionFiltersToAggregationMetadata( - @Nonnull final Criterion criterion, - @Nonnull final List aggregationMetadata) { - - // We should never see duplicate aggregation for the same field in aggregation metadata list. - final Map aggregationMetadataMap = - aggregationMetadata.stream() - .collect(Collectors.toMap(AggregationMetadata::getName, agg -> agg)); - - // Map a filter criterion to a facet field (e.g. domains.keyword -> domains) - final String finalFacetField = toFacetField(criterion.getField()); - - if (finalFacetField == null) { - log.warn( - String.format( - "Found invalid filter field for entity search. Invalid or unrecognized facet %s", - criterion.getField())); - return; - } - - // We don't want to add urn filters to the aggregations we return as a sidecar to search - // results. - // They are automatically added by searchAcrossLineage and we dont need them to show up in the - // filter panel. - if (finalFacetField.equals(URN_FILTER)) { - return; - } - - if (aggregationMetadataMap.containsKey(finalFacetField)) { - /* - * If we already have aggregations for the facet field, simply inject any missing values counts into the set. - * If there are no results for a particular facet value, it will NOT be in the original aggregation set returned by - * Elasticsearch. - */ - AggregationMetadata originalAggMetadata = aggregationMetadataMap.get(finalFacetField); - if (criterion.hasValues()) { - criterion - .getValues() - .forEach( - value -> - addMissingAggregationValueToAggregationMetadata(value, originalAggMetadata)); - } else { - addMissingAggregationValueToAggregationMetadata(criterion.getValue(), originalAggMetadata); - } - } else { - /* - * If we do not have ANY aggregation for the facet field, then inject a new aggregation metadata object for the - * facet field. - * If there are no results for a particular facet, it will NOT be in the original aggregation set returned by - * Elasticsearch. - */ - aggregationMetadata.add( - buildAggregationMetadata( - finalFacetField, - _filtersToDisplayName.getOrDefault(finalFacetField, finalFacetField), - new LongMap( - criterion.getValues().stream().collect(Collectors.toMap(i -> i, i -> 0L))), - new FilterValueArray( - criterion.getValues().stream() - .map(value -> createFilterValue(value, 0L, true)) - .collect(Collectors.toList())))); - } - } - - private void addMissingAggregationValueToAggregationMetadata( - @Nonnull final String value, @Nonnull final AggregationMetadata originalMetadata) { - if (originalMetadata.getAggregations().entrySet().stream() - .noneMatch(entry -> value.equals(entry.getKey())) - || originalMetadata.getFilterValues().stream() - .noneMatch(entry -> entry.getValue().equals(value))) { - // No aggregation found for filtered value -- inject one! - originalMetadata.getAggregations().put(value, 0L); - originalMetadata.getFilterValues().add(createFilterValue(value, 0L, true)); - } - } - - private AggregationMetadata buildAggregationMetadata( - @Nonnull final String facetField, - @Nonnull final String displayName, - @Nonnull final LongMap aggValues, - @Nonnull final FilterValueArray filterValues) { - return new AggregationMetadata() - .setName(facetField) - .setDisplayName(displayName) - .setAggregations(aggValues) - .setFilterValues(filterValues); - } - - private List> getFacetFieldDisplayNameFromAnnotation( - @Nonnull final SearchableAnnotation annotation) { - final List> facetsFromAnnotation = new ArrayList<>(); - // Case 1: Default Keyword field - if (annotation.isAddToFilters()) { - facetsFromAnnotation.add(Pair.of(annotation.getFieldName(), annotation.getFilterName())); - } - // Case 2: HasX boolean field - if (annotation.isAddHasValuesToFilters() && annotation.getHasValuesFieldName().isPresent()) { - facetsFromAnnotation.add( - Pair.of(annotation.getHasValuesFieldName().get(), annotation.getHasValuesFilterName())); - } - return facetsFromAnnotation; - } } diff --git a/metadata-io/src/test/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationCandidateSourceTest.java b/metadata-io/src/test/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationCandidateSourceTest.java index dcc59d06329544..2d60f3202b69f5 100644 --- a/metadata-io/src/test/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationCandidateSourceTest.java +++ b/metadata-io/src/test/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationCandidateSourceTest.java @@ -1,5 +1,6 @@ package com.linkedin.metadata.recommendation.candidatesource; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.testng.Assert.assertEquals; @@ -11,6 +12,7 @@ import com.linkedin.common.urn.CorpuserUrn; import com.linkedin.common.urn.TestEntityUrn; import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.query.filter.Criterion; import com.linkedin.metadata.recommendation.RecommendationContent; import com.linkedin.metadata.recommendation.RecommendationParams; @@ -29,6 +31,7 @@ public class EntitySearchAggregationCandidateSourceTest { private EntitySearchService _entitySearchService = Mockito.mock(EntitySearchService.class); + private EntityRegistry entityRegistry = Mockito.mock(EntityRegistry.class); private EntitySearchAggregationSource _valueBasedCandidateSource; private EntitySearchAggregationSource _urnBasedCandidateSource; @@ -45,7 +48,7 @@ public void setup() { private EntitySearchAggregationSource buildCandidateSource( String identifier, boolean isValueUrn) { - return new EntitySearchAggregationSource(_entitySearchService) { + return new EntitySearchAggregationSource(_entitySearchService, entityRegistry) { @Override protected String getSearchFieldName() { return identifier; @@ -98,8 +101,7 @@ public void testWhenSearchServiceReturnsEmpty() { @Test public void testWhenSearchServiceReturnsValueResults() { // One result - Mockito.when( - _entitySearchService.aggregateByValue(eq(null), eq("testValue"), eq(null), anyInt())) + Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testValue"), eq(null), anyInt())) .thenReturn(ImmutableMap.of("value1", 1L)); List candidates = _valueBasedCandidateSource.getRecommendations(USER, CONTEXT); @@ -120,8 +122,7 @@ public void testWhenSearchServiceReturnsValueResults() { assertTrue(_valueBasedCandidateSource.getRecommendationModule(USER, CONTEXT).isPresent()); // Multiple result - Mockito.when( - _entitySearchService.aggregateByValue(eq(null), eq("testValue"), eq(null), anyInt())) + Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testValue"), eq(null), anyInt())) .thenReturn(ImmutableMap.of("value1", 1L, "value2", 2L, "value3", 3L)); candidates = _valueBasedCandidateSource.getRecommendations(USER, CONTEXT); assertEquals(candidates.size(), 2); @@ -160,7 +161,7 @@ public void testWhenSearchServiceReturnsUrnResults() { Urn testUrn1 = new TestEntityUrn("testUrn1", "testUrn1", "testUrn1"); Urn testUrn2 = new TestEntityUrn("testUrn2", "testUrn2", "testUrn2"); Urn testUrn3 = new TestEntityUrn("testUrn3", "testUrn3", "testUrn3"); - Mockito.when(_entitySearchService.aggregateByValue(eq(null), eq("testUrn"), eq(null), anyInt())) + Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testUrn"), eq(null), anyInt())) .thenReturn(ImmutableMap.of(testUrn1.toString(), 1L)); List candidates = _urnBasedCandidateSource.getRecommendations(USER, CONTEXT); @@ -181,7 +182,7 @@ public void testWhenSearchServiceReturnsUrnResults() { assertTrue(_urnBasedCandidateSource.getRecommendationModule(USER, CONTEXT).isPresent()); // Multiple result - Mockito.when(_entitySearchService.aggregateByValue(eq(null), eq("testUrn"), eq(null), anyInt())) + Mockito.when(_entitySearchService.aggregateByValue(any(), eq("testUrn"), eq(null), anyInt())) .thenReturn( ImmutableMap.of( testUrn1.toString(), 1L, testUrn2.toString(), 2L, testUrn3.toString(), 3L)); diff --git a/metadata-io/src/test/java/com/linkedin/metadata/search/query/request/AggregationQueryBuilderTest.java b/metadata-io/src/test/java/com/linkedin/metadata/search/query/request/AggregationQueryBuilderTest.java index 9e8855622ced4b..ed4c9db5db6430 100644 --- a/metadata-io/src/test/java/com/linkedin/metadata/search/query/request/AggregationQueryBuilderTest.java +++ b/metadata-io/src/test/java/com/linkedin/metadata/search/query/request/AggregationQueryBuilderTest.java @@ -1,10 +1,13 @@ package com.linkedin.metadata.search.query.request; import static com.linkedin.metadata.utils.SearchUtil.*; +import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.linkedin.metadata.config.search.SearchConfiguration; +import com.linkedin.metadata.models.EntitySpec; import com.linkedin.metadata.models.annotation.SearchableAnnotation; import com.linkedin.metadata.search.elasticsearch.query.request.AggregationQueryBuilder; import java.util.Collections; @@ -42,7 +45,8 @@ public void testGetDefaultAggregationsHasFields() { config.setMaxTermBucketSize(25); AggregationQueryBuilder builder = - new AggregationQueryBuilder(config, ImmutableList.of(annotation)); + new AggregationQueryBuilder( + config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation))); List aggs = builder.getAggregations(); @@ -73,7 +77,8 @@ public void testGetDefaultAggregationsFields() { config.setMaxTermBucketSize(25); AggregationQueryBuilder builder = - new AggregationQueryBuilder(config, ImmutableList.of(annotation)); + new AggregationQueryBuilder( + config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation))); List aggs = builder.getAggregations(); @@ -120,7 +125,9 @@ public void testGetSpecificAggregationsHasFields() { config.setMaxTermBucketSize(25); AggregationQueryBuilder builder = - new AggregationQueryBuilder(config, ImmutableList.of(annotation1, annotation2)); + new AggregationQueryBuilder( + config, + ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation1, annotation2))); // Case 1: Ask for fields that should exist. List aggs = @@ -139,7 +146,9 @@ public void testAggregateOverStructuredProperty() { SearchConfiguration config = new SearchConfiguration(); config.setMaxTermBucketSize(25); - AggregationQueryBuilder builder = new AggregationQueryBuilder(config, List.of()); + AggregationQueryBuilder builder = + new AggregationQueryBuilder( + config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of())); List aggs = builder.getAggregations(List.of("structuredProperties.ab.fgh.ten")); @@ -202,7 +211,9 @@ public void testAggregateOverFieldsAndStructProp() { config.setMaxTermBucketSize(25); AggregationQueryBuilder builder = - new AggregationQueryBuilder(config, ImmutableList.of(annotation1, annotation2)); + new AggregationQueryBuilder( + config, + ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation1, annotation2))); // Aggregate over fields and structured properties List aggs = @@ -252,7 +263,8 @@ public void testMissingAggregation() { config.setMaxTermBucketSize(25); AggregationQueryBuilder builder = - new AggregationQueryBuilder(config, ImmutableList.of(annotation)); + new AggregationQueryBuilder( + config, ImmutableMap.of(mock(EntitySpec.class), ImmutableList.of(annotation))); List aggs = builder.getAggregations(); diff --git a/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/DomainsCandidateSourceFactory.java b/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/DomainsCandidateSourceFactory.java index fbfd80f85ff4d2..a7c2dde8b7d25e 100644 --- a/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/DomainsCandidateSourceFactory.java +++ b/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/DomainsCandidateSourceFactory.java @@ -1,6 +1,7 @@ package com.linkedin.gms.factory.recommendation.candidatesource; import com.linkedin.gms.factory.search.EntitySearchServiceFactory; +import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.recommendation.candidatesource.DomainsCandidateSource; import com.linkedin.metadata.search.EntitySearchService; import javax.annotation.Nonnull; @@ -20,7 +21,7 @@ public class DomainsCandidateSourceFactory { @Bean(name = "domainsCandidateSource") @Nonnull - protected DomainsCandidateSource getInstance() { - return new DomainsCandidateSource(entitySearchService); + protected DomainsCandidateSource getInstance(final EntityRegistry entityRegistry) { + return new DomainsCandidateSource(entitySearchService, entityRegistry); } } diff --git a/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTagsCandidateSourceFactory.java b/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTagsCandidateSourceFactory.java index fe5c2d03d19071..bc2520c2b4617d 100644 --- a/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTagsCandidateSourceFactory.java +++ b/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTagsCandidateSourceFactory.java @@ -1,6 +1,7 @@ package com.linkedin.gms.factory.recommendation.candidatesource; import com.linkedin.gms.factory.search.EntitySearchServiceFactory; +import com.linkedin.metadata.entity.EntityService; import com.linkedin.metadata.recommendation.candidatesource.TopTagsSource; import com.linkedin.metadata.search.EntitySearchService; import javax.annotation.Nonnull; @@ -20,7 +21,7 @@ public class TopTagsCandidateSourceFactory { @Bean(name = "topTagsCandidateSource") @Nonnull - protected TopTagsSource getInstance() { - return new TopTagsSource(entitySearchService); + protected TopTagsSource getInstance(final EntityService entityService) { + return new TopTagsSource(entitySearchService, entityService); } } diff --git a/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTermsCandidateSourceFactory.java b/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTermsCandidateSourceFactory.java index 36c53936094ff5..c8ad276eb3d862 100644 --- a/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTermsCandidateSourceFactory.java +++ b/metadata-service/factories/src/main/java/com/linkedin/gms/factory/recommendation/candidatesource/TopTermsCandidateSourceFactory.java @@ -1,6 +1,7 @@ package com.linkedin.gms.factory.recommendation.candidatesource; import com.linkedin.gms.factory.search.EntitySearchServiceFactory; +import com.linkedin.metadata.entity.EntityService; import com.linkedin.metadata.recommendation.candidatesource.TopTermsSource; import com.linkedin.metadata.search.EntitySearchService; import javax.annotation.Nonnull; @@ -20,7 +21,7 @@ public class TopTermsCandidateSourceFactory { @Bean(name = "topTermsCandidateSource") @Nonnull - protected TopTermsSource getInstance() { - return new TopTermsSource(entitySearchService); + protected TopTermsSource getInstance(final EntityService entityService) { + return new TopTermsSource(entitySearchService, entityService); } } diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/DomainsCandidateSource.java b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/DomainsCandidateSource.java index 9392f50b4749eb..e34fa8ff1bde57 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/DomainsCandidateSource.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/DomainsCandidateSource.java @@ -1,6 +1,7 @@ package com.linkedin.metadata.recommendation.candidatesource; import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.recommendation.RecommendationRenderType; import com.linkedin.metadata.recommendation.RecommendationRequestContext; import com.linkedin.metadata.recommendation.ScenarioType; @@ -13,8 +14,9 @@ public class DomainsCandidateSource extends EntitySearchAggregationSource { private static final String DOMAINS = "domains"; - public DomainsCandidateSource(EntitySearchService entitySearchService) { - super(entitySearchService); + public DomainsCandidateSource( + EntitySearchService entitySearchService, EntityRegistry entityRegistry) { + super(entitySearchService, entityRegistry); } @Override diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationSource.java b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationSource.java index a19909576d25ba..8d6ccb22660fb2 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationSource.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/EntitySearchAggregationSource.java @@ -2,6 +2,8 @@ import com.google.common.collect.ImmutableList; import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.models.EntitySpec; +import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.query.filter.Criterion; import com.linkedin.metadata.query.filter.CriterionArray; import com.linkedin.metadata.recommendation.ContentParams; @@ -10,6 +12,7 @@ import com.linkedin.metadata.recommendation.RecommendationRequestContext; import com.linkedin.metadata.recommendation.SearchParams; import com.linkedin.metadata.search.EntitySearchService; +import com.linkedin.metadata.search.utils.QueryUtils; import io.opentelemetry.extension.annotations.WithSpan; import java.net.URISyntaxException; import java.util.Collections; @@ -35,7 +38,8 @@ @Slf4j @RequiredArgsConstructor public abstract class EntitySearchAggregationSource implements RecommendationSource { - private final EntitySearchService _entitySearchService; + private final EntitySearchService entitySearchService; + private final EntityRegistry entityRegistry; /** Field to aggregate on */ protected abstract String getSearchFieldName(); @@ -69,8 +73,8 @@ protected boolean isValidCandidate(T candidate) { public List getRecommendations( @Nonnull Urn userUrn, @Nullable RecommendationRequestContext requestContext) { Map aggregationResult = - _entitySearchService.aggregateByValue( - getEntityNames(), getSearchFieldName(), null, getMaxContent()); + entitySearchService.aggregateByValue( + getEntityNames(entityRegistry), getSearchFieldName(), null, getMaxContent()); if (aggregationResult.isEmpty()) { return Collections.emptyList(); @@ -110,9 +114,11 @@ public List getRecommendations( .collect(Collectors.toList()); } - protected List getEntityNames() { + protected List getEntityNames(EntityRegistry entityRegistry) { // By default, no list is applied which means searching across entities. - return null; + return QueryUtils.getQueryByDefaultEntitySpecs(entityRegistry).stream() + .map(EntitySpec::getName) + .collect(Collectors.toList()); } // Get top K entries with the most count diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopPlatformsSource.java b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopPlatformsSource.java index 3012e35baa607a..aecd9bbbf769c3 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopPlatformsSource.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopPlatformsSource.java @@ -37,11 +37,12 @@ public class TopPlatformsSource extends EntitySearchAggregationSource { Constants.CONTAINER_ENTITY_NAME, Constants.NOTEBOOK_ENTITY_NAME); - private final EntityService _entityService; + private final EntityService _entityService; private static final String PLATFORM = "platform"; - public TopPlatformsSource(EntityService entityService, EntitySearchService entitySearchService) { - super(entitySearchService); + public TopPlatformsSource( + EntityService entityService, EntitySearchService entitySearchService) { + super(entitySearchService, entityService.getEntityRegistry()); _entityService = entityService; } diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTagsSource.java b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTagsSource.java index 317f956e1ca8ab..0897d441335fac 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTagsSource.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTagsSource.java @@ -1,6 +1,7 @@ package com.linkedin.metadata.recommendation.candidatesource; import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.entity.EntityService; import com.linkedin.metadata.recommendation.RecommendationRenderType; import com.linkedin.metadata.recommendation.RecommendationRequestContext; import com.linkedin.metadata.recommendation.ScenarioType; @@ -13,8 +14,8 @@ public class TopTagsSource extends EntitySearchAggregationSource { private static final String TAGS = "tags"; - public TopTagsSource(EntitySearchService entitySearchService) { - super(entitySearchService); + public TopTagsSource(EntitySearchService entitySearchService, EntityService entityService) { + super(entitySearchService, entityService.getEntityRegistry()); } @Override diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTermsSource.java b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTermsSource.java index 6cdb5fdb659113..0fab9a28b51ea4 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTermsSource.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/recommendation/candidatesource/TopTermsSource.java @@ -1,6 +1,7 @@ package com.linkedin.metadata.recommendation.candidatesource; import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.entity.EntityService; import com.linkedin.metadata.recommendation.RecommendationRenderType; import com.linkedin.metadata.recommendation.RecommendationRequestContext; import com.linkedin.metadata.recommendation.ScenarioType; @@ -13,8 +14,8 @@ public class TopTermsSource extends EntitySearchAggregationSource { private static final String TERMS = "glossaryTerms"; - public TopTermsSource(EntitySearchService entitySearchService) { - super(entitySearchService); + public TopTermsSource(EntitySearchService entitySearchService, EntityService entityService) { + super(entitySearchService, entityService.getEntityRegistry()); } @Override diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/search/utils/QueryUtils.java b/metadata-service/services/src/main/java/com/linkedin/metadata/search/utils/QueryUtils.java index 842cc51e117775..a148a45b20e0c7 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/search/utils/QueryUtils.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/search/utils/QueryUtils.java @@ -7,6 +7,10 @@ import com.linkedin.data.template.RecordTemplate; import com.linkedin.data.template.StringArray; import com.linkedin.metadata.aspect.AspectVersion; +import com.linkedin.metadata.models.EntitySpec; +import com.linkedin.metadata.models.SearchableFieldSpec; +import com.linkedin.metadata.models.annotation.SearchableAnnotation; +import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.query.filter.Condition; import com.linkedin.metadata.query.filter.ConjunctiveCriterion; import com.linkedin.metadata.query.filter.ConjunctiveCriterionArray; @@ -15,6 +19,7 @@ import com.linkedin.metadata.query.filter.Filter; import com.linkedin.metadata.query.filter.RelationshipDirection; import com.linkedin.metadata.query.filter.RelationshipFilter; +import com.linkedin.util.Pair; import java.util.Collections; import java.util.List; import java.util.Map; @@ -174,4 +179,20 @@ public static Filter getFilterFromCriteria(List criteria) { new ConjunctiveCriterionArray( new ConjunctiveCriterion().setAnd(new CriterionArray(criteria)))); } + + public static List getQueryByDefaultEntitySpecs(EntityRegistry entityRegistry) { + return entityRegistry.getEntitySpecs().values().stream() + .map( + spec -> + Pair.of( + spec, + spec.getSearchableFieldSpecs().stream() + .map(SearchableFieldSpec::getSearchableAnnotation) + .collect(Collectors.toList()))) + .filter( + specPair -> + specPair.getSecond().stream().anyMatch(SearchableAnnotation::isQueryByDefault)) + .map(Pair::getFirst) + .collect(Collectors.toList()); + } }