Skip to content

Commit

Permalink
Dynamic LCP Mode Choice Update
Browse files Browse the repository at this point in the history
  • Loading branch information
CorinStaves committed Sep 18, 2024
1 parent f2dc7b0 commit 4446fed
Show file tree
Hide file tree
Showing 15 changed files with 701 additions and 469 deletions.
25 changes: 18 additions & 7 deletions src/main/java/estimation/BFGS.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package estimation;

import estimation.dynamic.DynamicUtilityComponent;
import org.apache.log4j.Logger;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MultivariateFunction;

Expand Down Expand Up @@ -76,7 +77,7 @@
* @author Haifeng Li
*/
public class BFGS {
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(BFGS.class);
private static final Logger logger = Logger.getLogger(BFGS.class);
/** A number close to zero, between machine epsilon and its square root. */
private static final double EPSILON = Double.parseDouble(System.getProperty("smile.bfgs.epsilon", "1E-8"));
/** The convergence criterion on x values. */
Expand Down Expand Up @@ -167,12 +168,8 @@ public static Results minimize(DifferentiableMultivariateFunction func,DynamicUt

for (int iter = 1; iter <= maxIter; iter++) {

if(dynamicUtilityComponent != null) {
dynamicUtilityComponent.update(x);
}

// The new function evaluation occurs in line search.
f = linesearch(func, x, f, g, xi, xnew, stpmax);
f = linesearch(func, x, f, g, xi, xnew, stpmax, dynamicUtilityComponent);

logger.info(String.format("BFGS: the function value after %3d iterations: %.5f", iter, f));

Expand Down Expand Up @@ -323,7 +320,8 @@ public static Results minimize(DifferentiableMultivariateFunction func,DynamicUt
*
* @return the new function value.
*/
private static double linesearch(MultivariateFunction func, double[] xold, double fold, double[] g, double[] p, double[] x, double stpmax) {
private static double linesearch(MultivariateFunction func, double[] xold, double fold, double[] g, double[] p, double[] x, double stpmax,
DynamicUtilityComponent dynamicUtilityComponent) {
if (stpmax <= 0) {
throw new IllegalArgumentException("Invalid upper bound of linear search step: " + stpmax);
}
Expand Down Expand Up @@ -369,21 +367,34 @@ private static double linesearch(MultivariateFunction func, double[] xold, doubl

double alam2 = 0.0, f2 = 0.0;
double a, b, disc, rhs1, rhs2, tmpalam;
int runCount = 0;
while (true) {
runCount++;
// Evaluate the function and gradient at stp
// and compute the directional derivative.
for (int i = 0; i < n; i++) {
x[i] = xold[i] + alam * p[i];
}

// Update dynamic component
if(dynamicUtilityComponent != null) {
dynamicUtilityComponent.update(x);
}

double f = func.apply(x);

// Convergence on &Delta; x.
if (alam < alammin) {
System.arraycopy(xold, 0, x, 0, n);
logger.info("Linesearch ran " + runCount + " times, no update.");
// Go back to old dynamic component (unlikely)
if(dynamicUtilityComponent != null) {
dynamicUtilityComponent.update(x);
}
return f;
} else if (f <= fold + ftol * alam * slope) {
// Sufficient function decrease.
logger.info("Linesearch ran " + runCount + " times, completed with sufficient function decrease.");
return f;
} else {
// Backtrack
Expand Down
39 changes: 33 additions & 6 deletions src/main/java/estimation/CoefficientsWriter.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package estimation;

import estimation.utilities.AbstractUtilityFunction;
import estimation.utilities.AbstractUtilitySpecification;
import io.ioUtils;
import org.apache.log4j.Logger;

Expand All @@ -14,8 +14,27 @@ public class CoefficientsWriter {
private final static Logger logger = Logger.getLogger(CoefficientsWriter.class);
private final static String SEP = ",";

static void print(AbstractUtilitySpecification u, BFGS.Results results, double[] se, double[] t, double[] pVal, String[] sig) {

logger.info("ESTIMATION RESULTS AFTER " + results.iterations + " ITERATIONS:");
Iterator<String> coeffNames = u.getCoeffNames().iterator();
double[] finalResults = u.expandCoeffs(results.xAtEachIteration.get(results.iterations));
String[] seAll = u.expand(se,"%.5f");
String[] tAll = u.expand(t,"% .5f");
String[] pAll = u.expand(pVal,"%.5f");
String[] sigAll = u.expand(sig);

System.out.printf("| %-30s | %-10s | %-7s | %-10s | %-10s |%n","COEFFICIENT NAME","VALUE","STD.ERR","T.TEST","P.VAL");
int i = 0;
while(coeffNames.hasNext()) {
System.out.printf("| %-30s | % .7f | %-7s | %-10s | %-7s %-3s |%n",coeffNames.next(),finalResults[i],seAll[i],tAll[i],pAll[i],sigAll[i]);
i++;
}

}

// Write results to csv file
static void write(AbstractUtilityFunction u, BFGS.Results results, double[] se, double[] t, double[] pVal, String[] sig, String filePath) {
static void write(AbstractUtilitySpecification u, BFGS.Results results, double[] se, double[] t, double[] pVal, String[] sig, String filePath) {

PrintWriter out = ioUtils.openFileForSequentialWriting(new File(filePath),false);
assert out != null;
Expand All @@ -25,10 +44,18 @@ static void write(AbstractUtilityFunction u, BFGS.Results results, double[] se,
out.println(i + SEP + results.lAtEachIteration.get(i) + SEP + Arrays.stream(u.expandCoeffs(results.xAtEachIteration.get(i))).mapToObj(String::valueOf).collect(Collectors.joining(SEP)));
}

out.println("std.err" + SEP + SEP + String.join(SEP,u.expand(se)));
out.println("t.test" + SEP + SEP + String.join(SEP,u.expand(t)));
out.println("p.val" + SEP + SEP + String.join(SEP,u.expand(pVal)));
out.println("sig" + SEP + SEP + String.join(SEP,u.expand(sig)));
if(se != null) {
out.println("std.err" + SEP + SEP + String.join(SEP,u.expand(se)));
}
if(t != null) {
out.println("t.test" + SEP + SEP + String.join(SEP,u.expand(t)));
}
if(pVal != null) {
out.println("p.val" + SEP + SEP + String.join(SEP,u.expand(pVal)));
}
// if(sig != null) {
// out.println("sig" + SEP + SEP + String.join(SEP,u.expand(sig)));
// }

out.close();
logger.info("Wrote coefficients at each iteration to " + filePath);
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/estimation/FlexibleMultinomialObjective.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package estimation;

import estimation.utilities.AbstractUtilityFunction;
import estimation.utilities.AbstractUtilitySpecification;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;

Expand All @@ -9,7 +9,7 @@

public class FlexibleMultinomialObjective implements DifferentiableMultivariateFunction {

final AbstractUtilityFunction u;
final AbstractUtilitySpecification u;
final int[] y;
final int k;
final int p;
Expand All @@ -20,7 +20,7 @@ public class FlexibleMultinomialObjective implements DifferentiableMultivariateF
final double[][] posterioris;


FlexibleMultinomialObjective(AbstractUtilityFunction u, int[] y, int k, double lambda) {
FlexibleMultinomialObjective(AbstractUtilitySpecification u, int[] y, int k, double lambda) {

this.u = u;
this.y = y;
Expand Down Expand Up @@ -91,7 +91,7 @@ public double g(double[] w, double[] g) {
for (int j = 0; j < k; j++) {
double err = (y[i] == j ? 1.0 : 0.0) - posteriori[j];
for (int l = 0; l < p; l++) {
gradient[l] -= err * u.getDerivative(i,j,l);
gradient[l] -= err * u.getDerivative(i,j,wAll,l);
}
}

Expand Down
45 changes: 31 additions & 14 deletions src/main/java/estimation/MultinomialLogit.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package estimation;

import estimation.utilities.AbstractUtilityFunction;
import estimation.utilities.AbstractUtilitySpecification;
import org.apache.log4j.Logger;
import smile.math.matrix.Matrix;
import smile.stat.distribution.GaussianDistribution;
Expand All @@ -12,7 +12,7 @@ public class MultinomialLogit {

private final static Logger logger = Logger.getLogger(MultinomialLogit.class);

public static void run(AbstractUtilityFunction u, int[] y, int k, double lambda, double tol, int maxIter, String resultsFileName) {
public static void run(AbstractUtilitySpecification u, int[] y, int k, double lambda, double tol, int maxIter, String resultsFileName) {

if (lambda < 0.0) {
throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
Expand All @@ -34,25 +34,42 @@ public static void run(AbstractUtilityFunction u, int[] y, int k, double lambda,
logger.info("finished estimation.");

// Approximate variance-coviariance matrix (from BFGS method) – for debugging only
Matrix approxVarCov = Matrix.of(results.hessian);
// Matrix approxVarCov = Matrix.of(results.hessian);

// Hessian computed as numerical jacobian of gradient
Matrix numericalHessian = Matrix.of(Jacobian.richardson(objective,w));
Matrix.EVD eigenvalues = numericalHessian.eigen().sort();
double maxEigenvalue = Arrays.stream(eigenvalues.wr).max().orElseThrow();
logger.info("Eigenvalues: " + Arrays.stream(eigenvalues.wr).mapToObj(d -> String.format("%.5f",d)).collect(Collectors.joining(" , ")));
logger.info("Maximum eigenvalue: " + maxEigenvalue);

// Variance-covariance matrix computed as inverse of hessian
Matrix varCov = numericalHessian.inverse();
varCov.mul(-1);
// Check if matrix is singular
if(Arrays.stream(eigenvalues.wr).anyMatch(e -> e == 0)) {
logger.error("Hessian matrix is singular! Cannot compute standard errors");
CoefficientsWriter.write(u,results,null,null,null,null,resultsFileName);
} else {

// Standard errors
double[] se = Arrays.stream(varCov.diag()).map(Math::sqrt).toArray();
double[] t = tTest(w,se);
double[] pVal = pVal(t);
String[] sig = sig(t);
// Check convergence to saddle point (probably can still run)
if (maxEigenvalue > 0){
logger.error("Hessian is not negative definite! Convergence to saddle point!");
}

// Variance-covariance matrix computed as inverse of hessian
Matrix varCov = numericalHessian.inverse();
varCov.mul(-1);

// Standard errors
double[] se = Arrays.stream(varCov.diag()).map(Math::sqrt).toArray();
double[] t = tTest(w,se);
double[] pVal = pVal(t);
String[] sig = sig(t);

// Print results (to complete later...)
CoefficientsWriter.write(u,results,se,t,pVal,sig,resultsFileName);
// Print results to screen
CoefficientsWriter.print(u,results,se,t,pVal,sig);

// Print results to file
CoefficientsWriter.write(u,results,se,t,pVal,sig,resultsFileName);
}
}


Expand Down Expand Up @@ -91,7 +108,7 @@ private static String[] sig(double[] t) {
} else {
sig = "";
}
result[i] = String.format("%.4f",p[i]) + sig;
result[i] = sig;
}
return result;
}
Expand Down
18 changes: 10 additions & 8 deletions src/main/java/estimation/RunMnlDynamic.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package estimation;

import estimation.utilities.AbstractUtilityFunction;
import estimation.utilities.AbstractUtilitySpecification;
import estimation.utilities.MNL_Dynamic;
import gis.GisUtils;
import gis.GpkgReader;
Expand Down Expand Up @@ -37,15 +37,15 @@ public static void main(String[] args) throws IOException, FactoryException {

Resources.initializeResources(args[0]);

// Read Boundary Shapefile
logger.info("Reading boundary shapefile...");
Geometry boundary = GpkgReader.readNetworkBoundary();

// Read in TRADS trips from CSV
logger.info("Reading fixed input data from ascii file...");
LogitData logitData = new LogitData(args[1],"choice","t.ID");
logitData.read();

// Read Boundary Shapefile
logger.info("Reading boundary shapefile...");
Geometry boundary = GpkgReader.readNetworkBoundary();

// Read in TRADS trips
logger.info("Reading person micro data from ascii file...");
Set<Trip> trips = DiaryReader.readTrips(boundary);
Expand Down Expand Up @@ -75,7 +75,7 @@ public static void main(String[] args) throws IOException, FactoryException {
TravelTime ttBike = bicycle.getTravelTimeFast(networkBike);

// Deal with intrazonal trips – can remove after we get X/Y coordinates for TRADS)
Set<SimpleFeature> OAs = GisUtils.readGpkg("zones/gm_oa.gpkg");
Set<SimpleFeature> OAs = GisUtils.readGpkg("zones/2011/gm_oa.gpkg");

// Organise classes
int[] y = logitData.getChoices();
Expand All @@ -85,10 +85,12 @@ public static void main(String[] args) throws IOException, FactoryException {
System.out.println("Identified " + k + " classes.");

// Utility function
AbstractUtilityFunction u = new MNL_Dynamic(logitData,trip_data,OAs,networkBike,bike,ttBike,networkWalk,null,ttWalk);
AbstractUtilitySpecification u = new MNL_Dynamic(logitData,trip_data,OAs,networkBike,bike,ttBike,networkWalk,null,ttWalk);
// AbstractUtilityFunction u = new MNL_Static(logitData);


// Start model
MultinomialLogit.run(u,y,k,0,1e-10,500,"dynamic_results.csv");
MultinomialLogit.run(u,y,k,0,1e-10,500,"dynamic6.csv");

logger.info("finished estimation.");
}
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/estimation/UtilityFunction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package estimation;

public interface UtilityFunction {

double applyAsDouble(double[] a,int b);

}
Loading

0 comments on commit 4446fed

Please sign in to comment.