diff --git a/INSTALL_APPLICATION.md b/INSTALL_APPLICATION.md index 24fe1297a8..3158b22fd5 100644 --- a/INSTALL_APPLICATION.md +++ b/INSTALL_APPLICATION.md @@ -6,7 +6,7 @@ Please use a recent Java JDK. See [Setting up Java for Tetrad](https://github.co To download the Tetrad jar, please click the following link (which will always be updated to the latest version): -https://s01.oss.sonatype.org/content/repositories/releases/io/github/cmu-phil/tetrad-gui/7.5.0/tetrad-gui-7.5.0-launch.jar +https://s01.oss.sonatype.org/content/repositories/releases/io/github/cmu-phil/tetrad-gui/7.6.0/tetrad-gui-7.6.0-launch.jar You may be able to launch this jar by double-clicking the jar file name. However, on a Mac, this presents some security challenges. On all platforms, the jar may be launched at the command line (with a specification of the amount of RAM you will allow it to use) using this command: diff --git a/README.md b/README.md index 1d07814678..96f161e71a 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,6 @@ See out insructions for [Installing the Tetrad Application](https://github.com/c We have a project, [py-tetrad](https://github.com/cmu-phil/py-tetrad), that allows you to incorporate arbitrary Tetrad code into a Python workflow. It's new, and the installation is still nonstandard, but it had a good response. This requires Python 3.5+. and Java JDK 9+. -Please see our [description](https://sites.google.com/view/tetradcausal/tetrad-in-python - ## Tetrad in R We also have a project, [rpy-tetrad](https://github.com/cmu-phil/py-tetrad/tree/main/pytetrad/R), that allows you to incorporate _some_ Tetrad functionality in R. It's also new, and the installation for it is also still nonstandard, but has gotten good feedback. This requires Python 3.5+ and Java JDK 9+. diff --git a/data-reader/pom.xml b/data-reader/pom.xml index 448848e968..b18df6845f 100644 --- a/data-reader/pom.xml +++ b/data-reader/pom.xml @@ -5,7 +5,7 @@ io.github.cmu-phil tetrad - 7.6.0 + 7.6.1 data-reader diff --git a/pom.xml b/pom.xml index 200272634a..b03d15628e 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 io.github.cmu-phil tetrad - 7.6.0 + 7.6.1 pom Tetrad Project diff --git a/tetrad-gui/pom.xml b/tetrad-gui/pom.xml index 945acb3ec4..0483b1ab23 100644 --- a/tetrad-gui/pom.xml +++ b/tetrad-gui/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.0 + 7.6.1 tetrad-gui diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java index 20da5dae5e..3e06f09bc4 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java @@ -27,8 +27,8 @@ import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.IndependenceFact; import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.IndependenceTest; -import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndependenceResult; import edu.cmu.tetrad.search.test.MsepTest; import edu.cmu.tetrad.util.NumberFormatUtil; @@ -99,20 +99,24 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { throw new NullPointerException("Expecting a model"); } - conditioningSetTypeJComboBox.addItem("Parents(X)"); + conditioningSetTypeJComboBox.addItem("Parents(X) (Local Markov)"); + conditioningSetTypeJComboBox.addItem("Parents(X) for a Valid Order (Ordered Local Markov)"); conditioningSetTypeJComboBox.addItem("MarkovBlanket(X)"); - conditioningSetTypeJComboBox.addItem("All Subsets"); + conditioningSetTypeJComboBox.addItem("All Subsets (Global Markov)"); conditioningSetTypeJComboBox.addActionListener(e -> { switch ((String) Objects.requireNonNull(conditioningSetTypeJComboBox.getSelectedItem())) { - case "Parents(X)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.PARENTS); + case "Parents(X) (Local Markov)": + model.getMarkovCheck().setSetType(ConditioningSetType.LOCAL_MARKOV); + break; + case "Parents(X) for a Valid Order (Ordered Local Markov)": + model.getMarkovCheck().setSetType(ConditioningSetType.ORDERED_LOCAL_MARKOV); break; case "MarkovBlanket(X)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.MARKOV_BLANKET); + model.getMarkovCheck().setSetType(ConditioningSetType.MARKOV_BLANKET); break; - case "All Subsets": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.ALL_SUBSETS); + case "All Subsets (Global Markov)": + model.getMarkovCheck().setSetType(ConditioningSetType.GLOBAL_MARKOV); break; default: throw new IllegalArgumentException("Unknown conditioning set type: " + @@ -121,7 +125,7 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { class MyWatchedProcess extends WatchedProcess { public void watch() { - if (model.getMarkovCheck().getSetType() == MarkovCheck.ConditioningSetType.ALL_SUBSETS && model.getVars().size() > 12) { + if (model.getMarkovCheck().getSetType() == ConditioningSetType.GLOBAL_MARKOV && model.getVars().size() > 12) { int ret = JOptionPane.showOptionDialog(MarkovCheckEditor.this, "The all subsets option is exponential and can become extremely slow beyond 12" + "\nvariables. You may possibly be required to force quit Tetrad. Continue?", "Warning", @@ -277,7 +281,7 @@ public void watch() { JTabbedPane pane = new JTabbedPane(); pane.addTab("Check Markov", indep); - pane.addTab("Check Faithfulness", dep); + pane.addTab("Check Dependent Distribution", dep); pane.addTab("Help", scroll); box.add(pane); @@ -310,13 +314,13 @@ public void watch() { @NotNull private static String getHelpMessage() { - return "This tool lets you plot statistics for independence tests of a pair of variables given some conditioning calculated for one of those variables, for a given graph and dataset. Two tables are made, one in which the independence facts predicted by the graph using these conditioning sets are tested in the data and the other in which the graph's predicted dependence facts are tested. The first of these sets is a test for \"Markov\" for the relevant conditioning sets; the is a test for \"Faithfulness.”\n" + + return "This tool lets you plot statistics for independence tests of a pair of variables given some conditioning calculated for one of those variables, for a given graph and dataset. Two tables are made, one in which the independence facts predicted by the graph using these conditioning sets are tested in the data and the other in which the graph's predicted dependence facts are tested. The first of these sets is a check for \"Markov\" (a check for implied independence facts) for the chosen conditioning sets; the is a check of the \"Dependent Distribution.\" (a check of implied dependence facts)”\n" + "\n" + "Each table gives columns for the independence fact being checked, its test result, and its statistic. This statistic is either a p-value, ranging from 0 to 1, where p-values above the alpha level of the test are judged as independent, or a score bump, where this bump is negative for independent judgments and positive for dependent judgments.\n" + "\n" + "If the independence test yields a p-value, as for instance, for the Fisher Z test (for the linear, Gaussian case) or else the Chi-Square test (for the multinomial case), then under the null hypothesis of independence and for a consistent test, these p-values should be distributed as Uniform(0, 1). That is, it should be just as likely to see p-values in any range of equal width. If the test is inconsistent or the graph is incorrect (i.e., the parents of some or all of the nodes in the graph are incorrect), then this distribution of p-values will not be Uniform. To visualize this, we display the histogram of the p-values with equally sized bins; the bars in this histogram, for this case, should ideally all be of equal height.\n" + "\n" + - "If the first bar in this histogram is especially high (for the p-value case), that means that many tests are being judged as dependent. For checking Faithfulness, one hopes that this list is non-empty, then this first bar will be especially high, since high p-values are for examples where the graph is unfaithful to the distribution. These are likely for for cases where paths in the graph cancel unfaithfully. But for checking Markov, one hopes that this first bar will be the same height as all of the other bars.\n" + + "If the first bar in this histogram is especially high (for the p-value case), that means that many tests are being judged as dependent. For checking the dependent distribution, one hopes that this list is non-empty, in which case this first bar will be especially high, since high p-values are for examples where the graph is unfaithful to the distribution. These are likely for for cases where paths in the graph cancel unfaithfully. But for checking Markov, one hopes that this first bar will be the same height as all of the other bars.\n" + "\n" + "To make it especially clear, we give two statistics in the interface. The first is the percentage of p-values judged dependent on the test. If an alpha level is used in the test, this number should be very close to the alpha level for the Local Markov check since the distribution of p-values under this condition is Uniform. For the second, we test the Uniformity of the p-values using a Kolmogorov-Smirnov test. The p-value returned by this test should be greater than the user’s preferred alpha level if the distribution of p-values is Uniform and less then this alpha level if the distribution of p-values is non-Uniform.\n" + "\n" + @@ -492,6 +496,12 @@ public void mouseClicked(MouseEvent e) { // scroll.setPreferredSize(new Dimension(400, 400)); b1.add(scroll); + Box b1a = Box.createHorizontalBox(); + JLabel label = new JLabel("Table contents can be selected and copied in to, e.g., Excel."); + b1a.add(label); + b1a.add(Box.createHorizontalGlue()); + b1.add(b1a); + Box b4 = Box.createHorizontalBox(); b4.add(Box.createGlue()); b4.add(Box.createHorizontalStrut(10)); @@ -678,6 +688,12 @@ public void mouseClicked(MouseEvent e) { // scroll.setPreferredSize(new Dimension(400, 400)); b1.add(scroll); + Box b1a = Box.createHorizontalBox(); + JLabel label = new JLabel("Table contents can be selected and copied in to, e.g., Excel."); + b1a.add(label); + b1a.add(Box.createHorizontalGlue()); + b1.add(b1a); + Box b4 = Box.createHorizontalBox(); b4.add(Box.createGlue()); b4.add(Box.createHorizontalStrut(10)); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java index 05ed27b7aa..a0dcba9ee0 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java @@ -24,6 +24,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.Knowledge; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndependenceResult; @@ -64,7 +65,7 @@ public static Knowledge serializableInstance() { } public void setIndependenceTest(IndependenceTest test) { - this.markovCheck = new MarkovCheck(this.graph, test, this.markovCheck == null ? MarkovCheck.ConditioningSetType.PARENTS : this.markovCheck.getSetType()); + this.markovCheck = new MarkovCheck(this.graph, test, this.markovCheck == null ? ConditioningSetType.LOCAL_MARKOV : this.markovCheck.getSetType()); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java index 59f6984a47..d1a3252e2e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java @@ -499,7 +499,7 @@ public static void circleLayout(LayoutEditable layoutEditable) { int m = FastMath.min(r.width, r.height) / 2; - LayoutUtil.defaultLayout(graph); + LayoutUtil.circleLayout(graph); layoutEditable.layoutByGraph(graph); LayoutUtils.layout = Layout.circle; } diff --git a/tetrad-lib/dependency-reduced-pom.xml b/tetrad-lib/dependency-reduced-pom.xml new file mode 100644 index 0000000000..1a34c60792 --- /dev/null +++ b/tetrad-lib/dependency-reduced-pom.xml @@ -0,0 +1,62 @@ + + + + tetrad + io.github.cmu-phil + 7.6.0-SNAPSHOT + + 4.0.0 + tetrad-lib + + + + org.apache.maven.wagon + wagon-ssh + 2.10 + + + + + maven-compiler-plugin + 3.11.0 + + 1.8 + 1.8 + + + + maven-shade-plugin + 3.1.0 + + + package + + shade + + + + + + maven-antrun-plugin + 3.1.0 + + + compile + + run + + + + + + + + + + + + + UTF-8 + + + diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml index 76fb0acc05..b963ae99b4 100644 --- a/tetrad-lib/pom.xml +++ b/tetrad-lib/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.0 + 7.6.1 tetrad-lib @@ -22,6 +22,35 @@ 1.8 + + org.apache.maven.plugins + maven-shade-plugin + 3.1.0 + + + package + + shade + + + + + + + all-permissions + ${project.name} + ${project.version} + + + + true + shaded + + + + + maven-antrun-plugin 3.1.0 diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java index 42cefbfbe8..d30c33400d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java @@ -13,6 +13,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; import edu.cmu.tetrad.util.TetradLogger; @@ -56,6 +57,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { Graph graph = search.search(); TetradLogger.getInstance().forceLogMessage(graph.toString()); + LogUtilsSearch.stampWithBic(graph, dataSet); return graph; } else { DirectLingam algorithm = new DirectLingam(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java index c97e7a0614..3734fff3b0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java @@ -11,6 +11,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.IcaLingD; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.util.Matrix; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -56,6 +57,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { TetradLogger.getInstance().forceLogMessage(bHat.toString()); TetradLogger.getInstance().forceLogMessage(graph.toString()); + LogUtilsSearch.stampWithBic(graph, dataSet); return graph; } else { IcaLingam algorithm = new IcaLingam(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java index ec75e01630..a7c92da553 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java @@ -15,6 +15,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.PermutationSearch; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -51,7 +52,6 @@ public Boss(ScoreWrapper score) { this.score = score; } - @Override public Graph search(DataModel dataModel, Parameters parameters) { if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) { @@ -75,8 +75,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { boss.setVerbose(parameters.getBoolean(Params.VERBOSE)); PermutationSearch permutationSearch = new PermutationSearch(boss); permutationSearch.setKnowledge(this.knowledge); - - return permutationSearch.search(); + Graph graph = permutationSearch.search(); + LogUtilsSearch.stampWithScores(graph, dataModel, score); + return graph; } else { Boss algorithm = new Boss(this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PcLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java similarity index 85% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PcLingam.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java index 04e2a4b0eb..4dc585c651 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PcLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java @@ -23,23 +23,25 @@ import java.util.ArrayList; import java.util.List; +import static edu.cmu.tetrad.search.utils.LogUtilsSearch.stampWithBic; + /** * Peter/Clark algorithm (PC). * * @author josephramsey */ -@edu.cmu.tetrad.annotation.Algorithm(name = "PC-LiNGAM", command = "pc-lingam", algoType = AlgType.forbid_latent_common_causes) +@edu.cmu.tetrad.annotation.Algorithm(name = "BOSS-LiNGAM", command = "boss-lingam", algoType = AlgType.forbid_latent_common_causes) @Bootstrapping -public class PcLingam implements Algorithm, HasKnowledge, UsesScoreWrapper, ReturnsBootstrapGraphs { +public class BossLingam implements Algorithm, HasKnowledge, UsesScoreWrapper, ReturnsBootstrapGraphs { private static final long serialVersionUID = 23L; private ScoreWrapper score; private Knowledge knowledge = new Knowledge(); private List bootstrapGraphs = new ArrayList<>(); - public PcLingam() { + public BossLingam() { } - public PcLingam(ScoreWrapper scoreWrapper) { + public BossLingam(ScoreWrapper scoreWrapper) { this.score = scoreWrapper; } @@ -69,11 +71,13 @@ public Graph search(DataModel dataModel, Parameters parameters) { Graph cpdag = permutationSearch.search(); - edu.cmu.tetrad.search.PcLingam pcLingam = new edu.cmu.tetrad.search.PcLingam(cpdag, (DataSet) dataModel); + edu.cmu.tetrad.search.BossLingam bossLingam = new edu.cmu.tetrad.search.BossLingam(cpdag, (DataSet) dataModel); + Graph graph = bossLingam.search(); - return pcLingam.search(); + stampWithBic(graph, dataModel); + return graph; } else { - PcLingam pcAll = new PcLingam(this.score); + BossLingam pcAll = new BossLingam(this.score); DataSet data = (DataSet) dataModel; GeneralResamplingTest search = new GeneralResamplingTest(data, pcAll, parameters.getInt(Params.NUMBER_RESAMPLING), parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE), parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT), parameters.getInt(Params.RESAMPLING_ENSEMBLE), parameters.getBoolean(Params.ADD_ORIGINAL_DATASET)); @@ -94,7 +98,7 @@ public Graph getComparisonGraph(Graph graph) { @Override public String getDescription() { - return "PC-LiNGAM using " + this.score.getDescription(); + return "BOSS-LiNGAM using " + this.score.getDescription(); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java index 8430b3b623..7c34c9f73b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java @@ -23,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import static edu.cmu.tetrad.search.utils.LogUtilsSearch.stampWithBic; + /** * Conservative PC (CPC). * @@ -107,7 +109,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { search.setVerbose(parameters.getBoolean(Params.VERBOSE)); search.setKnowledge(knowledge); search.setConflictRule(conflictRule); - return search.search(); + Graph graph = search.search(); + stampWithBic(graph, dataModel); + return graph; } else { Cpc pcAll = new Cpc(this.test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java index ca9d0460f3..d3298a8528 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.algcomparison.algorithm.ReturnsBootstrapGraphs; import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; +import edu.cmu.tetrad.algcomparison.statistic.BicEst; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesExternalGraph; import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper; @@ -15,6 +16,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -98,6 +100,11 @@ public Graph search(DataModel dataModel, Parameters parameters) { graph = search.search(); + if (!graph.getAllAttributes().containsKey("BIC")) { + graph.addAttribute("BIC", new BicEst().getValue(null, graph, dataModel)); + } + + LogUtilsSearch.stampWithScores(graph, dataModel, score); return graph; } else { Fges fges = new Fges(this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java index 174d70ed85..287699409d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java @@ -17,6 +17,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -87,7 +88,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { grasp.setNumStarts(parameters.getInt(Params.NUM_STARTS)); grasp.setKnowledge(this.knowledge); grasp.bestOrder(score.getVariables()); - return grasp.getGraph(parameters.getBoolean(Params.OUTPUT_CPDAG)); + Graph graph = grasp.getGraph(parameters.getBoolean(Params.OUTPUT_CPDAG)); + LogUtilsSearch.stampWithScores(graph, dataModel, score); + return graph; } else { Grasp algorithm = new Grasp(this.test, this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java index f836722598..decfad9141 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java @@ -23,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import static edu.cmu.tetrad.search.utils.LogUtilsSearch.stampWithBic; + /** * Peter/Clark algorithm (PC). * @@ -106,7 +108,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { search.setKnowledge(this.knowledge); search.setStable(parameters.getBoolean(Params.STABLE_FAS)); search.setConflictRule(conflictRule); - return search.search(); + Graph graph = search.search(); + stampWithBic(graph, dataModel); + return graph; } else { Pc pcAll = new Pc(this.test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java index d0813c6b81..d61d61584e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java @@ -16,6 +16,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.PermutationSearch; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -43,7 +44,6 @@ public class Sp implements Algorithm, UsesScoreWrapper, HasKnowledge, ReturnsBoo private Knowledge knowledge = new Knowledge(); private List bootstrapGraphs = new ArrayList<>(); - public Sp() { // Used in reflection; do not delete. } @@ -69,8 +69,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { Score score = this.score.getScore(dataModel, parameters); PermutationSearch permutationSearch = new PermutationSearch(new edu.cmu.tetrad.search.Sp(score)); permutationSearch.setKnowledge(this.knowledge); - - return permutationSearch.search(); + Graph graph = permutationSearch.search(); + LogUtilsSearch.stampWithScores(graph, dataModel, score); + return graph; } else { Sp algorithm = new Sp(this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java index 176e4b0201..01579b610f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -36,7 +37,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getFractionDependent(false); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java index 70f9299e1f..737d298b61 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -36,7 +37,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getFractionDependent(true); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java index 6833b6f6e7..5bafefbb19 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -29,7 +30,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getMarkovAdequacyScore(alpha); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java index 8fef3f87ca..42fe7dfce6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -35,7 +36,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return abs(alpha - markovCheck.getKsPValue(true)); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java index 1a23b18e6f..cf152d6cb0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -33,7 +34,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getKsPValue(true); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java index d023bd4dc4..320faccf17 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java @@ -54,9 +54,16 @@ public final class Knowledge implements TetradSerializable { // private static final Pattern VARNAME_PATTERN = Pattern.compile("[A-Za-z0-9:_\\-.]+"); // private static final Pattern SPEC_PATTERN = Pattern.compile("[A-Za-z0-9:-_,\\-.*]+"); private static final Pattern COMMAN_DELIM = Pattern.compile(","); - private final Set variables; - private final Set>> forbiddenRulesSpecs; - private final Set>> requiredRulesSpecs; + + private final Set variables; + + // This needs to be a list for backward compatibility. Need to check when adding + // a new spec whether it's already in the list. + private final List>> forbiddenRulesSpecs; + + // This needs to be a list for backward compatibility. Need to check when adding + // a new spec whether it's already in the list. + private final List>> requiredRulesSpecs; private final List> tierSpecs; // Legacy. private final List knowledgeGroups; @@ -65,8 +72,8 @@ public final class Knowledge implements TetradSerializable { public Knowledge() { this.variables = new HashSet<>(); - this.forbiddenRulesSpecs = new HashSet<>(); - this.requiredRulesSpecs = new HashSet<>(); + this.forbiddenRulesSpecs = new ArrayList<>(); + this.requiredRulesSpecs = new ArrayList<>(); this.tierSpecs = new ArrayList<>(); this.knowledgeGroups = new LinkedList<>(); this.knowledgeGroupRules = new HashMap<>(); @@ -265,9 +272,13 @@ public void addKnowledgeGroup(KnowledgeGroup group) { this.knowledgeGroupRules.put(group, o); if (group.getType() == KnowledgeGroup.FORBIDDEN) { - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } } else if (group.getType() == KnowledgeGroup.REQUIRED) { - this.requiredRulesSpecs.add(o); + if (!requiredRulesSpecs.contains(o)) { + this.requiredRulesSpecs.add(o); + } } } @@ -536,7 +547,11 @@ public void setForbidden(String var1, String var2) { OrderedPair> o = new OrderedPair<>(f1, f2); - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } + } } /** @@ -580,7 +595,9 @@ public void setRequired(String var1, String var2) { OrderedPair> o = new OrderedPair<>(f1, f2); - this.requiredRulesSpecs.add(o); + if (!requiredRulesSpecs.contains(o)) { + this.requiredRulesSpecs.add(o); + } } /** @@ -609,9 +626,13 @@ public void setKnowledgeGroup(int index, KnowledgeGroup group) { knowledgeGroupRules.put(group, o); if (group.getType() == KnowledgeGroup.FORBIDDEN) { - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } } else if (group.getType() == KnowledgeGroup.REQUIRED) { - this.requiredRulesSpecs.add(o); + if (!requiredRulesSpecs.contains(o)) { + this.requiredRulesSpecs.add(o); + } } this.knowledgeGroups.set(index, group); @@ -640,7 +661,9 @@ public void setTierForbiddenWithin(int tier, boolean forbidden) { OrderedPair> o = new OrderedPair<>(varsInTier, varsInTier); if (forbidden) { - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } } else { this.forbiddenRulesSpecs.remove(o); } @@ -717,7 +740,10 @@ public List getListOfForbiddenEdges() { } for (int i = this.tierSpecs.size() - 1; i >= 0; i--) { - for (int j = i; j >= 0; j--) { + + // Make sure this iterates from i - 1 to 0 or else all directed edges will be + // forbidden within tiers! + for (int j = i - 1; j >= 0; j--) { Set tieri = this.tierSpecs.get(i); Set tierj = this.tierSpecs.get(j); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java index c5bd1ac995..dbb476ca16 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java @@ -45,7 +45,7 @@ public static void defaultLayout(Graph graph) { * Arranges the nodes in the graph in a circle. * @param graph the graph to be arranged. */ - private static void circleLayout(Graph graph) { + public static void circleLayout(Graph graph) { if (graph == null) { return; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java index 266b4b1022..112506f4aa 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java @@ -30,7 +30,7 @@ private static void addToSet(Map> previous, Node b, Node c) { * * @param initialOrder Variables in the order will be kept as close to this initial order as possible, either the * forward order or the reverse order, depending on the next parameter. - * @param forward Whether the variable will be iterated over in forward or reverse direction. + * @param forward Whether the variables will be iterated over in forward or reverse direction. * @return The valid causal order found. */ public List getValidOrder(List initialOrder, boolean forward) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java similarity index 96% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcLingam.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java index e5792a7fcb..ed9758a42a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java @@ -30,19 +30,16 @@ import edu.cmu.tetrad.regression.Regression; import edu.cmu.tetrad.regression.RegressionDataset; import edu.cmu.tetrad.regression.RegressionResult; -import edu.cmu.tetrad.search.utils.MeekRules; import edu.cmu.tetrad.util.Matrix; import edu.cmu.tetrad.util.TetradLogger; import edu.cmu.tetrad.util.Vector; import org.apache.commons.math3.util.FastMath; -import java.text.DecimalFormat; -import java.text.NumberFormat; import java.util.ArrayList; import java.util.List; /** - *

Implements the PC-LiNGAM algorithm which first finds a CPDAG for the variables + *

Implements the BOSS-LiNGAM algorithm which first finds a CPDAG for the variables * and then uses a non-Gaussian orientation method to orient the undirected edges. The reference is as follows: * *

>Hoyer et al., "Causal discovery of linear acyclic models with arbitrary @@ -65,7 +62,7 @@ * @author patrickhoyer * @author josephramsey */ -public class PcLingam { +public class BossLingam { private final Graph cpdag; private final DataSet dataSet; private double[] pValues; @@ -78,7 +75,7 @@ public class PcLingam { * @param cpdag The CPDAG whose unoriented edges are to be oriented. * @param dataSet Teh dataset to use. */ - public PcLingam(Graph cpdag, DataSet dataSet) + public BossLingam(Graph cpdag, DataSet dataSet) throws IllegalArgumentException { if (cpdag == null) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java new file mode 100644 index 0000000000..7744e915be --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java @@ -0,0 +1,10 @@ +package edu.cmu.tetrad.search; + +/** + * The type of conditioning set to use for the Markov check. The default is PARENTS, which uses the parents of the + * target variable to predict the separation set. DAG_MB uses the Markov blanket of the target variable in a DAG + * setting, and PAG_MB uses a Markov blanket of the target variable in a PAG setting. + */ +public enum ConditioningSetType { + LOCAL_MARKOV, ORDERED_LOCAL_MARKOV, MARKOV_BLANKET, GLOBAL_MARKOV +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java new file mode 100644 index 0000000000..b3854ed823 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java @@ -0,0 +1,266 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.util.Matrix; +import edu.cmu.tetrad.util.Vector; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Random; + +/** + * Uses expectation-maximization to sort a a data set with data sampled from two or more multivariate Gaussian + * distributions into its component data sets. + * + * @author Madelyn Glymour + */ +public class Demixer { + + private final int numVars; + private final int numCases; + private final int numClusters; // number of clusters + private final DataSet data; + private final double[][] dataArray; // v-by-n data matrix + private final Matrix[] variances; + private final double[][] meansArray; // k-by-v matrix representing means for each variable for each of k models + private final Matrix[] variancesArray; // k-by-v-by-v matrix representing covariance matrix for each of k models + private final double[] weightsArray; // array of length k representing weights for each model + private final double[][] gammaArray; // k-by-n matrix representing gamma for each data case in each model + private boolean demixed = false; + + public Demixer(DataSet data, int k) { + this.numClusters = k; + this.data = data; + dataArray = data.getDoubleData().toArray(); + numVars = data.getNumColumns(); + numCases = data.getNumRows(); + meansArray = new double[k][numVars]; + weightsArray = new double[k]; + variancesArray = new Matrix[k]; + variances = new Matrix[k]; + gammaArray = new double[k][numCases]; + + Random rand = new Random(); + + // initialize the means array to the mean of each variable plus noise + for (int i = 0; i < numVars; i++) { + for (int j = 0; j < k; j++) { + meansArray[j][i] = calcMean(data.getDoubleData().getColumn(i)) + (rand.nextGaussian()); + } + } + + // initialize the weights array uniformly + for (int i = 0; i < k; i++) { + weightsArray[i] = Math.abs((1.0 / k)); + } + + // initialize the covariance matrix array to the actual covariance matrix + for (int i = 0; i < k; i++) { + variances[i] = data.getCovarianceMatrix(); + } + } + + /* + * Runs the E-M algorithm iteratively until the weights array converges. Returns a MixtureModel object containing + * the final values of the means, covariance matrices, weights, and gammas arrays. + */ + public MixtureModel demix() { + double[] tempWeights = new double[numClusters]; + + System.arraycopy(weightsArray, 0, tempWeights, 0, numClusters); + + boolean weightsUnequal = true; + ArrayList diffsList; + int iterCounter = 0; + + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + // convergence check + while (weightsUnequal) { + expectation(); + maximization(); + + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + diffsList = new ArrayList<>(); // list of differences between new weights and old weights + for (int i = 0; i < numClusters; i++) { + diffsList.add(Math.abs(weightsArray[i] - tempWeights[i])); + } + + Collections.sort(diffsList); // sort the list + + // if the largest difference is below the threshold, or we've passed 100 iterations, converge + if (diffsList.get(numClusters - 1) < 0.0001 || iterCounter > 100) { + weightsUnequal = false; + } + + // new weights are now the old weights + System.arraycopy(weightsArray, 0, tempWeights, 0, numClusters); + + iterCounter++; + } + + MixtureModel model = new MixtureModel(data, dataArray, meansArray, weightsArray, variancesArray, gammaArray); + demixed = true; + + return model; + + } + + /* + * Returns true if the algorithm has been run, and the gamma, mean, and covariance arrays are at their stable values + */ + public boolean isDemixed() { + return demixed; + } + + /* + * Computes the probability that each case belongs to each model (the gamma), given the current values of the mean, + * weight, and covariance arrays + */ + private void expectation() { + + double gamma; + double divisor; + + for (int i = 0; i < numClusters; i++) { + for (int j = 0; j < numCases; j++) { + gamma = weightsArray[i] * normalPDF(j, i); + divisor = gamma; + + for (int w = 0; w < numClusters; w++) { + if (w != i) { + divisor += (weightsArray[w] * normalPDF(j, w)); + } + } + gamma = gamma / divisor; + gammaArray[i][j] = gamma; + } + } + } + + /* + * Estimates the means, covariances, and weight of each model, given the current values of the gamma array + */ + private void maximization() { + + // the weight of each model is the sum of the gamma for each case in that model, divided by the number of cases + double weight; + + for (int i = 0; i < numClusters; i++) { + weight = 0; + for (int j = 0; j < numCases; j++) { + weight += gammaArray[i][j]; + } + weight = weight / numCases; + weightsArray[i] = weight; + } + + // the mean for each variable in each model is determined by the weighted mean of that variable in the model + // (where each case i in the variable in model k is weighted by the gamma(i, k) + double meanNumerator; + double meanDivisor; + double mean; + + for (int i = 0; i < numClusters; i++) { + for (int v = 0; v < numVars; v++) { + meanNumerator = 0; + meanDivisor = 0; + for (int j = 0; j < numCases; j++) { + + meanNumerator += gammaArray[i][j] * dataArray[j][v]; + meanDivisor += gammaArray[i][j]; + } + mean = meanNumerator / meanDivisor; + meansArray[i][v] = mean; + } + } + + // the covariance matrix for each model is determined by the covariance matrix of the data, weighted by the + // gamma values for that model + double var; + + for (int i = 0; i < numClusters; i++) { + for (int v = 0; v < numVars; v++) { + for (int v2 = v; v2 < numVars; v2++) { + var = getVar(i, v, v2, numCases, gammaArray, dataArray, meansArray); + // if(Math.abs(var) >= 0.5) { + variancesArray[i].set(v, v2, var); + variancesArray[i].set(v2, v, var); + + // Reset the variances if things start to go awry with the algorithm; turns out not to be necessary + // } else{ + // Random rand = new Random(); + // double temp = 0.5 + rand.nextDouble(); + // variancesArray[i][v][v2] = temp; + // variancesArray[i][v2][v] = temp; + // } + } + } + variances[i] = new Matrix(variancesArray[i]); + } + + } + + static double getVar(int i, int v, int v2, int numCases, double[][] gammaArray, double[][] dataArray, double[][] meansArray) { + double varNumerator; + double varDivisor; + double var; + varNumerator = 0; + varDivisor = 0; + + for (int j = 0; j < numCases; j++) { + varNumerator += gammaArray[i][j] * (dataArray[j][v] - meansArray[i][v]) * (dataArray[j][v2] - meansArray[i][v2]); + varDivisor += gammaArray[i][j]; + } + + var = varNumerator / varDivisor; + return var; + } + + /* + * For an input case and model, returns the value of the model's normal PDF for that case, using the current + * estimations of the means and covariance matrix + */ + private double normalPDF(int caseIndex, int weightIndex) { + Matrix cov = variances[weightIndex]; + + Matrix covIn = cov.inverse(); + double[] mu = meansArray[weightIndex]; + double[] thisCase = dataArray[caseIndex]; + + double[][] diffs = new double[1][numVars]; + + for (int i = 0; i < numVars; i++) { + diffs[0][i] = thisCase[i] - mu[i]; + } + + Matrix diffsMatrix = new Matrix(diffs); + Matrix diffsTranspose = diffsMatrix.transpose(); + + Matrix distance = covIn.times(diffsTranspose); // inverse of the covariance matrix * (x - mu) + + distance = diffsMatrix.times(distance); // squared + + double distanceScal = distance.get(0, 0); // distance is a scalar, but in matrix representation + distanceScal = distanceScal * (-.5); + distanceScal = Math.exp(distanceScal); + distanceScal = distanceScal / Math.sqrt(2 * Math.PI * cov.det()); // exp(-.5 * distance) / sqrt(2 * pi * cov) + + return distanceScal; + } + + /* + * Returns the mean of a variable, input as a Vector + */ + private double calcMean(Vector dataPoints) { + double sum = 0; + + for (int i = 0; i < dataPoints.size(); i++) { + sum += dataPoints.get(i); + } + + return sum / dataPoints.size(); + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java new file mode 100644 index 0000000000..75461c4c7c --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java @@ -0,0 +1,417 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.cluster.KMeans; +import edu.cmu.tetrad.data.BoxDataSet; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DoubleDataBox; +import edu.cmu.tetrad.data.SimpleDataLoader; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.util.Matrix; +import edu.cmu.tetrad.util.MatrixUtils; +import edu.pitt.dbmi.data.reader.Delimiter; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Created by user on 2/27/18. + */ +public class DemixerMMLKun { + + private final double minWeight; + + public DemixerMMLKun() { + minWeight = 1e-3; + } + + public static void main(String... args) { + DataSet dataSet; + + try { + dataSet = SimpleDataLoader.loadContinuousData(new File("/Users/user/Documents/Demix_Testing/NonGaussian/sub_1500_4var_3comp.txt"), + "//", '\"', "*", true, Delimiter.TAB, false); + } catch (IOException e) { + throw new RuntimeException(e); + } + + DemixerMMLKun pedro = new DemixerMMLKun(); + long startTime = System.currentTimeMillis(); + MixtureModel model = pedro.demix(dataSet, 25); + long elapsed = System.currentTimeMillis() - startTime; + + double[] weights = model.getWeights(); + for (double weight : weights) { + System.out.print(weight + "\t"); + } + + try { + FileWriter writer = new FileWriter("/Users/user/Documents/Demix_Testing/sub_1500_4var_3comp.txt"); + BufferedWriter bufferedWriter = new BufferedWriter(writer); + + for (int i = 0; i < dataSet.getNumRows(); i++) { + bufferedWriter.write(model.getDistribution(i) + "\n"); + } + bufferedWriter.flush(); + bufferedWriter.close(); + + DataSet[] dataSets = model.getDemixedData(); + + for (int i = 0; i < dataSets.length; i++) { + writer = new FileWriter("/Users/user/Documents/Demix_Testing/sub_1500_4var_3comp_demixed_" + (i + 1) + ".txt"); + bufferedWriter = new BufferedWriter(writer); + bufferedWriter.write(dataSets[i].toString()); + bufferedWriter.flush(); + bufferedWriter.close(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + System.out.println("Elapsed: " + elapsed / 1000); + } + + private MixtureModel demix(DataSet data, int k) { + double[][] dataArray = data.getDoubleData().toArray(); + int numVars = data.getNumColumns(); + int numCases = data.getNumRows(); + double lambda2 = Math.sqrt((Math.log(numCases) - Math.log(Math.log(numCases))) / 2.0); + double lambda = lambda2 * ((Math.pow(numVars, 2)) * 0.5 + 1.5 * numVars + 1); // part of the MML score + double epsilon = 1e-6; // part of the MML score + double threshold = 1e-8; // threshold for MML score convergence + + System.out.println("Lambda: " + lambda); + + // Initialize clusterings with kmeans + KMeans kMeans = KMeans.randomClusters(k); + kMeans.cluster(data.getDoubleData()); + + // Use initial clusterings to get initial means, variances, and gamma arrays + double[][] meansArray = new double[k][numVars]; + double[] weightsArray = new double[k]; + Matrix[] variancesArray = new Matrix[k]; + Matrix[] variances = new Matrix[k]; + double[][] gammaArray = new double[k][numCases]; + + List> clusters = kMeans.getClusters(); + List cluster; + int clusterSize; + double[] means; + + double[][] clusterMatrixArray; + + for (int i = 0; i < clusters.size(); i++) { + cluster = clusters.get(i); + clusterSize = cluster.size(); + means = new double[numVars]; + + for (int j = 0; j < numVars; j++) { + means[j] = 0; + } + + clusterMatrixArray = new double[clusterSize][numVars]; + + for (int j = 0; j < clusterSize; j++) { + // System.out.print(Integer.toString(cluster.get(j)) + "\t"); + MatrixUtils.sum(means, dataArray[cluster.get(j)]); + clusterMatrixArray[j] = dataArray[cluster.get(j)]; + } + + // Initial mean is mean of cluster + means = MatrixUtils.scalarProduct(1.0 / clusterSize, means); + meansArray[i] = means; + + // Initial weight is percentage of rows taken up by cluster + weightsArray[i] = ((double) clusterSize) / ((double) numCases); + + // Initial covariance matrix is cov matrix of cluster, unless cluster cov matrix has 0 determinant + DoubleDataBox box = new DoubleDataBox(clusterMatrixArray); + List variables = data.getVariables(); + BoxDataSet clusterData = new BoxDataSet(box, variables); + Matrix clusterCovMatrix = clusterData.getCovarianceMatrix(); + if (MatrixUtils.determinant(clusterCovMatrix.toArray()) == 0) { + variances[i] = MatrixUtils.cholesky(data.getCovarianceMatrix()); + variancesArray[i] = data.getCovarianceMatrix(); + } else { + variances[i] = MatrixUtils.cholesky(clusterCovMatrix); + variancesArray[i] = clusterCovMatrix; + } + + } + + double gamma; + double divisor; + + for (int z = 0; z < k; z++) { + for (int j = 0; j < numCases; j++) { + gamma = weightsArray[z] * normalPDF(j, z, variances, meansArray, dataArray, numVars); + divisor = gamma; + + for (int w = 0; w < k; w++) { + if (w != z) { + divisor += (weightsArray[w] * normalPDF(j, w, variances, meansArray, dataArray, numVars)); + } + } + + //Initial gamma is weighted probability for the case in cluster k, divided by the sum of weighted probabilities in all clusters + gamma = gamma / divisor; + gammaArray[z][j] = gamma; + } + } + + // Verbose debugging output + System.out.println("Clusters: " + k); + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + // oldLogL and newLogL determine convergence + double oldLogL = Double.POSITIVE_INFINITY; + double newLogL; + + DeterminingStats stats; + + while (true) { + + // maximization step + stats = innerStep(data, dataArray, weightsArray, meansArray, variancesArray, variances, gammaArray, numCases, numVars, lambda); + meansArray = stats.getMeans(); + weightsArray = stats.getWeights(); + variancesArray = stats.getVariances(); + variances = stats.getVarMatrixArray(); + + k = weightsArray.length; + + // fail if there are no clusters + if (k == 0) { + break; + } + + // verbose debugging output + System.out.println("Clusters: " + k); + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + // expectation step; gamma computed as above, I should probably make a separate method for it + for (int i = 0; i < k; i++) { + + for (int j = 0; j < numCases; j++) { + + double pdf = normalPDF(j, i, variances, meansArray, dataArray, numVars); + + gamma = weightsArray[i] * pdf; + + divisor = gamma; + + for (int w = 0; w < k; w++) { + if (w != i) { + divisor += (weightsArray[w] * normalPDF(j, w, variances, meansArray, dataArray, numVars)); + } + } + gamma = gamma / divisor; + + + gammaArray[i][j] = gamma; + } + } + + // check for convergence + double mml = 0; + double gammaMean; + for (int i = 0; i < weightsArray.length; i++) { + gammaMean = 0; + for (int j = 0; j < numCases; j++) { + gammaMean += gammaArray[i][j]; + } + gammaMean /= numCases; + mml += Math.log(gammaMean); + } + + mml /= weightsArray.length; + + double weightSum = 0; + + for (double v : weightsArray) { + weightSum += Math.log(v / epsilon + 1); + } + + weightSum *= lambda / numCases; + + newLogL = mml + weightSum; + + // if oldLogL and newLogL converge, end; otherwise, set oldLogL to newLogL + if (Math.abs(oldLogL / (newLogL) - 1) < threshold) { + break; + } else { + oldLogL = newLogL; + } + + } + + return new MixtureModel(data, dataArray, meansArray, weightsArray, variancesArray, gammaArray); + } + + /** + * Performs the maximization step + */ + private DeterminingStats innerStep(DataSet data, double[][] dataArray, double[] weightsArray, double[][] meansArray, Matrix[] variancesArray, Matrix[] variances, double[][] gammaArray, int numCases, int numVars, double lambda) { + + double weight; + double pSum; // sum of all gammas for a case + double meanNumerator; + double mean; + Matrix tempVar; + + ArrayList meansList = new ArrayList<>(); + ArrayList varsLilst = new ArrayList<>(); + ArrayList varMatList = new ArrayList<>(); + + for (int i = 0; i < weightsArray.length; i++) { + + // maximize weights + pSum = 0; + for (int j = 0; j < numCases; j++) { + pSum += gammaArray[i][j]; + } + + weight = (pSum - lambda) / (numCases - (lambda * weightsArray.length)); + weightsArray[i] = weight; + + // maximize covariance matrices + tempVar = new Matrix(numVars, numVars); + + for (int v = 0; v < numVars; v++) { + + // maximize means + meanNumerator = 0; + for (int j = 0; j < numCases; j++) { + + meanNumerator += gammaArray[i][j] * dataArray[j][v]; + } + mean = meanNumerator / pSum; + meansArray[i][v] = mean; + + for (int v2 = v; v2 < numVars; v2++) { + double var = Demixer.getVar(i, v, v2, numCases, gammaArray, dataArray, meansArray); + tempVar.set(v, v2, var); + tempVar.set(v2, v, var); + } + } + + Matrix varMatrix = new Matrix(tempVar); + if (varMatrix.det() != 0) { + variancesArray[i] = MatrixUtils.cholesky(tempVar); + variances[i] = MatrixUtils.cholesky(varMatrix); + } else { + variances[i] = MatrixUtils.cholesky(data.getCovarianceMatrix()); + variancesArray[i] = data.getCovarianceMatrix(); + } + + } + + System.out.println(); + + // check weights, and remove any clusters with weights below threshold + ArrayList weightsList = new ArrayList<>(); + + for (int i = 0; i < weightsArray.length; i++) { + + if (weightsArray[i] >= minWeight) { + weightsList.add(weightsArray[i]); + meansList.add(meansArray[i]); + varsLilst.add(variancesArray[i]); + varMatList.add(variances[i]); + } + } + + double[] tempWeightsArray = new double[weightsList.size()]; + double[][] tempMeansArray = new double[weightsList.size()][numVars]; + Matrix[] tempVarsArray = new Matrix[weightsList.size()]; + Matrix[] tempVariances = new Matrix[weightsList.size()]; + for (int i = 0; i < weightsList.size(); i++) { + tempWeightsArray[i] = weightsList.get(i); + tempMeansArray[i] = meansList.get(i); + tempVarsArray[i] = varsLilst.get(i); + tempVariances[i] = varMatList.get(i); + } + + weightsArray = tempWeightsArray; + meansArray = tempMeansArray; + variancesArray = tempVarsArray; + variances = tempVariances; + + return new DeterminingStats(meansArray, weightsArray, variancesArray, variances); + } + + /** + * Returns the value of the Normal PDF for a given case if it belongs to a given cluster + */ + private double normalPDF(int caseIndex, int weightIndex, Matrix[] variances, double[][] meansArray, double[][] dataArray, int numVars) { + Matrix cov = variances[weightIndex]; + cov = cov.transpose(); + + Matrix covIn = cov.inverse(); + double[] mu = meansArray[weightIndex]; + double[] thisCase = dataArray[caseIndex]; + + double[][] diffs = new double[1][numVars]; + + for (int i = 0; i < numVars; i++) { + diffs[0][i] = thisCase[i] - mu[i]; + } + + Matrix diffsMatrix = new Matrix(diffs); + Matrix mah = diffsMatrix.times(covIn); + + double val; + double mahScal = 0; + for (int i = 0; i < mah.getNumRows(); i++) { + for (int j = 0; j < mah.getNumColumns(); j++) { + val = mah.get(i, j); + val = val * val; + mahScal += val; + mah.set(i, j, val); + } + } + + double distanceScal = Math.pow(2 * Math.PI, -(numVars) / 2.0); + distanceScal = distanceScal / cov.det(); + distanceScal = distanceScal * Math.exp(-.5 * mahScal); + + return distanceScal; + } + + /** + * Private wrapper class for statistics to be maximized + */ + private static class DeterminingStats { + private final double[][] meansArray; + private final double[] weightsArray; + private final Matrix[] variancesArray; + private final Matrix[] varMatrixArray; + + public DeterminingStats(double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, Matrix[] varMatrixArray) { + this.meansArray = meansArray; + this.weightsArray = weightsArray; + this.variancesArray = variancesArray; + this.varMatrixArray = varMatrixArray; + } + + public double[] getWeights() { + return weightsArray; + } + + public double[][] getMeans() { + return meansArray; + } + + public Matrix[] getVariances() { + return variancesArray; + } + + public Matrix[] getVarMatrixArray() { + return varMatrixArray; + } + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index fef052901c..bca58d5fcc 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -74,7 +74,7 @@ public void generateResults() { resultsIndep.clear(); resultsDep.clear(); - if (setType == ConditioningSetType.ALL_SUBSETS) { + if (setType == ConditioningSetType.GLOBAL_MARKOV) { AllSubsetsIndependenceFacts result = getAllSubsetsIndependenceFacts(graph); generateResultsAllSubsets(true, result.msep, result.mconn); generateResultsAllSubsets(false, result.msep, result.mconn); @@ -83,12 +83,29 @@ public void generateResults() { List nodes = new ArrayList<>(variables); Collections.sort(nodes); + List order = graph.paths().getValidOrder(graph.getNodes(), true); + for (Node x : nodes) { Set z; switch (setType) { - case PARENTS: + case LOCAL_MARKOV: + z = new HashSet<>(graph.getParents(x)); + break; + case ORDERED_LOCAL_MARKOV: + if (order == null) throw new IllegalArgumentException("No valid order found."); z = new HashSet<>(graph.getParents(x)); + + // Keep only the parents in Prefix(x). + for (Node w : new ArrayList<>(z)) { + int i1 = order.indexOf(x); + int i2 = order.indexOf(w); + + if (i2 >= i1) { + z.remove(w); + } + } + break; case MARKOV_BLANKET: z = GraphUtils.markovBlanket(x, graph); @@ -161,8 +178,8 @@ public static AllSubsetsIndependenceFacts getAllSubsetsIndependenceFacts(Graph g } public static class AllSubsetsIndependenceFacts { - public final List msep; - public final List mconn; + private final List msep; + private final List mconn; public AllSubsetsIndependenceFacts(List msep, List mconn) { this.msep = msep; @@ -189,6 +206,14 @@ public String toStringDep() { return builder.toString(); } + + public List getMsep() { + return msep; + } + + public List getMconn() { + return mconn; + } } /** @@ -536,12 +561,4 @@ private List getResultsLocal(boolean indep) { } - /** - * The type of conditioning set to use for the Markov check. The default is PARENTS, which uses the parents of the - * target variable to predict the separation set. DAG_MB uses the Markov blanket of the target variable in a DAG - * setting, and PAG_MB uses a Markov blanket of the target variable in a PAG setting. - */ - public enum ConditioningSetType { - PARENTS, MARKOV_BLANKET, ALL_SUBSETS - } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java new file mode 100644 index 0000000000..8a1516ef38 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java @@ -0,0 +1,207 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.BoxDataSet; +import edu.cmu.tetrad.data.CovarianceMatrixOnTheFly; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DoubleDataBox; +import edu.cmu.tetrad.search.score.SemBicScore; +import edu.cmu.tetrad.util.Matrix; + +/** + * Represents a Gaussian mixture model -- a dataset with data sampled from two or more multivariate Gaussian + * distributions. + * + * @author Madelyn Glymour + */ +public class MixtureModel { + private final DataSet data; + private final int[] cases; + private final int[] caseCounts; + private final double[][] dataArray; // v-by-n data matrix + private final double[][] meansArray; // k-by-v matrix representing means for each variable for each of k models + private final double[] weightsArray; // array of length k representing weights for each model + private final double[][] gammaArray; // k-by-n matrix representing gamma for each data case in each model + private final Matrix[] variancesArray; // k-by-v-by-v matrix representing covariance matrix for each of k models + private final int numModels; // number of models in mixture + + public MixtureModel(DataSet data, double[][] dataArray, double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, double[][] gammaArray) { + this.data = data; + this.dataArray = dataArray; + this.meansArray = meansArray; + this.weightsArray = weightsArray; + this.variancesArray = variancesArray; + this.numModels = weightsArray.length; + this.gammaArray = gammaArray; + this.cases = new int[data.getNumRows()]; + + // set the individual model for each case + for (int i = 0; i < cases.length; i++) { + cases[i] = getDistribution(i); + } + + this.caseCounts = new int[numModels]; + + // count the number of cases in each individual data set + for (int i = 0; i < numModels; i++) { + caseCounts[i] = 0; + } + + for (int aCase : cases) { + for (int j = 0; j < numModels; j++) { + if (aCase == j) { + caseCounts[j]++; + break; + } + } + } + } + + /** + * @return the mixed data set in array form + */ + public double[][] getData() { + return dataArray; + } + + /** + * @return the means matrix + */ + public double[][] getMeans() { + return meansArray; + } + + /** + * @return the weights array + */ + public double[] getWeights() { + return weightsArray; + } + + /** + * @return the variance matrix + */ + public Matrix[] getVariances() { + return variancesArray; + } + + /** + * @return an array assigning each case an integer corresponding to a model + */ + public int[] getCases() { + return cases; + } + + /** + * Classifies a given case into a model, based on which model has the highest gamma value for that case. + */ + public int getDistribution(int caseNum) { + + // hard classification + int dist = 0; + double highest = 0; + + for (int i = 0; i < numModels; i++) { + if (gammaArray[i][caseNum] > highest) { + highest = gammaArray[i][caseNum]; + dist = i; + } + + } + + return dist; + + // soft classification, deprecated because it doesn't classify as well + + /*int gammaSum = 0; + + for (int i = 0; i < k; i++) { + gammaSum += gammaArray[i][caseNum]; + } + + Random rand = new Random(); + double test = gammaSum * rand.nextDouble(); + + if(test < gammaArray[0][caseNum]){ + return 0; + } + + double sum = gammaArray[0][caseNum]; + + for (int i = 1; i < k; i++){ + sum = sum+gammaArray[i][caseNum]; + if(test < sum){ + return i; + } + } + + return k - 1; */ + } + + /* + * Sort the mixed data set into its component data sets. + * + * @return a list of data sets + */ + public DataSet[] getDemixedData() { + DoubleDataBox[] dataBoxes = new DoubleDataBox[numModels]; + int[] caseIndices = new int[numModels]; + + for (int i = 0; i < numModels; i++) { + dataBoxes[i] = new DoubleDataBox(caseCounts[i], data.getNumColumns()); + caseIndices[i] = 0; + } + + int index; + DoubleDataBox box; + int count; + for (int i = 0; i < cases.length; i++) { + + // get the correct data set and corresponding case count for this case + index = cases[i]; + box = dataBoxes[index]; + count = caseIndices[index]; + + // set the [count]th row of the given data set to the ith row of the mixed data set + for (int j = 0; j < data.getNumColumns(); j++) { + box.set(count, j, data.getDouble(i, j)); + } + + dataBoxes[index] = box; //make sure that the changes get carried to the next iteration of the loop + caseIndices[index] = count + 1; //increment case count of this data set + } + + // create list of data sets + DataSet[] dataSets = new DataSet[numModels]; + for (int i = 0; i < numModels; i++) { + dataSets[i] = new BoxDataSet(dataBoxes[i], data.getVariables()); + } + + return dataSets; + } + + /** + * Perform an FGES search on each of the demixed data sets. + * + * @return the BIC scores of the graphs returned by searches. + */ + public double[] searchDemixedData() { + DataSet[] dataSets = getDemixedData(); + SemBicScore score; + edu.cmu.tetrad.search.Fges fges; + DataSet dataSet; + double bic; + double[] bicScores = new double[numModels]; + + for (int i = 0; i < numModels; i++) { + dataSet = dataSets[i]; + score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); + score.setPenaltyDiscount(2.0); + fges = new edu.cmu.tetrad.search.Fges(score); + fges.search(); + bic = fges.getModelScore(); + bicScores[i] = bic; + } + + return bicScores; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java index c6d78ba379..57c5979b6a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java @@ -44,7 +44,7 @@ public static double scoreDag(Graph dag, DataModel data, double penaltyDiscount, SemBicScore score; if (data instanceof ICovarianceMatrix) { - score = new SemBicScore((ICovarianceMatrix) dag); + score = new SemBicScore((ICovarianceMatrix) data); } else if (data instanceof DataSet) { score = new SemBicScore((DataSet) data, precomputeCovariances); } else { @@ -69,7 +69,10 @@ public static double scoreDag(Graph dag, DataModel data, double penaltyDiscount, parentIndices[count++] = hashIndices.get(parent); } - _score += score.localScore(hashIndices.get(node), parentIndices); + double score1 = score.localScore(hashIndices.get(node), parentIndices); + if (!Double.isNaN(score1)) { + _score += score1; + } } return _score; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java index 596f422052..2c9cec9924 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java @@ -56,12 +56,22 @@ public Integer getIndex(Node node) { return this.index.get(node); } +// public Double localScore() { +// return this.score.localScore(this.nodeIndex); +// } +// +// public Double localScore(int[] X) { +// return this.score.localScore(this.nodeIndex, X); +// } + public Double localScore() { - return this.score.localScore(this.nodeIndex); + double score = this.score.localScore(this.nodeIndex); + return Double.isNaN(score) ? 0 : score; } public Double localScore(int[] X) { - return this.score.localScore(this.nodeIndex, X); + double score = this.score.localScore(this.nodeIndex, X); + return Double.isNaN(score) ? Double.NEGATIVE_INFINITY : score; } public boolean isRequired(Node node) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java index da798c4473..3bbd503f04 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java @@ -21,14 +21,18 @@ package edu.cmu.tetrad.search.utils; +import edu.cmu.tetrad.algcomparison.statistic.BicEst; +import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.score.Score; import edu.cmu.tetrad.util.NumberFormatUtil; +import org.jetbrains.annotations.NotNull; import java.text.NumberFormat; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; /** * Contains utilities for logging search steps. @@ -139,6 +143,51 @@ public static String getScoreFact(Node i, List parents) { return fact.toString(); } + + public static Map buildIndexing(List nodes) { + Map hashIndices = new HashMap<>(); + + int i = -1; + + for (Node n : nodes) { + hashIndices.put(n, ++i); + } + + return hashIndices; + } + + @NotNull + public static void stampWithScores(Graph graph, DataModel dataModel, Score score) { + if (!graph.getAllAttributes().containsKey("Score")) { + Graph dag = GraphTransforms.dagFromCPDAG(graph); + Map hashIndices = buildIndexing(dag.getNodes()); + + double _score = 0.0; + + for (Node node : dag.getNodes()) { + List x = dag.getParents(node); + + int[] parentIndices = new int[x.size()]; + + int count = 0; + for (Node parent : x) { + parentIndices[count++] = hashIndices.get(parent); + } + + _score += score.localScore(hashIndices.get(node), parentIndices); + } + + graph.addAttribute("Score", _score); + } + + stampWithBic(graph, dataModel); + } + + public static void stampWithBic(Graph graph, DataModel dataModel) { + if (!graph.getAllAttributes().containsKey("BIC")) { + graph.addAttribute("BIC", new BicEst().getValue(null, graph, dataModel)); + } + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java index c1175ce252..79584d3823 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java @@ -625,8 +625,7 @@ private double sum() { if (this.scores.get(i) == null) { recalculate(i); } - double score1 = this.scores.get(i).getScore(); - score += score1; + score += this.scores.get(i).getScore(); } return score; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java index 55584e6f7b..88db208200 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java @@ -41,7 +41,7 @@ * generalized. Instead of hard-coding FAS, an arbitrary algorithm can be used to obtain adjacencies. Instead of * hard-coding robust skew, and arbitrary algorithm can be used to to pairwise orientation. Instead of orienting all * edges, an option can be given to just orient the edges that are unoriented in the input graph (see, e.g., PC LiNGAM). - * This was an early attempt at this. For PC-LiNGAM, see this paper:

+ * This was an early attempt at this. For BOSS-LiNGAM, see this paper:

* *

Hoyer, P. O., Hyvarinen, A., Scheines, R., Spirtes, P. L., Ramsey, J., Lacerda, G., * & Shimizu, S. (2012). Causal discovery of linear acyclic models with arbitrary distributions. arXiv preprint diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java index 5141a62e56..61c99d1c38 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java @@ -161,7 +161,7 @@ public void testSearch11() { knowledge.addToTier(2, "X3"); checkSearch("Latent(L1),Latent(L2),L1-->X1,L1-->X2,L2-->X2,L2-->X3", - "X1<->X2,X2<->X3", knowledge); + "X1o->X2,X2<->X3", knowledge); } @Test