Skip to content

Commit

Permalink
add spark cluster test
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-wang-dl committed May 30, 2024
1 parent 0a2667a commit 7c10678
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
4 changes: 4 additions & 0 deletions test/discover_2_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

# This script is used in spark GPU cluster config for discovering available GPU.
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\"]}"
45 changes: 29 additions & 16 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import warnings
import os
from packaging.version import Version
from unittest import mock
import unittest
from unittest.mock import MagicMock

import pyspark
from pyspark.sql import SparkSession

from joblibspark.backend import SparkDistributedBackend

class TestLocalSparkCluster:

class TestLocalSparkCluster(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.spark = (
Expand Down Expand Up @@ -39,41 +41,52 @@ def test_resource_profile_supported(self):
assert not backend._support_stage_scheduling


class TestBasicSparkCluster:
spark = None
class TestBasicSparkCluster(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.num_cpus_per_spark_task = 1
cls.num_gpus_per_spark_task = 1
gpu_discovery_script_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh"
)

cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.maxFailures", "1")
.getOrCreate()
.config("spark.task.cpus", "1")
.config("spark.task.resource.gpu.amount", "1")
.config("spark.executor.cores", "2")
.config("spark.worker.resource.gpu.amount", "2")
.config("spark.executor.resource.gpu.amount", "2")
.config("spark.task.maxFailures", "1")
.config(
"spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path
)
.getOrCreate()
)

@classmethod
def teardown_class(cls):
cls.spark.stop()

@unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0),
"Resource group is only supported since spark 3.4.0")
def test_resource_profile(self):
backend = SparkDistributedBackend(
num_cpus_per_spark_task=self.num_cpus_per_spark_task,
num_gpus_per_spark_task=self.num_gpus_per_spark_task)

if Version(pyspark.__version__).release >= (3, 4, 0):
assert backend._support_stage_scheduling
assert backend._support_stage_scheduling

resource_group = backend._resource_profile
assert resource_group.taskResources['cpus'].amount == 1.0
assert resource_group.taskResources['gpu'].amount == 1.0
resource_group = backend._resource_profile
assert resource_group.taskResources['cpus'].amount == 1.0
assert resource_group.taskResources['gpu'].amount == 1.0

@unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0),
"Resource group is only supported since spark 3.4.0")
def test_resource_with_default(self):
backend = SparkDistributedBackend()
if Version(pyspark.__version__).release >= (3, 4, 0):
assert backend._support_stage_scheduling

resource_group = backend._resource_profile
assert resource_group.taskResources['cpus'].amount == 1.0
assert backend._support_stage_scheduling

resource_group = backend._resource_profile
assert resource_group.taskResources['cpus'].amount == 1.0

0 comments on commit 7c10678

Please sign in to comment.