-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make plugin task first-class citizen (#272)
* Revert "Revert "Make plugin task first-class citizen (#268)" (#271)" This reverts commit 3d9ab4e. Signed-off-by: Hongxin Liang <[email protected]> * Drop isSyncPlugin for now Signed-off-by: Hongxin Liang <[email protected]> * Skip staging and merging custom where applicable Signed-off-by: Hongxin Liang <[email protected]> * Skip container Signed-off-by: Hongxin Liang <[email protected]> * More description of SdkPluginTask Signed-off-by: Hongxin Liang <[email protected]> --------- Signed-off-by: Hongxin Liang <[email protected]>
- Loading branch information
Showing
11 changed files
with
608 additions
and
32 deletions.
There are no files selected for viewing
20 changes: 20 additions & 0 deletions
20
flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* | ||
* Copyright 2023 Flyte Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.flyte.api.v1; | ||
|
||
/** A task that is handled by a Flyte backend plugin instead of run as a container. */ | ||
public interface PluginTask extends Task {} |
20 changes: 20 additions & 0 deletions
20
flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* | ||
* Copyright 2023 Flyte Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.flyte.api.v1; | ||
|
||
/** A registrar that creates {@link PluginTask} instances. */ | ||
public abstract class PluginTaskRegistrar implements Registrar<TaskIdentifier, PluginTask> {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright 2023 Flyte Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.flyte.examples; | ||
|
||
import com.google.auto.service.AutoService; | ||
import org.flyte.flytekit.SdkPluginTask; | ||
import org.flyte.flytekit.SdkTypes; | ||
|
||
@AutoService(SdkPluginTask.class) | ||
public class NoopPluginTask extends SdkPluginTask<Void, Void> { | ||
|
||
public NoopPluginTask() { | ||
super(SdkTypes.nulls(), SdkTypes.nulls()); | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return "noop"; | ||
} | ||
} |
112 changes: 112 additions & 0 deletions
112
flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
/* | ||
* Copyright 2023 Flyte Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.flyte.flytekit; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import javax.annotation.Nullable; | ||
import org.flyte.api.v1.PartialTaskIdentifier; | ||
|
||
/** | ||
* A task that is handled by a Flyte backend plugin instead of run as a container. Note that a | ||
* plugin task template does not have a container defined, neither all the jars captured in | ||
* classpath, so if this is a requirement, one should use SdkRunnableTask overriding run method to | ||
* simply return null. | ||
*/ | ||
public abstract class SdkPluginTask<InputT, OutputT> extends SdkTransform<InputT, OutputT> { | ||
|
||
private final SdkType<InputT> inputType; | ||
private final SdkType<OutputT> outputType; | ||
|
||
/** | ||
* Called by subclasses passing the {@link SdkType}s for inputs and outputs. | ||
* | ||
* @param inputType type for inputs. | ||
* @param outputType type for outputs. | ||
*/ | ||
public SdkPluginTask(SdkType<InputT> inputType, SdkType<OutputT> outputType) { | ||
this.inputType = inputType; | ||
this.outputType = outputType; | ||
} | ||
|
||
public abstract String getType(); | ||
|
||
@Override | ||
public SdkType<InputT> getInputType() { | ||
return inputType; | ||
} | ||
|
||
@Override | ||
public SdkType<OutputT> getOutputType() { | ||
return outputType; | ||
} | ||
|
||
/** Specifies custom data that can be read by the backend plugin. */ | ||
public SdkStruct getCustom() { | ||
return SdkStruct.empty(); | ||
} | ||
|
||
/** | ||
* Number of retries. Retries will be consumed when the task fails with a recoverable error. The | ||
* number of retries must be less than or equals to 10. | ||
* | ||
* @return number of retries | ||
*/ | ||
public int getRetries() { | ||
return 0; | ||
} | ||
|
||
/** | ||
* Indicates whether the system should attempt to look up this task's output to avoid duplication | ||
* of work. | ||
*/ | ||
public boolean isCached() { | ||
return false; | ||
} | ||
|
||
/** Indicates a logical version to apply to this task for the purpose of cache. */ | ||
public String getCacheVersion() { | ||
return null; | ||
} | ||
|
||
/** | ||
* Indicates whether the system should attempt to execute cached instances in serial to avoid | ||
* duplicate work. | ||
*/ | ||
public boolean isCacheSerializable() { | ||
return false; | ||
} | ||
|
||
@Override | ||
SdkNode<OutputT> apply( | ||
SdkWorkflowBuilder builder, | ||
String nodeId, | ||
List<String> upstreamNodeIds, | ||
@Nullable SdkNodeMetadata metadata, | ||
Map<String, SdkBindingData<?>> inputs) { | ||
PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build(); | ||
List<CompilerError> errors = | ||
Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap()); | ||
|
||
if (!errors.isEmpty()) { | ||
throw new CompilerException(errors); | ||
} | ||
|
||
return new SdkTaskNode<>( | ||
builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); | ||
} | ||
} |
141 changes: 141 additions & 0 deletions
141
flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
/* | ||
* Copyright 2023 Flyte Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.flyte.flytekit; | ||
|
||
import com.google.auto.service.AutoService; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.ServiceLoader; | ||
import java.util.logging.Level; | ||
import java.util.logging.Logger; | ||
import org.flyte.api.v1.PluginTask; | ||
import org.flyte.api.v1.PluginTaskRegistrar; | ||
import org.flyte.api.v1.RetryStrategy; | ||
import org.flyte.api.v1.Struct; | ||
import org.flyte.api.v1.TaskIdentifier; | ||
import org.flyte.api.v1.TypedInterface; | ||
|
||
/** | ||
* Default implementation of a {@link PluginTaskRegistrar} that discovers {@link SdkPluginTask}s | ||
* implementation via {@link ServiceLoader} mechanism. Plugin tasks implementations must use | ||
* {@code @AutoService(SdkPluginTask.class)} or manually add their fully qualifies name to the | ||
* corresponding file. | ||
* | ||
* @see ServiceLoader | ||
*/ | ||
@AutoService(PluginTaskRegistrar.class) | ||
public class SdkPluginTaskRegistrar extends PluginTaskRegistrar { | ||
private static final Logger LOG = Logger.getLogger(SdkPluginTaskRegistrar.class.getName()); | ||
|
||
static { | ||
// enable all levels for the actual handler to pick up | ||
LOG.setLevel(Level.ALL); | ||
} | ||
|
||
private static class PluginTaskImpl<InputT, OutputT> implements PluginTask { | ||
private final SdkPluginTask<InputT, OutputT> sdkTask; | ||
|
||
private PluginTaskImpl(SdkPluginTask<InputT, OutputT> sdkTask) { | ||
this.sdkTask = sdkTask; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return sdkTask.getType(); | ||
} | ||
|
||
@Override | ||
public Struct getCustom() { | ||
return sdkTask.getCustom().struct(); | ||
} | ||
|
||
@Override | ||
public TypedInterface getInterface() { | ||
return TypedInterface.builder() | ||
.inputs(sdkTask.getInputType().getVariableMap()) | ||
.outputs(sdkTask.getOutputType().getVariableMap()) | ||
.build(); | ||
} | ||
|
||
@Override | ||
public RetryStrategy getRetries() { | ||
return RetryStrategy.builder().retries(sdkTask.getRetries()).build(); | ||
} | ||
|
||
@Override | ||
public boolean isCached() { | ||
return sdkTask.isCached(); | ||
} | ||
|
||
@Override | ||
public String getCacheVersion() { | ||
return sdkTask.getCacheVersion(); | ||
} | ||
|
||
@Override | ||
public boolean isCacheSerializable() { | ||
return sdkTask.isCacheSerializable(); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return sdkTask.getName(); | ||
} | ||
} | ||
|
||
/** | ||
* Load {@link SdkPluginTask}s using {@link ServiceLoader}. | ||
* | ||
* @param env env vars in a map that would be used to pick up the project, domain and version for | ||
* the discovered tasks. | ||
* @param classLoader class loader to use when discovering the task using {@link | ||
* ServiceLoader#load(Class, ClassLoader)} | ||
* @return a map of {@link SdkPluginTask}s by its task identifier. | ||
*/ | ||
@Override | ||
@SuppressWarnings("rawtypes") | ||
public Map<TaskIdentifier, PluginTask> load(Map<String, String> env, ClassLoader classLoader) { | ||
ServiceLoader<SdkPluginTask> loader = ServiceLoader.load(SdkPluginTask.class, classLoader); | ||
|
||
LOG.fine("Discovering SdkPluginTask"); | ||
|
||
Map<TaskIdentifier, PluginTask> tasks = new HashMap<>(); | ||
SdkConfig sdkConfig = SdkConfig.load(env); | ||
|
||
for (SdkPluginTask<?, ?> sdkTask : loader) { | ||
String name = sdkTask.getName(); | ||
TaskIdentifier taskId = | ||
TaskIdentifier.builder() | ||
.domain(sdkConfig.domain()) | ||
.project(sdkConfig.project()) | ||
.name(name) | ||
.version(sdkConfig.version()) | ||
.build(); | ||
LOG.fine(String.format("Discovered [%s]", name)); | ||
|
||
PluginTask task = new PluginTaskImpl<>(sdkTask); | ||
PluginTask previous = tasks.put(taskId, task); | ||
|
||
if (previous != null) { | ||
throw new IllegalArgumentException( | ||
String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous)); | ||
} | ||
} | ||
|
||
return tasks; | ||
} | ||
} |
Oops, something went wrong.