-
Notifications
You must be signed in to change notification settings - Fork 228
/
ShardMapper.scala
324 lines (272 loc) · 13 KB
/
ShardMapper.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
package filodb.coordinator
import scala.util.{Failure, Success, Try}
import akka.actor.{ActorRef, Address}
import com.typesafe.scalalogging.StrictLogging
import filodb.core.DatasetRef
/**
* Each FiloDB dataset is divided into a fixed number of shards for ingestion and distributed in-memory
* querying. The ShardMapper keeps track of the mapping between shards and nodes for a single dataset.
* It also keeps track of the status of each shard.
* - Given a partition hash, find the shard and node coordinator
* - Given a shard key hash and # bits, find the shards and node coordinators to query
* - Given a shard key hash and partition hash, # bits, compute the shard (for ingestion partitioning)
* - Register a node to given shard numbers
*
* It is not multi thread safe for mutations (registrations) but reads should be fine.
*
* The shard finding given a hash needs to be VERY fast, it is in the hot query and ingestion path.
*
* @param numShards number of shards. For this implementation, it needs to be a power of 2.
*
*/
class ShardMapper(val numShards: Int) extends Serializable {
import ShardMapper._
require((numShards & (numShards - 1)) == 0, s"numShards $numShards must be a power of two")
private final val log2NumShards = (scala.math.log10(numShards) / scala.math.log10(2)).round.toInt
private final val shardMap = Array.fill(numShards)(ActorRef.noSender)
private final val statusMap = Array.fill[ShardStatus](numShards)(ShardStatusUnassigned)
private final val log2NumShardsOneBits = (1 << log2NumShards) - 1 // results in log2NumShards one bits
// precomputed mask for shard key bits of shard for each spread value
// lower (log2NumShards-spread) bits of shard are devoted to the shard key and set to 1, rest of bits set to 0
// The spread is the array index.
private final val shardHashMask = Array.tabulate[Int](log2NumShards + 1) { i =>
(1 << (log2NumShards - i)) - 1
}
// precomputed mask for partition hash portion of the shard for each spread value
// upper (spread) bits of the shard are devoted to the partition hash to decide on final shard value
// The spread is the array index. Really it is the inverse of the shardHashMask within those bits.
private final val partHashMask = Array.tabulate[Int](log2NumShards + 1) { i =>
shardHashMask(i) ^ log2NumShardsOneBits
}
def copy(): ShardMapper = {
val shardMapperNew = new ShardMapper(numShards)
shardMap.copyToArray(shardMapperNew.shardMap)
statusMap.copyToArray(shardMapperNew.statusMap)
shardMapperNew
}
override def equals(other: Any): Boolean = other match {
case s: ShardMapper => s.numShards == numShards && s.shardValues == shardValues
case o: Any => false
}
override def hashCode: Int = shardValues.hashCode
override def toString: String = s"ShardMapper ${shardValues.zipWithIndex}"
def shardValues: Seq[(ActorRef, ShardStatus)] = shardMap.zip(statusMap).toBuffer
def statuses: Array[ShardStatus] = statusMap
/**
* Maps a partition hash to a shard number and a NodeCoordinator ActorRef
*/
def partitionToShardNode(partitionHash: Int): ShardAndNode = {
val shard = toShard(partitionHash, numShards) // TODO this is not right. Need to fix
ShardAndNode(shard, shardMap(shard))
}
def coordForShard(shardNum: Int): ActorRef = shardMap(shardNum)
def unassigned(shardNum: Int): Boolean = coordForShard(shardNum) == ActorRef.noSender
def statusForShard(shardNum: Int): ShardStatus = statusMap(shardNum)
def numAssignedCoords: Int = (shardMap.toSet - ActorRef.noSender).size
/**
* Use this function to identify the list of shards to query given the shard key hash.
*
* @param shardKeyHash This is the shard key hash, and is used to identify the shard group
* @param spread This is the 'spread' S assigned for a given appName. The data for every
* metric in the app is spread across 2^S^ shards. Example: if S=2, data
* is spread across 4 shards. If S=0, data is located in 1 shard. Bigger
* apps are assigned bigger S and smaller apps are assigned small S.
* @return The shard numbers that hold data for the given shardKeyHash
*/
def queryShards(shardKeyHash: Int, spread: Int): Seq[Int] = {
validateSpread(spread)
// lower (log2NumShards - spread) bits should go to shardKeyHash
val shardBase = shardKeyHash & shardHashMask(spread)
// create the shard for each possible partHash value portion of shard
val spacing = 1 << (log2NumShards - spread)
(shardBase until numShards by spacing)
}
private def validateSpread(spread: Int) = {
require(spread >= 0 && spread <= log2NumShards, s"Invalid spread $spread. log2NumShards is $log2NumShards")
}
/**
* Use this function to calculate the ingestion shard for a fully specified partition id.
* The code logic ingesting data into partitions can use this function to direct data
* to the right partition
*
* @param shardKeyHash This is the shard key hash, and is used to identify the shard group
* @param partitionHash The 32-bit hash of the overall partition or time series key, containing all tags
* @param spread This is the 'spread' S assigned for a given appName. The data for every
* metric in the app is spread across 2^S^ shards. Example: if S=2, data
* is spread across 4 shards. If S=0, data is located in 1 shard. Bigger
* apps are assigned bigger S and smaller apps are assigned small S.
* @return The shard number that contains the partition for the record described by the given
* shardKeyHash and partitionHash
*/
def ingestionShard(shardKeyHash: Int, partitionHash: Int, spread: Int): Int = {
validateSpread(spread)
// explanation for the one-liner:
// shardKeyHash forms the lower n bits of the shard, while partitionHash forms upper (spread) bits
// It is designed this way such that for the same shard key, the rest of the tags spreads out the shard
// across the shard space (thus nodes), ensuring more even distribution
(shardKeyHash & shardHashMask(spread)) | (partitionHash & partHashMask(spread))
}
@deprecated(message = "Use ingestionShard() instead of this method", since = "0.7")
def hashToShard(shardHash: Int, partitionHash: Int, numShardBits: Int): Int = {
ingestionShard(shardHash, partitionHash, log2NumShards - numShardBits)
}
/**
* Returns all shards that match a given address - typically used to compare to cluster.selfAddress
* for that node's own shards
*/
def shardsForAddress(addr: Address): Seq[Int] =
shardMap.toSeq.zipWithIndex.collect {
case (ref, shardNum) if ref != ActorRef.noSender && ref.path.address == addr => shardNum
}
def shardsForCoord(coord: ActorRef): Seq[Int] =
shardMap.toSeq.zipWithIndex.collect {
case (ref, shardNum) if ref == coord => shardNum
}
def unassignShard(shard: Int): Try[Unit] = {
shardMap(shard) = ActorRef.noSender
Success(())
}
/**
* Returns all the shards that have not yet been assigned or in process of being assigned
*/
def unassignedShards: Seq[Int] =
shardMap.toSeq.zipWithIndex.collect { case (ActorRef.noSender, shard) => shard }
def assignedShards: Seq[Int] =
shardMap.toSeq.zipWithIndex.collect { case (ref, shard) if ref != ActorRef.noSender => shard }
def numAssignedShards: Int = numShards - unassignedShards.length
def isAnIngestionState(shard: Int): Boolean = statusMap(shard) match {
case ShardStatusStopped | ShardStatusDown => false
case _ => true
}
/**
* Find out if a shard is active (Normal or Recovery status) or filter a list of shards
*/
def activeOrRecoveringShard(shard: Int): Boolean =
statusMap(shard) == ShardStatusActive || statusMap(shard).isInstanceOf[ShardStatusRecovery]
def activeOrRecoveringShards(shards: Seq[Int]): Seq[Int] = shards.filter(activeOrRecoveringShard)
def isActiveShard(shard: Int): Boolean =
statusMap(shard) == ShardStatusActive
def activeShards(shards: Seq[Int]): Seq[Int] = shards.filter(isActiveShard)
def isHealthy(): Boolean = statusMap.forall(s => s == ShardStatusActive)
def notActiveShards(): Set[Int] = {
statusMap.zipWithIndex.filter(_._1 != ShardStatusActive).map(_._2).toSet
}
def activeShards(): Set[Int] = {
statusMap.zipWithIndex.filter(_._1 == ShardStatusActive).map(_._2).toSet
}
/**
* Returns a set of unique NodeCoordinator ActorRefs for all assigned shards
*/
def allNodes: Set[ActorRef] = shardMap.toSeq.filter(_ != ActorRef.noSender).toSet
/**
* The main API for updating a ShardMapper.
* If you want to throw if an update does not succeed, call updateFromEvent(ev).get
*/
def updateFromEvent(event: ShardEvent): Try[Unit] = event match {
case e if statusMap.length < e.shard || e.shard < 0 =>
Failure(ShardError(e, s"Invalid shard=${e.shard}, unable to update status."))
case ShardAssignmentStarted(_, shard, node) =>
statusMap(shard) = ShardStatusAssigned
registerNode(Seq(shard), node)
case IngestionStarted(_, shard, node) =>
statusMap(shard) = ShardStatusActive
registerNode(Seq(shard), node)
case RecoveryStarted(_, shard, node, progress) =>
statusMap(shard) = ShardStatusRecovery(progress)
registerNode(Seq(shard), node)
case RecoveryInProgress(_, shard, node, progress) =>
statusMap(shard) = ShardStatusRecovery(progress)
registerNode(Seq(shard), node)
case IngestionError(_, shard, _) =>
statusMap(shard) = ShardStatusError
unassignShard(shard)
case IngestionStopped(_, shard) =>
statusMap(shard) = ShardStatusStopped
Success(())
case ShardDown(_, shard, node) =>
statusMap(shard) = ShardStatusDown
unassignShard(shard)
case _ =>
Success(())
}
/**
* Returns the minimal set of events needed to reconstruct this ShardMapper
*/
def minimalEvents(ref: DatasetRef): Seq[ShardEvent] =
(0 until numShards).flatMap { shard =>
statusMap(shard).minimalEvents(ref, shard, shardMap(shard))
}
def mergeFrom(from: ShardMapper, ref: DatasetRef): ShardMapper = {
from.minimalEvents(ref).foreach(updateFromEvent)
this
}
/**
* Registers a new node to the given shards. Modifies state in place.
* Idempotent.
*/
def registerNode(shards: Seq[Int], coordinator: ActorRef): Try[Unit] = {
shards foreach {
case shard =>
//we always override the mapping. There was code earlier which prevent from
//changing the mapping unless it was explicitly unassigned first.
//But functional tests uncovered that sometimes the member down event is not
//received and hence assignments were not removed first.
shardMap(shard) = coordinator
}
Success(())
}
/**
* Removes a coordinator ref from all shards mapped to it. Resets the shards to no owner and
* returns the shards removed.
*/
private[coordinator] def removeNode(coordinator: ActorRef): Seq[Int] = {
shardMap.toSeq.zipWithIndex.collect {
case (ref, i) if ref == coordinator =>
shardMap(i) = ActorRef.noSender
i
}
}
private[coordinator] def clear(): Unit = {
for { i <- 0 until numShards } { shardMap(i) = ActorRef.noSender }
}
/**
* Gives a pretty grid-view summary of the status of each shard, plus a sorted view of shards owned by each
* coordinator.
*/
def prettyPrint: String = {
val sortedCoords = allNodes.toSeq.sorted
"Status legend: .=Unassigned N=Assigned A=Active E=Error R=Recovery S=Stopped D=Down\n----- Status Map-----\n" +
statusMap.toSeq.grouped(16).zipWithIndex.map { case (statGroup, i) =>
f" ${i * 16}%4d-${Math.min(i * 16 + 15, numShards)}%4d " +
statGroup.grouped(8).map(_.map(statusToLetter).mkString("")).mkString(" ")
}.mkString("\n") +
"\n----- Coordinators -----\n" +
sortedCoords.map { coord =>
f" $coord%40s\t${shardsForCoord(coord).mkString(", ")}"
}.mkString("\n")
}
}
private[filodb] object ShardMapper extends StrictLogging {
val default = new ShardMapper(1)
val log = logger
final case class ShardAndNode(shard: Int, coord: ActorRef)
final def toShard(n: Int, numShards: Int): Int = (((n & 0xffffffffL) * numShards) >> 32).toInt
def copy(orig: ShardMapper, ref: DatasetRef): ShardMapper = {
val newMap = new ShardMapper(orig.numShards)
orig.minimalEvents(ref).foreach(newMap.updateFromEvent)
newMap
}
final case class ShardAlreadyAssigned(shard: Int, status: ShardStatus, assignedTo: ActorRef)
extends Exception(s"Shard [shard=$shard, status=$status, coordinator=$assignedTo] is already assigned.")
final case class ShardError(event: ShardEvent, context: String)
extends Exception(s"$context [shard=${event.shard}, event=$event]")
def statusToLetter(status: ShardStatus): String = status match {
case ShardStatusUnassigned => "."
case ShardStatusAssigned => "N"
case ShardStatusActive => "A"
case ShardStatusError => "E"
case s: ShardStatusRecovery => "R"
case ShardStatusStopped => "S"
case ShardStatusDown => "D"
}
}