Skip to content

Commit

Permalink
Azure sdk support for network acls (Azure#38511)
Browse files Browse the repository at this point in the history
* refactor code

* refactor code

* Adding IP based access control to SDK

* Added test and change logs

* Examples to choose one of three Public network access settings

* resolved circular dependency

* resolved circular dependency

* Added Ip based access control support to hub workspace

* updated changelog file

* removed ipallowlist dependencey

* resolved ManagedServiceIdentity version icompatibility issue

* fixed breaking test

* reformatted code

* reformatted code

* removed doc example

* code refactor

* Fixed Generate API Stubs issue

* Fixed Generate API Stubs issue

* refactor code

* add doc string

* add doc string

* add doc string

* add doc string

* add doc string

* add doc string
  • Loading branch information
mohitsinghnegi2 authored Nov 27, 2024
1 parent 2faf4cd commit 75b4da9
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 91 deletions.
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
## 1.23.0 (unreleased)

### Features Added
- Added support for IP-based access control to default and hub workspaces.
- Add support for additional include in spark component.

### Bugs Fixed
Expand Down
63 changes: 63 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/network_acls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import ValidationError, fields, post_load, validates_schema

from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml.entities._workspace.network_acls import DefaultActionType, IPRule, NetworkAcls


class IPRuleSchema(PathAwareSchema):
"""Schema for IPRule."""

value = fields.Str(required=True)

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
"""Create an IPRule object from the marshmallow schema.
:param data: The data from which the IPRule is being loaded.
:type data: OrderedDict[str, Any]
:returns: An IPRule object.
:rtype: azure.ai.ml.entities._workspace.network_acls.NetworkAcls.IPRule
"""
return IPRule(**data)


class NetworkAclsSchema(PathAwareSchema):
"""Schema for NetworkAcls.
:param default_action: Specifies the default action when no IP rules are matched.
:type default_action: str
:param ip_rules: Rules governing the accessibility of a resource from a specific IP address or IP range.
:type ip_rules: Optional[List[IPRule]]
"""

default_action = fields.Str(required=True)
ip_rules = fields.List(fields.Nested(IPRuleSchema), allow_none=True)

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
"""Create a NetworkAcls object from the marshmallow schema.
:param data: The data from which the NetworkAcls is being loaded.
:type data: OrderedDict[str, Any]
:returns: A NetworkAcls object.
:rtype: azure.ai.ml.entities._workspace.network_acls.NetworkAcls
"""
return NetworkAcls(**data)

@validates_schema
def validate_schema(self, data, **kwargs): # pylint: disable=unused-argument
"""Validate the NetworkAcls schema.
:param data: The data to validate.
:type data: OrderedDict[str, Any]
:raises ValidationError: If the schema is invalid.
"""
if data["default_action"] not in set([DefaultActionType.DENY, DefaultActionType.ALLOW]):
raise ValidationError("Invalid value for default_action. Must be 'Deny' or 'Allow'.")

if data["default_action"] == DefaultActionType.DENY and not data.get("ip_rules"):
raise ValidationError("ip_rules must be provided when default_action is 'Deny'.")
2 changes: 2 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml._schema.workspace.customer_managed_key import CustomerManagedKeySchema
from azure.ai.ml._schema.workspace.identity import IdentitySchema
from azure.ai.ml._schema.workspace.network_acls import NetworkAclsSchema
from azure.ai.ml._schema.workspace.networking import ManagedNetworkSchema
from azure.ai.ml._schema.workspace.serverless_compute import ServerlessComputeSettingsSchema
from azure.ai.ml._utils.utils import snake_to_pascal
Expand Down Expand Up @@ -36,6 +37,7 @@ class WorkspaceSchema(PathAwareSchema):
allowed_values=[PublicNetworkAccess.DISABLED, PublicNetworkAccess.ENABLED],
casing_transform=snake_to_pascal,
)
network_acls = NestedField(NetworkAclsSchema)
system_datastores_auth_mode = fields.Str()
identity = NestedField(IdentitySchema)
primary_user_assigned_identity = fields.Str()
Expand Down
119 changes: 31 additions & 88 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@
from ._assets.asset import Asset
from ._assets.environment import BuildContext, Environment
from ._assets.intellectual_property import IntellectualProperty
from ._assets.workspace_asset_reference import (
WorkspaceAssetReference as WorkspaceModelReference,
)
from ._assets.workspace_asset_reference import WorkspaceAssetReference as WorkspaceModelReference
from ._autogen_entities.models import (
AzureOpenAIDeployment,
MarketplacePlan,
MarketplaceSubscription,
ServerlessEndpoint,
MarketplacePlan,
)
from ._builders import Command, Parallel, Pipeline, Spark, Sweep
from ._component.command_component import CommandComponent
Expand All @@ -54,41 +52,21 @@
from ._component.pipeline_component import PipelineComponent
from ._component.spark_component import SparkComponent
from ._compute._aml_compute_node_info import AmlComputeNodeInfo
from ._compute._custom_applications import (
CustomApplications,
EndpointsSettings,
ImageSettings,
VolumeSettings,
)
from ._compute._custom_applications import CustomApplications, EndpointsSettings, ImageSettings, VolumeSettings
from ._compute._image_metadata import ImageMetadata
from ._compute._schedule import (
ComputePowerAction,
ComputeSchedules,
ComputeStartStopSchedule,
ScheduleState,
)
from ._compute._schedule import ComputePowerAction, ComputeSchedules, ComputeStartStopSchedule, ScheduleState
from ._compute._setup_scripts import ScriptReference, SetupScripts
from ._compute._usage import Usage, UsageName
from ._compute._vm_size import VmSize
from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings
from ._compute.compute import Compute, NetworkSettings
from ._compute.compute_instance import (
AssignedUserConfiguration,
ComputeInstance,
ComputeInstanceSshSettings,
)
from ._compute.compute_instance import AssignedUserConfiguration, ComputeInstance, ComputeInstanceSshSettings
from ._compute.kubernetes_compute import KubernetesCompute
from ._compute.synapsespark_compute import (
AutoPauseSettings,
AutoScaleSettings,
SynapseSparkCompute,
)
from ._compute.synapsespark_compute import AutoPauseSettings, AutoScaleSettings, SynapseSparkCompute
from ._compute.unsupported_compute import UnsupportedCompute
from ._compute.virtual_machine_compute import (
VirtualMachineCompute,
VirtualMachineSshSettings,
)
from ._compute.virtual_machine_compute import VirtualMachineCompute, VirtualMachineSshSettings
from ._credentials import (
AadCredentialConfiguration,
AccessKeyConfiguration,
AccountKeyConfiguration,
AmlTokenConfiguration,
Expand All @@ -97,7 +75,6 @@
IdentityConfiguration,
ManagedIdentityConfiguration,
NoneCredentialConfiguration,
AadCredentialConfiguration,
PatTokenConfiguration,
SasTokenConfiguration,
ServicePrincipalConfiguration,
Expand All @@ -107,11 +84,7 @@
from ._data_import.data_import import DataImport
from ._data_import.schedule import ImportDataSchedule
from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore
from ._datastore.azure_storage import (
AzureBlobDatastore,
AzureDataLakeGen2Datastore,
AzureFileDatastore,
)
from ._datastore.azure_storage import AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore
from ._datastore.datastore import Datastore
from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore
from ._deployment.batch_deployment import BatchDeployment
Expand All @@ -121,11 +94,7 @@
from ._deployment.data_asset import DataAsset
from ._deployment.data_collector import DataCollector
from ._deployment.deployment_collection import DeploymentCollection
from ._deployment.deployment_settings import (
BatchRetrySettings,
OnlineRequestSettings,
ProbeSettings,
)
from ._deployment.deployment_settings import BatchRetrySettings, OnlineRequestSettings, ProbeSettings
from ._deployment.model_batch_deployment import ModelBatchDeployment
from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings
from ._deployment.online_deployment import (
Expand All @@ -134,22 +103,16 @@
ManagedOnlineDeployment,
OnlineDeployment,
)
from ._deployment.pipeline_component_batch_deployment import (
PipelineComponentBatchDeployment,
)
from ._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
from ._deployment.request_logging import RequestLogging
from ._deployment.resource_requirements_settings import ResourceRequirementsSettings
from ._deployment.scale_settings import (
DefaultScaleSettings,
OnlineScaleSettings,
TargetUtilizationScaleSettings,
)
from ._deployment.scale_settings import DefaultScaleSettings, OnlineScaleSettings, TargetUtilizationScaleSettings
from ._endpoint.batch_endpoint import BatchEndpoint
from ._endpoint.endpoint import Endpoint
from ._endpoint.online_endpoint import (
EndpointAadToken,
EndpointAuthKeys,
EndpointAuthToken,
EndpointAadToken,
KubernetesOnlineEndpoint,
ManagedOnlineEndpoint,
OnlineEndpoint,
Expand All @@ -158,41 +121,26 @@
from ._feature_set.feature import Feature
from ._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata
from ._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest
from ._feature_set.feature_set_materialization_metadata import (
FeatureSetMaterializationMetadata,
)
from ._feature_set.feature_set_materialization_metadata import FeatureSetMaterializationMetadata
from ._feature_set.feature_set_specification import FeatureSetSpecification
from ._feature_set.feature_window import FeatureWindow
from ._feature_set.materialization_compute_resource import (
MaterializationComputeResource,
)
from ._feature_set.materialization_compute_resource import MaterializationComputeResource
from ._feature_set.materialization_settings import MaterializationSettings
from ._feature_set.materialization_type import MaterializationType
from ._feature_store.feature_store import FeatureStore
from ._feature_store.materialization_store import MaterializationStore
from ._feature_store_entity.data_column import DataColumn
from ._feature_store_entity.data_column_type import DataColumnType
from ._feature_store_entity.feature_store_entity import FeatureStoreEntity
from ._indexes import (
AzureAISearchConfig,
IndexDataSource,
GitSource,
LocalSource,
)
from ._indexes import AzureAISearchConfig, GitSource, IndexDataSource, LocalSource
from ._indexes import ModelConfiguration as IndexModelConfiguration
from ._job.command_job import CommandJob
from ._job.compute_configuration import ComputeConfiguration
from ._job.input_port import InputPort
from ._job.job import Job
from ._job.job_limits import CommandJobLimits
from ._job.job_resource_configuration import JobResourceConfiguration
from ._job.job_service import (
JobService,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
VsCodeJobService,
)
from ._job.job_service import JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService
from ._job.parallel.parallel_task import ParallelTask
from ._job.parallel.retry_settings import RetrySettings
from ._job.parameterized_command import ParameterizedCommand
Expand All @@ -208,12 +156,7 @@
from ._monitoring.alert_notification import AlertNotification
from ._monitoring.compute import ServerlessSparkCompute
from ._monitoring.definition import MonitorDefinition
from ._monitoring.input_data import (
FixedInputData,
MonitorInputData,
StaticInputData,
TrailingInputData,
)
from ._monitoring.input_data import FixedInputData, MonitorInputData, StaticInputData, TrailingInputData
from ._monitoring.schedule import MonitorSchedule
from ._monitoring.signals import (
BaselineDataRange,
Expand Down Expand Up @@ -261,22 +204,24 @@
from ._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger
from ._system_data import SystemData
from ._validation import ValidationResult
from ._workspace._ai_workspaces.hub import Hub
from ._workspace._ai_workspaces.project import Project
from ._workspace.compute_runtime import ComputeRuntime
from ._workspace.connections.workspace_connection import WorkspaceConnection
from ._workspace.connections.connection_subtypes import (
AzureBlobStoreConnection,
MicrosoftOneLakeConnection,
AzureOpenAIConnection,
AzureAIServicesConnection,
APIKeyConnection,
AzureAISearchConnection,
AzureAIServicesConnection,
AzureBlobStoreConnection,
AzureContentSafetyConnection,
AzureOpenAIConnection,
AzureSpeechServicesConnection,
APIKeyConnection,
MicrosoftOneLakeConnection,
OpenAIConnection,
SerpConnection,
ServerlessConnection,
)
from ._workspace.connections.one_lake_artifacts import OneLakeConnectionArtifact
from ._workspace.connections.workspace_connection import WorkspaceConnection
from ._workspace.customer_managed_key import CustomerManagedKey
from ._workspace.diagnose import (
DiagnoseRequestProperties,
Expand All @@ -286,6 +231,7 @@
DiagnoseWorkspaceParameters,
)
from ._workspace.feature_store_settings import FeatureStoreSettings
from ._workspace.network_acls import DefaultActionType, IPRule, NetworkAcls
from ._workspace.networking import (
FqdnDestination,
IsolationMode,
Expand All @@ -298,13 +244,7 @@
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
from ._workspace.serverless_compute import ServerlessComputeSettings
from ._workspace.workspace import Workspace
from ._workspace._ai_workspaces.hub import Hub
from ._workspace._ai_workspaces.project import Project
from ._workspace.workspace_keys import (
ContainerRegistryCredential,
NotebookAccessKeys,
WorkspaceKeys,
)
from ._workspace.workspace_keys import ContainerRegistryCredential, NotebookAccessKeys, WorkspaceKeys

__all__ = [
"Resource",
Expand Down Expand Up @@ -357,6 +297,9 @@
"Model",
"ModelBatchDeployment",
"ModelBatchDeploymentSettings",
"IPRule",
"DefaultActionType",
"NetworkAcls",
"Workspace",
"WorkspaceKeys",
"WorkspaceConnection",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from azure.ai.ml._schema.workspace import HubSchema
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants._common import WorkspaceKind
from azure.ai.ml.entities import CustomerManagedKey, Workspace
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.ai.ml.entities._workspace.customer_managed_key import CustomerManagedKey
from azure.ai.ml.entities._workspace.network_acls import NetworkAcls
from azure.ai.ml.entities._workspace.networking import ManagedNetwork
from azure.ai.ml.entities._workspace.workspace import Workspace


@experimental
Expand Down Expand Up @@ -54,6 +56,8 @@ class Hub(Workspace):
:param public_network_access: Whether to allow public endpoint connectivity.
when a workspace is private link enabled.
:type public_network_access: str
:param network_acls: The network access control list (ACL) settings of the workspace.
:type network_acls: ~azure.ai.ml.entities.NetworkAcls
:param identity: The hub's Managed Identity (user assigned, or system assigned).
:type identity: ~azure.ai.ml.entities.IdentityConfiguration
:param primary_user_assigned_identity: The hub's primary user assigned identity.
Expand Down Expand Up @@ -92,6 +96,7 @@ def __init__(
container_registry: Optional[str] = None,
customer_managed_key: Optional[CustomerManagedKey] = None,
public_network_access: Optional[str] = None,
network_acls: Optional[NetworkAcls] = None,
identity: Optional[IdentityConfiguration] = None,
primary_user_assigned_identity: Optional[str] = None,
enable_data_isolation: bool = False,
Expand All @@ -115,6 +120,7 @@ def __init__(
resource_group=resource_group,
customer_managed_key=customer_managed_key,
public_network_access=public_network_access,
network_acls=network_acls,
identity=identity,
primary_user_assigned_identity=primary_user_assigned_identity,
managed_network=managed_network,
Expand Down Expand Up @@ -152,6 +158,7 @@ def _from_rest_object(cls, rest_obj: RestWorkspace, v2_service_context: Optional
managed_network=workspace_object.managed_network,
customer_managed_key=workspace_object.customer_managed_key,
public_network_access=workspace_object.public_network_access,
network_acls=workspace_object.network_acls,
identity=workspace_object.identity,
primary_user_assigned_identity=workspace_object.primary_user_assigned_identity,
storage_account=rest_obj.storage_account,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from typing import Any, Dict, Optional

from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._schema.workspace import ProjectSchema
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants._common import WorkspaceKind
from azure.ai.ml.entities import Workspace
from azure.ai.ml.entities._workspace.workspace import Workspace


# Effectively a lightweight wrapper around a v2 SDK workspace
Expand Down
Loading

0 comments on commit 75b4da9

Please sign in to comment.