Skip to content

Commit

Permalink
Fetch task template
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Sep 29, 2023
1 parent 03b87f2 commit f2eb4b7
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import flyteidl.admin.LaunchPlanOuterClass;
import flyteidl.admin.TaskOuterClass;
import flyteidl.admin.WorkflowOuterClass;
import flyteidl.core.IdentifierOuterClass;
import flyteidl.service.AdminServiceGrpc;
import io.grpc.Channel;
import io.grpc.ClientInterceptor;
Expand Down Expand Up @@ -185,34 +184,38 @@ public TaskIdentifier fetchLatestTaskId(NamedEntityIdentifier taskId) {
return fetchLatestResource(
taskId,
request -> stub.listTasks(request).getTasksList(),
TaskOuterClass.Task::getId,
ProtoUtil::deserializeTaskId);
task -> ProtoUtil.deserializeTaskId(task.getId()));
}

@Nullable
public TaskTemplate fetchLatestTaskTemplate(NamedEntityIdentifier taskId) {
return fetchLatestResource(
taskId,
request -> stub.listTasks(request).getTasksList(),
task -> ProtoUtil.deserialize(task.getClosure().getCompiledTask().getTemplate()));
}

@Nullable
public WorkflowIdentifier fetchLatestWorkflowId(NamedEntityIdentifier workflowId) {
return fetchLatestResource(
workflowId,
request -> stub.listWorkflows(request).getWorkflowsList(),
WorkflowOuterClass.Workflow::getId,
ProtoUtil::deserializeWorkflowId);
workflow -> ProtoUtil.deserializeWorkflowId(workflow.getId()));
}

@Nullable
public LaunchPlanIdentifier fetchLatestLaunchPlanId(NamedEntityIdentifier launchPlanId) {
return fetchLatestResource(
launchPlanId,
request -> stub.listLaunchPlans(request).getLaunchPlansList(),
LaunchPlanOuterClass.LaunchPlan::getId,
ProtoUtil::deserializeLaunchPlanId);
launchPlan -> ProtoUtil.deserializeLaunchPlanId(launchPlan.getId()));
}

@Nullable
private <T, RespT> T fetchLatestResource(
NamedEntityIdentifier nameId,
Function<ResourceListRequest, List<RespT>> performRequestFn,
Function<RespT, IdentifierOuterClass.Identifier> extractIdFn,
Function<IdentifierOuterClass.Identifier, T> deserializeFn) {
Function<RespT, T> deserializeFn) {
ResourceListRequest request =
ResourceListRequest.newBuilder()
.setLimit(1)
Expand All @@ -230,8 +233,7 @@ private <T, RespT> T fetchLatestResource(
return null;
}

IdentifierOuterClass.Identifier id = extractIdFn.apply(list.get(0));
return deserializeFn.apply(id);
return deserializeFn.apply(list.get(0));
}

private <T> void idempotentCreate(String label, Object id, GrpcRetries.Retryable<T> retryable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.flyte.api.v1.Container;
Expand Down Expand Up @@ -295,8 +296,8 @@ static void checkCycles(Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows)
checkCycles(
workflowId,
allWorkflows,
/*beingVisited=*/ new HashSet<>(),
/*visited=*/ new HashSet<>()))
/* beingVisited= */ new HashSet<>(),
/* visited= */ new HashSet<>()))
.findFirst();
if (cycle.isPresent()) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -374,8 +375,10 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
.collect(toUnmodifiableMap());
}

