-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non-deterministic behavior of gradient registration #367
Comments
Interestingly, I did now run the consistency test on the current develop and the issue seems to be fixed. |
Unfortunately, this is not yet fixed. I tested on the recent develop branch at v0.91-RC3. The MeanPointwiseLossMetric uses the ParVector and therefore has some small numerical inaccuracies in terms of reproducability. I used the snippet below to check it: import breeze.linalg.{DenseMatrix, DenseVector}
import scalismo.common.interpolation.TriangleMeshInterpolator3D
import scalismo.common.{DifferentiableField, Domain, EuclideanSpace, EuclideanSpace3D, Field, PointId, Scalar}
import scalismo.geometry.{EuclideanVector, NDSpace, Point, _3D}
import scalismo.kernels.{DiagonalKernel3D, GaussianKernel3D, MatrixValuedPDKernel}
import scalismo.mesh.{TriangleCell, TriangleList, TriangleMesh, TriangleMesh3D}
import scalismo.numerics.{FixedPointsUniformMeshSampler3D, Sampler}
import scalismo.registration.RegistrationMetric.ValueAndDerivative
import scalismo.registration.{GaussianProcessTransformationSpace, ImageMetric, MeanPointwiseLossMetric, MeanSquaresMetric}
import scalismo.statisticalmodel.{GaussianProcess, LowRankGaussianProcess}
import scalismo.transformations.TransformationSpace
import scalismo.utils.Random
import scala.collection.parallel.immutable.ParVector
object TestPointMetric extends App {
scalismo.initialize()
implicit val rng: Random = Random(1024L)
val mesh1 = generateRandomMesh()
val mesh2 = generateRandomMesh()
val zeroMean = Field(EuclideanSpace3D, (_: Point[_3D]) => EuclideanVector.zeros[_3D])
val cov: MatrixValuedPDKernel[_3D] = DiagonalKernel3D(GaussianKernel3D(sigma = 20) * 5.0, outputDim = 3)
val gp = GaussianProcess(zeroMean, cov)
val interpolator = TriangleMeshInterpolator3D[EuclideanVector[_3D]]()
val lowRankGP = LowRankGaussianProcess.approximateGPCholesky(
mesh1,
gp,
relativeTolerance = 0.1,
interpolator = interpolator)
val transformationSpace = GaussianProcessTransformationSpace(lowRankGP)
val fixedImage = mesh1.operations.toDistanceImage
val movingImage = mesh2.operations.toDistanceImage
val sampler = FixedPointsUniformMeshSampler3D(mesh1, numberOfPoints = 1000)(rng)
// THE PROVIDED METRIC HAS NUMERICAL INACURACIES
// val metric = MeanSquaresMetric(fixedImage, movingImage, transformationSpace, sampler)
// THE METRIC WITH REMOVED PAR IS SLOWER BUT PROVIDES ALWAYS THE EXACT SAME VALUE
val metric = SequentialMeanPointwiseLossMetric(fixedImage, movingImage, transformationSpace, sampler)
val coefficients = DenseVector.fill(lowRankGP.rank)(rng.scalaRandom.nextDouble())
val reference = metric.value(coefficients)
val values = (0 until 1000).map{i => println(i); metric.value(coefficients)}
println(s"max diff: ${values.map(value => math.abs(value-reference)).max}")
println(s"${values.count( _ == reference)} identical values")
def generateRandomMesh(numberOfPoints: Int = 400, numberOfTriangles: Int = 500) = {
TriangleMesh3D(IndexedSeq.fill(numberOfPoints)(Point(
rng.scalaRandom.nextDouble(),
rng.scalaRandom.nextDouble(),
rng.scalaRandom.nextDouble()
)),TriangleList(IndexedSeq.fill(numberOfTriangles)(
TriangleCell(
PointId(rng.scalaRandom.nextInt(numberOfPoints)),
PointId(rng.scalaRandom.nextInt(numberOfPoints)),
PointId(rng.scalaRandom.nextInt(numberOfPoints))
)
)))
}
}
case class SequentialMeanPointwiseLossMetric[D: NDSpace, A: Scalar](
fixedImage: Field[D, A],
movingImage: DifferentiableField[D, A],
transformationSpace: TransformationSpace[D],
sampler: Sampler[D]
) extends ImageMetric[D, A] {
override val ndSpace: NDSpace[D] = implicitly[NDSpace[D]]
val scalar = Scalar[A]
protected def lossFunction(v: A): Double = {
val value = scalar.toDouble(v)
value * value;
}
protected def lossFunctionDerivative(v: A): Double = {
2.0 * scalar.toDouble(v)
}
def value(parameters: DenseVector[Double]): Double = {
computeValue(parameters, sampler)
}
// compute the derivative of the cost function
def derivative(parameters: DenseVector[Double]): DenseVector[Double] = {
computeDerivative(parameters, sampler)
}
override def valueAndDerivative(parameters: DenseVector[Double]): ValueAndDerivative = {
// We create a new sampler, which always returns the same points. In this way we can make sure that the
// same sample points are used for computing the value and the derivative
val sampleOnceSampler = new Sampler[D] {
override val numberOfPoints: Int = sampler.numberOfPoints
private val samples = sampler.sample()
override def sample(): IndexedSeq[(Point[D], Double)] = samples
override def volumeOfSampleRegion: Double = sampler.volumeOfSampleRegion
}
val value = computeValue(parameters, sampleOnceSampler)
val derivative = computeDerivative(parameters, sampleOnceSampler)
ValueAndDerivative(value, derivative)
}
private def computeValue(parameters: DenseVector[Double], sampler: Sampler[D]) = {
val transform = transformationSpace.transformationForParameters(parameters)
val warpedImage = movingImage.compose(transform)
val metricValue = (fixedImage - warpedImage).andThen(lossFunction _).liftValues
// we compute the mean using a monte carlo integration
val samples = sampler.sample()
samples.toIndexedSeq.map { case (pt, _) => metricValue(pt).getOrElse(0.0) }.sum / samples.size
}
private def computeDerivative(parameters: DenseVector[Double], sampler: Sampler[D]): DenseVector[Double] = {
val transform = transformationSpace.transformationForParameters(parameters)
val movingImageGradient = movingImage.differentiate
val warpedImage = movingImage.compose(transform)
val dDMovingImage = (warpedImage - fixedImage).andThen(lossFunctionDerivative _)
val dMovingImageDomain = Domain.intersection(warpedImage.domain, fixedImage.domain)
val fullMetricGradient = (x: Point[D]) => {
val domain = Domain.intersection(fixedImage.domain, dMovingImageDomain)
if (domain.isDefinedAt(x))
Some(
transform
.derivativeWRTParameters(x)
.t * (movingImageGradient(transform(x)) * dDMovingImage(x).toDouble).toBreezeVector
)
else None
}
// we compute the mean using a monte carlo integration
val samples = sampler.sample()
val zeroVector = DenseVector.zeros[Double](transformationSpace.numberOfParameters)
val gradientValues = new ParVector(samples.toVector).map {
case (pt, _) => fullMetricGradient(pt).getOrElse(zeroVector)
}
gradientValues.foldLeft(zeroVector)((acc, g) => acc + g) * (1.0 / samples.size)
}
} |
PS: A much simpler test, illustrating the problem, is the following snippet using scala-cli //> using scala "2.13"
//> using lib "org.scala-lang.modules:scala-parallel-collections_2.13:1.0.4"
import scala.collection.parallel.CollectionConverters._
object test_ParVector extends App {
val rng = scala.util.Random
val values = IndexedSeq.fill(10000)(rng.nextDouble * 10000)
val parValues = values.par
val reference = values.sum
println((0 until 10000).count { _ =>
val current = parValues.sum
if (reference != current) {
println(reference - current)
true
} else {
false
}
})
} |
If we impose an order on the Parvector summing (with .toIndexedSeq.sorted.sum) we could still leverage some parallelization while remaining deterministic. I tested replacing line 121 of your first snippet with: samples.par.map { case (pt, _) => metricValue(pt).getOrElse(0.0) }.toIndexedSeq.sorted.sum / samples.size and it seems to work (approx 3x faster than sequential) |
Are the errors we get so large that the order we sum up the terms has a practical impact in the optimization or is the problem just that the two resulting number are not exactly equal? If the former is true, we should maybe look more carefully at the numerics and see how we could stabilize the computations? It would mean that the result we get depends on sheer luck any particular order that we enforce could still be suboptimal. If it does not have a real impact in the results, we should maybe just make our tests more tolerant. |
It seems that for most practical applications the error from the optimization, as shown above, is not within a range I would consider harmful. One problem we are currently facing is that when you start a stochastic algorithm from just this tiny little bit of different results, you can end up with really different results. Maybe not better or worse, but definitively not with exact reproducibility which is sometimes requested. In my opinion, it is a fair question if scalismo should offer such precise reproducibility or not. If one needs it, as we do at the moment, one can implement also a different metric. |
I would like to come back to this issue as reproducible results are important for our deliverables. In the end, there were 2 issues causing the non-deterministic behavior:
Back then, we published our own Scalismo version v0.90.0-deterministic to fix these issues (Shapemeans-GmbH@86e6818) .. However, long term, it would be great if this is integrated into the main branch somehow. |
Thanks for reviving this discussion. I think we really need to fix point 2 of your list. This is a clear bug. I am still reluctant to the fix for point 1. To me it seems that if the differences are tiny, we should not incur a huge performance penalty just to have perfect reproducibility. Results from a floating point computation should never be compared to an exact value but always lie withing an epsilon range. Of course if large differences are observed, we should do something against it. But if the range is small, and the differences just multiply due to a downstream task (as Andreas mentions above) then I think it is a bug in the downstream task. Do you have any experiments that quantify the differences of the metric numerically? |
The short version:
The bit longer version: I disagree with the claim that if differences multiply due to a downstream task then this would mean there is a bug in the downstream. If the downstream task is stochastic (for example MCMC), currently, we can very well ensure deterministic results by seeding all random generators properly. However, if some "randomness" comes for the upstream result, which was assumed to be deterministic, this can lead to different results with the same input. And as Andreas noted, the scale of the differences from the upstream is not really what is important here. For Scalismo to be used in practical applications (e.g. medical devices), the same input needs to give the same result at different runs. As said above, these arguments do no affect this issue since this can be solved with the 2 bullet points above. Just important for future developments in my opinion. |
@Ghazi-Bouabene After an (offline) discussion with @thogerig , the problem is more clear to me now. As you suggested above, it should be sufficient to modify line 121. But instead of sorting the collection, we should be able to just use
Maybe you could see if it solves your problem and do a PR? |
Are there any news on this issue? Does the fix with .seq solve the problem, or will you use your own implementation of the metric? |
The fix we implemented in the metric was to convert the values to BigDecimal before performing the sum (Shapemeans-GmbH@86e6818). As previously said, if there is objection to changing the metric, we can fix this on the application side and not necessarily in Scalismo. For the other issue related to the Spatial Index cache, we will push a PR soon. |
Great, thanks for the clarification. Then we leave the metric as is. Maybe it is still worthwhile to give |
The following snippet for testing the numerical consistency of the registration output lead to non-deterministic behavior since not all test-runs gave exactly the same output.
The reason seems to be an issue of the parallel calculation of the sum in the MeanPointwiseLossMetric.scala in line 81. Changing there the line to the following by adding a toIndexedSeq, the test passes and all runs are consistent. It has been tested mainly on 0.18.
I will make a PR with a test and a fix.
The text was updated successfully, but these errors were encountered: