diff --git a/examples/run-example.scala b/examples/run-example.scala
index 17d600bb..7096db4a 100644
--- a/examples/run-example.scala
+++ b/examples/run-example.scala
@@ -1,7 +1,7 @@
import org.bdgenomics.gnocchi.sql.GnocchiSession._
-import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLinearRegression
-val genotypesPath = "examples/testData/time_phenos.vcf"
-val phenotypesPath = "examples/testData/tab_time_phenos.txt"
+import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression
+val genotypesPath = "examples/testData/time_genos_1.vcf"
+val phenotypesPath = "examples/testData/tab_time_phenos_1.txt"
val geno = sc.loadGenotypes(genotypesPath)
val pheno = sc.loadPhenotypes(phenotypesPath, "IID", "pheno_1", "\t", Option(phenotypesPath), Option(List("pheno_4", "pheno_5")))
@@ -10,4 +10,4 @@ val filteredGenoVariants = sc.filterVariants(filteredGeno, geno = 0.1, maf = 0.1
val broadPheno = sc.broadcast(pheno)
-val assoications = AdditiveLinearRegression(geno, broadPheno)
\ No newline at end of file
+val assoications = LinearRegression(geno, broadPheno)
diff --git a/examples/test_merge.scala b/examples/test_merge.scala
index f5e34356..59b6abd8 100644
--- a/examples/test_merge.scala
+++ b/examples/test_merge.scala
@@ -1,6 +1,6 @@
import org.bdgenomics.gnocchi.sql.GnocchiSession._
-val genotypesPath1 = "testData/time_genos_1.vcf"
-val phenotypesPath1 = "testData/tab_time_phenos_1.txt"
+val genotypesPath1 = "examples/testData/time_genos_1.vcf"
+val phenotypesPath1 = "examples/testData/tab_time_phenos_1.txt"
val geno1 = sc.loadGenotypes(genotypesPath1)
val pheno1 = sc.loadPhenotypes(phenotypesPath1, "IID", "pheno_1", "\t")
@@ -19,7 +19,7 @@ val broadPheno1 = sc.broadcast(pheno1)
// val broadPheno2 = sc.broadcast(pheno2)
-import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLinearRegression
+import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression
-val assoc_1 = AdditiveLinearRegression(fullFiltered1, broadPheno1)
+val assoc_1 = LinearSiteRegression(fullFiltered1, broadPheno1)
// val assoc_2 = AdditiveLinearRegression(fullFiltered2, broadPheno2)
diff --git a/gnocchi-cli/src/main/scala/org/bdgenomics/gnocchi/cli/RegressPhenotypes.scala b/gnocchi-cli/src/main/scala/org/bdgenomics/gnocchi/cli/RegressPhenotypes.scala
index 80d40062..de51c997 100755
--- a/gnocchi-cli/src/main/scala/org/bdgenomics/gnocchi/cli/RegressPhenotypes.scala
+++ b/gnocchi-cli/src/main/scala/org/bdgenomics/gnocchi/cli/RegressPhenotypes.scala
@@ -18,9 +18,7 @@
package org.bdgenomics.gnocchi.cli
import org.bdgenomics.gnocchi.algorithms.siteregression._
-import org.bdgenomics.gnocchi.models.variant.VariantModel
-import org.bdgenomics.gnocchi.models.variant.linear.{ AdditiveLinearVariantModel, DominantLinearVariantModel }
-import org.bdgenomics.gnocchi.models.variant.logistic.{ AdditiveLogisticVariantModel, DominantLogisticVariantModel }
+import org.bdgenomics.gnocchi.models.variant.{ LinearVariantModel, LogisticVariantModel, VariantModel }
import org.bdgenomics.gnocchi.sql.GnocchiSession._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
@@ -122,20 +120,20 @@ class RegressPhenotypes(protected val args: RegressPhenotypesArgs) extends BDGSp
args.associationType match {
case "ADDITIVE_LINEAR" => {
- val associations = AdditiveLinearRegression(filteredGeno, broadPhenotype)
- logResults[AdditiveLinearVariantModel](associations, sc)
+ val associations = LinearSiteRegression(filteredGeno, broadPhenotype, "ADDITIVE")
+ logResults[LinearVariantModel](associations, sc)
}
case "DOMINANT_LINEAR" => {
- val associations = DominantLinearRegression(filteredGeno, broadPhenotype)
- logResults[DominantLinearVariantModel](associations, sc)
+ val associations = LinearSiteRegression(filteredGeno, broadPhenotype, "DOMINANT")
+ logResults[LinearVariantModel](associations, sc)
}
case "ADDITIVE_LOGISTIC" => {
- val associations = AdditiveLogisticRegression(filteredGeno, broadPhenotype)
- logResults[AdditiveLogisticVariantModel](associations, sc)
+ val associations = LogisticSiteRegression(filteredGeno, broadPhenotype, "ADDITIVE")
+ logResults[LogisticVariantModel](associations, sc)
}
case "DOMINANT_LOGISTIC" => {
- val associations = DominantLogisticRegression(filteredGeno, broadPhenotype)
- logResults[DominantLogisticVariantModel](associations, sc)
+ val associations = LogisticSiteRegression(filteredGeno, broadPhenotype, "DOMINANT")
+ logResults[LogisticVariantModel](associations, sc)
}
}
}
diff --git a/gnocchi-core/pom.xml b/gnocchi-core/pom.xml
index 19b0bca3..9d8d1ea7 100755
--- a/gnocchi-core/pom.xml
+++ b/gnocchi-core/pom.xml
@@ -45,7 +45,6 @@
As explained here: http://stackoverflow.com/questions/1660441/java-flag-to-enable-extended-serialization-debugging-info
The second option allows us better debugging for serialization-based errors.
-->
- -Xmx1024m -Dsun.io.serialization.extendedDebugInfo=true -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.PDuErZrK -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.d1mt6vD9 -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.eRmiFaTc -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.u9bHWT1h -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.J5tKjn9N -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.PyTFRsyh
F
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LinearSiteRegression.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LinearSiteRegression.scala
index 1f39fdea..c0be5927 100755
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LinearSiteRegression.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LinearSiteRegression.scala
@@ -17,7 +17,6 @@
*/
package org.bdgenomics.gnocchi.algorithms.siteregression
-import org.bdgenomics.gnocchi.models.variant.linear.{ AdditiveLinearVariantModel, DominantLinearVariantModel, LinearVariantModel }
import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
@@ -27,18 +26,31 @@ import breeze.stats._
import breeze.stats.distributions.StudentsT
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{ Dataset, SparkSession }
+import org.bdgenomics.gnocchi.models.variant.LinearVariantModel
import scala.collection.immutable.Map
-trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[VM] {
+trait LinearSiteRegression extends SiteRegression[LinearVariantModel, LinearAssociation] {
+
+ val sparkSession = SparkSession.builder().getOrCreate()
+ import sparkSession.implicits._
def apply(genotypes: Dataset[CalledVariant],
phenotypes: Broadcast[Map[String, Phenotype]],
- validationStringency: String = "STRICT"): Dataset[VM]
+ allelicAssumption: String = "ADDITIVE",
+ validationStringency: String = "STRICT"): Dataset[LinearVariantModel] = {
+ //ToDo: Singular Matrix Exceptions
+ genotypes.map((genos: CalledVariant) => {
+ val association = applyToSite(phenotypes.value, genos, allelicAssumption)
+ constructVM(genos, phenotypes.value.head._2, association, allelicAssumption)
+ })
+ }
def applyToSite(phenotypes: Map[String, Phenotype],
- genotypes: CalledVariant): LinearAssociation = {
- val (x, y) = prepareDesignMatrix(genotypes, phenotypes)
+ genotypes: CalledVariant,
+ allelicAssumption: String): LinearAssociation = {
+
+ val (x, y) = prepareDesignMatrix(genotypes, phenotypes, allelicAssumption)
// TODO: Determine if QR factorization is faster
val beta = x \ y
@@ -83,13 +95,20 @@ trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[
}
private[algorithms] def prepareDesignMatrix(genotypes: CalledVariant,
- phenotypes: Map[String, Phenotype]): (DenseMatrix[Double], DenseVector[Double]) = {
- val filteredGenotypes = genotypes.samples.filter(_.value != ".")
+ phenotypes: Map[String, Phenotype],
+ allelicAssumption: String): (DenseMatrix[Double], DenseVector[Double]) = {
+
+ val validGenos = genotypes.samples.filter(_.value != ".")
- val (primitiveX, primitiveY) = filteredGenotypes.flatMap({
- case gs if phenotypes.contains(gs.sampleID) => {
- val pheno = phenotypes(gs.sampleID)
- Some(1.0 +: clipOrKeepState(gs.toDouble) +: pheno.covariates.toArray, pheno.phenotype)
+ val samplesGenotypes = allelicAssumption.toUpperCase match {
+ case "ADDITIVE" => validGenos.map(genotypeState => (genotypeState.sampleID, genotypeState.additive))
+ case "DOMINANT" => validGenos.map(genotypeState => (genotypeState.sampleID, genotypeState.dominant))
+ case "RECESSIVE" => validGenos.map(genotypeState => (genotypeState.sampleID, genotypeState.recessive))
+ }
+
+ val (primitiveX, primitiveY) = samplesGenotypes.flatMap({
+ case (sampleID, genotype) if phenotypes.contains(sampleID) => {
+ Some(1.0 +: genotype +: phenotypes(sampleID).covariates.toArray, phenotypes(sampleID).phenotype)
}
case _ => None
}).toArray.unzip
@@ -107,73 +126,22 @@ trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[
(new DenseMatrix(primitiveX.length, primitiveX(0).length, primitiveX.transpose.flatten), new DenseVector(primitiveY))
}
- protected def constructVM(variant: CalledVariant,
- phenotype: Phenotype,
- association: LinearAssociation): VM
-}
-
-object AdditiveLinearRegression extends AdditiveLinearRegression {
- val regressionName = "additiveLinearRegression"
-}
-
-trait AdditiveLinearRegression extends LinearSiteRegression[AdditiveLinearVariantModel] with Additive {
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def apply(genotypes: Dataset[CalledVariant],
- phenotypes: Broadcast[Map[String, Phenotype]],
- validationStringency: String = "STRICT"): Dataset[AdditiveLinearVariantModel] = {
-
- //ToDo: Singular Matrix Exceptions
- genotypes.map((genos: CalledVariant) => {
- val association = applyToSite(phenotypes.value, genos)
- constructVM(genos, phenotypes.value.head._2, association)
- })
- }
-
- protected def constructVM(variant: CalledVariant,
- phenotype: Phenotype,
- association: LinearAssociation): AdditiveLinearVariantModel = {
- AdditiveLinearVariantModel(variant.uniqueID,
+ def constructVM(variant: CalledVariant,
+ phenotype: Phenotype,
+ association: LinearAssociation,
+ allelicAssumption: String): LinearVariantModel = {
+ LinearVariantModel(variant.uniqueID,
association,
phenotype.phenoName,
variant.chromosome,
variant.position,
variant.referenceAllele,
variant.alternateAllele,
+ allelicAssumption,
phaseSetId = 0)
}
}
-object DominantLinearRegression extends DominantLinearRegression {
- val regressionName = "dominantLinearRegression"
-}
-
-trait DominantLinearRegression extends LinearSiteRegression[DominantLinearVariantModel] with Dominant {
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def apply(genotypes: Dataset[CalledVariant],
- phenotypes: Broadcast[Map[String, Phenotype]],
- validationStringency: String = "STRICT"): Dataset[DominantLinearVariantModel] = {
-
- //ToDo: Singular Matrix Exceptions
- genotypes.map((genos: CalledVariant) => {
- val association = applyToSite(phenotypes.value, genos)
- constructVM(genos, phenotypes.value.head._2, association)
- })
- }
-
- protected def constructVM(variant: CalledVariant,
- phenotype: Phenotype,
- association: LinearAssociation): DominantLinearVariantModel = {
- DominantLinearVariantModel(variant.uniqueID,
- association,
- phenotype.phenoName,
- variant.chromosome,
- variant.position,
- variant.referenceAllele,
- variant.alternateAllele,
- phaseSetId = 0)
- }
-}
+object LinearSiteRegression extends LinearSiteRegression {
+ val regressionName = "LinearSiteRegression"
+}
\ No newline at end of file
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LogisticSiteRegression.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LogisticSiteRegression.scala
index b43f9943..4965c9e5 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LogisticSiteRegression.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LogisticSiteRegression.scala
@@ -19,27 +19,44 @@ package org.bdgenomics.gnocchi.algorithms.siteregression
import breeze.linalg._
import breeze.numerics._
-import org.bdgenomics.gnocchi.models.variant.logistic.{ AdditiveLogisticVariantModel, DominantLogisticVariantModel, LogisticVariantModel }
import org.bdgenomics.gnocchi.primitives.association.LogisticAssociation
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
import org.apache.commons.math3.distribution.ChiSquaredDistribution
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{ Dataset, SparkSession }
+import org.bdgenomics.gnocchi.models.variant.LogisticVariantModel
import scala.annotation.tailrec
import scala.collection.immutable.Map
-trait LogisticSiteRegression[VM <: LogisticVariantModel[VM]] extends SiteRegression[VM] {
+trait LogisticSiteRegression extends SiteRegression[LogisticVariantModel, LogisticAssociation] {
+
+ val sparkSession = SparkSession.builder().getOrCreate()
+ import sparkSession.implicits._
def apply(genotypes: Dataset[CalledVariant],
phenotypes: Broadcast[Map[String, Phenotype]],
- validationStringency: String = "STRICT"): Dataset[VM]
+ allelicAssumption: String = "ADDITIVE",
+ validationStringency: String = "STRICT"): Dataset[LogisticVariantModel] = {
+ genotypes.flatMap((genos: CalledVariant) => {
+ try {
+ val association = applyToSite(phenotypes.value, genos, allelicAssumption)
+ Some(constructVM(genos, phenotypes.value.head._2, association, allelicAssumption))
+ } catch {
+ case e: breeze.linalg.MatrixSingularException => {
+ logError(e.toString)
+ None
+ }
+ }
+ })
+ }
def applyToSite(phenotypes: Map[String, Phenotype],
- genotypes: CalledVariant): LogisticAssociation = {
+ genotypes: CalledVariant,
+ allelicAssumption: String): LogisticAssociation = {
- val (data, labels) = prepareDesignMatrix(phenotypes, genotypes)
+ val (data, labels) = prepareDesignMatrix(phenotypes, genotypes, allelicAssumption)
val numObservations = genotypes.samples.count(x => !x.value.contains("."))
val maxIter = 1000
@@ -130,11 +147,16 @@ trait LogisticSiteRegression[VM <: LogisticVariantModel[VM]] extends SiteRegress
* is [[DenseVector]] of labels
*/
def prepareDesignMatrix(phenotypes: Map[String, Phenotype],
- genotypes: CalledVariant): (DenseMatrix[Double], DenseVector[Double]) = {
+ genotypes: CalledVariant,
+ allelicAssumption: String): (DenseMatrix[Double], DenseVector[Double]) = {
- val samplesGenotypes = genotypes.samples
- .filter { case genotypeState => !genotypeState.value.contains(".") }
- .map { case genotypeState => (genotypeState.sampleID, List(clipOrKeepState(genotypeState.toDouble))) }
+ val validGenos = genotypes.samples.filter(genotypeState => !genotypeState.value.contains("."))
+
+ val samplesGenotypes = allelicAssumption.toUpperCase match {
+ case "ADDITIVE" => validGenos.map(genotypeState => (genotypeState.sampleID, List(genotypeState.additive)))
+ case "DOMINANT" => validGenos.map(genotypeState => (genotypeState.sampleID, List(genotypeState.dominant)))
+ case "RECESSIVE" => validGenos.map(genotypeState => (genotypeState.sampleID, List(genotypeState.recessive)))
+ }
val cleanedSampleVector = samplesGenotypes
.map { case (sampleID, genotype) => (sampleID, (genotype ++ phenotypes(sampleID).covariates).toList) }
@@ -145,89 +167,22 @@ trait LogisticSiteRegression[VM <: LogisticVariantModel[VM]] extends SiteRegress
(X, Y)
}
- protected def constructVM(variant: CalledVariant,
- phenotype: Phenotype,
- association: LogisticAssociation): VM
-}
-
-object AdditiveLogisticRegression extends AdditiveLogisticRegression {
- val regressionName = "additiveLogisticRegression"
-}
-
-trait AdditiveLogisticRegression extends LogisticSiteRegression[AdditiveLogisticVariantModel] with Additive {
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def apply(genotypes: Dataset[CalledVariant],
- phenotypes: Broadcast[Map[String, Phenotype]],
- validationStringency: String = "STRICT"): Dataset[AdditiveLogisticVariantModel] = {
-
- // Note: we would like to use a map below, but need some way to deal with singular matrix exceptions being thrown
- // by applyToSite. flatMap unpacks the Some/None objects into the correct product case classes.
- genotypes.flatMap((genos: CalledVariant) => {
- try {
- val association = applyToSite(phenotypes.value, genos)
- Some(constructVM(genos, phenotypes.value.head._2, association))
- } catch {
- case e: breeze.linalg.MatrixSingularException => {
- logError(e.toString)
- None: Option[AdditiveLogisticVariantModel]
- }
- }
- })
- }
-
- protected def constructVM(variant: CalledVariant,
- phenotype: Phenotype,
- association: LogisticAssociation): AdditiveLogisticVariantModel = {
- AdditiveLogisticVariantModel(variant.uniqueID,
+ def constructVM(variant: CalledVariant,
+ phenotype: Phenotype,
+ association: LogisticAssociation,
+ allelicAssumption: String): LogisticVariantModel = {
+ LogisticVariantModel(variant.uniqueID,
association,
phenotype.phenoName,
variant.chromosome,
variant.position,
variant.referenceAllele,
variant.alternateAllele,
+ allelicAssumption,
phaseSetId = 0)
}
}
-object DominantLogisticRegression extends DominantLogisticRegression {
- val regressionName = "dominantLogisticRegression"
-}
-
-trait DominantLogisticRegression extends LogisticSiteRegression[DominantLogisticVariantModel] with Dominant {
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def apply(genotypes: Dataset[CalledVariant],
- phenotypes: Broadcast[Map[String, Phenotype]],
- validationStringency: String = "STRICT"): Dataset[DominantLogisticVariantModel] = {
-
- // Note: we would like to use a map below, but need some way to deal with singular matrix exceptions being thrown
- // by applyToSite. flatMap unpacks the Some/None objects into the correct product case classes.
- genotypes.flatMap((genos: CalledVariant) => {
- try {
- val association = applyToSite(phenotypes.value, genos)
- Some(constructVM(genos, phenotypes.value.head._2, association))
- } catch {
- case e: breeze.linalg.MatrixSingularException => {
- logError(e.toString)
- None: Option[DominantLogisticVariantModel]
- }
- }
- })
- }
-
- protected def constructVM(variant: CalledVariant,
- phenotype: Phenotype,
- association: LogisticAssociation): DominantLogisticVariantModel = {
- DominantLogisticVariantModel(variant.uniqueID,
- association,
- phenotype.phenoName,
- variant.chromosome,
- variant.position,
- variant.referenceAllele,
- variant.alternateAllele,
- phaseSetId = 0)
- }
+object LogisticSiteRegression extends LogisticSiteRegression {
+ val regressionName = "LogisticSiteRegression"
}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/SiteRegression.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/SiteRegression.scala
index d6eebe29..f6375c12 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/SiteRegression.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/SiteRegression.scala
@@ -17,6 +17,8 @@
*/
package org.bdgenomics.gnocchi.algorithms.siteregression
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.Dataset
import org.bdgenomics.gnocchi.models.variant.VariantModel
import org.bdgenomics.gnocchi.primitives.association.Association
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
@@ -25,62 +27,22 @@ import org.bdgenomics.utils.misc.Logging
import scala.collection.immutable.Map
-trait SiteRegression[VM <: VariantModel[VM]] extends Serializable with Logging {
+trait SiteRegression[VM <: VariantModel[VM], A <: Association] extends Serializable with Logging {
val regressionName: String
- // def apply(genotypes: Dataset[CalledVariant],
- // phenotypes: Broadcast[Map[String, Phenotype]],
- // validationStringency: String = "STRICT"): Dataset[VM]
+ def apply(genotypes: Dataset[CalledVariant],
+ phenotypes: Broadcast[Map[String, Phenotype]],
+ allelicAssumption: String = "ADDITIVE",
+ validationStringency: String = "STRICT"): Dataset[VM]
def applyToSite(phenotypes: Map[String, Phenotype],
- genotypes: CalledVariant): Association
+ genotypes: CalledVariant,
+ allelicAssumption: String): A
- /**
- * Known implementations: [[Additive]], [[Dominant]]
- *
- * @param gs GenotypeState object to be clipped
- * @return Formatted GenotypeState object
- */
- def clipOrKeepState(gs: Double): Double
+ def constructVM(variant: CalledVariant,
+ phenotype: Phenotype,
+ association: A,
+ allelicAssumption: String): VM
}
-trait Additive {
-
- /**
- * Formats a GenotypeState object by converting the state to a double. Uses cumulative weighting of genotype
- * states which is typical of an Additive model.
- *
- * @param gs GenotypeState object to be clipped
- * @return Formatted GenotypeState object
- */
- def clipOrKeepState(gs: Double): Double = {
- gs
- }
-}
-
-trait Dominant {
-
- /**
- * 1/0 or 0/1 or 1/1 ==> response
- *
- * @param gs GenotypeState object to be clipped
- * @return Formatted GenotypeState object
- */
- def clipOrKeepState(gs: Double): Double = {
- if (gs == 0.0) 0.0 else 1.0
- }
-}
-
-trait Recessive {
-
- /**
- * 1/1 ==> response
- *
- * @param gs GenotypeState object to be clipped
- * @return Formatted GenotypeState object
- */
- def clipOrKeepState(gs: Double): Double = {
- if (gs == 2.0) 1.0 else 0.0
- }
-}
\ No newline at end of file
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/GnocchiModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/GnocchiModel.scala
index e3623f66..a2e9d8a4 100755
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/GnocchiModel.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/GnocchiModel.scala
@@ -42,11 +42,6 @@ case class GnocchiModelMetaData(modelType: String,
}
}
-//object GnocchiModel extends GnocchiModel {
-//
-// def apply(): GnocchiModel
-//}
-
/**
* A trait that wraps an RDD of variant-specific models that are incrementally
* updated, an RDD of variant-specific models that are recomputed over entire
@@ -110,6 +105,9 @@ trait GnocchiModel[VM <: VariantModel[VM], GM <: GnocchiModel[VM, GM]] {
* @return Returns an RDD of incrementally updated VariantModels
*/
def mergeVariantModels(newVariantModels: Dataset[VM]): Dataset[VM]
+ // = {
+ // variantModels.joinWith(newVariantModels, variantModels("uniqueID") === newVariantModels("uniqueID")).map(x => x._1.mergeWith(x._2))
+ // }
// /**
// * Returns VariantModels created from full recompute over all data for each variant
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/linear/DominantLinearGnocchiModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/LinearGnocchiModel.scala
similarity index 72%
rename from gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/linear/DominantLinearGnocchiModel.scala
rename to gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/LinearGnocchiModel.scala
index aa310121..f14fec17 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/linear/DominantLinearGnocchiModel.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/LinearGnocchiModel.scala
@@ -15,20 +15,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.bdgenomics.gnocchi.models.linear
+package org.bdgenomics.gnocchi.models
-import org.bdgenomics.gnocchi.algorithms.siteregression.DominantLinearRegression
-import org.bdgenomics.gnocchi.models._
-import org.bdgenomics.gnocchi.models.variant.QualityControlVariantModel
-import org.bdgenomics.gnocchi.models.variant.linear.DominantLinearVariantModel
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{ Dataset, SparkSession }
+import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression
+import org.bdgenomics.gnocchi.models.variant.{ LinearVariantModel, QualityControlVariantModel }
+import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
+import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-object DominantLinearGnocchiModelFactory {
+object LinearGnocchiModelFactory {
- val regressionName = "dominantLinearRegression"
+ val regressionName = "LinearRegression"
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
@@ -37,10 +35,10 @@ object DominantLinearGnocchiModelFactory {
phenotypeNames: Option[List[String]],
QCVariantIDs: Option[Set[String]] = None,
QCVariantSamplingRate: Double = 0.1,
- validationStringency: String = "STRICT"): DominantLinearGnocchiModel = {
+ validationStringency: String = "STRICT"): LinearGnocchiModel = {
// ToDo: sampling QC Variants better.
- val variantModels = DominantLinearRegression(genotypes, phenotypes, validationStringency)
+ val variantModels = LinearSiteRegression(genotypes, phenotypes, validationStringency)
// Create QCVariantModels
val comparisonVariants = if (QCVariantIDs.isEmpty) {
@@ -53,7 +51,7 @@ object DominantLinearGnocchiModelFactory {
.joinWith(comparisonVariants, variantModels("uniqueID") === comparisonVariants("uniqueID"), "inner")
.withColumnRenamed("_1", "variantModel")
.withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[DominantLinearVariantModel]]
+ .as[QualityControlVariantModel[LinearVariantModel]]
val phenoNames = if (phenotypeNames.isEmpty) {
List(phenotypes.value.head._2.phenoName) ++ (1 to phenotypes.value.head._2.covariates.length).map(x => "covar_" + x)
@@ -68,23 +66,23 @@ object DominantLinearGnocchiModelFactory {
genotypes.count().toInt,
flaggedVariantModels = Option(QCVariantModels.select("variant.uniqueID").as[String].collect().toList))
- DominantLinearGnocchiModel(metaData = metadata,
+ LinearGnocchiModel(metaData = metadata,
variantModels = variantModels,
QCVariantModels = QCVariantModels,
QCPhenotypes = phenotypes.value)
}
}
-case class DominantLinearGnocchiModel(metaData: GnocchiModelMetaData,
- variantModels: Dataset[DominantLinearVariantModel],
- QCVariantModels: Dataset[QualityControlVariantModel[DominantLinearVariantModel]],
- QCPhenotypes: Map[String, Phenotype])
- extends GnocchiModel[DominantLinearVariantModel, DominantLinearGnocchiModel] {
+case class LinearGnocchiModel(metaData: GnocchiModelMetaData,
+ variantModels: Dataset[LinearVariantModel],
+ QCVariantModels: Dataset[QualityControlVariantModel[LinearVariantModel]],
+ QCPhenotypes: Map[String, Phenotype])
+ extends GnocchiModel[LinearVariantModel, LinearGnocchiModel] {
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
- def mergeGnocchiModel(otherModel: GnocchiModel[DominantLinearVariantModel, DominantLinearGnocchiModel]): GnocchiModel[DominantLinearVariantModel, DominantLinearGnocchiModel] = {
+ def mergeGnocchiModel(otherModel: GnocchiModel[LinearVariantModel, LinearGnocchiModel]): GnocchiModel[LinearVariantModel, LinearGnocchiModel] = {
require(otherModel.metaData.modelType == metaData.modelType,
"Models being merged are not the same type. Type equality is required to merge two models correctly.")
@@ -100,17 +98,17 @@ case class DominantLinearGnocchiModel(metaData: GnocchiModelMetaData,
val mergedQCVariantModels = mergedVMs.joinWith(mergedQCVariants, mergedVMs("uniqueID") === mergedQCVariants("uniqueID"), "inner")
.withColumnRenamed("_1", "variantModel")
.withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[DominantLinearVariantModel]]
+ .as[QualityControlVariantModel[LinearVariantModel]]
val mergedQCPhenotypes = QCPhenotypes ++ otherModel.QCPhenotypes
- DominantLinearGnocchiModel(updatedMetaData, mergedVMs, mergedQCVariantModels, mergedQCPhenotypes)
+ LinearGnocchiModel(updatedMetaData, mergedVMs, mergedQCVariantModels, mergedQCPhenotypes)
}
- def mergeVariantModels(newVariantModels: Dataset[DominantLinearVariantModel]): Dataset[DominantLinearVariantModel] = {
+ def mergeVariantModels(newVariantModels: Dataset[LinearVariantModel]): Dataset[LinearVariantModel] = {
variantModels.joinWith(newVariantModels, variantModels("uniqueID") === newVariantModels("uniqueID")).map(x => x._1.mergeWith(x._2))
}
- def mergeQCVariants(newQCVariantModels: Dataset[QualityControlVariantModel[DominantLinearVariantModel]]): Dataset[CalledVariant] = {
+ def mergeQCVariants(newQCVariantModels: Dataset[QualityControlVariantModel[LinearVariantModel]]): Dataset[CalledVariant] = {
val variants1 = QCVariantModels.map(_.variant)
val variants2 = newQCVariantModels.map(_.variant)
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/logistic/AdditiveLogisticGnocchiModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/LogisticGnocchiModel.scala
similarity index 72%
rename from gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/logistic/AdditiveLogisticGnocchiModel.scala
rename to gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/LogisticGnocchiModel.scala
index c209d791..d6d58ae9 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/logistic/AdditiveLogisticGnocchiModel.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/LogisticGnocchiModel.scala
@@ -15,18 +15,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.bdgenomics.gnocchi.models.logistic
+package org.bdgenomics.gnocchi.models
-import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLogisticRegression
-import org.bdgenomics.gnocchi.models._
-import org.bdgenomics.gnocchi.models.variant.QualityControlVariantModel
-import org.bdgenomics.gnocchi.models.variant.logistic.AdditiveLogisticVariantModel
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{ Dataset, SparkSession }
+import org.bdgenomics.gnocchi.algorithms.siteregression.LogisticSiteRegression
+import org.bdgenomics.gnocchi.models.variant.{ LogisticVariantModel, QualityControlVariantModel }
+import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
+import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-object AdditiveLogisticGnocchiModelFactory {
+object LogisticGnocchiModelFactory {
val regressionName = "additiveLinearRegression"
val sparkSession = SparkSession.builder().getOrCreate()
@@ -37,10 +35,10 @@ object AdditiveLogisticGnocchiModelFactory {
phenotypeNames: Option[List[String]],
QCVariantIDs: Option[Set[String]] = None,
QCVariantSamplingRate: Double = 0.1,
- validationStringency: String = "STRICT"): AdditiveLogisticGnocchiModel = {
+ validationStringency: String = "STRICT"): LogisticGnocchiModel = {
// ToDo: sampling QC Variants better.
- val variantModels = AdditiveLogisticRegression(genotypes, phenotypes)
+ val variantModels = LogisticSiteRegression(genotypes, phenotypes)
// Create QCVariantModels
val comparisonVariants = if (QCVariantIDs.isEmpty) {
@@ -53,7 +51,7 @@ object AdditiveLogisticGnocchiModelFactory {
.joinWith(comparisonVariants, variantModels("uniqueID") === comparisonVariants("uniqueID"), "inner")
.withColumnRenamed("_1", "variantModel")
.withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[AdditiveLogisticVariantModel]]
+ .as[QualityControlVariantModel[LogisticVariantModel]]
val phenoNames = if (phenotypeNames.isEmpty) {
List(phenotypes.value.head._2.phenoName) ++ (1 to phenotypes.value.head._2.covariates.length).map(x => "covar_" + x)
@@ -68,23 +66,23 @@ object AdditiveLogisticGnocchiModelFactory {
genotypes.count().toInt,
flaggedVariantModels = Option(QCVariantModels.select("variant.uniqueID").as[String].collect().toList))
- AdditiveLogisticGnocchiModel(metaData = metadata,
+ LogisticGnocchiModel(metaData = metadata,
variantModels = variantModels,
QCVariantModels = QCVariantModels,
QCPhenotypes = phenotypes.value)
}
}
-case class AdditiveLogisticGnocchiModel(metaData: GnocchiModelMetaData,
- variantModels: Dataset[AdditiveLogisticVariantModel],
- QCVariantModels: Dataset[QualityControlVariantModel[AdditiveLogisticVariantModel]],
- QCPhenotypes: Map[String, Phenotype])
- extends GnocchiModel[AdditiveLogisticVariantModel, AdditiveLogisticGnocchiModel] {
+case class LogisticGnocchiModel(metaData: GnocchiModelMetaData,
+ variantModels: Dataset[LogisticVariantModel],
+ QCVariantModels: Dataset[QualityControlVariantModel[LogisticVariantModel]],
+ QCPhenotypes: Map[String, Phenotype])
+ extends GnocchiModel[LogisticVariantModel, LogisticGnocchiModel] {
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._
- def mergeGnocchiModel(otherModel: GnocchiModel[AdditiveLogisticVariantModel, AdditiveLogisticGnocchiModel]): GnocchiModel[AdditiveLogisticVariantModel, AdditiveLogisticGnocchiModel] = {
+ def mergeGnocchiModel(otherModel: GnocchiModel[LogisticVariantModel, LogisticGnocchiModel]): GnocchiModel[LogisticVariantModel, LogisticGnocchiModel] = {
require(otherModel.metaData.modelType == metaData.modelType,
"Models being merged are not the same type. Type equality is required to merge two models correctly.")
@@ -100,17 +98,17 @@ case class AdditiveLogisticGnocchiModel(metaData: GnocchiModelMetaData,
val mergedQCVariantModels = mergedVMs.joinWith(mergedQCVariants, mergedVMs("uniqueID") === mergedQCVariants("uniqueID"), "inner")
.withColumnRenamed("_1", "variantModel")
.withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[AdditiveLogisticVariantModel]]
+ .as[QualityControlVariantModel[LogisticVariantModel]]
val mergedQCPhenotypes = QCPhenotypes ++ otherModel.QCPhenotypes
- AdditiveLogisticGnocchiModel(updatedMetaData, mergedVMs, mergedQCVariantModels, mergedQCPhenotypes)
+ LogisticGnocchiModel(updatedMetaData, mergedVMs, mergedQCVariantModels, mergedQCPhenotypes)
}
- def mergeVariantModels(newVariantModels: Dataset[AdditiveLogisticVariantModel]): Dataset[AdditiveLogisticVariantModel] = {
+ def mergeVariantModels(newVariantModels: Dataset[LogisticVariantModel]): Dataset[LogisticVariantModel] = {
variantModels.joinWith(newVariantModels, variantModels("uniqueID") === newVariantModels("uniqueID")).map(x => x._1.mergeWith(x._2))
}
- def mergeQCVariants(newQCVariantModels: Dataset[QualityControlVariantModel[AdditiveLogisticVariantModel]]): Dataset[CalledVariant] = {
+ def mergeQCVariants(newQCVariantModels: Dataset[QualityControlVariantModel[LogisticVariantModel]]): Dataset[CalledVariant] = {
val variants1 = QCVariantModels.map(_.variant)
val variants2 = newQCVariantModels.map(_.variant)
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/linear/AdditiveLinearGnocchiModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/linear/AdditiveLinearGnocchiModel.scala
deleted file mode 100644
index faef4b37..00000000
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/linear/AdditiveLinearGnocchiModel.scala
+++ /dev/null
@@ -1,136 +0,0 @@
-/**
- * Licensed to Big Data Genomics (BDG) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The BDG licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.bdgenomics.gnocchi.models.linear
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLinearRegression
-import org.bdgenomics.gnocchi.models._
-import org.bdgenomics.gnocchi.models.variant.QualityControlVariantModel
-import org.bdgenomics.gnocchi.models.variant.linear.AdditiveLinearVariantModel
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.{ Dataset, SparkSession }
-
-import scala.collection.immutable.Map
-
-object AdditiveLinearGnocchiModelFactory {
-
- val regressionName = "additiveLinearRegression"
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def apply(genotypes: Dataset[CalledVariant],
- phenotypes: Broadcast[Map[String, Phenotype]],
- phenotypeNames: Option[List[String]],
- QCVariantIDs: Option[Set[String]] = None,
- QCVariantSamplingRate: Double = 0.1,
- validationStringency: String = "STRICT"): AdditiveLinearGnocchiModel = {
-
- // ToDo: sampling QC Variants better.
- val variantModels = AdditiveLinearRegression(genotypes, phenotypes, validationStringency)
-
- // Create QCVariantModels
- val comparisonVariants = if (QCVariantIDs.isEmpty) {
- genotypes.sample(withReplacement = false, fraction = 0.1)
- } else {
- genotypes.filter(x => QCVariantIDs.get.contains(x.uniqueID))
- }
-
- val QCVariantModels = variantModels
- .joinWith(comparisonVariants, variantModels("uniqueID") === comparisonVariants("uniqueID"), "inner")
- .withColumnRenamed("_1", "variantModel")
- .withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[AdditiveLinearVariantModel]]
-
- val phenoNames = if (phenotypeNames.isEmpty) {
- List(phenotypes.value.head._2.phenoName) ++ (1 to phenotypes.value.head._2.covariates.length).map(x => "covar_" + x)
- } else {
- phenotypeNames.get
- }
-
- // Create metadata
- val metadata = GnocchiModelMetaData(regressionName,
- phenoNames.head,
- phenoNames.tail.mkString(","),
- genotypes.count().toInt,
- flaggedVariantModels = Option(QCVariantModels.select("variant.uniqueID").as[String].collect().toList))
-
- AdditiveLinearGnocchiModel(metaData = metadata,
- variantModels = variantModels,
- QCVariantModels = QCVariantModels,
- QCPhenotypes = phenotypes.value)
- }
-}
-
-case class AdditiveLinearGnocchiModel(metaData: GnocchiModelMetaData,
- variantModels: Dataset[AdditiveLinearVariantModel],
- QCVariantModels: Dataset[QualityControlVariantModel[AdditiveLinearVariantModel]],
- QCPhenotypes: Map[String, Phenotype])
- extends GnocchiModel[AdditiveLinearVariantModel, AdditiveLinearGnocchiModel] {
-
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def mergeGnocchiModel(otherModel: GnocchiModel[AdditiveLinearVariantModel, AdditiveLinearGnocchiModel]): GnocchiModel[AdditiveLinearVariantModel, AdditiveLinearGnocchiModel] = {
-
- require(otherModel.metaData.modelType == metaData.modelType,
- "Models being merged are not the same type. Type equality is required to merge two models correctly.")
-
- val mergedVMs = mergeVariantModels(otherModel.variantModels)
-
- // ToDo: 1. [DONE] make sure models are of same type 2. [DONE] find intersection of QCVariants and use those as the gnocchiModel
- // ToDo: QCVariants 3. Make sure the phenotype of the models are the same 4. Make sure the covariates of the model
- // ToDo: are the same (currently broken because covariates stored in [[Phenotype]] object are the values not names)
- val updatedMetaData = updateMetaData(otherModel.metaData.numSamples)
-
- val mergedQCVariants = mergeQCVariants(otherModel.QCVariantModels)
- val mergedQCVariantModels = mergedVMs.joinWith(mergedQCVariants, mergedVMs("uniqueID") === mergedQCVariants("uniqueID"), "inner")
- .withColumnRenamed("_1", "variantModel")
- .withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[AdditiveLinearVariantModel]]
- val mergedQCPhenotypes = QCPhenotypes ++ otherModel.QCPhenotypes
-
- AdditiveLinearGnocchiModel(updatedMetaData, mergedVMs, mergedQCVariantModels, mergedQCPhenotypes)
- }
-
- def mergeVariantModels(newVariantModels: Dataset[AdditiveLinearVariantModel]): Dataset[AdditiveLinearVariantModel] = {
-
- // ToDo: Logging here to denote what the results of the merge were. How many variants were retained / thrown out from
- // ToDo: each model.
- variantModels.joinWith(newVariantModels, variantModels("uniqueID") === newVariantModels("uniqueID")).map(x => x._1.mergeWith(x._2))
- }
-
- def mergeQCVariants(newQCVariantModels: Dataset[QualityControlVariantModel[AdditiveLinearVariantModel]]): Dataset[CalledVariant] = {
- val variants1 = QCVariantModels.map(_.variant)
- val variants2 = newQCVariantModels.map(_.variant)
-
- variants1.joinWith(variants2, variants1("uniqueID") === variants2("uniqueID"))
- .as[(CalledVariant, CalledVariant)]
- .map(x =>
- CalledVariant(x._1.chromosome,
- x._1.position,
- x._1.uniqueID,
- x._1.referenceAllele,
- x._1.alternateAllele,
- x._1.qualityScore,
- x._1.filter,
- x._1.info,
- x._1.format,
- x._1.samples ++ x._2.samples))
- }
-}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/logistic/DominantLogisticGnocchiModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/logistic/DominantLogisticGnocchiModel.scala
deleted file mode 100644
index 9a2b87e5..00000000
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/logistic/DominantLogisticGnocchiModel.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/**
- * Licensed to Big Data Genomics (BDG) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The BDG licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.bdgenomics.gnocchi.models.logistic
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.DominantLogisticRegression
-import org.bdgenomics.gnocchi.models._
-import org.bdgenomics.gnocchi.models.variant.QualityControlVariantModel
-import org.bdgenomics.gnocchi.models.variant.logistic.DominantLogisticVariantModel
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.{ Dataset, SparkSession }
-
-object DominantLogisticGnocchiModelFactory {
-
- val regressionName = "DominantLogisticRegression"
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def apply(genotypes: Dataset[CalledVariant],
- phenotypes: Broadcast[Map[String, Phenotype]],
- phenotypeNames: Option[List[String]],
- QCVariantIDs: Option[Set[String]] = None,
- QCVariantSamplingRate: Double = 0.1,
- validationStringency: String = "STRICT"): DominantLogisticGnocchiModel = {
-
- // ToDo: sampling QC Variants better.
- val variantModels = DominantLogisticRegression(genotypes, phenotypes, validationStringency)
-
- // Create QCVariantModels
- val comparisonVariants = if (QCVariantIDs.isEmpty) {
- genotypes.sample(withReplacement = false, fraction = 0.1)
- } else {
- genotypes.filter(x => QCVariantIDs.get.contains(x.uniqueID))
- }
-
- val QCVariantModels = variantModels
- .joinWith(comparisonVariants, variantModels("uniqueID") === comparisonVariants("uniqueID"), "inner")
- .withColumnRenamed("_1", "variantModel")
- .withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[DominantLogisticVariantModel]]
-
- val phenoNames = if (phenotypeNames.isEmpty) {
- List(phenotypes.value.head._2.phenoName) ++ (1 to phenotypes.value.head._2.covariates.length).map(x => "covar_" + x)
- } else {
- phenotypeNames.get
- }
-
- // Create metadata
- val metadata = GnocchiModelMetaData(regressionName,
- phenoNames.head,
- phenoNames.tail.mkString(","),
- genotypes.count().toInt,
- flaggedVariantModels = Option(QCVariantModels.select("variant.uniqueID").as[String].collect().toList))
-
- DominantLogisticGnocchiModel(metaData = metadata,
- variantModels = variantModels,
- QCVariantModels = QCVariantModels,
- QCPhenotypes = phenotypes.value)
- }
-}
-
-case class DominantLogisticGnocchiModel(metaData: GnocchiModelMetaData,
- variantModels: Dataset[DominantLogisticVariantModel],
- QCVariantModels: Dataset[QualityControlVariantModel[DominantLogisticVariantModel]],
- QCPhenotypes: Map[String, Phenotype])
- extends GnocchiModel[DominantLogisticVariantModel, DominantLogisticGnocchiModel] {
-
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
-
- def mergeGnocchiModel(otherModel: GnocchiModel[DominantLogisticVariantModel, DominantLogisticGnocchiModel]): GnocchiModel[DominantLogisticVariantModel, DominantLogisticGnocchiModel] = {
-
- require(otherModel.metaData.modelType == metaData.modelType,
- "Models being merged are not the same type. Type equality is required to merge two models correctly.")
-
- val mergedVMs = mergeVariantModels(otherModel.variantModels)
-
- // ToDo: 1. [DONE] make sure models are of same type 2. [DONE] find intersection of QCVariants and use those as the gnocchiModel
- // ToDo: QCVariants 3. Make sure the phenotype of the models are the same 4. Make sure the covariates of the model
- // ToDo: are the same (currently broken because covariates stored in [[Phenotype]] object are the values not names)
- val updatedMetaData = updateMetaData(otherModel.metaData.numSamples)
-
- val mergedQCVariants = mergeQCVariants(otherModel.QCVariantModels)
- val mergedQCVariantModels = mergedVMs.joinWith(mergedQCVariants, mergedVMs("uniqueID") === mergedQCVariants("uniqueID"), "inner")
- .withColumnRenamed("_1", "variantModel")
- .withColumnRenamed("_2", "variant")
- .as[QualityControlVariantModel[DominantLogisticVariantModel]]
- val mergedQCPhenotypes = QCPhenotypes ++ otherModel.QCPhenotypes
-
- DominantLogisticGnocchiModel(updatedMetaData, mergedVMs, mergedQCVariantModels, mergedQCPhenotypes)
- }
-
- def mergeVariantModels(newVariantModels: Dataset[DominantLogisticVariantModel]): Dataset[DominantLogisticVariantModel] = {
- variantModels.joinWith(newVariantModels, variantModels("uniqueID") === newVariantModels("uniqueID")).map(x => x._1.mergeWith(x._2))
- }
-
- def mergeQCVariants(newQCVariantModels: Dataset[QualityControlVariantModel[DominantLogisticVariantModel]]): Dataset[CalledVariant] = {
- val variants1 = QCVariantModels.map(_.variant)
- val variants2 = newQCVariantModels.map(_.variant)
-
- variants1.joinWith(variants2, variants1("uniqueID") === variants2("uniqueID"))
- .as[(CalledVariant, CalledVariant)]
- .map(x =>
- CalledVariant(x._1.chromosome,
- x._1.position,
- x._1.uniqueID,
- x._1.referenceAllele,
- x._1.alternateAllele,
- x._1.qualityScore,
- x._1.filter,
- x._1.info,
- x._1.format,
- x._1.samples ++ x._2.samples))
- }
-}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/LinearVariantModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/LinearVariantModel.scala
similarity index 68%
rename from gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/LinearVariantModel.scala
rename to gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/LinearVariantModel.scala
index 11642fb7..ec9d1516 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/LinearVariantModel.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/LinearVariantModel.scala
@@ -1,11 +1,47 @@
-package org.bdgenomics.gnocchi.models.variant.linear
+/**
+ * Licensed to Big Data Genomics (BDG) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The BDG licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.bdgenomics.gnocchi.models.variant
-import org.bdgenomics.gnocchi.models.variant.VariantModel
-import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
import org.apache.commons.math3.distribution.TDistribution
+import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression
+import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
+import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
+import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
+
+import scala.collection.immutable.Map
-trait LinearVariantModel[VM <: LinearVariantModel[VM]] extends VariantModel[VM] {
- val association: LinearAssociation
+case class LinearVariantModel(uniqueID: String,
+ association: LinearAssociation,
+ phenotype: String,
+ chromosome: Int,
+ position: Int,
+ referenceAllele: String,
+ alternateAllele: String,
+ allelicAssumption: String,
+ phaseSetId: Int = 0) extends VariantModel[LinearVariantModel] with LinearSiteRegression {
+
+ val modelType: String = "Linear Variant Model"
+ val regressionName = "Linear Regression"
+
+ def update(genotypes: CalledVariant, phenotypes: Map[String, Phenotype]): LinearVariantModel = {
+ val batchVariantModel = constructVariantModel(uniqueID, applyToSite(phenotypes, genotypes, allelicAssumption))
+ mergeWith(batchVariantModel)
+ }
/**
* Returns updated LinearVariantModel of correct subtype
@@ -15,7 +51,7 @@ trait LinearVariantModel[VM <: LinearVariantModel[VM]] extends VariantModel[VM]
*
* @return Returns updated LinearVariantModel of correct subtype
*/
- def mergeWith(variantModel: VM): VM = {
+ def mergeWith(variantModel: LinearVariantModel): LinearVariantModel = {
val updatedNumSamples = updateNumSamples(variantModel.association.numSamples)
val updatedWeights = updateWeights(variantModel.association.weights, variantModel.association.numSamples)
val updatedSsDeviations = updateSsDeviations(variantModel.association.ssDeviations)
@@ -153,6 +189,39 @@ trait LinearVariantModel[VM <: LinearVariantModel[VM]] extends VariantModel[VM]
updatedResidualDegreesOfFreedom: Int,
updatedPValue: Double,
updatedWeights: List[Double],
- updatedNumSamples: Int): VM
+ updatedNumSamples: Int): LinearVariantModel = {
+
+ val updatedAssociation = LinearAssociation(ssDeviations = updatedSsDeviations,
+ ssResiduals = updatedSsResiduals,
+ geneticParameterStandardError = updatedGeneticParameterStandardError,
+ tStatistic = updatedtStatistic,
+ residualDegreesOfFreedom = updatedResidualDegreesOfFreedom,
+ pValue = updatedPValue,
+ weights = updatedWeights,
+ numSamples = updatedNumSamples)
+
+ LinearVariantModel(variantID,
+ updatedAssociation,
+ phenotype,
+ chromosome,
+ position,
+ referenceAllele,
+ alternateAllele,
+ allelicAssumption,
+ phaseSetId)
+ }
+
+ def constructVariantModel(variantID: String,
+ association: LinearAssociation): LinearVariantModel = {
+ LinearVariantModel(variantID,
+ association,
+ phenotype,
+ chromosome,
+ position,
+ referenceAllele,
+ alternateAllele,
+ allelicAssumption,
+ phaseSetId)
+ }
}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/LogisticVariantModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/LogisticVariantModel.scala
similarity index 62%
rename from gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/LogisticVariantModel.scala
rename to gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/LogisticVariantModel.scala
index f70352ec..e9ad0682 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/LogisticVariantModel.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/LogisticVariantModel.scala
@@ -15,14 +15,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.bdgenomics.gnocchi.models.variant.logistic
+package org.bdgenomics.gnocchi.models.variant
-import org.bdgenomics.gnocchi.models.variant.VariantModel
-import org.bdgenomics.gnocchi.primitives.association.LogisticAssociation
import org.apache.commons.math3.distribution.ChiSquaredDistribution
+import org.apache.commons.math3.linear.SingularMatrixException
+import org.bdgenomics.gnocchi.algorithms.siteregression.LogisticSiteRegression
+import org.bdgenomics.gnocchi.primitives.association.LogisticAssociation
+import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
+import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
+
+import scala.collection.immutable.Map
+
+case class LogisticVariantModel(uniqueID: String,
+ association: LogisticAssociation,
+ phenotype: String,
+ chromosome: Int,
+ position: Int,
+ referenceAllele: String,
+ alternateAllele: String,
+ allelicAssumption: String,
+ phaseSetId: Int = 0) extends VariantModel[LogisticVariantModel] with LogisticSiteRegression {
-trait LogisticVariantModel[VM <: LogisticVariantModel[VM]] extends VariantModel[VM] {
- val association: LogisticAssociation
+ val modelType = "Logistic Variant Model"
+ val regressionName = "Logistic Regression"
/**
* Returns updated LogisticVariantModel of correct subtype
@@ -32,7 +47,7 @@ trait LogisticVariantModel[VM <: LogisticVariantModel[VM]] extends VariantModel[
*
* @return Returns updated LogisticVariantModel of correct subtype
*/
- def mergeWith(variantModel: VM): VM = {
+ def mergeWith(variantModel: LogisticVariantModel): LogisticVariantModel = {
val updatedNumSamples = updateNumSamples(variantModel.association.numSamples)
val updatedGeneticParameterStandardError = computeGeneticParameterStandardError(variantModel.association.geneticParameterStandardError, variantModel.association.numSamples)
val updatedWeights = updateWeights(variantModel.association.weights, variantModel.association.numSamples)
@@ -58,7 +73,6 @@ trait LogisticVariantModel[VM <: LogisticVariantModel[VM]] extends VariantModel[
*/
def computeGeneticParameterStandardError(batchStandardError: Double, batchNumSamples: Int): Double = {
(batchStandardError * batchNumSamples.toDouble + association.geneticParameterStandardError * association.numSamples.toDouble) / (batchNumSamples + association.numSamples).toDouble
-
}
/**
@@ -87,9 +101,50 @@ trait LogisticVariantModel[VM <: LogisticVariantModel[VM]] extends VariantModel[
1 - chiDist.cumulativeProbability(waldStatistic)
}
+ def update(genotypes: CalledVariant,
+ phenotypes: Map[String, Phenotype]): LogisticVariantModel = {
+
+ //TODO: add validation stringency here rather than just creating empty association object
+ val batchVariantModel = try {
+ constructVariantModel(uniqueID, applyToSite(phenotypes, genotypes, allelicAssumption))
+ } catch {
+ case error: SingularMatrixException => throw new SingularMatrixException()
+ }
+ mergeWith(batchVariantModel)
+ }
+
def constructVariantModel(variantId: String,
updatedGeneticParameterStandardError: Double,
updatedPValue: Double,
updatedWeights: List[Double],
- updatedNumSamples: Int): VM
+ updatedNumSamples: Int): LogisticVariantModel = {
+
+ val association = LogisticAssociation(weights = updatedWeights,
+ geneticParameterStandardError = updatedGeneticParameterStandardError,
+ pValue = updatedPValue,
+ numSamples = updatedNumSamples)
+
+ LogisticVariantModel(variantId,
+ association,
+ phenotype,
+ chromosome,
+ position,
+ referenceAllele,
+ alternateAllele,
+ allelicAssumption,
+ phaseSetId)
+ }
+
+ def constructVariantModel(variantId: String,
+ association: LogisticAssociation): LogisticVariantModel = {
+ LogisticVariantModel(variantId,
+ association,
+ phenotype,
+ chromosome,
+ position,
+ referenceAllele,
+ alternateAllele,
+ allelicAssumption,
+ phaseSetId)
+ }
}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/AdditiveLinearVariantModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/AdditiveLinearVariantModel.scala
deleted file mode 100644
index 215576ec..00000000
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/AdditiveLinearVariantModel.scala
+++ /dev/null
@@ -1,98 +0,0 @@
-/**
- * Licensed to Big Data Genomics (BDG) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The BDG licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.bdgenomics.gnocchi.models.variant.linear
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLinearRegression
-import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-
-import scala.collection.immutable.Map
-
-case class AdditiveLinearVariantModel(uniqueID: String,
- association: LinearAssociation,
- phenotype: String,
- chromosome: Int,
- position: Int,
- referenceAllele: String,
- alternateAllele: String,
- phaseSetId: Int = 0)
- extends LinearVariantModel[AdditiveLinearVariantModel]
- with AdditiveLinearRegression with Serializable {
-
- type VM = AdditiveLinearVariantModel
- val modelType = "Additive Linear Variant Model"
- val regressionName = "Additive Linear Regression"
-
- /**
- * Updates the AdditiveLinearVariantModel given a new batch of data
- *
- * @param observations Array containing data at the particular site for
- * all samples. Format of each element is:
- * (gs, Array(pheno, covar1, ... covarp))
- * where gs is the diploid genotype at that site for the
- * given sample [0, 1, or 2], pheno is the sample's value for
- * the primary phenotype being regressed on, and covar1-covarp
- * are that sample's values for each covariate.
- */
- def update(genotypes: CalledVariant, phenotypes: Map[String, Phenotype]): AdditiveLinearVariantModel = {
- val batchVariantModel = constructVariantModel(uniqueID, applyToSite(phenotypes, genotypes))
- mergeWith(batchVariantModel)
- }
-
- def constructVariantModel(variantID: String,
- updatedSsDeviations: Double,
- updatedSsResiduals: Double,
- updatedGeneticParameterStandardError: Double,
- updatedtStatistic: Double,
- updatedResidualDegreesOfFreedom: Int,
- updatedPValue: Double,
- updatedWeights: List[Double],
- updatedNumSamples: Int): AdditiveLinearVariantModel = {
-
- val updatedAssociation = LinearAssociation(ssDeviations = updatedSsDeviations,
- ssResiduals = updatedSsResiduals,
- geneticParameterStandardError = updatedGeneticParameterStandardError,
- tStatistic = updatedtStatistic,
- residualDegreesOfFreedom = updatedResidualDegreesOfFreedom,
- pValue = updatedPValue,
- weights = updatedWeights,
- numSamples = updatedNumSamples)
-
- AdditiveLinearVariantModel(variantID,
- updatedAssociation,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele,
- phaseSetId)
- }
-
- def constructVariantModel(variantID: String,
- association: LinearAssociation): AdditiveLinearVariantModel = {
- AdditiveLinearVariantModel(variantID,
- association,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele,
- phaseSetId)
- }
-}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/DominantLinearVariantModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/DominantLinearVariantModel.scala
deleted file mode 100644
index 13d457a4..00000000
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/linear/DominantLinearVariantModel.scala
+++ /dev/null
@@ -1,98 +0,0 @@
-/**
- * Licensed to Big Data Genomics (BDG) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The BDG licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.bdgenomics.gnocchi.models.variant.linear
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.DominantLinearRegression
-import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-
-import scala.collection.immutable.Map
-
-case class DominantLinearVariantModel(uniqueID: String,
- association: LinearAssociation,
- phenotype: String,
- chromosome: Int,
- position: Int,
- referenceAllele: String,
- alternateAllele: String,
- phaseSetId: Int = 0)
- extends LinearVariantModel[DominantLinearVariantModel]
- with DominantLinearRegression with Serializable {
-
- type VM = DominantLinearVariantModel
- val modelType = "Dominant Linear Variant Model"
- val regressionName = "Dominant Linear Regression"
-
- /**
- * Updates the DominantLinearVariantModel given a new batch of data
- *
- * @param observations Array containing data at the particular site for
- * all samples. Format of each element is:
- * (gs, Array(pheno, covar1, ... covarp))
- * where gs is the diploid genotype at that site for the
- * given sample [0, 1, or 2], pheno is the sample's value for
- * the primary phenotype being regressed on, and covar1-covarp
- * are that sample's values for each covariate.
- */
- def update(genotypes: CalledVariant, phenotypes: Map[String, Phenotype]): DominantLinearVariantModel = {
- val batchVariantModel = constructVariantModel(uniqueID, applyToSite(phenotypes, genotypes))
- mergeWith(batchVariantModel)
- }
-
- def constructVariantModel(variantID: String,
- updatedSsDeviations: Double,
- updatedSsResiduals: Double,
- updatedGeneticParameterStandardError: Double,
- updatedtStatistic: Double,
- updatedResidualDegreesOfFreedom: Int,
- updatedPValue: Double,
- updatedWeights: List[Double],
- updatedNumSamples: Int): DominantLinearVariantModel = {
-
- val updatedAssociation = LinearAssociation(ssDeviations = updatedSsDeviations,
- ssResiduals = updatedSsResiduals,
- geneticParameterStandardError = updatedGeneticParameterStandardError,
- tStatistic = updatedtStatistic,
- residualDegreesOfFreedom = updatedResidualDegreesOfFreedom,
- pValue = updatedPValue,
- weights = updatedWeights,
- numSamples = updatedNumSamples)
-
- DominantLinearVariantModel(variantID,
- updatedAssociation,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele,
- phaseSetId)
- }
-
- def constructVariantModel(variantID: String,
- association: LinearAssociation): DominantLinearVariantModel = {
- DominantLinearVariantModel(variantID,
- association,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele, phaseSetId)
- }
-
-}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/AdditiveLogisticVariantModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/AdditiveLogisticVariantModel.scala
deleted file mode 100644
index 2974b2fe..00000000
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/AdditiveLogisticVariantModel.scala
+++ /dev/null
@@ -1,94 +0,0 @@
-/**
- * Licensed to Big Data Genomics (BDG) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The BDG licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.bdgenomics.gnocchi.models.variant.logistic
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLogisticRegression
-import org.bdgenomics.gnocchi.primitives.association.LogisticAssociation
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-import org.apache.commons.math3.linear.SingularMatrixException
-
-import scala.collection.immutable.Map
-
-case class AdditiveLogisticVariantModel(uniqueID: String,
- association: LogisticAssociation,
- phenotype: String,
- chromosome: Int,
- position: Int,
- referenceAllele: String,
- alternateAllele: String,
- phaseSetId: Int = 0)
- extends LogisticVariantModel[AdditiveLogisticVariantModel]
- with AdditiveLogisticRegression with Serializable {
-
- val modelType = "Additive Logistic Variant Model"
- override val regressionName = "Additive Logistic Regression"
-
- /**
- * Updates the LogisticVariantModel given a new batch of data
- *
- * @param observations Array containing data at the particular site for
- * all samples. Format of each element is:
- * (gs, Array(pheno, covar1, ... covarp))
- * where gs is the diploid genotype at that site for the
- * given sample [0, 1, or 2], pheno is the sample's value for
- * the primary phenotype being regressed on, and covar1-covarp
- * are that sample's values for each covariate.
- */
- def update(genotypes: CalledVariant, phenotypes: Map[String, Phenotype]): AdditiveLogisticVariantModel = {
-
- //TODO: add validation stringency here rather than just creating empty association object
- val batchVariantModel = try {
- constructVariantModel(uniqueID, applyToSite(phenotypes, genotypes))
- } catch {
- case error: SingularMatrixException => throw new SingularMatrixException()
- }
- mergeWith(batchVariantModel)
- }
-
- def constructVariantModel(variantId: String,
- updatedGeneticParameterStandardError: Double,
- updatedPValue: Double,
- updatedWeights: List[Double],
- updatedNumSamples: Int): AdditiveLogisticVariantModel = {
-
- val association = LogisticAssociation(weights = updatedWeights,
- geneticParameterStandardError = updatedGeneticParameterStandardError,
- pValue = updatedPValue,
- numSamples = updatedNumSamples)
-
- AdditiveLogisticVariantModel(variantId,
- association,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele, phaseSetId)
- }
-
- def constructVariantModel(variantId: String,
- association: LogisticAssociation): AdditiveLogisticVariantModel = {
- AdditiveLogisticVariantModel(variantId,
- association,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele, phaseSetId)
- }
-}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/DominantLogisticVariantModel.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/DominantLogisticVariantModel.scala
deleted file mode 100644
index 44b53c9d..00000000
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/models/variant/logistic/DominantLogisticVariantModel.scala
+++ /dev/null
@@ -1,98 +0,0 @@
-/**
- * Licensed to Big Data Genomics (BDG) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The BDG licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.bdgenomics.gnocchi.models.variant.logistic
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.DominantLogisticRegression
-import org.bdgenomics.gnocchi.primitives.association.LogisticAssociation
-import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
-import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
-import org.apache.commons.math3.linear.SingularMatrixException
-
-import scala.collection.immutable.Map
-
-case class DominantLogisticVariantModel(uniqueID: String,
- association: LogisticAssociation,
- phenotype: String,
- chromosome: Int,
- position: Int,
- referenceAllele: String,
- alternateAllele: String,
- phaseSetId: Int = 0)
- extends LogisticVariantModel[DominantLogisticVariantModel]
- with DominantLogisticRegression with Serializable {
-
- val modelType = "Dominant Logistic Variant Model"
- override val regressionName = "Dominant Logistic Regression"
-
- /**
- * Updates the LogisticVariantModel given a new batch of data
- *
- * @param observations Array containing data at the particular site for
- * all samples. Format of each element is:
- * (gs, Array(pheno, covar1, ... covarp))
- * where gs is the diploid genotype at that site for the
- * given sample [0, 1, or 2], pheno is the sample's value for
- * the primary phenotype being regressed on, and covar1-covarp
- * are that sample's values for each covariate.
- */
- def update(genotypes: CalledVariant, phenotypes: Map[String, Phenotype]): DominantLogisticVariantModel = {
-
- //TODO: add validation stringency here rather than just creating empty association object
- val batchVariantModel = try {
- constructVariantModel(uniqueID, applyToSite(phenotypes, genotypes))
- } catch {
- case error: SingularMatrixException => throw new SingularMatrixException()
- }
- mergeWith(batchVariantModel)
- }
-
- def constructVariantModel(variantId: String,
- updatedGeneticParameterStandardError: Double,
- updatedPValue: Double,
- updatedWeights: List[Double],
- updatedNumSamples: Int): DominantLogisticVariantModel = {
-
- val association = LogisticAssociation(weights = updatedWeights,
- geneticParameterStandardError = updatedGeneticParameterStandardError,
- pValue = updatedPValue,
- numSamples = updatedNumSamples)
-
- DominantLogisticVariantModel(variantId,
- association,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele,
- phaseSetId)
- }
-
- def constructVariantModel(variantId: String,
- association: LogisticAssociation): DominantLogisticVariantModel = {
-
- DominantLogisticVariantModel(variantId,
- association,
- phenotype,
- chromosome,
- position,
- referenceAllele,
- alternateAllele,
- phaseSetId)
- }
-}
diff --git a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeState.scala b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeState.scala
index 30d51c51..678b6b8d 100644
--- a/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeState.scala
+++ b/gnocchi-core/src/main/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeState.scala
@@ -18,4 +18,16 @@ case class GenotypeState(sampleID: String,
def toList: List[String] = {
value.split("/|\\|").toList
}
+
+ def additive: Double = {
+ toDouble
+ }
+
+ def dominant: Double = {
+ if (toDouble == 0.0) 0.0 else 1.0
+ }
+
+ def recessive: Double = {
+ if (toDouble == 2.0) 1.0 else 0.0
+ }
}
\ No newline at end of file
diff --git a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LinearSiteRegressionSuite.scala b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LinearSiteRegressionSuite.scala
index 5a8a8bd8..bfc57cba 100755
--- a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LinearSiteRegressionSuite.scala
+++ b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LinearSiteRegressionSuite.scala
@@ -19,7 +19,7 @@ package org.bdgenomics.gnocchi.algorithms
import breeze.linalg.MatrixSingularException
import org.bdgenomics.gnocchi.GnocchiFunSuite
-import org.bdgenomics.gnocchi.algorithms.siteregression.{ AdditiveLinearRegression, DominantLinearRegression }
+import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression
import org.bdgenomics.gnocchi.primitives.genotype.GenotypeState
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
@@ -75,7 +75,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.toMap
// use additiveLinearRegression to regress on AscombeI
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
// Assert that the rsquared is in the right threshold.
// R^2 = 1 - (SS_res / SS_tot)
@@ -112,7 +112,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.toMap
// use additiveLinearRegression to regress on AscombeII
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
// Assert that the rsquared is in the right threshold.
// R^2 = 1 - (SS_res / SS_tot)
@@ -149,7 +149,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.toMap
// use additiveLinearRegression to regress on AscombeIII
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
// Assert that the rsquared is in the right threshold.
// R^2 = 1 - (SS_res / SS_tot)
@@ -186,7 +186,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.toMap
//use additiveLinearRegression to regress on AscombeIV
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
// Assert that the rsquared is in the right threshold.
// R^2 = 1 - (SS_res / SS_tot)
@@ -295,7 +295,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.toMap
// use additiveLinearRegression to regress on PIQ
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
// Assert that the rsquared is in the right threshold.
// R^2 = 1 - (SS_res / SS_tot)
@@ -321,7 +321,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
val phenoMap = Map("sample1" -> Phenotype("sample1", "pheno1", 1))
intercept[MatrixSingularException] {
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
}
}
@@ -332,7 +332,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
val phenoMap = Map("sample2" -> Phenotype("sample2", "pheno1", 1))
intercept[IllegalArgumentException] {
- val regressionResult = AdditiveLinearRegression.applyToSite(phenoMap, cv)
+ val regressionResult = LinearSiteRegression.applyToSite(phenoMap, cv, "ADDITIVE")
}
}
@@ -368,8 +368,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.map(item => (item._2.toString, Phenotype(item._2.toString, "pheno1", item._1(0), item._1.slice(1, 3).toList)))
.toMap
- val (x, y) = AdditiveLinearRegression.prepareDesignMatrix(cv, phenoMap)
-
+ val (x, y) = LinearSiteRegression.prepareDesignMatrix(cv, phenoMap, "ADDITIVE")
// Verify length of X and Y matrices
assert(x.rows === 5)
assert(y.length === 5)
@@ -397,7 +396,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.map(item => (item._2.toString, Phenotype(item._2.toString, "pheno1", item._1(0), item._1.slice(1, 3).toList)))
.toMap
- val (x, y) = AdditiveLinearRegression.prepareDesignMatrix(cv, phenoMap)
+ val (x, y) = LinearSiteRegression.prepareDesignMatrix(cv, phenoMap, "ADDITIVE")
// Verify length of Y label vector
assert(y.length === 5)
@@ -424,7 +423,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.map(item => (item._2.toString, Phenotype(item._2.toString, "pheno1", item._1(0), item._1.slice(1, 3).toList)))
.toMap
- val (x, y) = AdditiveLinearRegression.prepareDesignMatrix(cv, phenoMap)
+ val (x, y) = LinearSiteRegression.prepareDesignMatrix(cv, phenoMap, "ADDITIVE")
// Verify length of X data matrix
assert(x.rows === 5)
@@ -451,7 +450,7 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
.map(item => (item._2.toString, Phenotype(item._2.toString, "pheno1", item._1(0), item._1.slice(1, 3).toList)))
.toMap
- val (x, y) = AdditiveLinearRegression.prepareDesignMatrix(cv, phenoMap)
+ val (x, y) = LinearSiteRegression.prepareDesignMatrix(cv, phenoMap, "ADDITIVE")
// Verify length of X data matrix
assert(x.rows === 5)
@@ -466,8 +465,8 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
}
- sparkTest("AdditiveLinearRegression should have regressionName set to `additiveLinearRegression`") {
- assert(AdditiveLinearRegression.regressionName === "additiveLinearRegression")
+ sparkTest("LinearSiteRegression should have regressionName set to `LinearSiteRegression`") {
+ assert(LinearSiteRegression.regressionName === "LinearSiteRegression")
}
// DominantLinearRegression tests
@@ -475,8 +474,4 @@ class LinearSiteRegressionSuite extends GnocchiFunSuite {
ignore("DominantLinearRegression.constructVM should call the clip or keep state from the `Dominant` trait.") {
}
-
- sparkTest("DominantLinearRegression should have regressionName set to `dominantLinearRegression`") {
- assert(DominantLinearRegression.regressionName === "dominantLinearRegression")
- }
}
diff --git a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LogisticSiteRegressionSuite.scala b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LogisticSiteRegressionSuite.scala
index ead84c69..3bbf71f5 100755
--- a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LogisticSiteRegressionSuite.scala
+++ b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/LogisticSiteRegressionSuite.scala
@@ -19,11 +19,11 @@ package org.bdgenomics.gnocchi.algorithms
import breeze.linalg
import breeze.linalg.{ DenseMatrix, DenseVector, MatrixSingularException }
-import org.bdgenomics.gnocchi.algorithms.siteregression.{ AdditiveLogisticRegression, DominantLogisticRegression }
-import org.bdgenomics.gnocchi.models.variant.logistic.AdditiveLogisticVariantModel
+import org.bdgenomics.gnocchi.algorithms.siteregression.LogisticSiteRegression
import org.bdgenomics.gnocchi.primitives.association.LogisticAssociation
import org.apache.spark.sql.SparkSession
import org.bdgenomics.gnocchi.GnocchiFunSuite
+import org.bdgenomics.gnocchi.models.variant.LogisticVariantModel
import org.mockito.{ ArgumentMatchers, Mockito }
class LogisticSiteRegressionSuite extends GnocchiFunSuite {
@@ -86,7 +86,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cvDS = sparkSession.createDataset(List(cv))
val phenos = sc.broadcast(createSamplePhenotype(calledVariant = Option(cv)))
- val assoc = AdditiveLogisticRegression(cvDS, phenos).collect
+ val assoc = LogisticSiteRegression(cvDS, phenos).collect
// Note: due to lazy compute, the error won't actually
// materialize till an action is called on the dataset, hence the collect
@@ -105,7 +105,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cvDS = sparkSession.createDataset(List(cv))
val phenos = sc.broadcast(createSamplePhenotype(calledVariant = Option(cv)))
- val assoc = DominantLogisticRegression(cvDS, phenos).collect
+ val assoc = LogisticSiteRegression(cvDS, phenos, "DOMINANT").collect
// Note: due to lazy compute, the error won't actually
// materialize till an action is called on the dataset, hence the collect
@@ -130,7 +130,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val phenos = sc.broadcast(createSamplePhenotype(calledVariant = Option(cv), numCovariate = 0))
val cvDS = sparkSession.createDataset(List(cv))
- val assoc = AdditiveLogisticRegression(cvDS, phenos).collect
+ val assoc = LogisticSiteRegression(cvDS, phenos).collect
assert(assoc.length != 0, "LogisticSiteRegression.applyToSite breaks on missing covariates.")
}
@@ -146,9 +146,9 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cvDS = sparkSession.createDataset(List(cv))
val phenos = sc.broadcast(createSamplePhenotype(calledVariant = Option(cv), numCovariate = 10))
- val assoc = AdditiveLogisticRegression(cvDS, phenos).collect
+ val assoc = LogisticSiteRegression(cvDS, phenos).collect
- assert(assoc.head.isInstanceOf[AdditiveLogisticVariantModel], "LogisticSiteRegression.applyToSite does not return a AdditiveLogisticVariantModel")
+ assert(assoc.head.isInstanceOf[LogisticVariantModel], "LogisticSiteRegression.applyToSite does not return a AdditiveLogisticVariantModel")
}
// LogisticSiteRegression.prepareDesignMatrix tests
@@ -159,7 +159,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cv = createSampleCalledVariant(samples = Option(gs))
val phenos = createSamplePhenotype(calledVariant = Option(cv))
- val (data, label) = AdditiveLogisticRegression.prepareDesignMatrix(phenos, cv)
+ val (data, label) = LogisticSiteRegression.prepareDesignMatrix(phenos, cv, "ADDITIVE")
assert(data.rows == cv.numValidSamples, "LogisticSiteRegression.prepareDesignMatrix doesn't filter out missing values properly, design matrix.")
assert(data.isInstanceOf[DenseMatrix[Double]], "LogisticSiteRegression.prepareDesignMatrix doesn't produce a `breeze.linalg.DenseMatrix[Double]`.")
@@ -172,7 +172,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cv = createSampleCalledVariant(samples = Option(gs))
val phenos = createSamplePhenotype(calledVariant = Option(cv))
- val (data, label) = AdditiveLogisticRegression.prepareDesignMatrix(phenos, cv)
+ val (data, label) = LogisticSiteRegression.prepareDesignMatrix(phenos, cv, "ADDITIVE")
val genos = DenseVector(cv.samples.filter(!_.toList.contains(".")).map(_.toDouble): _*)
assert(data(::, 1) == genos, "LogisticSiteRegression.prepareDesignMatrix places genos in the wrong place")
@@ -183,7 +183,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cv = createSampleCalledVariant(samples = Option(gs))
val phenos = createSamplePhenotype(calledVariant = Option(cv), numCovariate = 3)
- val (data, label) = AdditiveLogisticRegression.prepareDesignMatrix(phenos, cv)
+ val (data, label) = LogisticSiteRegression.prepareDesignMatrix(phenos, cv, "ADDITIVE")
val covs = data(::, 2 to -1)
val rows = phenos.filter(x => cv.samples.filter(!_.toList.contains(".")).map(_.sampleID).contains(x._1))
@@ -199,7 +199,7 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cv = createSampleCalledVariant(samples = Option(gs))
val phenos = createSamplePhenotype(calledVariant = Option(cv), numCovariate = 3)
- val XandY = AdditiveLogisticRegression.prepareDesignMatrix(phenos, cv)
+ val XandY = LogisticSiteRegression.prepareDesignMatrix(phenos, cv, "ADDITIVE")
assert(XandY.isInstanceOf[(DenseMatrix[Double], DenseVector[Double])], "LogisticSiteRegression.prepareDesignMatrix returned an incorrect type.")
}
@@ -215,11 +215,11 @@ class LogisticSiteRegressionSuite extends GnocchiFunSuite {
val cv = createSampleCalledVariant(samples = Option(gs))
val phenos = createSamplePhenotype(calledVariant = Option(cv))
- val (data, label) = AdditiveLogisticRegression.prepareDesignMatrix(phenos, cv)
+ val (data, label) = LogisticSiteRegression.prepareDesignMatrix(phenos, cv, "ADDITIVE")
val beta = DenseVector.zeros[Double](data.cols)
try {
- AdditiveLogisticRegression.findBeta(data, label, beta)
+ LogisticSiteRegression.findBeta(data, label, beta)
fail("LogisticSiteRegression.findBeta does not break on singular hessian.")
} catch {
case e: breeze.linalg.MatrixSingularException =>
diff --git a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/SiteRegressionSuite.scala b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/SiteRegressionSuite.scala
deleted file mode 100644
index c5a864da..00000000
--- a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/algorithms/SiteRegressionSuite.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-package org.bdgenomics.gnocchi.algorithms
-
-import org.bdgenomics.gnocchi.algorithms.siteregression.{ AdditiveLinearRegression, AdditiveLogisticRegression, DominantLinearRegression, DominantLogisticRegression }
-import org.bdgenomics.gnocchi.GnocchiFunSuite
-import org.mockito.Mockito._
-
-class SiteRegressionSuite extends GnocchiFunSuite {
-
- sparkTest("SiteRegression.Dominant.clipOrKeepState should map 0.0 to 0.0 and everything else to 1.0") {
- assert(DominantLinearRegression.clipOrKeepState(2.0) == 1.0,
- "SiteRegression.Dominant.clipOrKeepState does not correctly map 2.0 to 1.0")
- assert(DominantLinearRegression.clipOrKeepState(1.0) == 1.0,
- "SiteRegression.Dominant.clipOrKeepState does not correctly map 1.0 to 1.0")
- assert(DominantLinearRegression.clipOrKeepState(0.0) == 0.0,
- "SiteRegression.Dominant.clipOrKeepState does not correctly map 0.0 to 0.0")
- assert(DominantLogisticRegression.clipOrKeepState(2.0) == 1.0,
- "SiteRegression.Dominant.clipOrKeepState does not correctly map 2.0 to 1.0")
- assert(DominantLogisticRegression.clipOrKeepState(1.0) == 1.0,
- "SiteRegression.Dominant.clipOrKeepState does not correctly map 1.0 to 1.0")
- assert(DominantLogisticRegression.clipOrKeepState(0.0) == 0.0,
- "SiteRegression.Dominant.clipOrKeepState does not correctly map 0.0 to 0.0")
- }
-
- sparkTest("SiteRegression.Additive.clipOrKeepState should be an identity map") {
- assert(AdditiveLinearRegression.clipOrKeepState(2.0) == 2.0,
- "SiteRegression.Additive.clipOrKeepState does not correctly map 2.0 to 2.0")
- assert(AdditiveLinearRegression.clipOrKeepState(1.0) == 1.0,
- "SiteRegression.Additive.clipOrKeepState does not correctly map 1.0 to 1.0")
- assert(AdditiveLinearRegression.clipOrKeepState(0.0) == 0.0,
- "SiteRegression.Additive.clipOrKeepState does not correctly map 0.0 to 0.0")
- assert(AdditiveLogisticRegression.clipOrKeepState(2.0) == 2.0,
- "SiteRegression.Additive.clipOrKeepState does not correctly map 2.0 to 2.0")
- assert(AdditiveLogisticRegression.clipOrKeepState(1.0) == 1.0,
- "SiteRegression.Additive.clipOrKeepState does not correctly map 1.0 to 1.0")
- assert(AdditiveLogisticRegression.clipOrKeepState(0.0) == 0.0,
- "SiteRegression.Additive.clipOrKeepState does not correctly map 0.0 to 0.0")
- }
-
- ignore("SiteRegression.Recessive.clipOrKeepState should map 2.0 to 1.0 and everything else to 0.0") {
- // assert(RecessiveLinearRegression.clipOrKeepState(2.0) == 1.0,
- // "SiteRegression.Dominant.clipOrKeepState does not correctly map 2.0 to 1.0")
- // assert(RecessiveLinearRegression.clipOrKeepState(1.0) == 0.0,
- // "SiteRegression.Dominant.clipOrKeepState does not correctly map 1.0 to 0.0")
- // assert(RecessiveLinearRegression.clipOrKeepState(0.0) == 0.0,
- // "SiteRegression.Dominant.clipOrKeepState does not correctly map 0.0 to 0.0")
- }
-}
diff --git a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeStateSuite.scala b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeStateSuite.scala
index 44c2e57f..1663f7c0 100644
--- a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeStateSuite.scala
+++ b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/primitives/genotype/GenotypeStateSuite.scala
@@ -31,4 +31,45 @@ class GenotypeStateSuite extends GnocchiFunSuite {
val gs = GenotypeState("1234", "1/0")
assert(gs.toList == List[String]("1", "0"), "GenotypeState.toList does not correctly split the genotypes on forward slash delimiter.")
}
+
+ // Allelic Assumption tests
+
+ sparkTest("GenotypeState.dominant should map 0.0 to 0.0 and everything else to 1.0") {
+ val gs0 = GenotypeState("0", "0/0")
+ val gs1 = GenotypeState("1", "0/1")
+ val gs2 = GenotypeState("2", "1/1")
+
+ assert(gs2.dominant == 1.0,
+ "GenotypeState.dominant does not correctly map 2.0 to 1.0")
+ assert(gs1.dominant == 1.0,
+ "GenotypeState.dominant does not correctly map 1.0 to 1.0")
+ assert(gs0.dominant == 0.0,
+ "GenotypeState.dominant does not correctly map 0.0 to 0.0")
+ }
+
+ sparkTest("GenotypeState.additive should be an identity map") {
+ val gs0 = GenotypeState("0", "0/0")
+ val gs1 = GenotypeState("1", "0/1")
+ val gs2 = GenotypeState("2", "1/1")
+
+ assert(gs2.additive == 2.0,
+ "GenotypeState.additive does not correctly map 2.0 to 2.0")
+ assert(gs1.additive == 1.0,
+ "GenotypeState.additive does not correctly map 1.0 to 1.0")
+ assert(gs0.additive == 0.0,
+ "GenotypeState.additive does not correctly map 0.0 to 0.0")
+ }
+
+ sparkTest("GenotypeState.recessive should map 2.0 to 1.0 and everything else to 0.0") {
+ val gs0 = GenotypeState("0", "0/0")
+ val gs1 = GenotypeState("1", "0/1")
+ val gs2 = GenotypeState("2", "1/1")
+
+ assert(gs2.recessive == 1.0,
+ "GenotypeState.recessive does not correctly map 2.0 to 1.0")
+ assert(gs1.recessive == 0.0,
+ "GenotypeState.recessive does not correctly map 1.0 to 0.0")
+ assert(gs0.recessive == 0.0,
+ "GenotypeState.recessive does not correctly map 0.0 to 0.0")
+ }
}
diff --git a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/sql/GnocchiSessionSuite.scala b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/sql/GnocchiSessionSuite.scala
index 5e82d939..a3812c7c 100644
--- a/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/sql/GnocchiSessionSuite.scala
+++ b/gnocchi-core/src/test/scala/org/bdgenomics/gnocchi/sql/GnocchiSessionSuite.scala
@@ -19,7 +19,6 @@
package org.bdgenomics.gnocchi.sql
import org.bdgenomics.gnocchi.GnocchiFunSuite
-import org.bdgenomics.gnocchi.algorithms.siteregression.{ AdditiveLinearRegression, DominantLinearRegression }
import org.bdgenomics.gnocchi.primitives.genotype.GenotypeState
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
diff --git a/pom.xml b/pom.xml
index 1fcc4f0a..cf5691c2 100755
--- a/pom.xml
+++ b/pom.xml
@@ -163,7 +163,6 @@
As explained here: http://stackoverflow.com/questions/1660441/java-flag-to-enable-extended-serialization-debugging-info
The second option allows us better debugging for serialization-based errors.
-->
- -Xmx1024m -Dsun.io.serialization.extendedDebugInfo=true -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.PDuErZrK -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.d1mt6vD9 -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.eRmiFaTc -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.u9bHWT1h -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.J5tKjn9N -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.PyTFRsyh