Skip to content

Commit

Permalink
Added this test for other databases (#954)
Browse files Browse the repository at this point in the history
Refactor inferNullability test logic

Extracted common inferNullability test logic to `commonTestScenarios.kt` for reusability. Removed redundant code from individual test files and added necessary imports to support the new structure.
  • Loading branch information
zaleslaw authored Nov 19, 2024
1 parent c4bb29c commit 7a895e2
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 368 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package org.jetbrains.kotlinx.dataframe.io

import io.kotest.matchers.shouldBe
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.io.db.MsSql
import java.sql.Connection
import java.sql.ResultSet
import kotlin.reflect.typeOf

internal fun inferNullability(connection: Connection) {
// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)

connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSqlTable` method

// start testing `readSQLQuery` method

// ith default inferNullability: Boolean = true
@Language("SQL")
val sqlQuery =
"""
SELECT name, surname, age FROM TestTable1
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
dataSchema2.columns.size shouldBe 3
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)
// this column changed a type because it doesn't contain nulls
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSQLQuery` method

// start testing `readResultSet` method

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
val df4 = DataFrame.readResultSet(rs, MsSql)
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()

rs.beforeFirst()

val dataSchema3 = DataFrame.getSchemaForResultSet(rs, MsSql)
dataSchema3.columns.size shouldBe 4
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
rs.beforeFirst()

val df5 = DataFrame.readResultSet(rs, MsSql, inferNullability = false)
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
}
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.select
import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig
import org.jetbrains.kotlinx.dataframe.io.db.H2
Expand All @@ -20,6 +19,7 @@ import org.jetbrains.kotlinx.dataframe.io.getSchemaForAllSqlTables
import org.jetbrains.kotlinx.dataframe.io.getSchemaForResultSet
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable
import org.jetbrains.kotlinx.dataframe.io.inferNullability
import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables
import org.jetbrains.kotlinx.dataframe.io.readDataFrame
import org.jetbrains.kotlinx.dataframe.io.readResultSet
Expand Down Expand Up @@ -841,128 +841,9 @@ class JdbcTest {
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
}

// TODO: add the same test for each particular database and refactor the scenario to the common test case
// https://github.com/Kotlin/dataframe/issues/688
@Test
fun `infer nullability`() {
// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)

connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement()
.execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSqlTable` method

// start testing `readSQLQuery` method

// ith default inferNullability: Boolean = true
@Language("SQL")
val sqlQuery =
"""
SELECT name, surname, age FROM TestTable1
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
dataSchema2.columns.size shouldBe 3
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)

// this column changed a type because it doesn't contain nulls
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSQLQuery` method

// start testing `readResultSet` method

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
val df4 = DataFrame.readResultSet(rs, H2(MySql))
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()

rs.beforeFirst()

val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2(MySql))
dataSchema3.columns.size shouldBe 4
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
rs.beforeFirst()

val df5 = DataFrame.readResultSet(rs, H2(MySql), inferNullability = false)
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()

// this column changed a type because it doesn't contain nulls
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
}
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
inferNullability(connection)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.select
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable
import org.jetbrains.kotlinx.dataframe.io.inferNullability
import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables
import org.jetbrains.kotlinx.dataframe.io.readSqlQuery
import org.jetbrains.kotlinx.dataframe.io.readSqlTable
Expand Down Expand Up @@ -417,4 +418,9 @@ class MariadbH2Test {
schema.columns["doublecol"]!!.type shouldBe typeOf<Double>()
schema.columns["decimalcol"]!!.type shouldBe typeOf<BigDecimal>()
}

@Test
fun `infer nullability`() {
inferNullability(connection)
}
}
Loading

0 comments on commit 7a895e2

Please sign in to comment.