Skip to content

Commit

Permalink
Linear regression refactor (bigdatagenomics#20)
Browse files Browse the repository at this point in the history
* LogisticSiteRegression single test

* changed

* Finish linear regression impl

* Linear regression breeze impl complete

* Enable Anscombe tests

* Add PIQ test

* Fixes for p-value calc, deviations, matrix inst

* More linear regression tests

* Removed nearby() (not required with scalactic)

* Formatting

* Reflect nate's comments
  • Loading branch information
p-yang authored and nathanielparke committed Oct 13, 2017
1 parent dea8236 commit a72c7cf
Show file tree
Hide file tree
Showing 3 changed files with 440 additions and 382 deletions.
140 changes: 66 additions & 74 deletions ...rc/main/scala/org/bdgenomics/gnocchi/algorithms/siteregression/LinearSiteRegression.scala
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ import org.bdgenomics.gnocchi.models.variant.linear.{ AdditiveLinearVariantModel
import org.bdgenomics.gnocchi.primitives.association.LinearAssociation
import org.bdgenomics.gnocchi.primitives.phenotype.Phenotype
import org.bdgenomics.gnocchi.primitives.variants.CalledVariant
import org.apache.commons.math3.distribution.TDistribution
import org.apache.commons.math3.linear.SingularMatrixException
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression
import breeze.linalg._
import breeze.numerics._
import breeze.stats._
import breeze.stats.distributions.StudentsT
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{ Dataset, SparkSession }

import scala.collection.immutable.Map
import scala.math.log10

trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[VM] {

Expand All @@ -38,81 +38,73 @@ trait LinearSiteRegression[VM <: LinearVariantModel[VM]] extends SiteRegression[

def applyToSite(phenotypes: Map[String, Phenotype],
genotypes: CalledVariant): LinearAssociation = {

val XandY = prepareDesignMatrix(phenotypes, genotypes)
val x = XandY.map(_._1.toArray).toArray
val y = XandY.map(_._2).toArray

val phenotypesLength = phenotypes.head._2.covariates.length + 1

try {
// create linear model
val ols = new OLSMultipleLinearRegression()

// input sample data
ols.newSampleData(y, x)

// calculate coefficients
val beta = ols.estimateRegressionParameters()

// calculate sum of squared residuals
val ssResiduals = ols.calculateResidualSumOfSquares()

// calculate sum of squared deviations
val ssDeviations = sumOfSquaredDeviations(genotypes)

// compute the regression parameters standard errors
val standardErrors = ols.estimateRegressionParametersStandardErrors()

// get standard error for genotype parameter (for p value calculation)
val genoSE = standardErrors(1)

// test statistic t for jth parameter is equal to bj/SEbj, the parameter estimate divided by its standard error
val t = beta(1) / genoSE

/* calculate p-value and report:
Under null hypothesis (i.e. the j'th element of weight vector is 0) the relevant distribution is
a t-distribution with N-p-1 degrees of freedom. (N = number of samples, p = number of regressors i.e. genotype+covariates+intercept)
https://en.wikipedia.org/wiki/T-statistic
*/
val residualDegreesOfFreedom = genotypes.numValidSamples - phenotypesLength - 1
val tDist = new TDistribution(residualDegreesOfFreedom)
val pvalue = 2 * tDist.cumulativeProbability(-math.abs(t))
val logPValue = log10(pvalue)

LinearAssociation(
ssDeviations,
ssResiduals,
genoSE,
t,
residualDegreesOfFreedom,
pvalue,
beta.toList,
genotypes.numValidSamples)
} catch {
case _: breeze.linalg.MatrixSingularException => {
throw new SingularMatrixException()
}
}
val (x, y) = prepareDesignMatrix(genotypes, phenotypes)

// TODO: Determine if QR factorization is faster
val beta = x \ y

val residuals = y - (x * beta)
val ssResiduals = residuals.t * residuals

// calculate sum of squared deviations
val deviations = y - mean(y)
val ssDeviations = deviations.t * deviations

// compute the regression parameters standard errors
val betaVariance = diag(inv(x.t * x))
val sigma = residuals.t * residuals / (x.rows - x.cols)
val standardErrors = sqrt(sigma * betaVariance)

// get standard error for genotype parameter (for p value calculation)
val genoSE = standardErrors(1)

// test statistic t for jth parameter is equal to bj/SEbj, the parameter estimate divided by its standard error
val t = beta(1) / genoSE

/* calculate p-value and report:
Under null hypothesis (i.e. the j'th element of weight vector is 0) the relevant distribution is
a t-distribution with N-p degrees of freedom.
(N = number of samples, p = number of regressors i.e. genotype+covariates+intercept)
https://en.wikipedia.org/wiki/T-statistic
*/
val residualDegreesOfFreedom = x.rows - x.cols
val tDist = StudentsT(residualDegreesOfFreedom)
val pValue = 2 * tDist.cdf(-math.abs(t))

LinearAssociation(
ssDeviations,
ssResiduals,
genoSE,
t,
residualDegreesOfFreedom,
pValue,
beta.data.toList,
genotypes.numValidSamples)
}

private def prepareDesignMatrix(phenotypes: Map[String, Phenotype],
genotypes: CalledVariant): List[(List[Double], Double)] = {
private[algorithms] def prepareDesignMatrix(genotypes: CalledVariant,
phenotypes: Map[String, Phenotype]): (DenseMatrix[Double], DenseVector[Double]) = {
val filteredGenotypes = genotypes.samples.filter(_.value != ".")

// class for ols: org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression
// see http://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/org/apache/commons/math3/stat/regression/OLSMultipleLinearRegression.html
val samplesGenotypes = genotypes.samples.filter(x => !x.value.contains(".")).map(x => (x.sampleID, List(clipOrKeepState(x.toDouble))))
val samplesCovariates = phenotypes.map(x => (x._1, x._2.covariates)).toMap
val cleanedSampleVector = samplesGenotypes.map(x => (x._1, (x._2 ++ samplesCovariates(x._1)).toList)).toMap
val (primitiveX, primitiveY) = filteredGenotypes.flatMap({
case gs if phenotypes.contains(gs.sampleID) => {
val pheno = phenotypes(gs.sampleID)
Some(1.0 +: clipOrKeepState(gs.toDouble) +: pheno.covariates.toArray, pheno.phenotype)
}
case _ => None
}).toArray.unzip

cleanedSampleVector.toList.map(x => (x._2, phenotypes(x._1).phenotype.toDouble))
}
if (primitiveX.length == 0) {
// TODO: Determine what to do when the design matrix is empty (i.e. no overlap btwn geno and pheno sampleIDs, etc.)
throw new IllegalArgumentException("No overlap between phenotype and genotype state sample IDs.")
}

// NOTE: This may cause problems in the future depending on JVM max varargs, use one of these instead if it breaks:
// val x = new DenseMatrix(x(0).length, x.length, x.flatten).t
// val x = new DenseMatrix(x.length, x(0).length, x.flatten, 0, x(0).length, isTranspose = true)
// val x = new DenseMatrix(x :_*)

protected def sumOfSquaredDeviations(genotypes: CalledVariant): Double = {
val sum = genotypes.samples.filter(x => !x.value.contains(".")).map(x => clipOrKeepState(x.toDouble)).sum
val mean = sum / genotypes.numValidSamples
val squaredDeviations = genotypes.samples.map(x => math.pow(x.toDouble - mean, 2))
squaredDeviations.sum
(new DenseMatrix(primitiveX.length, primitiveX(0).length, primitiveX.transpose.flatten), new DenseVector(primitiveY))
}

protected def constructVM(variant: CalledVariant,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,3 @@ trait GnocchiFunSuite extends SparkFunSuite {
phenos.toMap
}
}

Loading

0 comments on commit a72c7cf

Please sign in to comment.