diff --git a/core/src/main/scala/org/apache/spark/eventhubs/EventHubsUtils.scala b/core/src/main/scala/org/apache/spark/eventhubs/EventHubsUtils.scala index 826ddca86..d344fb05d 100644 --- a/core/src/main/scala/org/apache/spark/eventhubs/EventHubsUtils.scala +++ b/core/src/main/scala/org/apache/spark/eventhubs/EventHubsUtils.scala @@ -40,9 +40,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{ JavaInputDStream, JavaStreamingContext } import org.apache.spark.streaming.eventhubs.EventHubsDirectDStream -import org.apache.spark.{ SparkContext, TaskContext } +import org.apache.spark.{ SparkContext, SparkEnv, TaskContext } import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.SparkEnv +import org.apache.spark.util.RpcUtils /** * Helper to create Direct DStreams which consume events from Event Hubs. @@ -53,12 +53,21 @@ object EventHubsUtils extends Logging { private def createRpcEndpoint() = { if (partitionPerformanceReceiverRef == null) { - // RPC endpoint for partition performance communication in the driver - val partitionsStatusTracker = PartitionsStatusTracker.getPartitionStatusTracker - val partitionPerformanceReceiver: PartitionPerformanceReceiver = - new PartitionPerformanceReceiver(SparkEnv.get.rpcEnv, partitionsStatusTracker) - partitionPerformanceReceiverRef = SparkEnv.get.rpcEnv - .setupEndpoint(PartitionPerformanceReceiver.ENDPOINT_NAME, partitionPerformanceReceiver) + try { + partitionPerformanceReceiverRef = RpcUtils.makeDriverRef( + PartitionPerformanceReceiver.ENDPOINT_NAME, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + logInfo( + s"There is an existing partitionPerformanceReceiverRef on the driver, use that one rather than creating a new one") + } catch { + case e: Exception => + val partitionsStatusTracker = PartitionsStatusTracker.getPartitionStatusTracker + val partitionPerformanceReceiver: PartitionPerformanceReceiver = + new PartitionPerformanceReceiver(SparkEnv.get.rpcEnv, partitionsStatusTracker) + partitionPerformanceReceiverRef = SparkEnv.get.rpcEnv + .setupEndpoint(PartitionPerformanceReceiver.ENDPOINT_NAME, partitionPerformanceReceiver) + } } }