Skip to content

Commit

Permalink
YARN-11684. Fix general contract violation in PriorityQueueComparator. (
Browse files Browse the repository at this point in the history
apache#6725) Contributed by Tamas Domok.

Signed-off-by: Shilun Fan <[email protected]>
  • Loading branch information
tomicooler authored Apr 19, 2024
1 parent e8b2c28 commit a386ac1
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.server.resourcemanager.nodelabels
.RMNodeLabelsManager;
Expand All @@ -32,7 +33,6 @@
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand All @@ -54,17 +54,7 @@
public class PriorityUtilizationQueueOrderingPolicy
implements QueueOrderingPolicy {
private List<CSQueue> queues;
private boolean respectPriority;

// This makes multiple threads can sort queues at the same time
// For different partitions.
private static ThreadLocal<String> partitionToLookAt =
ThreadLocal.withInitial(new Supplier<String>() {
@Override
public String get() {
return RMNodeLabelsManager.NO_LABEL;
}
});
private final boolean respectPriority;

/**
* Compare two queues with possibly different priority and assigned capacity,
Expand Down Expand Up @@ -101,15 +91,21 @@ public static int compare(double relativeAssigned1, double relativeAssigned2,
/**
* Comparator that both looks at priority and utilization
*/
private class PriorityQueueComparator
final private class PriorityQueueComparator
implements Comparator<PriorityQueueResourcesForSorting> {

final private String partition;

private PriorityQueueComparator(String partition) {
this.partition = partition;
}

@Override
public int compare(PriorityQueueResourcesForSorting q1Sort,
PriorityQueueResourcesForSorting q2Sort) {
String p = partitionToLookAt.get();

int rc = compareQueueAccessToPartition(q1Sort.queue, q2Sort.queue, p);
int rc = compareQueueAccessToPartition(
q1Sort.nodeLabelAccessible,
q2Sort.nodeLabelAccessible);
if (0 != rc) {
return rc;
}
Expand All @@ -133,17 +129,17 @@ public int compare(PriorityQueueResourcesForSorting q1Sort,
float used2 = q2Sort.absoluteUsedCapacity;

return compare(q1Sort, q2Sort, used1, used2,
q1Sort.queue.getPriority().
getPriority(), q2Sort.queue.getPriority().getPriority());
q1Sort.priority.
getPriority(), q2Sort.priority.getPriority());
} else{
// both q1 has positive abs capacity and q2 has positive abs
// capacity
float used1 = q1Sort.usedCapacity;
float used2 = q2Sort.usedCapacity;

return compare(q1Sort, q2Sort, used1, used2,
q1Sort.queue.getPriority().getPriority(),
q2Sort.queue.getPriority().getPriority());
q1Sort.priority.getPriority(),
q2Sort.priority.getPriority());
}
}

Expand Down Expand Up @@ -181,8 +177,7 @@ private int compare(PriorityQueueResourcesForSorting q1Sort,
return rc;
}

private int compareQueueAccessToPartition(CSQueue q1, CSQueue q2,
String partition) {
private int compareQueueAccessToPartition(boolean q1Accessible, boolean q2Accessible) {
// Everybody has access to default partition
if (StringUtils.equals(partition, RMNodeLabelsManager.NO_LABEL)) {
return 0;
Expand All @@ -192,14 +187,6 @@ private int compareQueueAccessToPartition(CSQueue q1, CSQueue q2,
* Check accessible to given partition, if one queue accessible and
* the other not, accessible queue goes first.
*/
boolean q1Accessible =
q1.getAccessibleNodeLabels() != null && q1.getAccessibleNodeLabels()
.contains(partition) || q1.getAccessibleNodeLabels().contains(
RMNodeLabelsManager.ANY);
boolean q2Accessible =
q2.getAccessibleNodeLabels() != null && q2.getAccessibleNodeLabels()
.contains(partition) || q2.getAccessibleNodeLabels().contains(
RMNodeLabelsManager.ANY);
if (q1Accessible && !q2Accessible) {
return -1;
} else if (!q1Accessible && q2Accessible) {
Expand All @@ -218,22 +205,32 @@ public static class PriorityQueueResourcesForSorting {
private final float usedCapacity;
private final Resource configuredMinResource;
private final float absoluteCapacity;
private final Priority priority;
private final boolean nodeLabelAccessible;
private final CSQueue queue;

PriorityQueueResourcesForSorting(CSQueue queue) {
PriorityQueueResourcesForSorting(CSQueue queue, String partition) {
this.queue = queue;
this.absoluteUsedCapacity =
queue.getQueueCapacities().
getAbsoluteUsedCapacity(partitionToLookAt.get());
getAbsoluteUsedCapacity(partition);
this.usedCapacity =
queue.getQueueCapacities().
getUsedCapacity(partitionToLookAt.get());
getUsedCapacity(partition);
this.absoluteCapacity =
queue.getQueueCapacities().
getAbsoluteCapacity(partitionToLookAt.get());
getAbsoluteCapacity(partition);
this.configuredMinResource =
queue.getQueueResourceQuotas().
getConfiguredMinResource(partitionToLookAt.get());
getConfiguredMinResource(partition);
this.priority = queue.getPriority();
this.nodeLabelAccessible = queue.getAccessibleNodeLabels() != null &&
queue.getAccessibleNodeLabels().contains(partition) ||
queue.getAccessibleNodeLabels().contains(RMNodeLabelsManager.ANY);
}

static PriorityQueueResourcesForSorting create(CSQueue queue, String partition) {
return new PriorityQueueResourcesForSorting(queue, partition);
}

public CSQueue getQueue() {
Expand All @@ -252,14 +249,13 @@ public void setQueues(List<CSQueue> queues) {

@Override
public Iterator<CSQueue> getAssignmentIterator(String partition) {
// partitionToLookAt is a thread local variable, therefore it is safe to mutate it.
PriorityUtilizationQueueOrderingPolicy.partitionToLookAt.set(partition);

// Copy (for thread safety) and sort the snapshot of the queues in order to avoid breaking
// the prerequisites of TimSort. See YARN-10178 for details.
return new ArrayList<>(queues).stream().map(PriorityQueueResourcesForSorting::new).sorted(
new PriorityQueueComparator()).map(PriorityQueueResourcesForSorting::getQueue).collect(
Collectors.toList()).iterator();
return new ArrayList<>(queues).stream()
.map(queue -> PriorityQueueResourcesForSorting.create(queue, partition))
.sorted(new PriorityQueueComparator(partition))
.map(PriorityQueueResourcesForSorting::getQueue)
.collect(Collectors.toList()).iterator();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet;

import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.QueueResourceQuotas;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CSQueue;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.QueueCapacities;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -250,4 +255,90 @@ public void testPriorityUtilizationOrdering() {
verifyOrder(policy, "x", new String[] { "e", "c", "d", "b", "a" });

}

@Test
public void testComparatorDoesNotValidateGeneralContract() {
final String[] nodeLabels = {"x", "y", "z"};
PriorityUtilizationQueueOrderingPolicy policy =
new PriorityUtilizationQueueOrderingPolicy(true);

final String partition = nodeLabels[randInt(0, nodeLabels.length - 1)];
List<CSQueue> list = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
CSQueue q = mock(CSQueue.class);
when(q.getQueuePath()).thenReturn(String.format("%d", i));

// simulating change in queueCapacities
when(q.getQueueCapacities())
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition))
.thenReturn(randomQueueCapacities(partition));

// simulating change in the priority
when(q.getPriority())
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)))
.thenReturn(Priority.newInstance(randInt(0, 10)));

if (randInt(0, nodeLabels.length) == 1) {
// simulating change in nodeLabels
when(q.getAccessibleNodeLabels())
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels))
.thenReturn(randomNodeLabels(nodeLabels));
}

// simulating change in configuredMinResource
when(q.getQueueResourceQuotas())
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition))
.thenReturn(randomResourceQuotas(partition));
list.add(q);
}

policy.setQueues(list);
// java.lang.IllegalArgumentException: Comparison method violates its general contract!
assertDoesNotThrow(() -> policy.getAssignmentIterator(partition));
}

private QueueCapacities randomQueueCapacities(String partition) {
QueueCapacities qc = new QueueCapacities(false);
qc.setAbsoluteCapacity(partition, (float) randFloat(0.0d, 100.0d));
qc.setUsedCapacity(partition, (float) randFloat(0.0d, 100.0d));
qc.setAbsoluteUsedCapacity(partition, (float) randFloat(0.0d, 100.0d));
return qc;
}

private Set<String> randomNodeLabels(String[] availableNodeLabels) {
Set<String> nodeLabels = new HashSet<>();
for (String label : availableNodeLabels) {
if (randInt(0, 1) == 1) {
nodeLabels.add(label);
}
}
return nodeLabels;
}

private QueueResourceQuotas randomResourceQuotas(String partition) {
QueueResourceQuotas qr = new QueueResourceQuotas();
qr.setConfiguredMinResource(partition,
Resource.newInstance(randInt(1, 10) * 1024, randInt(1, 10)));
return qr;
}

private static double randFloat(double min, double max) {
return min + ThreadLocalRandom.current().nextFloat() * (max - min);
}

private static int randInt(int min, int max) {
return ThreadLocalRandom.current().nextInt(min, max + 1);
}
}

0 comments on commit a386ac1

Please sign in to comment.