From 0ede85e75724325c14698e7d02b056fda20c4385 Mon Sep 17 00:00:00 2001 From: chen__h Date: Thu, 15 Dec 2022 17:04:20 +0800 Subject: [PATCH] DisjunctionMaxQuery shouldn't depend on disjunct order for equals checks --- .../Xml/TestParser.cs | 22 +- .../Search/TestDisjunctionMaxQuery.cs | 13 ++ src/Lucene.Net.Tests/Search/TestMultiset.cs | 114 +++++++++++ src/Lucene.Net/Search/DisjunctionMaxQuery.cs | 61 +++--- src/Lucene.Net/Search/Multiset.cs | 191 ++++++++++++++++++ 5 files changed, 359 insertions(+), 42 deletions(-) create mode 100644 src/Lucene.Net.Tests/Search/TestMultiset.cs create mode 100644 src/Lucene.Net/Search/Multiset.cs diff --git a/src/Lucene.Net.Tests.QueryParser/Xml/TestParser.cs b/src/Lucene.Net.Tests.QueryParser/Xml/TestParser.cs index c759ddc003..bc2175cc32 100644 --- a/src/Lucene.Net.Tests.QueryParser/Xml/TestParser.cs +++ b/src/Lucene.Net.Tests.QueryParser/Xml/TestParser.cs @@ -107,13 +107,21 @@ public void TestBooleanQueryXML() public void TestDisjunctionMaxQueryXML() { Query q = Parse("DisjunctionMaxQuery.xml"); - assertTrue(q is DisjunctionMaxQuery); - DisjunctionMaxQuery d = (DisjunctionMaxQuery)q; - assertEquals(0.0f, d.TieBreakerMultiplier, 0.0001f); - assertEquals(2, d.Disjuncts.size()); - DisjunctionMaxQuery ndq = (DisjunctionMaxQuery)d.Disjuncts[1]; - assertEquals(1.2f, ndq.TieBreakerMultiplier, 0.0001f); - assertEquals(1, ndq.Disjuncts.size()); + // assertTrue(q is DisjunctionMaxQuery); + // DisjunctionMaxQuery d = (DisjunctionMaxQuery)q; + // assertEquals(0.0f, d.TieBreakerMultiplier, 0.0001f); + // assertEquals(2, d.Disjuncts.size()); + // DisjunctionMaxQuery ndq = (DisjunctionMaxQuery)d.Disjuncts[1]; + // assertEquals(1.2f, ndq.TieBreakerMultiplier, 0.0001f); + // assertEquals(1, ndq.Disjuncts.size()); + Query expected = + new DisjunctionMaxQuery( + new List{ + new TermQuery(new Term("a", "merger")), + new DisjunctionMaxQuery( + new List{new TermQuery(new Term("b", "verger"))}, 1.2f)}, + 0.0f); + assertEquals(expected, q); } [Test] diff --git a/src/Lucene.Net.Tests/Search/TestDisjunctionMaxQuery.cs b/src/Lucene.Net.Tests/Search/TestDisjunctionMaxQuery.cs index b1c6961176..31c5effe8e 100644 --- a/src/Lucene.Net.Tests/Search/TestDisjunctionMaxQuery.cs +++ b/src/Lucene.Net.Tests/Search/TestDisjunctionMaxQuery.cs @@ -2,7 +2,9 @@ using System.Globalization; using Lucene.Net.Documents; using Lucene.Net.Index.Extensions; +using Lucene.Net.Support; using NUnit.Framework; +using System.Collections.Generic; using Assert = Lucene.Net.TestFramework.Assert; using Console = Lucene.Net.Util.SystemConsole; @@ -535,6 +537,17 @@ public virtual void TestBooleanSpanQuery() directory.Dispose(); } + [Test] + public void TestDisjunctOrderAndEquals() + { + // the order that disjuncts are provided in should not matter for equals() comparisons + Query sub1 = Tq("hed", "albino"); + Query sub2 = Tq("hed", "elephant"); + Query q1 = new DisjunctionMaxQuery(new List{sub1, sub2}, 1.0f); + Query q2 = new DisjunctionMaxQuery(new List{sub2, sub1}, 1.0f); + assertEquals(q1, q2); + } + /// /// macro protected internal virtual Query Tq(string f, string t) diff --git a/src/Lucene.Net.Tests/Search/TestMultiset.cs b/src/Lucene.Net.Tests/Search/TestMultiset.cs new file mode 100644 index 0000000000..813f7c2725 --- /dev/null +++ b/src/Lucene.Net.Tests/Search/TestMultiset.cs @@ -0,0 +1,114 @@ +using Lucene.Net.Search; +using Lucene.Net.Support; +using Lucene.Net.Util; +using NUnit.Framework; +using J2N.Collections.Generic; + +namespace Lucene.Net +{ + /* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for Additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + public class TestMultiset : LuceneTestCase + { + [Test] + public void TestDuplicatesMatter() { + Multiset s1 = new Multiset(); + Multiset s2 = new Multiset(); + assertEquals(s1.size(), s2.size()); + assertEquals(s1, s2); + + s1.Add(42); + s2.Add(42); + assertEquals(s1, s2); + + s2.Add(42); + assertFalse(s1.equals(s2)); + + s1.Add(43); + s1.Add(43); + s2.Add(43); + assertEquals(s1.size(), s2.size()); + assertFalse(s1.equals(s2)); + } + + private static Dictionary ToCountMap(Multiset set) { + Dictionary map = new(); + int recomputedSize = 0; + + foreach (T element in set) { + Add(map, element); + recomputedSize += 1; + } + assertEquals(set.toString(), recomputedSize, set.size()); + return map; + } + + private static void Add(Dictionary map, T element) + { + map.TryGetValue(element, out int value); + map.Put(element, value + 1); + } + + private static void Remove(Dictionary map, T element) { + if (element is null) + { + return; + } + + map.TryGetValue(element, out int cnt); + switch (cnt) + { + case 0: + return; + case 1: + map.Remove(element); + break; + default: + map.Put((T)element, cnt - 1); + break; + } + } + + [Test] + public void TestRandom() { + Dictionary reference = new(); + Multiset multiset = new(); + int iters = AtLeast(100); + for (int i = 0; i < iters; ++i) { + int value = Random.Next(10); + switch (Random.Next(10)) { + case 0: + case 1: + case 2: + Remove(reference, value); + multiset.Remove(value); + break; + case 3: + reference.Clear(); + multiset.Clear(); + break; + default: + Add(reference, value); + multiset.Add(value); + break; + } + assertEquals(reference, ToCountMap(multiset)); + } + } + } +} \ No newline at end of file diff --git a/src/Lucene.Net/Search/DisjunctionMaxQuery.cs b/src/Lucene.Net/Search/DisjunctionMaxQuery.cs index 8ad478d27f..1e94cf9381 100644 --- a/src/Lucene.Net/Search/DisjunctionMaxQuery.cs +++ b/src/Lucene.Net/Search/DisjunctionMaxQuery.cs @@ -48,10 +48,10 @@ namespace Lucene.Net.Search /// /// Collection initializer note: To create and populate a /// in a single statement, you can use the following example as a guide: - /// + /// /// /// var disjunctionMaxQuery = new DisjunctionMaxQuery(0.1f) { - /// new TermQuery(new Term("field1", "albino")), + /// new TermQuery(new Term("field1", "albino")), /// new TermQuery(new Term("field2", "elephant")) /// }; /// @@ -61,7 +61,8 @@ public class DisjunctionMaxQuery : Query, IEnumerable /// /// The subqueries /// - private IList disjuncts = new JCG.List(); + // private IList disjuncts = new JCG.List(); + private Multiset disjuncts = new(); /// /// Multiple of the non-max disjunct scores added into our final score. Non-zero values support tie-breaking. @@ -119,7 +120,7 @@ IEnumerator IEnumerable.GetEnumerator() } /// The disjuncts. - public virtual IList Disjuncts => disjuncts; + public virtual Multiset Disjuncts => disjuncts; /// Tie breaker value for multiple matches. public virtual float TieBreakerMultiplier => tieBreakerMultiplier; @@ -244,9 +245,12 @@ public override Weight CreateWeight(IndexSearcher searcher) public override Query Rewrite(IndexReader reader) { int numDisjunctions = disjuncts.Count; + var it = disjuncts.GetEnumerator(); if (numDisjunctions == 1) { - Query singleton = disjuncts[0]; + it.MoveNext(); + Query singleton = it.Current; + it.Dispose(); Query result = singleton.Rewrite(reader); if (Boost != 1.0f) { @@ -258,28 +262,17 @@ public override Query Rewrite(IndexReader reader) } return result; } - DisjunctionMaxQuery clone = null; - for (int i = 0; i < numDisjunctions; i++) + // DisjunctionMaxQuery clone = null; + bool actuallyRewritten = false; + IList rewrittenDisjuncts = new JCG.List();; + foreach (var sub in disjuncts) { - Query clause = disjuncts[i]; - Query rewrite = clause.Rewrite(reader); - if (rewrite != clause) - { - if (clone is null) - { - clone = (DisjunctionMaxQuery)this.Clone(); - } - clone.disjuncts[i] = rewrite; - } - } - if (clone != null) - { - return clone; - } - else - { - return this; + Query rewrittenSub = sub.Rewrite(reader); + actuallyRewritten |= !rewrittenSub.Equals(sub); + rewrittenDisjuncts.Add(rewrittenSub); } + + return actuallyRewritten ? new DisjunctionMaxQuery(rewrittenDisjuncts, tieBreakerMultiplier) : this; } /// @@ -288,7 +281,7 @@ public override Query Rewrite(IndexReader reader) public override object Clone() { DisjunctionMaxQuery clone = (DisjunctionMaxQuery)base.Clone(); - clone.disjuncts = new JCG.List(this.disjuncts); + clone.disjuncts = new Multiset(this.disjuncts); return clone; } @@ -313,10 +306,8 @@ public override string ToString(string field) { StringBuilder buffer = new StringBuilder(); buffer.Append('('); - int numDisjunctions = disjuncts.Count; - for (int i = 0; i < numDisjunctions; i++) + foreach (var subquery in disjuncts) { - Query subquery = disjuncts[i]; if (subquery is BooleanQuery) // wrap sub-bools in parens { buffer.Append('('); @@ -327,11 +318,11 @@ public override string ToString(string field) { buffer.Append(subquery.ToString(field)); } - if (i != numDisjunctions - 1) - { - buffer.Append(" | "); - } + + buffer.Append(" | "); } + + buffer.Remove(buffer.Length - 3, 3); buffer.Append(')'); if (tieBreakerMultiplier != 0.0f) { @@ -368,8 +359,8 @@ public override bool Equals(object o) /// the hash code public override int GetHashCode() { - return J2N.BitConversion.SingleToInt32Bits(Boost) - + J2N.BitConversion.SingleToInt32Bits(tieBreakerMultiplier) + return J2N.BitConversion.SingleToInt32Bits(Boost) + + J2N.BitConversion.SingleToInt32Bits(tieBreakerMultiplier) + disjuncts.GetHashCode(); } } diff --git a/src/Lucene.Net/Search/Multiset.cs b/src/Lucene.Net/Search/Multiset.cs new file mode 100644 index 0000000000..dba29f4499 --- /dev/null +++ b/src/Lucene.Net/Search/Multiset.cs @@ -0,0 +1,191 @@ +using J2N.Collections.Generic.Extensions; +using Lucene.Net.Support; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data; +using System.Linq; + +namespace Lucene.Net.Search +{ + /* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + public class Multiset : ICollection + { + private readonly J2N.Collections.Generic.Dictionary map = new(); + + private int count = 0; + + public Multiset() + { + } + + public Multiset(Multiset multiset) + { + if (multiset == null) + throw new ArgumentNullException(nameof(multiset)); + + foreach (var element in multiset) + { + Add(element); + } + } + + + + public IEnumerator GetEnumerator() => new MultiSetEnumerator(map); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + private class MultiSetEnumerator : IEnumerator + { + private int remaining; + private T current; + private IEnumerator> mapEnumerator; + + public MultiSetEnumerator(J2N.Collections.Generic.Dictionary map) + { + mapEnumerator = map.GetEnumerator(); + current = mapEnumerator.Current.Key; + remaining = mapEnumerator.Current.Value; + } + + public T Current => current; + + object IEnumerator.Current => current; + + public void Dispose() { } + + public bool MoveNext() + { + if (remaining == 0) + { + if (mapEnumerator.MoveNext()) + { + current = mapEnumerator.Current.Key; + remaining = mapEnumerator.Current.Value; + } + else + { + return false; + } + } + + if (remaining <= 0) + { + throw AssertionError.Create("Inner error: remaining count should be positive"); + } + --remaining; + return true; + } + + public void Reset() + { + mapEnumerator.Reset(); + current = mapEnumerator.Current.Key; + remaining = mapEnumerator.Current.Value; + } + + } + + public void Add(T item) + { + if (item is null) + { + return; + } + + map.TryGetValue(item, out int preValue); + map.Put(item, preValue + 1); + count += 1; + } + + public void AddRange(ICollection items) + { + if (items == null) return; + foreach (var item in items) + { + Add(item); + } + } + + public void Clear() + { + map.Clear(); + count = 0; + } + + public bool Contains(T item) + { + return item is not null && map.ContainsKey(item); + } + + + public void CopyTo(T[] array, int arrayIndex) + { + if (array.Length < arrayIndex + count) + throw new IndexOutOfRangeException("Array is too small"); + + int i = 0; + foreach (var item in this) + array[arrayIndex + i++] = item; + } + + public bool Remove(T item) + { + if (item is null) + { + return false; + } + + map.TryGetValue(item, out int cnt); + switch (cnt) + { + case 0: + return false; + case 1: + map.Remove(item); + break; + default: + map.Put((T)item, cnt - 1); + break; + } + + count -= 1; + return true; + } + + public int Count => count; + + public bool IsReadOnly => false; + + public override bool Equals(object obj) + { + if (obj == null || obj.GetType() != GetType()) { + return false; + } + + Multiset that = (Multiset) obj; + return count == that.count // not necessary but helps escaping early + && map.Equals(that.map); + } + + public override int GetHashCode() + { + return GetType().GetHashCode() + map.GetHashCode(); + } + } +} \ No newline at end of file