diff --git a/.github/workflows/template-setup-e2e-test/action.yaml b/.github/workflows/template-setup-e2e-test/action.yaml index c1b988f8b57..75ee040aea2 100644 --- a/.github/workflows/template-setup-e2e-test/action.yaml +++ b/.github/workflows/template-setup-e2e-test/action.yaml @@ -37,7 +37,7 @@ runs: version: ${{ inputs.kubernetes-version }} - name: Setup Minikube Cluster - uses: medyagh/setup-minikube@v0.0.16 + uses: medyagh/setup-minikube@v0.0.18 with: network-plugin: cni cni: flannel diff --git a/hack/gen-python-sdk/post_gen.py b/hack/gen-python-sdk/post_gen.py index 70eab3a2595..1803bb20430 100644 --- a/hack/gen-python-sdk/post_gen.py +++ b/hack/gen-python-sdk/post_gen.py @@ -41,6 +41,10 @@ def _rewrite_helper(input_file, output_file, rewrite_rules): if output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py": lines.append("# Import Katib API client.\n") lines.append("from kubeflow.katib.api.katib_client import KatibClient\n") + lines.append("# Import Katib TrainerResources class.\n") + lines.append( + "from kubeflow.katib.types.trainer_resources import TrainerResources\n" + ) lines.append("# Import Katib report metrics functions\n") lines.append("from kubeflow.katib.api.report_metrics import report_metrics\n") lines.append("# Import Katib helper functions.\n") diff --git a/sdk/python/v1beta1/kubeflow/katib/__init__.py b/sdk/python/v1beta1/kubeflow/katib/__init__.py index 7aef4c9897d..bafe7befea3 100644 --- a/sdk/python/v1beta1/kubeflow/katib/__init__.py +++ b/sdk/python/v1beta1/kubeflow/katib/__init__.py @@ -71,6 +71,8 @@ # Import Katib API client. from kubeflow.katib.api.katib_client import KatibClient +# Import Katib TrainerResources class. +from kubeflow.katib.types.trainer_resources import TrainerResources # Import Katib report metrics functions from kubeflow.katib.api.report_metrics import report_metrics # Import Katib helper functions. diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index 78808d17f05..05fd1405a3f 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect +import json import logging import multiprocessing -import textwrap import time from typing import Any, Callable, Dict, List, Optional, Union @@ -25,6 +24,7 @@ from kubeflow.katib import models from kubeflow.katib.api_client import ApiClient from kubeflow.katib.constants import constants +from kubeflow.katib.types.trainer_resources import TrainerResources from kubeflow.katib.utils import utils from kubernetes import client, config @@ -71,6 +71,7 @@ def __init__( k8s_client = client.ApiClient(client_configuration) self.custom_api = client.CustomObjectsApi(k8s_client) + self.core_api = client.CoreV1Api(k8s_client) self.api_client = ApiClient() self.namespace = namespace @@ -166,9 +167,21 @@ def tune( self, # TODO (andreyvelich): How to be consistent with other APIs (name) ? name: str, - objective: Callable, - parameters: Dict[str, Any], - base_image: str = constants.BASE_IMAGE_TENSORFLOW, + model_provider_parameters: Optional[ + "HuggingFaceModelParams" # noqa: F821 + ] = None, + dataset_provider_parameters: Optional[ + Union["HuggingFaceDatasetParams", "S3DatasetParams"] # noqa: F821 + ] = None, + trainer_parameters: Optional["HuggingFaceTrainerParams"] = None, # noqa: F821 + storage_config: Optional[Dict[str, Optional[Union[str, List[str]]]]] = { + "size": constants.PVC_DEFAULT_SIZE, + "storage_class": None, + "access_modes": constants.PVC_DEFAULT_ACCESS_MODES, + }, + objective: Optional[Callable] = None, + base_image: Optional[str] = constants.BASE_IMAGE_TENSORFLOW, + parameters: Optional[Dict[str, Any]] = None, namespace: Optional[str] = None, env_per_trial: Optional[ Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]] @@ -184,29 +197,77 @@ def tune( max_trial_count: int = None, parallel_trial_count: int = None, max_failed_trial_count: int = None, - resources_per_trial: Union[dict, client.V1ResourceRequirements, None] = None, + resources_per_trial: Optional[ + Union[dict, client.V1ResourceRequirements, TrainerResources] + ] = None, retain_trials: bool = False, packages_to_install: List[str] = None, pip_index_url: str = "https://pypi.org/simple", metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"}, ): - """Create HyperParameter Tuning Katib Experiment from the objective function. + """ + Create HyperParameter Tuning Katib Experiment using one of the following + options: + + 1. External models and datasets + Parameters: `model_provider_parameters` + `dataset_provider_parameters` + + `trainer_parameters`. + Usage: Specify both `model_provider_parameters` and + `dataset_provider_parameters` to download models and datasets from external + platforms (currently support HuggingFace and Amazon S3) using the Storage + Initializer. The `trainer_parameters` should be of type + `HuggingFaceTrainerParams` to set the hyperparameters search space. This API + will automatically define the "Trainer" in HuggingFace with the provided + parameters and utilize `Trainer.train()` from HuggingFace to obtain the metrics + for optimizing hyperparameters. + + 2. Custom objective function + Parameters: `objective` + `base_image` + `parameters`. + Usage: Specify the `objective` parameter to define your own objective function. + The `base_image` parameter will be used to execute the objective function. The + `parameters` should be a dictionary to define the search space for these + parameters. Args: name: Name for the Experiment. - objective: Objective function that Katib uses to train the model. - This function must be Callable and it must have only one dict argument. - Katib uses this argument to send HyperParameters to the function. - The function should not use any code declared outside of the function - definition. Import statements must be added inside the function. - parameters: Dict of HyperParameters to tune your Experiment. You - should use Katib SDK to define the search space for these parameters. - - For example: `parameters = {"lr": katib.search.double(min=0.1, max=0.2)}` - - Also, you can use these parameters to define input for your - objective function. + model_provider_parameters: Parameters for the model provider in the Storage + Initializer. + For example, HuggingFace model name and Transformer type for that model, + like: AutoModelForSequenceClassification. This argument must be the type + of `kubeflow.storage_initializer.hugging_face.HuggingFaceModelParams`. + dataset_provider_parameters: Parameters for the dataset provider in the + Storage Initializer. + For example, name of the HuggingFace dataset or AWS S3 configuration. + This argument must be the type of `kubeflow.storage_initializer.hugging_face. + HuggingFaceDatasetParams` or `kubeflow.storage_initializer.s3.S3DatasetParams`. + trainer_parameters: Parameters for configuring the training process, + including settings for the hyperparameters search space. It should be of + type `HuggingFaceTrainerParams`. You should use the Katib SDK to define + the search space for these parameters. For example: + ``` + trainer_parameters = HuggingFaceTrainerParams( + training_parameters = transformers.TrainingArguments( + learning_rate = katib.search.double(min=0.1, max=0.2), + ), + ), + ``` + Also, you can use these parameters to define input for training the + models. + storage_config: Configuration for Storage Initializer PVC to download + pre-trained model and dataset. You can configure PVC size and storage + class name in this argument. + objective: Objective function that Katib uses to train the model. This + function must be Callable and it must have only one dict argument. Katib + uses this argument to send HyperParameters to the function. The function + should not use any code declared outside of the function definition. + Import statements must be added inside the function. base_image: Image to use when executing the objective function. + parameters: Dict of HyperParameters to tune your Experiment if you choose a custom + objective function. You should use Katib SDK to define the search space for these + parameters. For example: + `parameters = {"lr": katib.search.double(min=0.1, max=0.2)}` + + Also, you can use these parameters to define input for your objective function. namespace: Namespace for the Experiment. env_per_trial: Environment variable(s) to be attached to each trial container. You can specify a dictionary as a mapping object representing the environment @@ -230,24 +291,48 @@ def tune( https://www.kubeflow.org/docs/components/katib/experiment/#configuration-spec. parallel_trial_count: Number of Trials that Experiment runs in parallel. max_failed_trial_count: Maximum number of Trials allowed to fail. - resources_per_trial: A parameter that lets you specify how much - resources each trial container should have. You can either specify a - kubernetes.client.V1ResourceRequirements object (documented here: - https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md) - or a dictionary that includes one or more of the following keys: - `cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate - values for these keys are documented here: - https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/. - For example: + resources_per_trial: A parameter that lets you specify how much resources + each trial container should have. + For custom objective function, you can either specify a kubernetes.client. + V1ResourceRequirements object (documented here: + https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md) + or a dictionary that includes one or more of the following keys: `cpu`, + `memory`, or `gpu` (other keys will be ignored). Appropriate values + for these keys are documented here: + https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/. + For example: { "cpu": "1", "gpu": "1", "memory": "2Gi", } - Please note, `gpu` specifies a resource request with a key of - `nvidia.com/gpu`, i.e. an NVIDIA GPU. If you need a different type - of GPU, pass in a V1ResourceRequirement instance instead, since it's - more flexible. This parameter is optional and defaults to None. + Please note, `gpu` specifies a resource request with a key of + `nvidia.com/gpu`, i.e. an NVIDIA GPU. If you need a different type of + GPU, pass in a V1ResourceRequirement instance instead, since it's more + flexible. This parameter is optional and defaults to None. + + For external models and datasets, you can specify a TrainerResources object, + which includes `num_workers`, `num_procs_per_worker`, and `resources_per_worker`. + For example: + ``` + resources_per_trial = TrainerResources( + num_workers=4, + num_procs_per_worker=2, + resources_per_worker={ + "gpu": "2", + "cpu": "5", + "memory": "10Gi" + } + ) + ``` + - num_workers: Number of PyTorchJob workers. + - num_procs_per_worker: Number of processes per PyTorchJob worker for + `torchrun` CLI. You can use this parameter if you want to use more than 1 GPU + per PyTorchJob worker. + - resources_per_worker: A parameter that lets you specify how much resources + each PyTorchJob worker container should have. You can either specify + a kubernetes.client.V1ResourceRequirements object or a dictionary, same as + resources specified under the option of custom objective function. retain_trials: Whether Trials' resources (e.g. pods) are deleted after Succeeded state. packages_to_install: List of Python packages to install in addition to the base image packages. These packages are installed before @@ -263,6 +348,33 @@ def tune( RuntimeError: Failed to create Katib Experiment. """ + if ( + ( + model_provider_parameters is not None + or dataset_provider_parameters is not None + or trainer_parameters is not None + ) + and (objective is not None or parameters is not None) + ) or ( + ( + model_provider_parameters is None + and dataset_provider_parameters is None + and trainer_parameters is None + ) + and (objective is None and parameters is None) + ): + raise ValueError( + "Invalid configuration for creating a Katib Experiment for hyperparameter " + "optimization. You should specify one of the following options:\n" + "1. Use external models and datasets: specify `model_provider_parameters`, " + "`dataset_provider_parameters` and `trainer_parameters`;\n" + "2. Use custom objective function: specify `objective`, `base_image` and " + "`parameters`." + ) + + if not name: + raise ValueError("Please specify name for the Experiment.") + namespace = namespace or self.namespace # Create Katib Experiment template. @@ -302,137 +414,296 @@ def tune( if max_failed_trial_count is not None: experiment.spec.max_failed_trial_count = max_failed_trial_count - # Validate objective function. - utils.validate_objective_function(objective) - - # Extract objective function implementation. - objective_code = inspect.getsource(objective) - - # Objective function might be defined in some indented scope - # (e.g. in another function). We need to dedent the function code. - objective_code = textwrap.dedent(objective_code) - - # Iterate over input parameters. - input_params = {} - experiment_params = [] - trial_params = [] - for p_name, p_value in parameters.items(): - # If input parameter value is Katib Experiment parameter sample. - if isinstance(p_value, models.V1beta1ParameterSpec): - # Wrap value for the function input. - input_params[p_name] = f"${{trialParameters.{p_name}}}" - - # Add value to the Katib Experiment parameters. - p_value.name = p_name - experiment_params.append(p_value) - - # Add value to the Katib Experiment's Trial parameters. - trial_params.append( - models.V1beta1TrialParameterSpec(name=p_name, reference=p_name) + # If users choose to use a custom objective function. + if objective is not None: + # Add metrics collector to the Katib Experiment. + # Up to now, we only support parameter `kind`, of which default value + # is `StdOut`, to specify the kind of metrics collector. + experiment.spec.metrics_collector_spec = models.V1beta1MetricsCollectorSpec( + collector=models.V1beta1CollectorSpec( + kind=metrics_collector_config["kind"] ) - else: - # Otherwise, add value to the function input. - input_params[p_name] = p_value - - # Wrap objective function to execute it from the file. For example: - # def objective(parameters): - # print(f'Parameters are {parameters}') - # objective({ - # 'lr': '${trialParameters.lr}', - # 'epochs': '${trialParameters.epochs}', - # 'is_dist': False - # }) - objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n" - - # Prepare execute script template. - exec_script = textwrap.dedent( - """ - program_path=$(mktemp -d) - read -r -d '' SCRIPT << EOM\n - {objective_code} - EOM - printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py - python3 -u $program_path/ephemeral_objective.py""" - ) + ) - # Add objective code to the execute script. - exec_script = exec_script.format(objective_code=objective_code) + # Iterate over input parameters and do substitutions. + experiment_params = [] + trial_params = [] + input_params = utils.get_trial_substitutions_from_dict( + parameters, experiment_params, trial_params + ) - # Install Python packages if that is required. - if packages_to_install is not None: - exec_script = ( - utils.get_script_for_python_packages(packages_to_install, pip_index_url) - + exec_script + # Get the execution script from the objective function. + exec_script = utils.get_exec_script_from_objective( + objective, input_params, packages_to_install, pip_index_url ) - if isinstance(resources_per_trial, dict): - if "gpu" in resources_per_trial: - resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop("gpu") + if isinstance(resources_per_trial, dict): + if "gpu" in resources_per_trial: + resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop( + "gpu" + ) - resources_per_trial = client.V1ResourceRequirements( - requests=resources_per_trial, - limits=resources_per_trial, + resources_per_trial = client.V1ResourceRequirements( + requests=resources_per_trial, + limits=resources_per_trial, + ) + + env = [] + env_from = [] + if isinstance(env_per_trial, dict): + env = [ + client.V1EnvVar(name=str(k), value=str(v)) + for k, v in env_per_trial.items() + ] + elif env_per_trial: + for x in env_per_trial: + if isinstance(x, client.V1EnvVar): + env.append(x) + elif isinstance(x, client.V1EnvFromSource): + env_from.append(x) + else: + raise ValueError( + f"Incorrect value for env_per_trial: {env_per_trial}" + ) + + # Create Trial specification. + trial_spec = client.V1Job( + api_version="batch/v1", + kind="Job", + spec=client.V1JobSpec( + template=client.V1PodTemplateSpec( + metadata=models.V1ObjectMeta( + annotations={"sidecar.istio.io/inject": "false"} + ), + spec=client.V1PodSpec( + restart_policy="Never", + containers=[ + client.V1Container( + name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, + image=base_image, + command=["bash", "-c"], + args=[exec_script], + env=env if env else None, + env_from=env_from if env_from else None, + resources=resources_per_trial, + ) + ], + ), + ) + ), ) - env = [] - env_from = [] - if isinstance(env_per_trial, dict): - env = [ - client.V1EnvVar(name=str(k), value=str(v)) - for k, v in env_per_trial.items() - ] - elif env_per_trial: - for x in env_per_trial: - if isinstance(x, client.V1EnvVar): - env.append(x) - elif isinstance(x, client.V1EnvFromSource): - env_from.append(x) - else: - raise ValueError( - f"Incorrect value for env_per_trial: {env_per_trial}" + # Create Trial template. + trial_template = models.V1beta1TrialTemplate( + primary_container_name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, + retain=retain_trials, + trial_parameters=trial_params, + trial_spec=trial_spec, + ) + + # If users choose to use external models and datasets. + else: + if ( + not model_provider_parameters + or not dataset_provider_parameters + or not trainer_parameters + ): + raise ValueError("One of the required parameters is None") + + try: + from kubeflow.storage_initializer.constants import ( + VOLUME_PATH_DATASET, + VOLUME_PATH_MODEL, + ) + from kubeflow.storage_initializer.hugging_face import ( + HuggingFaceDatasetParams, + HuggingFaceModelParams, + ) + from kubeflow.storage_initializer.s3 import S3DatasetParams + from kubeflow.training import models as training_models + from kubeflow.training.constants.constants import ( + JOB_PARAMETERS, + PYTORCHJOB_KIND, + STORAGE_INITIALIZER, + STORAGE_INITIALIZER_IMAGE, + STORAGE_INITIALIZER_VOLUME_MOUNT, + TRAINER_TRANSFORMER_IMAGE, + ) + from kubeflow.training.utils import utils as training_utils + except ImportError: + raise ImportError( + "LLM dependencies for Tune API are not installed. " + + "Run: pip install -U 'kubeflow-katib[huggingface]' " + ) + + print( + "Thank you for using `tune` API for LLM hyperparameter optimization. This feature " + "is in the alpha stage. Kubeflow community is looking for your feedback. Please " + "share your experience via #kubeflow-katib Slack channel or the Kubeflow Katib " + "GitHub." + ) + + # Specify metrics format for the collector, for example: 'train_loss':0.846 + experiment.spec.metrics_collector_spec = models.V1beta1MetricsCollectorSpec( + source=models.V1beta1SourceSpec( + filter=models.V1beta1FilterSpec( + metrics_format=[ + r"'([\w|-]+)'\s*:\s*([+-]?\d*(\.\d+)?([Ee][+-]?\d+)?)", + ] ) + ), + ) - # Add metrics collector to the Katib Experiment. - # Up to now, we only support parameter `kind`, of which default value - # is `StdOut`, to specify the kind of metrics collector. - experiment.spec.metrics_collector_spec = models.V1beta1MetricsCollectorSpec( - collector=models.V1beta1CollectorSpec(kind=metrics_collector_config["kind"]) - ) + # Create PVC for the Storage Initializer. + # TODO (helenxie-bit): PVC Creation should be part of Katib Controller. + try: + self.core_api.create_namespaced_persistent_volume_claim( + namespace=namespace, + body=training_utils.get_pvc_spec( + pvc_name=name, + namespace=namespace, + storage_config=storage_config, + ), + ) + except Exception as e: + pvc_list = self.core_api.list_namespaced_persistent_volume_claim( + namespace + ) + # Check if the PVC with the specified name exists. + for pvc in pvc_list.items: + if pvc.metadata.name == name: + print( + f"PVC '{name}' already exists in namespace " f"{namespace}." + ) + break + else: + raise RuntimeError(f"failed to create PVC. Error: {e}") + + if isinstance(model_provider_parameters, HuggingFaceModelParams): + mp = "hf" + else: + raise ValueError( + "Model provider parameters must be an instance of HuggingFaceModelParams." + ) + + if isinstance(dataset_provider_parameters, S3DatasetParams): + dp = "s3" + elif isinstance(dataset_provider_parameters, HuggingFaceDatasetParams): + dp = "hf" + else: + raise ValueError( + "Dataset provider parameters must be an instance of S3DatasetParams " + "or HuggingFaceDatasetParams." + ) - # Create Trial specification. - trial_spec = client.V1Job( - api_version="batch/v1", - kind="Job", - spec=client.V1JobSpec( - template=client.V1PodTemplateSpec( - metadata=models.V1ObjectMeta( - annotations={"sidecar.istio.io/inject": "false"} + # Iterate over input parameters and do substitutions. + experiment_params = [] + trial_params = [] + training_args = utils.get_trial_substitutions_from_trainer( + trainer_parameters.training_parameters, experiment_params, trial_params + ) + lora_config = utils.get_trial_substitutions_from_trainer( + trainer_parameters.lora_config, experiment_params, trial_params + ) + + # Create the init and the primary container. + init_container_spec = training_utils.get_container_spec( + name=STORAGE_INITIALIZER, + base_image=STORAGE_INITIALIZER_IMAGE, + args=[ + "--model_provider", + mp, + "--model_provider_parameters", + json.dumps( + model_provider_parameters.__dict__, cls=utils.SetEncoder ), - spec=client.V1PodSpec( - restart_policy="Never", - containers=[ - client.V1Container( - name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, - image=base_image, - command=["bash", "-c"], - args=[exec_script], - env=env if env else None, - env_from=env_from if env_from else None, - resources=resources_per_trial, - ) - ], + "--dataset_provider", + dp, + "--dataset_provider_parameters", + json.dumps(dataset_provider_parameters.__dict__), + ], + volume_mounts=[STORAGE_INITIALIZER_VOLUME_MOUNT], + ) + + container_spec = training_utils.get_container_spec( + name=JOB_PARAMETERS[PYTORCHJOB_KIND]["container"], + base_image=TRAINER_TRANSFORMER_IMAGE, + args=[ + "--model_uri", + model_provider_parameters.model_uri, + "--transformer_type", + model_provider_parameters.transformer_type.__name__, + "--model_dir", + VOLUME_PATH_MODEL, + "--dataset_dir", + VOLUME_PATH_DATASET, + "--lora_config", + f"'{lora_config}'", + "--training_parameters", + f"'{training_args}'", + ], + volume_mounts=[STORAGE_INITIALIZER_VOLUME_MOUNT], + resources=resources_per_trial.resources_per_worker, + ) + + # Create the worker and the master pod. + storage_initializer_volume = models.V1Volume( + name=STORAGE_INITIALIZER, + persistent_volume_claim=models.V1PersistentVolumeClaimVolumeSource( + claim_name=name + ), + ) + + worker_pod_template_spec = training_utils.get_pod_template_spec( + containers=[container_spec], + volumes=[storage_initializer_volume], + ) + + master_pod_template_spec = training_utils.get_pod_template_spec( + containers=[container_spec], + init_containers=[init_container_spec], + volumes=[storage_initializer_volume], + ) + + # Create PyTorchJob. + pytorchjob = training_models.KubeflowOrgV1PyTorchJob( + api_version="kubeflow.org/v1", + kind="PyTorchJob", + spec=training_models.KubeflowOrgV1PyTorchJobSpec( + run_policy=training_models.KubeflowOrgV1RunPolicy( + clean_pod_policy=None ), + pytorch_replica_specs={}, + ), + ) + + if resources_per_trial.num_procs_per_worker: + pytorchjob.spec.nproc_per_node = str( + resources_per_trial.num_procs_per_worker ) - ), - ) - # Create Trial template. - trial_template = models.V1beta1TrialTemplate( - primary_container_name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, - retain=retain_trials, - trial_parameters=trial_params, - trial_spec=trial_spec, - ) + pytorchjob.spec.pytorch_replica_specs["Master"] = ( + training_models.KubeflowOrgV1ReplicaSpec( + replicas=1, + template=master_pod_template_spec, + ) + ) + + if resources_per_trial.num_workers > 1: + pytorchjob.spec.pytorch_replica_specs["Worker"] = ( + training_models.KubeflowOrgV1ReplicaSpec( + replicas=resources_per_trial.num_workers - 1, + template=worker_pod_template_spec, + ) + ) + + # Create Trial template. + trial_template = models.V1beta1TrialTemplate( + primary_container_name=JOB_PARAMETERS[PYTORCHJOB_KIND]["container"], + retain=retain_trials, + trial_parameters=trial_params, + trial_spec=pytorchjob, + ) # Add parameters to the Katib Experiment. experiment.spec.parameters = experiment_params diff --git a/sdk/python/v1beta1/kubeflow/katib/constants/constants.py b/sdk/python/v1beta1/kubeflow/katib/constants/constants.py index 5bdb9911f76..1e0478f48f8 100644 --- a/sdk/python/v1beta1/kubeflow/katib/constants/constants.py +++ b/sdk/python/v1beta1/kubeflow/katib/constants/constants.py @@ -60,3 +60,8 @@ BASE_IMAGE_MXNET = "docker.io/mxnet/python:1.9.1_native_py3" DEFAULT_DB_MANAGER_ADDRESS = "katib-db-manager.kubeflow:6789" + +# The default value for dataset and model storage PVC. +PVC_DEFAULT_SIZE = "10Gi" +# The default value for PVC access modes. +PVC_DEFAULT_ACCESS_MODES = ["ReadWriteOnce", "ReadOnlyMany"] diff --git a/sdk/python/v1beta1/kubeflow/katib/types/trainer_resources.py b/sdk/python/v1beta1/kubeflow/katib/types/trainer_resources.py new file mode 100644 index 00000000000..87bbbbf67fc --- /dev/null +++ b/sdk/python/v1beta1/kubeflow/katib/types/trainer_resources.py @@ -0,0 +1,10 @@ +class TrainerResources(object): + def __init__( + self, + num_workers=None, + num_procs_per_worker=None, + resources_per_worker=None, + ): + self.num_workers = num_workers + self.num_procs_per_worker = num_procs_per_worker + self.resources_per_worker = resources_per_worker diff --git a/sdk/python/v1beta1/kubeflow/katib/utils/utils.py b/sdk/python/v1beta1/kubeflow/katib/utils/utils.py index c6e0734438f..28f3126bbfa 100644 --- a/sdk/python/v1beta1/kubeflow/katib/utils/utils.py +++ b/sdk/python/v1beta1/kubeflow/katib/utils/utils.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import json +import logging import os import textwrap -from typing import Any, Callable +from typing import Any, Callable, Dict, List, Optional, Union from kubeflow.katib import models from kubeflow.katib.constants import constants +logger = logging.getLogger(__name__) + def is_running_in_k8s(): return os.path.isdir("/var/run/secrets/kubernetes.io/") @@ -85,7 +89,6 @@ def validate_metrics_value(value: Any): def validate_objective_function(objective: Callable): - # Check if objective function is callable. if not callable(objective): raise ValueError( @@ -129,3 +132,138 @@ class FakeResponse: def __init__(self, obj): self.data = json.dumps(obj) + + +class SetEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, set): + return list(obj) + if isinstance(obj, type): + return obj.__name__ + return json.JSONEncoder.default(self, obj) + + +def get_trial_substitutions_from_dict( + parameters: Dict[str, Any], + experiment_params: List[models.V1beta1ParameterSpec], + trial_params: List[models.V1beta1TrialParameterSpec], +) -> Dict[str, str]: + for p_name, p_value in parameters.items(): + # If input parameter value is Katib Experiment parameter sample. + if isinstance(p_value, models.V1beta1ParameterSpec): + # Wrap value for the function input. + parameters[p_name] = f"${{trialParameters.{p_name}}}" + + # Add value to the Katib Experiment parameters. + p_value.name = p_name + experiment_params.append(p_value) + + # Add value to the Katib Experiment's Trial parameters. + trial_params.append( + models.V1beta1TrialParameterSpec(name=p_name, reference=p_name) + ) + else: + # Otherwise, add value to the function input. + parameters[p_name] = p_value + + return parameters + + +def get_trial_substitutions_from_trainer( + parameters: Union["TrainingArguments", "LoraConfig"], # noqa: F821 + experiment_params: List[models.V1beta1ParameterSpec], + trial_params: List[models.V1beta1TrialParameterSpec], +) -> Dict[str, str]: + from peft import LoraConfig # noqa: F401 + from transformers import TrainingArguments # noqa: F401 + + if isinstance(parameters, TrainingArguments): + parameters_dict = parameters.to_dict() + else: + parameters_dict = parameters.__dict__ + + for p_name, p_value in parameters_dict.items(): + if not hasattr(parameters, p_name): + logger.warning(f"Training parameter {p_name} is not supported.") + continue + + if isinstance(p_value, models.V1beta1ParameterSpec): + old_attr = getattr(parameters, p_name, None) + if old_attr is not None: + value = f"${{trialParameters.{p_name}}}" + setattr(parameters, p_name, value) + p_value.name = p_name + experiment_params.append(p_value) + trial_params.append( + models.V1beta1TrialParameterSpec(name=p_name, reference=p_name) + ) + elif p_value is not None: + old_attr = getattr(parameters, p_name, None) + if old_attr is not None: + if isinstance(p_value, dict): + # Update the existing dictionary without nesting + value = copy.deepcopy(p_value) + else: + value = type(old_attr)(p_value) + setattr(parameters, p_name, value) + + if isinstance(parameters, TrainingArguments): + parameters = json.dumps(parameters.to_dict()) + else: + parameters = json.dumps(parameters.__dict__, cls=SetEncoder) + + return parameters + + +def get_exec_script_from_objective( + objective: Callable, + input_params: Dict[str, Any] = None, + packages_to_install: Optional[List[str]] = None, + pip_index_url: str = "https://pypi.org/simple", +) -> str: + """ + Get executable script for container args from the given objective function and parameters. + """ + # Validate objective function. + validate_objective_function(objective) + + # Extract objective function implementation. + objective_code = inspect.getsource(objective) + + # Objective function might be defined in some indented scope + # (e.g. in another function). We need to dedent the function code. + objective_code = textwrap.dedent(objective_code) + + # Wrap objective function to execute it from the file. For example: + # def objective(parameters): + # print(f'Parameters are {parameters}') + # objective({ + # 'lr': '${trialParameters.lr}', + # 'epochs': '${trialParameters.epochs}', + # 'is_dist': False + # }) + objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n" + + # Prepare execute script template. + exec_script = textwrap.dedent( + """ + program_path=$(mktemp -d) + read -r -d '' SCRIPT << EOM\n + {objective_code} + EOM + printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py + python3 -u $program_path/ephemeral_objective.py""" + ) + + # Add objective code to the execute script. + exec_script = exec_script.format(objective_code=objective_code) + + # Install Python packages if that is required. + if packages_to_install is not None: + exec_script = ( + get_script_for_python_packages(packages_to_install, pip_index_url) + + exec_script + ) + + # Return executable script to execute objective function. + return exec_script diff --git a/sdk/python/v1beta1/setup.py b/sdk/python/v1beta1/setup.py index 49c689a235c..78ae02aa739 100644 --- a/sdk/python/v1beta1/setup.py +++ b/sdk/python/v1beta1/setup.py @@ -85,4 +85,7 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], install_requires=REQUIRES, + extras_require={ + "huggingface": ["kubeflow-training[huggingface]==1.8.0"], + }, )