Skip to content

Commit

Permalink
Fix branch node upstreams using LocalEngine
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
andresgomezfrr committed Dec 19, 2023
1 parent 8f3eaae commit 7374b47
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,33 @@ ExecutionNode compile(Node node) {
String.format("Node [%s] must be a task, branch or workflow node", node.id()));
}

private static List<String> compileUpstreamNodeIds(Node node) {
List<String> upstreamNodeIds = new ArrayList<>();
node.inputs().stream()
private static List<String> getUpstreamsFromInputs(List<Binding> inputs) {
return inputs.stream()
.map(Binding::binding)
.flatMap(ExecutionNodeCompiler::unpackBindingData)
.filter(x -> x.kind() == BindingData.Kind.PROMISE)
.map(x -> x.promise().nodeId())
.forEach(upstreamNodeIds::add);
.collect(toList());
}

static List<String> compileUpstreamNodeIds(Node node) {
List<String> upstreamNodeIds = new ArrayList<>(getUpstreamsFromInputs(node.inputs()));

if (node.branchNode() != null) {
var ifElse = node.branchNode().ifElse();
upstreamNodeIds.addAll(getUpstreamsFromInputs(ifElse.case_().thenNode().inputs()));

if (ifElse.elseNode() != null) {
upstreamNodeIds.addAll(getUpstreamsFromInputs(ifElse.elseNode().inputs()));
}

ifElse
.other()
.forEach(
other -> {
upstreamNodeIds.addAll(getUpstreamsFromInputs(other.thenNode().inputs()));
});
}

upstreamNodeIds.addAll(node.upstreamNodeIds());
if (upstreamNodeIds.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,18 @@
import java.util.Map;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BindingData;
import org.flyte.api.v1.BooleanExpression;
import org.flyte.api.v1.BranchNode;
import org.flyte.api.v1.ComparisonExpression;
import org.flyte.api.v1.ComparisonExpression.Operator;
import org.flyte.api.v1.IfBlock;
import org.flyte.api.v1.IfElseBlock;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.Node;
import org.flyte.api.v1.Operand;
import org.flyte.api.v1.OutputReference;
import org.flyte.api.v1.PartialTaskIdentifier;
import org.flyte.api.v1.Primitive;
import org.flyte.api.v1.RetryStrategy;
import org.flyte.api.v1.RunnableTask;
import org.flyte.api.v1.TaskNode;
Expand Down Expand Up @@ -133,6 +141,131 @@ void testSort_notConnected() {
assertEquals(ImmutableList.of("node-1", "node-3", "node-4", "node-2"), getNodeIds(sorted));
}

@Test
void testcompileUpstreamNodeIds_branchNode_ifOthersElse() {
var caseNode =
createNode(
"node-1",
ImmutableList.of(START_NODE_ID),
ImmutableList.of(createOutputReferenceBinding("input", "node-1", "node-1-output")));

var elseNode =
createNode(
"node-2",
ImmutableList.of(START_NODE_ID),
ImmutableList.of(createOutputReferenceBinding("input", "node-2", "node-2-output")));

var otherNode =
createNode(
"node-3",
ImmutableList.of(START_NODE_ID),
ImmutableList.of(createOutputReferenceBinding("input", "node-3", "node-3-output")));

var falseCondition =
BooleanExpression.ofComparison(
ComparisonExpression.builder()
.leftValue(Operand.ofPrimitive(Primitive.ofBooleanValue(true)))
.operator(Operator.EQ)
.rightValue(Operand.ofPrimitive(Primitive.ofBooleanValue(false)))
.build());

var ifElse =
IfElseBlock.builder()
.case_(IfBlock.builder().condition(falseCondition).thenNode(caseNode).build())
.elseNode(elseNode)
.other(
ImmutableList.of(
IfBlock.builder().condition(falseCondition).thenNode(otherNode).build()))
.build();

Node branchNode =
Node.builder()
.id("branch-node")
.branchNode(BranchNode.builder().ifElse(ifElse).build())
.upstreamNodeIds(ImmutableList.of())
.inputs(ImmutableList.of())
.build();

assertEquals(
ImmutableList.of("node-1", "node-2", "node-3"),
ExecutionNodeCompiler.compileUpstreamNodeIds(branchNode));
}

@Test
void testcompileUpstreamNodeIds_branchNode_ifNotElse() {
var caseNode =
createNode(
"node-2",
ImmutableList.of(START_NODE_ID),
ImmutableList.of(createOutputReferenceBinding("input", "node-2", "node-2-output")));

var falseCondition =
BooleanExpression.ofComparison(
ComparisonExpression.builder()
.leftValue(Operand.ofPrimitive(Primitive.ofBooleanValue(true)))
.operator(Operator.EQ)
.rightValue(Operand.ofPrimitive(Primitive.ofBooleanValue(false)))
.build());

var ifElse =
IfElseBlock.builder()
.case_(IfBlock.builder().condition(falseCondition).thenNode(caseNode).build())
.other(ImmutableList.of())
.build();

Node branchNode =
Node.builder()
.id("branch-node")
.branchNode(BranchNode.builder().ifElse(ifElse).build())
.upstreamNodeIds(ImmutableList.of())
.inputs(ImmutableList.of())
.build();

assertEquals(
ImmutableList.of("node-2"), ExecutionNodeCompiler.compileUpstreamNodeIds(branchNode));
}

@Test
void testcompileUpstreamNodeIds_branchNode_ifElse() {
var caseNode =
createNode(
"node-2",
ImmutableList.of(START_NODE_ID),
ImmutableList.of(createOutputReferenceBinding("input", "node-1", "node-1-output")));
var elseNode =
createNode(
"node-3",
ImmutableList.of(START_NODE_ID),
ImmutableList.of(createOutputReferenceBinding("input", "node-2", "node-2-output")));

var falseCondition =
BooleanExpression.ofComparison(
ComparisonExpression.builder()
.leftValue(Operand.ofPrimitive(Primitive.ofBooleanValue(true)))
.operator(Operator.EQ)
.rightValue(Operand.ofPrimitive(Primitive.ofBooleanValue(false)))
.build());

var ifElse =
IfElseBlock.builder()
.case_(IfBlock.builder().condition(falseCondition).thenNode(caseNode).build())
.elseNode(elseNode)
.other(ImmutableList.of())
.build();

Node branchNode =
Node.builder()
.id("branch-node")
.branchNode(BranchNode.builder().ifElse(ifElse).build())
.upstreamNodeIds(ImmutableList.of())
.inputs(ImmutableList.of())
.build();

assertEquals(
ImmutableList.of("node-1", "node-2"),
ExecutionNodeCompiler.compileUpstreamNodeIds(branchNode));
}

@Test
void testCompile_unknownTask() {
Node node = createNode("node-1", ImmutableList.of(START_NODE_ID));
Expand Down Expand Up @@ -196,6 +329,16 @@ void testCompile_inputCollection() {
assertEquals(ImmutableList.of("node-1", "node-2", "node-3"), execNode.upstreamNodeIds());
}

private static Binding createOutputReferenceBinding(
String varName, String referenceNodeId, String referenceVarName) {
return Binding.builder()
.var_(varName)
.binding(
BindingData.ofOutputReference(
OutputReference.builder().nodeId(referenceNodeId).var(referenceVarName).build()))
.build();
}

private static List<String> getNodeIds(List<ExecutionNode> nodes) {
return nodes.stream().map(ExecutionNode::nodeId).collect(toList());
}
Expand All @@ -211,14 +354,19 @@ private static ExecutionNode createExecutionNode(String nodeId, List<String> ups
}

private static Node createNode(String nodeId, List<String> upstreamNodeIds) {
return createNode(nodeId, upstreamNodeIds, ImmutableList.of());
}

private static Node createNode(
String nodeId, List<String> upstreamNodeIds, List<Binding> inputs) {
return Node.builder()
.id(nodeId)
.taskNode(
TaskNode.builder()
.referenceId(PartialTaskIdentifier.builder().name("unknownTask").build())
.build())
.upstreamNodeIds(upstreamNodeIds)
.inputs(ImmutableList.of())
.inputs(inputs)
.build();
}

Expand Down

0 comments on commit 7374b47

Please sign in to comment.