diff --git a/joblibspark/utils.py b/joblibspark/utils.py index 50fa7ed..8c32775 100644 --- a/joblibspark/utils.py +++ b/joblibspark/utils.py @@ -46,9 +46,12 @@ def create_resource_profile(num_cpus_worker_node, num_gpus_worker_node): """ resource_profile = None if Version(pyspark.__version__).release > (3, 1, 0): - from pyspark.resource.profile import ResourceProfileBuilder - from pyspark.resource.requests import TaskResourceRequests - + try: + from pyspark.resource.profile import ResourceProfileBuilder + from pyspark.resource.requests import TaskResourceRequests + except ImportError: + pass + task_res_req = TaskResourceRequests().cpus(num_cpus_worker_node) if num_gpus_worker_node > 0: task_res_req = task_res_req.resource("gpu", num_gpus_worker_node)