Skip to content

Commit

Permalink
Move to a more generic structure for models package (bigdatagenomics#23)
Browse files Browse the repository at this point in the history
* models refactor

* refactor models

* generics for linear

* concordance with new models

* fix test suite
  • Loading branch information
nathanielparke authored Oct 16, 2017
1 parent a72c7cf commit 1637917
Show file tree
Hide file tree
Showing 25 changed files with 371 additions and 1,027 deletions.
8 changes: 4 additions & 4 deletions examples/run-example.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import org.bdgenomics.gnocchi.sql.GnocchiSession._
import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLinearRegression
val genotypesPath = "examples/testData/time_phenos.vcf"
val phenotypesPath = "examples/testData/tab_time_phenos.txt"
import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression
val genotypesPath = "examples/testData/time_genos_1.vcf"
val phenotypesPath = "examples/testData/tab_time_phenos_1.txt"
val geno = sc.loadGenotypes(genotypesPath)
val pheno = sc.loadPhenotypes(phenotypesPath, "IID", "pheno_1", "\t", Option(phenotypesPath), Option(List("pheno_4", "pheno_5")))

Expand All @@ -10,4 +10,4 @@ val filteredGenoVariants = sc.filterVariants(filteredGeno, geno = 0.1, maf = 0.1

val broadPheno = sc.broadcast(pheno)

val assoications = AdditiveLinearRegression(geno, broadPheno)
val assoications = LinearRegression(geno, broadPheno)
8 changes: 4 additions & 4 deletions examples/test_merge.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import org.bdgenomics.gnocchi.sql.GnocchiSession._
val genotypesPath1 = "testData/time_genos_1.vcf"
val phenotypesPath1 = "testData/tab_time_phenos_1.txt"
val genotypesPath1 = "examples/testData/time_genos_1.vcf"
val phenotypesPath1 = "examples/testData/tab_time_phenos_1.txt"
val geno1 = sc.loadGenotypes(genotypesPath1)
val pheno1 = sc.loadPhenotypes(phenotypesPath1, "IID", "pheno_1", "\t")

Expand All @@ -19,7 +19,7 @@ val broadPheno1 = sc.broadcast(pheno1)

// val broadPheno2 = sc.broadcast(pheno2)

import org.bdgenomics.gnocchi.algorithms.siteregression.AdditiveLinearRegression
import org.bdgenomics.gnocchi.algorithms.siteregression.LinearSiteRegression

val assoc_1 = AdditiveLinearRegression(fullFiltered1, broadPheno1)
val assoc_1 = LinearSiteRegression(fullFiltered1, broadPheno1)
// val assoc_2 = AdditiveLinearRegression(fullFiltered2, broadPheno2)
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
package org.bdgenomics.gnocchi.cli

import org.bdgenomics.gnocchi.algorithms.siteregression._
import org.bdgenomics.gnocchi.models.variant.VariantModel
import org.bdgenomics.gnocchi.models.variant.linear.{ AdditiveLinearVariantModel, DominantLinearVariantModel }
import org.bdgenomics.gnocchi.models.variant.logistic.{ AdditiveLogisticVariantModel, DominantLogisticVariantModel }
import org.bdgenomics.gnocchi.models.variant.{ LinearVariantModel, LogisticVariantModel, VariantModel }
import org.bdgenomics.gnocchi.sql.GnocchiSession._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
Expand Down Expand Up @@ -122,20 +120,20 @@ class RegressPhenotypes(protected val args: RegressPhenotypesArgs) extends BDGSp

args.associationType match {
case "ADDITIVE_LINEAR" => {
val associations = AdditiveLinearRegression(filteredGeno, broadPhenotype)
logResults[AdditiveLinearVariantModel](associations, sc)
val associations = LinearSiteRegression(filteredGeno, broadPhenotype, "ADDITIVE")
logResults[LinearVariantModel](associations, sc)
}
case "DOMINANT_LINEAR" => {
val associations = DominantLinearRegression(filteredGeno, broadPhenotype)
logResults[DominantLinearVariantModel](associations, sc)
val associations = LinearSiteRegression(filteredGeno, broadPhenotype, "DOMINANT")
logResults[LinearVariantModel](associations, sc)
}
case "ADDITIVE_LOGISTIC" => {
val associations = AdditiveLogisticRegression(filteredGeno, broadPhenotype)
logResults[AdditiveLogisticVariantModel](associations, sc)
val associations = LogisticSiteRegression(filteredGeno, broadPhenotype, "ADDITIVE")
logResults[LogisticVariantModel](associations, sc)
}
case "DOMINANT_LOGISTIC" => {
val associations = DominantLogisticRegression(filteredGeno, broadPhenotype)
logResults[DominantLogisticVariantModel](associations, sc)
val associations = LogisticSiteRegression(filteredGeno, broadPhenotype, "DOMINANT")
logResults[LogisticVariantModel](associations, sc)
}
}
}
Expand Down
1 change: 0 additions & 1 deletion gnocchi-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
As explained here: http://stackoverflow.com/questions/1660441/java-flag-to-enable-extended-serialization-debugging-info
The second option allows us better debugging for serialization-based errors.
-->
<argLine>-Xmx1024m -Dsun.io.serialization.extendedDebugInfo=true -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.PDuErZrK -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.d1mt6vD9 -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.eRmiFaTc -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.u9bHWT1h -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.J5tKjn9N -Djava.io.tmpdir=/var/folders/66/8r2ldh5d0xq4zcglwgjf6gv80000gn/T/gnocchiTestMvnXXXXXXX.PyTFRsyh</argLine>
<stdout>F</stdout>
</configuration>
<executions>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
*/
package org.bdgenomics.gnocchi.algorithms.siteregression

import org.bdgenomics.gnocchi.models.variant.linear.{ AdditiveLinearVariantModel, DominantLinearVariantModel, LinearVariantModel }
import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
Expand All @@ -27,18 +26,31 @@ import breeze.stats._
import breeze.stats.distributions.StudentsT
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{ Dataset, SparkSession }
import org.bdgenomics.gnocchi.models.variant.LinearVariantModel

import scala.collection.immutable.Map

trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[VM] {
trait LinearSiteRegression extends SiteRegression[LinearVariantModel, LinearAssociation] {

val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._

def apply(genotypes: Dataset[CalledVariant],
phenotypes: Broadcast[Map[String, Phenotype]],
validationStringency: String = "STRICT"): Dataset[VM]
allelicAssumption: String = "ADDITIVE",
validationStringency: String = "STRICT"): Dataset[LinearVariantModel] = {
//ToDo: Singular Matrix Exceptions
genotypes.map((genos: CalledVariant) => {
val association = applyToSite(phenotypes.value, genos, allelicAssumption)
constructVM(genos, phenotypes.value.head._2, association, allelicAssumption)
})
}

def applyToSite(phenotypes: Map[String, Phenotype],
genotypes: CalledVariant): LinearAssociation = {
val (x, y) = prepareDesignMatrix(genotypes, phenotypes)
genotypes: CalledVariant,
allelicAssumption: String): LinearAssociation = {

val (x, y) = prepareDesignMatrix(genotypes, phenotypes, allelicAssumption)

// TODO: Determine if QR factorization is faster
val beta = x \ y
Expand Down Expand Up @@ -83,13 +95,20 @@ trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[
}

private[algorithms] def prepareDesignMatrix(genotypes: CalledVariant,
phenotypes: Map[String, Phenotype]): (DenseMatrix[Double], DenseVector[Double]) = {
val filteredGenotypes = genotypes.samples.filter(_.value != ".")
phenotypes: Map[String, Phenotype],
allelicAssumption: String): (DenseMatrix[Double], DenseVector[Double]) = {

val validGenos = genotypes.samples.filter(_.value != ".")

val (primitiveX, primitiveY) = filteredGenotypes.flatMap({
case gs if phenotypes.contains(gs.sampleID) => {
val pheno = phenotypes(gs.sampleID)
Some(1.0 +: clipOrKeepState(gs.toDouble) +: pheno.covariates.toArray, pheno.phenotype)
val samplesGenotypes = allelicAssumption.toUpperCase match {
case "ADDITIVE" => validGenos.map(genotypeState => (genotypeState.sampleID, genotypeState.additive))
case "DOMINANT" => validGenos.map(genotypeState => (genotypeState.sampleID, genotypeState.dominant))
case "RECESSIVE" => validGenos.map(genotypeState => (genotypeState.sampleID, genotypeState.recessive))
}

val (primitiveX, primitiveY) = samplesGenotypes.flatMap({
case (sampleID, genotype) if phenotypes.contains(sampleID) => {
Some(1.0 +: genotype +: phenotypes(sampleID).covariates.toArray, phenotypes(sampleID).phenotype)
}
case _ => None
}).toArray.unzip
Expand All @@ -107,73 +126,22 @@ trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[
(new DenseMatrix(primitiveX.length, primitiveX(0).length, primitiveX.transpose.flatten), new DenseVector(primitiveY))
}

protected def constructVM(variant: CalledVariant,
phenotype: Phenotype,
association: LinearAssociation): VM
}

object AdditiveLinearRegression extends AdditiveLinearRegression {
val regressionName = "additiveLinearRegression"
}

trait AdditiveLinearRegression extends LinearSiteRegression[AdditiveLinearVariantModel] with Additive {
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._

def apply(genotypes: Dataset[CalledVariant],
phenotypes: Broadcast[Map[String, Phenotype]],
validationStringency: String = "STRICT"): Dataset[AdditiveLinearVariantModel] = {

//ToDo: Singular Matrix Exceptions
genotypes.map((genos: CalledVariant) => {
val association = applyToSite(phenotypes.value, genos)
constructVM(genos, phenotypes.value.head._2, association)
})
}

protected def constructVM(variant: CalledVariant,
phenotype: Phenotype,
association: LinearAssociation): AdditiveLinearVariantModel = {
AdditiveLinearVariantModel(variant.uniqueID,
def constructVM(variant: CalledVariant,
phenotype: Phenotype,
association: LinearAssociation,
allelicAssumption: String): LinearVariantModel = {
LinearVariantModel(variant.uniqueID,
association,
phenotype.phenoName,
variant.chromosome,
variant.position,
variant.referenceAllele,
variant.alternateAllele,
allelicAssumption,
phaseSetId = 0)
}
}

object DominantLinearRegression extends DominantLinearRegression {
val regressionName = "dominantLinearRegression"
}

trait DominantLinearRegression extends LinearSiteRegression[DominantLinearVariantModel] with Dominant {
val sparkSession = SparkSession.builder().getOrCreate()
import sparkSession.implicits._

def apply(genotypes: Dataset[CalledVariant],
phenotypes: Broadcast[Map[String, Phenotype]],
validationStringency: String = "STRICT"): Dataset[DominantLinearVariantModel] = {

//ToDo: Singular Matrix Exceptions
genotypes.map((genos: CalledVariant) => {
val association = applyToSite(phenotypes.value, genos)
constructVM(genos, phenotypes.value.head._2, association)
})
}

protected def constructVM(variant: CalledVariant,
phenotype: Phenotype,
association: LinearAssociation): DominantLinearVariantModel = {
DominantLinearVariantModel(variant.uniqueID,
association,
phenotype.phenoName,
variant.chromosome,
variant.position,
variant.referenceAllele,
variant.alternateAllele,
phaseSetId = 0)
}
}
object LinearSiteRegression extends LinearSiteRegression {
val regressionName = "LinearSiteRegression"
}
Loading

0 comments on commit 1637917

Please sign in to comment.