public static Map<TaskIdentifier, TaskTemplate> collectTasks(
List<Node> rewrittenNodes, Map<TaskIdentifier, TaskTemplate> allTasks) {
public static Map<TaskIdentifier, TaskTemplate> collectDynamicWorkflowTasks(
List<Node> rewrittenNodes,
Map<TaskIdentifier, TaskTemplate> allTasks,
Function<TaskIdentifier, TaskTemplate> remoteTaskTemplateFetcher) {
return collectTaskIds(rewrittenNodes).stream()
// all identifiers should be rewritten at this point
.map(
Expand All @@ -389,7 +392,9 @@ public static Map<TaskIdentifier, TaskTemplate> collectTasks(
.distinct()
.map(
taskId -> {
TaskTemplate taskTemplate = allTasks.get(taskId);
TaskTemplate taskTemplate =
Optional.ofNullable(allTasks.get(taskId))
.orElseGet(() -> remoteTaskTemplateFetcher.apply(taskId));

if (taskTemplate == null) {
throw new NoSuchElementException("Can't find referenced task " + taskId);
Expand Down
62 changes: 62 additions & 0 deletions jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.jflyte.utils;

import static java.util.Collections.emptyMap;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.flyte.api.v1.Container;
import org.flyte.api.v1.KeyValuePair;
import org.flyte.api.v1.RetryStrategy;
import org.flyte.api.v1.SimpleType;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.TaskTemplate;
import org.flyte.api.v1.TypedInterface;

final class Fixtures {
static final String IMAGE_NAME = "alpine:latest";
static final String COMMAND = "date";

static final Container CONTAINER =
Container.builder()
.command(ImmutableList.of(COMMAND))
.args(ImmutableList.of())
.image(IMAGE_NAME)
.env(ImmutableList.of(KeyValuePair.of("key", "value")))
.build();
static final TypedInterface INTERFACE_ =
TypedInterface.builder()
.inputs(ImmutableMap.of("x", ApiUtils.createVar(SimpleType.STRING)))
.outputs(ImmutableMap.of("y", ApiUtils.createVar(SimpleType.INTEGER)))
.build();
static final RetryStrategy RETRIES = RetryStrategy.builder().retries(4).build();
static final TaskTemplate TASK_TEMPLATE =
TaskTemplate.builder()
.container(CONTAINER)
.type("custom-task")
.interface_(INTERFACE_)
.custom(Struct.of(emptyMap()))
.retries(RETRIES)
.discoverable(false)
.cacheSerializable(false)
.build();

private Fixtures() {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.flyte.jflyte.utils;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static org.flyte.jflyte.utils.Fixtures.COMMAND;
import static org.flyte.jflyte.utils.Fixtures.IMAGE_NAME;
import static org.flyte.jflyte.utils.Fixtures.TASK_TEMPLATE;
import static org.flyte.jflyte.utils.FlyteAdminClient.TRIGGERING_PRINCIPAL;
import static org.flyte.jflyte.utils.FlyteAdminClient.USER_TRIGGERED_EXECUTION_NESTING;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -32,7 +34,10 @@
import flyteidl.admin.LaunchPlanOuterClass;
import flyteidl.admin.ScheduleOuterClass;
import flyteidl.admin.TaskOuterClass;
import flyteidl.admin.TaskOuterClass.Task;
import flyteidl.admin.TaskOuterClass.TaskClosure;
import flyteidl.admin.WorkflowOuterClass;
import flyteidl.core.Compiler.CompiledTask;
import flyteidl.core.IdentifierOuterClass;
import flyteidl.core.IdentifierOuterClass.ResourceType;
import flyteidl.core.Interface;
Expand All @@ -51,9 +56,7 @@
import java.util.Collections;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BindingData;
import org.flyte.api.v1.Container;
import org.flyte.api.v1.CronSchedule;
import org.flyte.api.v1.KeyValuePair;
import org.flyte.api.v1.LaunchPlan;
import org.flyte.api.v1.LaunchPlanIdentifier;
import org.flyte.api.v1.Literal;
Expand All @@ -64,10 +67,8 @@
import org.flyte.api.v1.PartialTaskIdentifier;
import org.flyte.api.v1.PartialWorkflowIdentifier;
import org.flyte.api.v1.Primitive;
import org.flyte.api.v1.RetryStrategy;
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.SimpleType;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.TaskIdentifier;
import org.flyte.api.v1.TaskNode;
import org.flyte.api.v1.TaskTemplate;
Expand Down Expand Up @@ -95,8 +96,6 @@ public class FlyteAdminClientTest {
private static final String WF_NAME = "workflow-foo";
private static final String WF_VERSION = "version-wf-foo";
private static final String WF_OLD_VERSION = "version-wf-bar";
private static final String IMAGE_NAME = "alpine:latest";
private static final String COMMAND = "date";

private FlyteAdminClient client;
private TestAdminService stubService;
Expand Down Expand Up @@ -138,33 +137,7 @@ public void shouldPropagateCreateTaskToStub() {
.version(TASK_VERSION)
.build();

TypedInterface interface_ =
TypedInterface.builder()
.inputs(ImmutableMap.of("x", ApiUtils.createVar(SimpleType.STRING)))
.outputs(ImmutableMap.of("y", ApiUtils.createVar(SimpleType.INTEGER)))
.build();

Container container =
Container.builder()
.command(ImmutableList.of(COMMAND))
.args(ImmutableList.of())
.image(IMAGE_NAME)
.env(ImmutableList.of(KeyValuePair.of("key", "value")))
.build();

RetryStrategy retries = RetryStrategy.builder().retries(4).build();
TaskTemplate template =
TaskTemplate.builder()
.container(container)
.type("custom-task")
.interface_(interface_)
.custom(Struct.of(emptyMap()))
.retries(retries)
.discoverable(false)
.cacheSerializable(false)
.build();

client.createTask(identifier, template);
client.createTask(identifier, TASK_TEMPLATE);

assertThat(
stubService.createTaskRequest,
Expand Down Expand Up @@ -397,6 +370,35 @@ public void fetchLatestTaskIdShouldReturnFirstTaskFromList() {
.build()));
}

@Test
public void fetchLatestTaskShouldReturnFirstTaskFromList() {
stubService.taskLists =
Arrays.asList(
Task.newBuilder()
.setId(newIdentifier(ResourceType.TASK, TASK_NAME, TASK_VERSION))
.setClosure(
TaskClosure.newBuilder()
.setCompiledTask(
CompiledTask.newBuilder()
.setTemplate(ProtoUtil.serialize(TASK_TEMPLATE))
.build())
.build())
.build(),
TaskOuterClass.Task.newBuilder()
.setId(newIdentifier(ResourceType.TASK, TASK_NAME, TASK_OLD_VERSION))
.build());

TaskTemplate fetchLatestTaskTemplate =
client.fetchLatestTaskTemplate(
NamedEntityIdentifier.builder()
.project(PROJECT)
.domain(DOMAIN)
.name(TASK_NAME)
.build());

assertThat(fetchLatestTaskTemplate, equalTo(TASK_TEMPLATE));
}

@Test
public void fetchLatestTaskIdShouldReturnNullWhenEmptyList() {
stubService.taskLists = Collections.emptyList();
Expand Down Expand Up @@ -517,7 +519,7 @@ private TaskOuterClass.TaskSpec newTaskSpec() {
Tasks.TaskTemplate.newBuilder()
.setContainer(
Tasks.Container.newBuilder()
.setImage(FlyteAdminClientTest.IMAGE_NAME)
.setImage(IMAGE_NAME)
.addCommand(COMMAND)
.addEnv(
Literals.KeyValuePair.newBuilder()
Expand Down
Loading

0 comments on commit f2eb4b7

Please sign in to comment.