diff --git a/core/src/main/scala/filodb.core/downsample/DownsampledTimeSeriesShard.scala b/core/src/main/scala/filodb.core/downsample/DownsampledTimeSeriesShard.scala index 66174a3582..92acb548d9 100644 --- a/core/src/main/scala/filodb.core/downsample/DownsampledTimeSeriesShard.scala +++ b/core/src/main/scala/filodb.core/downsample/DownsampledTimeSeriesShard.scala @@ -169,7 +169,7 @@ class DownsampledTimeSeriesShard(rawDatasetRef: DatasetRef, endTime: Long, startTime: Long, limit: Int): Iterator[Map[ZeroCopyUTF8String, ZeroCopyUTF8String]] = { - partKeyIndex.partKeyRecordsFromFilters(filter, startTime, endTime).iterator.take(limit).map { pk => + partKeyIndex.partKeyRecordsFromFilters(filter, startTime, endTime, limit).iterator.map { pk => val partKey = PartKeyWithTimes(pk.partKey, UnsafeUtils.arayOffset, pk.startTime, pk.endTime) schemas.part.binSchema.toStringPairs(partKey.base, partKey.offset).map(pair => { pair._1.utf8 -> pair._2.utf8 diff --git a/core/src/main/scala/filodb.core/memstore/PartKeyLuceneIndex.scala b/core/src/main/scala/filodb.core/memstore/PartKeyLuceneIndex.scala index 2e129664d3..c9eb673428 100644 --- a/core/src/main/scala/filodb.core/memstore/PartKeyLuceneIndex.scala +++ b/core/src/main/scala/filodb.core/memstore/PartKeyLuceneIndex.scala @@ -429,7 +429,7 @@ class PartKeyLuceneIndex(ref: DatasetRef, * @return matching partIds */ def partIdsEndedBefore(endedBefore: Long): debox.Buffer[Int] = { - val collector = new PartIdCollector() + val collector = new PartIdCollector(Int.MaxValue) val deleteQuery = LongPoint.newRangeQuery(PartKeyLuceneIndex.END_TIME, 0, endedBefore) withNewSearcher(s => s.search(deleteQuery, collector)) @@ -900,8 +900,9 @@ class PartKeyLuceneIndex(ref: DatasetRef, //scalastyle:on method.length def partIdsFromFilters(columnFilters: Seq[ColumnFilter], startTime: Long, - endTime: Long): debox.Buffer[Int] = { - val collector = new PartIdCollector() // passing zero for unlimited results + endTime: Long, + limit: Int = Int.MaxValue): debox.Buffer[Int] = { + val collector = new PartIdCollector(limit) searchFromFilters(columnFilters, startTime, endTime, collector) collector.result } @@ -910,7 +911,7 @@ class PartKeyLuceneIndex(ref: DatasetRef, startTime: Long, endTime: Long): Option[Array[Byte]] = { - val collector = new SinglePartKeyCollector() // passing zero for unlimited results + val collector = new SinglePartKeyCollector searchFromFilters(columnFilters, startTime, endTime, collector) val pkBytesRef = collector.singleResult if (pkBytesRef == null) @@ -929,8 +930,9 @@ class PartKeyLuceneIndex(ref: DatasetRef, def partKeyRecordsFromFilters(columnFilters: Seq[ColumnFilter], startTime: Long, - endTime: Long): Seq[PartKeyLuceneIndexRecord] = { - val collector = new PartKeyRecordCollector() + endTime: Long, + limit: Int = Int.MaxValue): Seq[PartKeyLuceneIndexRecord] = { + val collector = new PartKeyRecordCollector(limit) searchFromFilters(columnFilters, startTime, endTime, collector) collector.records } @@ -1175,7 +1177,7 @@ class TopKPartIdsCollector(limit: Int) extends Collector with StrictLogging { } } -class PartIdCollector extends SimpleCollector { +class PartIdCollector(limit: Int) extends SimpleCollector { val result: debox.Buffer[Int] = debox.Buffer.empty[Int] private var partIdDv: NumericDocValues = _ @@ -1187,7 +1189,9 @@ class PartIdCollector extends SimpleCollector { } override def collect(doc: Int): Unit = { - if (partIdDv.advanceExact(doc)) { + if (result.length >= limit) { + throw new CollectionTerminatedException + } else if (partIdDv.advanceExact(doc)) { result += partIdDv.longValue().toInt } else { throw new IllegalStateException("This shouldn't happen since every document should have a partIdDv") @@ -1219,7 +1223,7 @@ class PartIdStartTimeCollector extends SimpleCollector { } } -class PartKeyRecordCollector extends SimpleCollector { +class PartKeyRecordCollector(limit: Int) extends SimpleCollector { val records = new ArrayBuffer[PartKeyLuceneIndexRecord] private var partKeyDv: BinaryDocValues = _ private var startTimeDv: NumericDocValues = _ @@ -1234,7 +1238,9 @@ class PartKeyRecordCollector extends SimpleCollector { } override def collect(doc: Int): Unit = { - if (partKeyDv.advanceExact(doc) && startTimeDv.advanceExact(doc) && endTimeDv.advanceExact(doc)) { + if (records.size >= limit) { + throw new CollectionTerminatedException + } else if (partKeyDv.advanceExact(doc) && startTimeDv.advanceExact(doc) && endTimeDv.advanceExact(doc)) { val pkBytesRef = partKeyDv.binaryValue() // Gotcha! make copy of array because lucene reuses bytesRef for next result val pkBytes = util.Arrays.copyOfRange(pkBytesRef.bytes, pkBytesRef.offset, pkBytesRef.offset + pkBytesRef.length) diff --git a/core/src/main/scala/filodb.core/memstore/TimeSeriesShard.scala b/core/src/main/scala/filodb.core/memstore/TimeSeriesShard.scala index 4293fef734..67c32906dd 100644 --- a/core/src/main/scala/filodb.core/memstore/TimeSeriesShard.scala +++ b/core/src/main/scala/filodb.core/memstore/TimeSeriesShard.scala @@ -1819,21 +1819,21 @@ class TimeSeriesShard(val ref: DatasetRef, startTime: Long, limit: Int): Iterator[Map[ZeroCopyUTF8String, ZeroCopyUTF8String]] = { if (fetchFirstLastSampleTimes) { - partKeyIndex.partKeyRecordsFromFilters(filter, startTime, endTime).iterator.map { pk => + partKeyIndex.partKeyRecordsFromFilters(filter, startTime, endTime, limit).iterator.map { pk => val partKeyMap = convertPartKeyWithTimesToMap( PartKeyWithTimes(pk.partKey, UnsafeUtils.arayOffset, pk.startTime, pk.endTime)) partKeyMap ++ Map( ("_firstSampleTime_".utf8, pk.startTime.toString.utf8), ("_lastSampleTime_".utf8, pk.endTime.toString.utf8)) - } take(limit) + } } else { - val partIds = partKeyIndex.partIdsFromFilters(filter, startTime, endTime) + val partIds = partKeyIndex.partIdsFromFilters(filter, startTime, endTime, limit) val inMem = InMemPartitionIterator2(partIds) val inMemPartKeys = inMem.map { p => convertPartKeyWithTimesToMap(PartKeyWithTimes(p.partKeyBase, p.partKeyOffset, -1, -1))} val skippedPartKeys = inMem.skippedPartIDs.iterator().map(partId => { convertPartKeyWithTimesToMap(partKeyFromPartId(partId))}) - (inMemPartKeys ++ skippedPartKeys).take(limit) + (inMemPartKeys ++ skippedPartKeys) } } diff --git a/core/src/test/scala/filodb.core/memstore/PartKeyLuceneIndexSpec.scala b/core/src/test/scala/filodb.core/memstore/PartKeyLuceneIndexSpec.scala index e776700b3f..5438099722 100644 --- a/core/src/test/scala/filodb.core/memstore/PartKeyLuceneIndexSpec.scala +++ b/core/src/test/scala/filodb.core/memstore/PartKeyLuceneIndexSpec.scala @@ -65,10 +65,18 @@ class PartKeyLuceneIndexSpec extends AnyFunSpec with Matchers with BeforeAndAfte val partNums1 = keyIndex.partIdsFromFilters(Nil, start, end) partNums1 shouldEqual debox.Buffer(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + // return only 4 partIds - empty filter + val partNumsLimit = keyIndex.partIdsFromFilters(Nil, start, end, 4) + partNumsLimit shouldEqual debox.Buffer(0, 1, 2, 3) + val filter2 = ColumnFilter("Actor2Code", Equals("GOV".utf8)) val partNums2 = keyIndex.partIdsFromFilters(Seq(filter2), start, end) partNums2 shouldEqual debox.Buffer(7, 8, 9) + // return only 2 partIds - with filter + val partNumsLimitFilter = keyIndex.partIdsFromFilters(Seq(filter2), start, end, 2) + partNumsLimitFilter shouldEqual debox.Buffer(7, 8) + val filter3 = ColumnFilter("Actor2Name", Equals("REGIME".utf8)) val partNums3 = keyIndex.partIdsFromFilters(Seq(filter3), start, end) partNums3 shouldEqual debox.Buffer(8, 9) @@ -109,6 +117,24 @@ class PartKeyLuceneIndexSpec extends AnyFunSpec with Matchers with BeforeAndAfte result.map( p => (p.startTime, p.endTime)) shouldEqual expected.map( p => (p.startTime, p.endTime)) } + it("should fetch only two part key records from filters") { + // Add the first ten keys and row numbers + val pkrs = partKeyFromRecords(dataset6, records(dataset6, readers.take(10)), Some(partBuilder)) + .zipWithIndex.map { case (addr, i) => + val pk = partKeyOnHeap(dataset6.partKeySchema, ZeroPointer, addr) + keyIndex.addPartKey(pk, i, i, i + 10)() + PartKeyLuceneIndexRecord(pk, i, i + 10) + } + keyIndex.refreshReadersBlocking() + + val filter2 = ColumnFilter("Actor2Code", Equals("GOV".utf8)) + val result = keyIndex.partKeyRecordsFromFilters(Seq(filter2), 0, Long.MaxValue, 2) + val expected = Seq(pkrs(7), pkrs(8)) + + result.map(_.partKey.toSeq) shouldEqual expected.map(_.partKey.toSeq) + result.map(p => (p.startTime, p.endTime)) shouldEqual expected.map(p => (p.startTime, p.endTime)) + } + it("should fetch part key iterator records from filters correctly") { // Add the first ten keys and row numbers