Skip to content

Commit

Permalink
Merge pull request #20 from tumido/train-stage
Browse files Browse the repository at this point in the history
feat(kfp): add train stage
  • Loading branch information
cooktheryan authored Sep 13, 2024
2 parents 03e813d + 38cd301 commit 0fb529e
Show file tree
Hide file tree
Showing 14 changed files with 1,213 additions and 25 deletions.
138 changes: 131 additions & 7 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,48 @@
from typing import List, Literal, Optional
import click
from kfp import dsl, compiler
from kfp.kubernetes import use_config_map_as_env, use_secret_as_env
from kfp.kubernetes import (
use_config_map_as_env,
use_secret_as_env,
CreatePVC,
DeletePVC,
mount_pvc,
)

K8S_NAME = "kfp-model-server"
MOCKED_STAGES = ['sdg', 'train', 'eval']
MOCKED_STAGES = ["sdg", "train", "eval"]


def pipeline_wrapper(mock: List[Literal[MOCKED_STAGES]]):
"""Wrapper for KFP pipeline, which allows for mocking individual stages."""
if 'sdg' in mock:

# Imports for SDG stage
if "sdg" in mock:
from sdg.faked import git_clone_op, sdg_op
else:
from sdg import git_clone_op, sdg_op

# Imports for Training stage
if "train" in mock:
from training.faked import pytorchjob_manifest_op
from utils.faked import (
kubectl_apply_op,
kubectl_wait_for_op,
huggingface_importer_op,
pvc_to_artifact_op,
pvc_to_model_op
)
from utils import artifact_to_pvc_op
else:
from training import pytorchjob_manifest_op
from utils import (
kubectl_apply_op,
kubectl_wait_for_op,
artifact_to_pvc_op,
huggingface_importer_op,
pvc_to_artifact_op,
pvc_to_model_op
)

@dsl.pipeline(
display_name="InstructLab",
Expand All @@ -26,7 +56,11 @@ def pipeline(
repo_url: str = "https://github.com/instructlab/taxonomy.git",
repo_branch: Optional[str] = None,
repo_pr: Optional[int] = None,
storage_class_name: str = "ocs-external-storagecluster-ceph-rbd",
base_model: str = "ibm-granite/granite-7b-base",
):

# SDG stage
git_clone_task = git_clone_op(
repo_branch=repo_branch, repo_pr=repo_pr, repo_url=repo_url
)
Expand All @@ -37,16 +71,106 @@ def pipeline(
repo_branch=repo_branch,
repo_pr=repo_pr,
)

# For example on K8S object to populate see kfp-model-server.yaml
use_config_map_as_env(sdg_task, K8S_NAME, dict(endpoint="endpoint", model="model"))
use_config_map_as_env(
sdg_task, K8S_NAME, dict(endpoint="endpoint", model="model")
)
use_secret_as_env(sdg_task, K8S_NAME, {"api_key": "api_key"})

# Training stage

# We need to pass storage_class_name as "" to use the default StorageClass, if left empty, KFP uses "standard" StorageClass.
# 'standard' != default StorageClass
# https://github.com/kubeflow/pipelines/blob/1cded35cf5e93d8c8d32fefbddceb2eed8de9a0a/backend/src/v2/driver/driver.go#L1428-L1436
# At least we made it a pipeline parameter
model_pvc_task = CreatePVC(
pvc_name_suffix="-model-cache",
access_modes=["ReadWriteOnce"],
size="50Gi",
storage_class_name=storage_class_name,
)
model_to_artifact = huggingface_importer_op(repo_name=base_model)
model_to_pvc_task = artifact_to_pvc_op(
data=model_to_artifact.outputs["model"], pvc_path="/model"
)
model_to_pvc_task.set_caching_options(False)
mount_pvc(
task=model_to_pvc_task, pvc_name=model_pvc_task.output, mount_path="/model"
)

sdg_input_pvc_task = CreatePVC(
pvc_name_suffix="-sdg",
access_modes=["ReadWriteOnce"],
size="1Gi",
storage_class_name=storage_class_name,
)
sdg_to_pvc_task = artifact_to_pvc_op(
data=sdg_task.outputs["sdg"], pvc_path="/data"
)
sdg_to_pvc_task.set_caching_options(False)
mount_pvc(
task=sdg_to_pvc_task, pvc_name=sdg_input_pvc_task.output, mount_path="/data"
)

output_pvc_task = CreatePVC(
pvc_name_suffix="-output",
access_modes=["ReadWriteOnce"],
size="50Gi",
storage_class_name=storage_class_name,
)

# Using pvc_create_task.output as PyTorchJob name since dsl.PIPELINE_* global variables do not template/work in KFP v2
# https://github.com/kubeflow/pipelines/issues/10453
pytorchjob_manifest_task = pytorchjob_manifest_op(
model_pvc_name=model_pvc_task.output,
input_pvc_name=sdg_input_pvc_task.output,
name_suffix=sdg_input_pvc_task.output,
output_pvc_name=output_pvc_task.output,
)
pytorchjob_manifest_task.set_caching_options(False)

kubectl_apply_task = kubectl_apply_op(
manifest=pytorchjob_manifest_task.outputs["manifest"]
)
kubectl_apply_task.after(sdg_to_pvc_task, model_to_pvc_task)
kubectl_apply_task.set_caching_options(False)

kubectl_wait_task = kubectl_wait_for_op(
condition="condition=Succeeded",
kind="pytorchjobs",
name=pytorchjob_manifest_task.outputs["name"],
)
kubectl_wait_task.after(kubectl_apply_task)
kubectl_wait_task.set_caching_options(False)

sdg_pvc_delete_task = DeletePVC(pvc_name=sdg_input_pvc_task.output)
sdg_pvc_delete_task.after(kubectl_wait_task)

model_pvc_delete_task = DeletePVC(pvc_name=model_pvc_task.output)
model_pvc_delete_task.after(kubectl_wait_task)

output_model_task = pvc_to_artifact_op(pvc_path="/output/data")
output_model_task.after(kubectl_wait_task)
output_model_task.set_caching_options(False)
output_data_task = pvc_to_model_op(pvc_path="/output/model")
output_data_task.after(kubectl_wait_task)
output_model_task.set_caching_options(False)

output_pvc_delete_task = DeletePVC(pvc_name=output_pvc_task.output)
output_pvc_delete_task.after(output_model_task, output_data_task)

return

return pipeline


@click.command()
@click.option('--mock', type=click.Choice(MOCKED_STAGES, case_sensitive=False), help="Mock part of the pipeline", multiple=True, default=[])
@click.option(
"--mock",
type=click.Choice(MOCKED_STAGES, case_sensitive=False),
help="Mock part of the pipeline",
multiple=True,
default=[],
)
def cli(mock):

p = pipeline_wrapper(mock)
Expand Down
Loading

0 comments on commit 0fb529e

Please sign in to comment.