Skip to content

Commit

Permalink
[KYUUBI #6223] Fix Scala interpreter can not access spark.jars issue
Browse files Browse the repository at this point in the history
# 🔍 Description
## Issue References 🔗

This pull request fixes #6223

Even the user specify `spark.jars`, but they can not access the classes in jars with Scala code.

## Describe Your Solution 🔧

Add the jars into repl classpath.

## Types of changes 🔖

- [x] Bugfix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

## Test Plan 🧪

#### Behavior Without This Pull Request ⚰️

#### Behavior With This Pull Request 🎉

#### Related Unit Tests
UT.

---

# Checklist 📝

- [ ] This patch was not authored or co-authored using [Generative Tooling](https://www.apache.org/legal/generative-tooling.html)

**Be nice. Be informative.**

Closes #6235 from turboFei/scala_repl_urls.

Closes #6223

3445026 [Wang, Fei] scala 2.13
cc6e289 [Wang, Fei] todo
a8b3731 [Wang, Fei] refine
65b438c [Wang, Fei] remove scala reflect check
eb257c7 [Wang, Fei] using -classpath
e1c6f0e [Wang, Fei] revert 2.13
15d3766 [Wang, Fei] repl
41ebe10 [Wang, Fei] fix ut
ed5d344 [Wang, Fei] info
1cdd82a [Wang, Fei] comment
aa4292d [Wang, Fei] fix
a10cfa5 [Wang, Fei] ut
63fdb88 [Wang, Fei] Use global.classPath.asURLs instead of class loader urls

Authored-by: Wang, Fei <[email protected]>
Signed-off-by: Cheng Pan <[email protected]>
  • Loading branch information
turboFei authored and pan3793 committed Apr 3, 2024
1 parent 9649884 commit 9b618c9
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.kyuubi.engine.spark.repl

import java.io.{ByteArrayOutputStream, File, PrintWriter}
import java.net.URL
import java.util.concurrent.locks.ReentrantLock

import scala.tools.nsc.Settings
Expand All @@ -28,47 +29,35 @@ import org.apache.spark.repl.SparkILoop
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.util.MutableURLClassLoader

import org.apache.kyuubi.Utils
import org.apache.kyuubi.{Logging, Utils}

private[spark] case class KyuubiSparkILoop private (
spark: SparkSession,
output: ByteArrayOutputStream)
extends SparkILoop(None, new PrintWriter(output)) {
extends SparkILoop(None, new PrintWriter(output)) with Logging {
import KyuubiSparkILoop._

val result = new DataFrameHolder(spark)

private def initialize(): Unit = withLockRequired {
val currentClassLoader = Thread.currentThread().getContextClassLoader
val interpreterClasspath = getAllJars(currentClassLoader).mkString(File.pathSeparator)
info(s"Adding jars to Scala interpreter's class path: $interpreterClasspath")
settings = new Settings
val interpArguments = List(
"-Yrepl-class-based",
"-Yrepl-outdir",
s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}")
s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}",
"-classpath",
interpreterClasspath)
settings.processArguments(interpArguments, processAll = true)
settings.usejavacp.value = true
val currentClassLoader = Thread.currentThread().getContextClassLoader
settings.embeddedDefaults(currentClassLoader)
this.createInterpreter()
this.initializeSynchronous()
try {
this.compilerClasspath
this.ensureClassLoader()
var classLoader: ClassLoader = Thread.currentThread().getContextClassLoader
while (classLoader != null) {
classLoader match {
case loader: MutableURLClassLoader =>
val allJars = loader.getURLs.filter { u =>
val file = new File(u.getPath)
u.getProtocol == "file" && file.isFile &&
file.getName.contains("scala-lang_scala-reflect")
}
this.addUrlsToClassPath(allJars: _*)
classLoader = null
case _ =>
classLoader = classLoader.getParent
}
}

this.addUrlsToClassPath(
classOf[DataFrameHolder].getProtectionDomain.getCodeSource.getLocation)
} finally {
Expand Down Expand Up @@ -97,6 +86,24 @@ private[spark] case class KyuubiSparkILoop private (
}
}

private def getAllJars(currentClassLoader: ClassLoader): Array[URL] = {
var classLoader: ClassLoader = currentClassLoader
var allJars = Array.empty[URL]
while (classLoader != null) {
classLoader match {
case loader: MutableURLClassLoader =>
allJars = loader.getURLs.filter { u =>
// TODO: handle SPARK-47475 since Spark 4.0.0 in the future
u.getProtocol == "file" && new File(u.getPath).isFile
}
classLoader = null
case _ =>
classLoader = classLoader.getParent
}
}
allJars
}

def getResult(statementId: String): DataFrame = result.get(statementId)

def clearResult(statementId: String): Unit = result.unset(statementId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.kyuubi.engine.spark.repl

import java.io.{ByteArrayOutputStream, File, PrintWriter}
import java.net.URL
import java.util.concurrent.locks.ReentrantLock

import scala.tools.nsc.Settings
Expand All @@ -28,48 +29,36 @@ import org.apache.spark.repl.SparkILoop
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.util.MutableURLClassLoader

import org.apache.kyuubi.Utils
import org.apache.kyuubi.{Logging, Utils}

private[spark] case class KyuubiSparkILoop private (
spark: SparkSession,
output: ByteArrayOutputStream)
extends SparkILoop(null, new PrintWriter(output)) {
extends SparkILoop(null, new PrintWriter(output)) with Logging {
import KyuubiSparkILoop._

val result = new DataFrameHolder(spark)

private def initialize(): Unit = withLockRequired {
val currentClassLoader = Thread.currentThread().getContextClassLoader
val interpreterClasspath = getAllJars(currentClassLoader).mkString(File.pathSeparator)
info(s"Adding jars to Scala interpreter's class path: $interpreterClasspath")
val settings = new Settings
val interpArguments = List(
"-Yrepl-class-based",
"-Yrepl-outdir",
s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}")
s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}",
"-classpath",
interpreterClasspath)
settings.processArguments(interpArguments, processAll = true)
settings.usejavacp.value = true
val currentClassLoader = Thread.currentThread().getContextClassLoader
settings.embeddedDefaults(currentClassLoader)
this.createInterpreter(settings)
val iMain = this.intp.asInstanceOf[IMain]
iMain.initializeCompiler()
try {
this.compilerClasspath
iMain.ensureClassLoader()
var classLoader: ClassLoader = Thread.currentThread().getContextClassLoader
while (classLoader != null) {
classLoader match {
case loader: MutableURLClassLoader =>
val allJars = loader.getURLs.filter { u =>
val file = new File(u.getPath)
u.getProtocol == "file" && file.isFile &&
file.getName.contains("scala-lang_scala-reflect")
}
this.addUrlsToClassPath(allJars: _*)
classLoader = null
case _ =>
classLoader = classLoader.getParent
}
}

this.addUrlsToClassPath(
classOf[DataFrameHolder].getProtectionDomain.getCodeSource.getLocation)
} finally {
Expand Down Expand Up @@ -98,6 +87,23 @@ private[spark] case class KyuubiSparkILoop private (
}
}

private def getAllJars(currentClassLoader: ClassLoader): Array[URL] = {
var classLoader: ClassLoader = currentClassLoader
var allJars = Array.empty[URL]
while (classLoader != null) {
classLoader match {
case loader: MutableURLClassLoader =>
allJars = loader.getURLs.filter { u =>
u.getProtocol == "file" && new File(u.getPath).isFile
}
classLoader = null
case _ =>
classLoader = classLoader.getParent
}
}
allJars
}

def getResult(statementId: String): DataFrame = result.get(statementId)

def clearResult(statementId: String): Unit = result.unset(statementId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ package org.apache.kyuubi.operation

import java.sql.SQLException
import java.util
import java.util.Properties
import java.util.{Properties, UUID}

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.Path
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime

import org.apache.kyuubi.{KYUUBI_VERSION, WithKyuubiServer}
import org.apache.kyuubi.{KYUUBI_VERSION, Utils, WithKyuubiServer}
import org.apache.kyuubi.config.{KyuubiConf, KyuubiReservedKeys}
import org.apache.kyuubi.config.KyuubiConf.SESSION_CONF_ADVISOR
import org.apache.kyuubi.engine.{ApplicationManagerInfo, ApplicationState}
import org.apache.kyuubi.jdbc.KyuubiHiveDriver
import org.apache.kyuubi.jdbc.hive.{KyuubiConnection, KyuubiSQLException}
import org.apache.kyuubi.jdbc.hive.{KyuubiConnection, KyuubiSQLException, KyuubiStatement}
import org.apache.kyuubi.metrics.{MetricsConstants, MetricsSystem}
import org.apache.kyuubi.plugin.SessionConfAdvisor
import org.apache.kyuubi.session.{KyuubiSessionImpl, KyuubiSessionManager, SessionHandle, SessionType}
Expand Down Expand Up @@ -346,6 +347,34 @@ class KyuubiOperationPerConnectionSuite extends WithKyuubiServer with HiveJDBCTe
}
}
}

test("Scala REPL should see jars added by spark.jars") {
val jarDir = Utils.createTempDir().toFile
val udfCode =
"""
|package test.utils
|
|object Math {
| def add(x: Int, y: Int): Int = x + y
|}
|
|""".stripMargin
val jarFile = UserJarTestUtils.createJarFile(
udfCode,
"test",
s"test-function-${UUID.randomUUID}.jar",
jarDir.toString)
val localPath = new Path(jarFile.getAbsolutePath)
withSessionConf()(Map("spark.jars" -> localPath.toString))() {
withJdbcStatement() { statement =>
val kyuubiStatement = statement.asInstanceOf[KyuubiStatement]
kyuubiStatement.executeScala("import test.utils.{Math => TMath}")
val rs = kyuubiStatement.executeScala("println(TMath.add(1,2))")
rs.next()
assert(rs.getString(1) === "3")
}
}
}
}

class TestSessionConfAdvisor extends SessionConfAdvisor {
Expand Down

0 comments on commit 9b618c9

Please sign in to comment.