diff --git a/vertex-registry-and-deployer/.assets/cloud_mcp.png b/vertex-registry-and-deployer/.assets/cloud_mcp.png new file mode 100644 index 00000000..81197e95 Binary files /dev/null and b/vertex-registry-and-deployer/.assets/cloud_mcp.png differ diff --git a/vertex-registry-and-deployer/.assets/cloud_mcp_predictions.png b/vertex-registry-and-deployer/.assets/cloud_mcp_predictions.png new file mode 100644 index 00000000..a6bf7c90 Binary files /dev/null and b/vertex-registry-and-deployer/.assets/cloud_mcp_predictions.png differ diff --git a/vertex-registry-and-deployer/.assets/cloud_mcp_screenshot.png b/vertex-registry-and-deployer/.assets/cloud_mcp_screenshot.png new file mode 100644 index 00000000..8f56defa Binary files /dev/null and b/vertex-registry-and-deployer/.assets/cloud_mcp_screenshot.png differ diff --git a/vertex-registry-and-deployer/.assets/feature_engineering_pipeline.png b/vertex-registry-and-deployer/.assets/feature_engineering_pipeline.png new file mode 100644 index 00000000..db301913 Binary files /dev/null and b/vertex-registry-and-deployer/.assets/feature_engineering_pipeline.png differ diff --git a/vertex-registry-and-deployer/.assets/inference_pipeline.png b/vertex-registry-and-deployer/.assets/inference_pipeline.png new file mode 100644 index 00000000..358d5537 Binary files /dev/null and b/vertex-registry-and-deployer/.assets/inference_pipeline.png differ diff --git a/vertex-registry-and-deployer/.assets/pipeline_overview.png b/vertex-registry-and-deployer/.assets/pipeline_overview.png new file mode 100644 index 00000000..609e97d2 Binary files /dev/null and b/vertex-registry-and-deployer/.assets/pipeline_overview.png differ diff --git a/vertex-registry-and-deployer/.assets/training_pipeline.png b/vertex-registry-and-deployer/.assets/training_pipeline.png new file mode 100644 index 00000000..a2e6a7d0 Binary files /dev/null and b/vertex-registry-and-deployer/.assets/training_pipeline.png differ diff --git a/vertex-registry-and-deployer/.copier-answers.yml b/vertex-registry-and-deployer/.copier-answers.yml new file mode 100644 index 00000000..8b1fb818 --- /dev/null +++ b/vertex-registry-and-deployer/.copier-answers.yml @@ -0,0 +1,8 @@ +# Changes here will be overwritten by Copier +_commit: 2024.09.24 +_src_path: gh:zenml-io/template-starter +email: info@zenml.io +full_name: ZenML GmbH +open_source_license: apache +project_name: ZenML Starter +version: 0.1.0 diff --git a/vertex-registry-and-deployer/.dockerignore b/vertex-registry-and-deployer/.dockerignore new file mode 100644 index 00000000..455f4d7a --- /dev/null +++ b/vertex-registry-and-deployer/.dockerignore @@ -0,0 +1,2 @@ +.venv* +.requirements* \ No newline at end of file diff --git a/vertex-registry-and-deployer/LICENSE b/vertex-registry-and-deployer/LICENSE new file mode 100644 index 00000000..75d01fb4 --- /dev/null +++ b/vertex-registry-and-deployer/LICENSE @@ -0,0 +1,15 @@ +Apache Software License 2.0 + +Copyright (c) ZenML GmbH 2024. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/vertex-registry-and-deployer/README.md b/vertex-registry-and-deployer/README.md new file mode 100644 index 00000000..b8612492 --- /dev/null +++ b/vertex-registry-and-deployer/README.md @@ -0,0 +1,87 @@ +# 🚀 Deploying ML Models with ZenML on Vertex AI + + +Welcome to your ZenML project for deploying ML models using Google Cloud's Vertex AI! This project provides a hands-on experience with MLOps pipelines using ZenML and Vertex AI. It contains a collection of ZenML steps, pipelines, and other artifacts to help you efficiently deploy your machine learning models. + +Using these pipelines, you can run data preparation, model training, registration, and deployment with a single command while using YAML files for [configuration](https://docs.zenml.io/user-guide/production-guide/configure-pipeline). ZenML takes care of tracking your metadata and [containerizing your pipelines](https://docs.zenml.io/how-to/customize-docker-builds). + + +## 🏃 How to run + +In this project, we will train and deploy a classification model to [Vertex AI](https://cloud.google.com/vertex-ai). Before running any pipelines, set up your environment as follows, we need to set up our environment as follows: + +```bash +# Set up a Python virtual environment, if you haven't already +python3 -m venv .venv +source .venv/bin/activate + +# Install requirements +pip install -r requirements.txt +``` + +We will need to set up access to Google Cloud and Vertex AI. You can follow the instructions in the [ZenML documentation](https://docs.zenml.io/how-to/auth-management/gcp-service-connector) +to register a service connector and set up your Google Cloud credentials. + +Once you have set up your Google Cloud credentials, we can create a stack and run the deployment pipeline: + +```bash +# Register the artifact store +zenml artifact-store register gs_store -f gcp --path=gs://bucket-name +zenml artifact-store connect gs_store --connector gcp + +# Register the model registry +zenml model-registry register vertex_registry --flavor=vertex --location=us-central1 +zenml model-registry connect vertex_registry --connector gcp + +# Register Model Deployer +zenml model-deployer register vertex_deployer --flavor=vertex --location=us-central1 +zenml model-deployer connect vertex_deployer --connector vertex_deployer_connector + +# Register the stack +zenml stack register vertex_stack --orchestrator default --artifact-store gs_store --model-registry vertex_registry --model-deployer vertex_deployer +``` + +Now that we have set up our stack, we can run the training pipeline, which will train and register the model into the Vertex AI model registry and Deploys it into Vertex AI endpoint. + +```bash +python run.py --training-pipeline +``` + +Once the pipeline has completed, you can check the status of the model in the Vertex AI model registry and the deployed model in the Vertex AI endpoint. + +```bash +# List models in the model registry +zenml model-registry models list + +# List deployed models +zenml model-deployer models list +``` + +You can also run the deployment pipeline separately: + +```bash +python run.py --inference-pipeline +``` + + +## 📜 Project Structure + +The project loosely follows [the recommended ZenML project structure](https://docs.zenml.io/how-to/setting-up-a-project-repository/best-practices): + +``` +. +├── configs # Pipeline configuration files +│ ├── training.yaml # Configuration for training pipeline +│ ├── inference.yaml # Configuration for inference pipeline +├── pipelines # `zenml.pipeline` implementations +│ ├── training.py # Training pipeline +│ ├── inference.py # Inference pipeline +├── steps # `zenml.step` implementations +│ ├── model_trainer.py # Model training step +│ ├── model_register.py # Model registration step +│ ├── model_promoter.py # Model promotion step +│ ├── model_deployer.py # Model deployment step to Vertex AI +├── README.md # This file +├── requirements.txt # Extra Python dependencies +└── run.py # CLI tool to run pipelines with ZenML # CLI tool to run pipelines on ZenML Stack +``` \ No newline at end of file diff --git a/vertex-registry-and-deployer/configs/inference.yaml b/vertex-registry-and-deployer/configs/inference.yaml new file mode 100644 index 00000000..8f73d762 --- /dev/null +++ b/vertex-registry-and-deployer/configs/inference.yaml @@ -0,0 +1,16 @@ +# environment configuration +settings: + docker: + required_integrations: + - sklearn + - pandas + requirements: + - pyarrow + +# configuration of the Model Control Plane +model: + name: "breast_cancer_classifier" + version: "production" + license: Apache 2.0 + description: A breast cancer classifier + tags: ["breast_cancer", "classifier"] \ No newline at end of file diff --git a/vertex-registry-and-deployer/configs/training_sgd.yaml b/vertex-registry-and-deployer/configs/training_sgd.yaml new file mode 100644 index 00000000..f90ca0e9 --- /dev/null +++ b/vertex-registry-and-deployer/configs/training_sgd.yaml @@ -0,0 +1,16 @@ +# environment configuration +settings: + docker: + required_integrations: + - sklearn + - pandas + requirements: + - pyarrow + +# configuration of the Model Control Plane +model: + name: breast_cancer_classifier + version: sgd + license: Apache 2.0 + description: A breast cancer classifier + tags: ["breast_cancer", "classifier"] \ No newline at end of file diff --git a/vertex-registry-and-deployer/pipelines/__init__.py b/vertex-registry-and-deployer/pipelines/__init__.py new file mode 100644 index 00000000..a8464d67 --- /dev/null +++ b/vertex-registry-and-deployer/pipelines/__init__.py @@ -0,0 +1,19 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .inference import inference +from .training import training diff --git a/vertex-registry-and-deployer/pipelines/inference.py b/vertex-registry-and-deployer/pipelines/inference.py new file mode 100644 index 00000000..b19e5b45 --- /dev/null +++ b/vertex-registry-and-deployer/pipelines/inference.py @@ -0,0 +1,45 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from zenml import get_pipeline_context, pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def inference(random_state: int, target: str): + """ + Model inference pipeline. + + This is a pipeline that loads the inference data, processes it with + the same preprocessing pipeline used in training, and runs inference + with the trained model. + + Args: + random_state: Random state for reproducibility. + target: Name of target column in dataset. + """ + # Get the production model artifact + model = get_pipeline_context().model.get_artifact("sklearn_classifier") + + # Get the preprocess pipeline artifact associated with this version + preprocess_pipeline = get_pipeline_context().model.get_artifact( + "preprocess_pipeline" + ) + + # Link all the steps together by calling them and passing the output + # of one step as the input of the next step. diff --git a/vertex-registry-and-deployer/pipelines/training.py b/vertex-registry-and-deployer/pipelines/training.py new file mode 100644 index 00000000..7e710545 --- /dev/null +++ b/vertex-registry-and-deployer/pipelines/training.py @@ -0,0 +1,55 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional +from uuid import UUID + +from steps import model_deployer, model_promoter, model_register, model_trainer +from zenml import pipeline +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def training( + target: Optional[str] = "target", +): + """Model training pipeline. + + This is a pipeline that loads the data from a preprocessing pipeline, + trains a model on it and evaluates the model. If it is the first model + to be trained, it will be promoted to production. If not, it will be + promoted only if it has a higher accuracy than the current production + model version. + + Args: + train_dataset_id: ID of the train dataset produced by feature engineering. + test_dataset_id: ID of the test dataset produced by feature engineering. + target: Name of target column in dataset. + model_type: The type of model to train. + """ + # Link all the steps together by calling them and passing the output + # of one step as the input of the next step. + + model, accuracy = model_trainer(target=target) + is_promoted = model_promoter(accuracy=accuracy) + if is_promoted: + model_registry_uri = model_register() + model_deployer(model_registry_uri=model_registry_uri) + diff --git a/vertex-registry-and-deployer/requirements.txt b/vertex-registry-and-deployer/requirements.txt new file mode 100644 index 00000000..c61cb0d4 --- /dev/null +++ b/vertex-registry-and-deployer/requirements.txt @@ -0,0 +1,5 @@ +zenml[server]>=0.70.1 +notebook +scikit-learn +pyarrow +pandas diff --git a/vertex-registry-and-deployer/run.py b/vertex-registry-and-deployer/run.py new file mode 100644 index 00000000..c0cf1174 --- /dev/null +++ b/vertex-registry-and-deployer/run.py @@ -0,0 +1,152 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Optional + +import click +import yaml +from pipelines import ( + inference, + training, +) +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@click.command( + help=""" +ZenML Starter project. + +Run the ZenML starter project with basic options. + +Examples: + + \b + # Run the feature engineering pipeline + python run.py --feature-pipeline + + \b + # Run the training pipeline + python run.py --training-pipeline + + \b + # Run the training pipeline with versioned artifacts + python run.py --training-pipeline --train-dataset-version-name=1 --test-dataset-version-name=1 + + \b + # Run the inference pipeline + python run.py --inference-pipeline + +""" +) +@click.option( + "--training-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that trains the model.", +) +@click.option( + "--inference-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that performs inference.", +) +@click.option( + "--no-cache", + is_flag=True, + default=False, + help="Disable caching for the pipeline run.", +) +def main( + training_pipeline: bool = False, + inference_pipeline: bool = False, + no_cache: bool = False, +): + """Main entry point for the pipeline execution. + + This entrypoint is where everything comes together: + + * configuring pipeline with the required parameters + (some of which may come from command line arguments, but most + of which comes from the YAML config files) + * launching the pipeline + + Args: + training_pipeline: Whether to run the pipeline that trains the model. + inference_pipeline: Whether to run the pipeline that performs inference. + no_cache: If `True` cache will be disabled. + """ + client = Client() + + config_folder = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + ) + + # Execute Training Pipeline + if training_pipeline: + run_args_train = {} + + # Run the SGD pipeline + pipeline_args = {} + if no_cache: + pipeline_args["enable_cache"] = False + pipeline_args["config_path"] = os.path.join( + config_folder, "training_sgd.yaml" + ) + training.with_options(**pipeline_args)(**run_args_train) + logger.info("Training pipeline with SGD finished successfully!\n\n") + + if inference_pipeline: + run_args_inference = {} + pipeline_args = {"enable_cache": False} + pipeline_args["config_path"] = os.path.join( + config_folder, "inference.yaml" + ) + + # Configure the pipeline + inference_configured = inference.with_options(**pipeline_args) + + # Fetch the production model + with open(pipeline_args["config_path"], "r") as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + zenml_model = client.get_model_version( + config["model"]["name"], config["model"]["version"] + ) + preprocess_pipeline_artifact = zenml_model.get_artifact( + "preprocess_pipeline" + ) + + # Use the metadata of feature engineering pipeline artifact + # to get the random state and target column + random_state = preprocess_pipeline_artifact.run_metadata[ + "random_state" + ].value + target = preprocess_pipeline_artifact.run_metadata["target"].value + run_args_inference["random_state"] = random_state + run_args_inference["target"] = target + + # Run the pipeline + inference_configured(**run_args_inference) + logger.info("Inference pipeline finished successfully!") + + +if __name__ == "__main__": + main() diff --git a/vertex-registry-and-deployer/steps/__init__.py b/vertex-registry-and-deployer/steps/__init__.py new file mode 100644 index 00000000..b88f9536 --- /dev/null +++ b/vertex-registry-and-deployer/steps/__init__.py @@ -0,0 +1,29 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .model_register import ( + model_register, +) +from .model_promoter import ( + model_promoter, +) +from .model_trainer import ( + model_trainer, +) +from .model_deployer import ( + model_deployer, +) diff --git a/vertex-registry-and-deployer/steps/model_deployer.py b/vertex-registry-and-deployer/steps/model_deployer.py new file mode 100644 index 00000000..b06abbbb --- /dev/null +++ b/vertex-registry-and-deployer/steps/model_deployer.py @@ -0,0 +1,65 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2023. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +from typing_extensions import Annotated +from zenml import ArtifactConfig, get_step_context, step +from zenml.client import Client +from zenml.integrations.gcp.services.vertex_deployment import ( + VertexAIDeploymentConfig, + VertexDeploymentService, +) +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step(enable_cache=False) +def model_deployer( + model_registry_uri: str, +) -> Annotated[ + VertexDeploymentService, ArtifactConfig(name="vertex_deployment", is_deployment_artifact=True) +]: + """Model deployer step. + + Args: + model_registry_uri: URI of the model in the model registry. + + Returns: + The deployed model service. + """ + zenml_client = Client() + current_model = get_step_context().model + model_deployer = zenml_client.active_stack.model_deployer + breakpoint() + vertex_deployment_config = VertexAIDeploymentConfig( + location="europe-west1", + name="zenml-vertex-quickstart", + model_name=current_model.name, + description="An example of deploying a model using the MLflow Model Deployer", + model_id=model_registry_uri, + ) + service = model_deployer.deploy_model( + config=vertex_deployment_config, + service_type=VertexDeploymentService.SERVICE_TYPE, + ) + + logger.info( + f"The deployed service info: {model_deployer.get_model_server_info(service)}" + ) + return service diff --git a/vertex-registry-and-deployer/steps/model_promoter.py b/vertex-registry-and-deployer/steps/model_promoter.py new file mode 100644 index 00000000..0c1851ad --- /dev/null +++ b/vertex-registry-and-deployer/steps/model_promoter.py @@ -0,0 +1,72 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from zenml import get_step_context, step +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def model_promoter(accuracy: float, stage: str = "production") -> bool: + """Model promoter step. + + This is an example of a step that conditionally promotes a model. It takes + in the accuracy of the model and the stage to promote the model to. If the + accuracy is below 80%, the model is not promoted. If it is above 80%, the + model is promoted to the stage indicated in the parameters. If there is + already a model in the indicated stage, the model with the higher accuracy + is promoted. + + Args: + accuracy: Accuracy of the model. + stage: Which stage to promote the model to. + + Returns: + Whether the model was promoted or not. + """ + is_promoted = False + + if accuracy < 0.8: + logger.info( + f"Model accuracy {accuracy*100:.2f}% is below 80% ! Not promoting model." + ) + else: + logger.info(f"Model promoted to {stage}!") + is_promoted = True + + # Get the model in the current context + current_model = get_step_context().model + + # Get the model that is in the production stage + client = Client() + try: + stage_model = client.get_model_version(current_model.name, stage) + # We compare their metrics + prod_accuracy = ( + stage_model.get_artifact("sklearn_classifier").run_metadata["test_accuracy"].value + ) + if accuracy > float(prod_accuracy): + # If current model has better metrics, we promote it + is_promoted = True + current_model.set_stage(stage, force=True) + except KeyError: + # If no such model exists, current one is promoted + is_promoted = True + current_model.set_stage(stage, force=True) + return is_promoted diff --git a/vertex-registry-and-deployer/steps/model_register.py b/vertex-registry-and-deployer/steps/model_register.py new file mode 100644 index 00000000..fecd5233 --- /dev/null +++ b/vertex-registry-and-deployer/steps/model_register.py @@ -0,0 +1,28 @@ +# model_register.py + +from typing_extensions import Annotated +from zenml import ArtifactConfig, get_step_context, step +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + +@step(enable_cache=False) +def model_register() -> Annotated[str, ArtifactConfig(name="model_registry_uri")]: + """Model registration step.""" + # Get the current model from the context + current_model = get_step_context().model + + client = Client() + model_registry = client.active_stack.model_registry + model_version = model_registry.register_model_version( + name=current_model.name, + version=str(current_model.version), + model_source_uri=current_model.get_model_artifact("sklearn_classifier").uri, + description="ZenML model registered after promotion", + ) + logger.info( + f"Model version {model_version.version} registered in Model Registry" + ) + + return model_version.model_source_uri \ No newline at end of file diff --git a/vertex-registry-and-deployer/steps/model_trainer.py b/vertex-registry-and-deployer/steps/model_trainer.py new file mode 100644 index 00000000..1b154d64 --- /dev/null +++ b/vertex-registry-and-deployer/steps/model_trainer.py @@ -0,0 +1,119 @@ + +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Tuple, Union + +import pandas as pd +from sklearn.base import ClassifierMixin +from sklearn.datasets import load_breast_cancer +from sklearn.linear_model import SGDClassifier +from sklearn.model_selection import train_test_split +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import MinMaxScaler +from typing_extensions import Annotated +from zenml import ArtifactConfig, log_artifact_metadata, step +from zenml.logger import get_logger + +logger = get_logger(__name__) + +@step +def model_trainer( + random_state: int = 42, + test_size: float = 0.2, + drop_na: bool = True, + normalize: bool = True, + target: str = "target", + min_train_accuracy: float = 0.3, + min_test_accuracy: float = 0.3, +) -> Tuple[ + Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier", is_model_artifact=True)], + Annotated[float, ArtifactConfig(name="accuracy")], +]: + # Load the dataset + dataset = load_breast_cancer(as_frame=True).frame + dataset.reset_index(drop=True, inplace=True) + logger.info(f"Dataset with {len(dataset)} records loaded!") + + # Split the dataset + dataset_trn, dataset_tst = train_test_split( + dataset, + test_size=test_size, + random_state=random_state, + shuffle=True, + ) + + # Separate features and target + X_trn = dataset_trn.drop(columns=[target]) + y_trn = dataset_trn[target] + X_tst = dataset_tst.drop(columns=[target]) + y_tst = dataset_tst[target] + + # Preprocess the data + preprocess_steps = [] + if drop_na: + preprocess_steps.append(("drop_na", NADropper())) + if normalize: + preprocess_steps.append(("normalize", MinMaxScaler())) + preprocess_pipeline = Pipeline(preprocess_steps) + + X_trn = preprocess_pipeline.fit_transform(X_trn) + X_tst = preprocess_pipeline.transform(X_tst) + + # Train the model + model = SGDClassifier() + logger.info(f"Training model {model}...") + + model.fit(X_trn, y_trn) + + # Evaluate the model + trn_acc = model.score(X_trn, y_trn) + tst_acc = model.score(X_tst, y_tst) + logger.info(f"Train accuracy={trn_acc*100:.2f}%") + logger.info(f"Test accuracy={tst_acc*100:.2f}%") + + messages = [] + if trn_acc < min_train_accuracy: + messages.append( + f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}%!" + ) + if tst_acc < min_test_accuracy: + messages.append( + f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}%!" + ) + else: + for message in messages: + logger.warning(message) + + log_artifact_metadata( + metadata={ + "train_accuracy": float(trn_acc), + "test_accuracy": float(tst_acc), + }, + artifact_name="sklearn_classifier", + ) + + return model, tst_acc + +class NADropper: + """Support class to drop NA values in sklearn Pipeline.""" + + def fit(self, *args, **kwargs): # noqa: D102 + return self + + def transform(self, X: Union[pd.DataFrame, pd.Series]): # noqa: D102 + return X.dropna() \ No newline at end of file