Skip to content

Commit 77a9a98

Browse files
Birdasaursamypr100
andauthored
Fixes to logP and cov functions for GMMs. (#133)
Co-authored-by: samypr100 <3933065+samypr100@users.noreply.github.com>
1 parent f567802 commit 77a9a98

9 files changed

Lines changed: 93 additions & 48 deletions

File tree

build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ assemble {
367367
jlink {
368368
def fileSep = System.getProperty('file.separator')
369369
def imageZipFile = layout.buildDirectory.file("${artifactNameUpper}-${project.version}.zip")
370-
options.set(['--strip-debug', '--compress', '2', '--no-header-files', '--no-man-pages'])
370+
options.set(['--strip-debug', '--compress', 'zip-6', '--no-header-files', '--no-man-pages'])
371371
imageZip.set(imageZipFile)
372372
launcher {
373373
def currentOS = org.gradle.internal.os.OperatingSystem.current()

src/main/java/edu/jhuapl/trinity/javafx/components/panes/Shape3DControlPane.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ private void buildFindClustersTab() {
121121
findClustersTab.setContent(findClusterBorderPane);
122122

123123
componentsSpinner = new Spinner(
124-
new SpinnerValueFactory.IntegerSpinnerValueFactory(2, 20, 5, 1));
124+
new SpinnerValueFactory.IntegerSpinnerValueFactory(2, 500, 5, 1));
125125
componentsSpinner.setPrefWidth(SPINNER_PREF_WIDTH);
126126
componentsSpinner.setEditable(true);
127127
iterationsSpinner = new Spinner(

src/main/java/edu/jhuapl/trinity/javafx/handlers/ManifoldEventHandler.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,6 @@ public void handleExport(ManifoldEvent event) {
298298
}
299299

300300
public void handleNewManifoldData(ManifoldEvent event) {
301-
// Platform.runLater(() -> {
302-
// App.getAppScene().getRoot().fireEvent(
303-
// new CommandTerminalEvent("Loading Manifold Data...",
304-
// new Font("Consolas", 20), Color.GREEN));
305-
// });
306301
// System.out.println("Loading Manifold Data...");
307302
ManifoldData md = (ManifoldData) event.object1;
308303
//convert deserialized points to Fxyz3D point3ds

src/main/java/edu/jhuapl/trinity/javafx/javafx3d/AsteroidFieldPane.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ private void resetAsteroids() {
757757

758758
public Manifold3D makeHull(List<Point3D> labelMatchedPoints, String label, Double tolerance) {
759759
Manifold3D manifold3D = new Manifold3D(
760-
labelMatchedPoints, true, true, true, tolerance
760+
labelMatchedPoints, true, false, false, tolerance
761761
);
762762
manifold3D.quickhullMeshView.setCullFace(CullFace.FRONT);
763763
// manifold3D.addEventHandler(MouseEvent.MOUSE_CLICKED, e -> {

src/main/java/edu/jhuapl/trinity/javafx/javafx3d/Manifold3D.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public class Manifold3D extends Group {
7777

7878
public Manifold3D(List<Point3D> point3DList, boolean triangulate, boolean makeLines, boolean makePoints, Double tolerance) {
7979
originalPoint3Ds = point3DList;
80-
buildHullMesh(point3DList, triangulate, makeLines, makePoints, tolerance);
80+
buildHullMesh(point3DList, triangulate, tolerance);
8181

8282
List<Point3D> fxyzPoints = new ArrayList<>();
8383
for (int i = 0; i < hull.getNumVertices(); i++) {
@@ -349,7 +349,7 @@ public void refreshMesh(List<Point3D> point3DList, boolean triangulate, boolean
349349
quickhullLinesTriangleMesh.getPoints().clear();
350350
quickhullLinesTriangleMesh.getTexCoords().clear();
351351
quickhullLinesTriangleMesh.getFaces().clear();
352-
buildHullMesh(point3DList, triangulate, makeLines, makePoints, tolerance);
352+
buildHullMesh(point3DList, triangulate, tolerance);
353353
quickhullMeshView.setMesh(quickhullTriangleMesh);
354354
if (makeLines) {
355355
quickhullLinesTriangleMesh.getPoints().addAll(quickhullTriangleMesh.getPoints());
@@ -361,7 +361,7 @@ public void refreshMesh(List<Point3D> point3DList, boolean triangulate, boolean
361361
// makeDebugPoints(hull, artScale, false);
362362
}
363363

364-
private void buildHullMesh(List<Point3D> point3DList, boolean triangulate, boolean makeLines, boolean makePoints, Double tolerance) {
364+
private void buildHullMesh(List<Point3D> point3DList, boolean triangulate, Double tolerance) {
365365
hull = new QuickHull3D();
366366
if (null != tolerance)
367367
hull.setExplicitDistanceTolerance(tolerance);
@@ -455,19 +455,21 @@ public void handle(long now) {
455455
}
456456

457457
public void makeLines() {
458+
boolean wasVisible = null != quickhullLinesMeshView
459+
? quickhullLinesMeshView.isVisible() : false;
458460
quickhullLinesTriangleMesh = new TriangleMesh();
459461
quickhullLinesTriangleMesh.getPoints().addAll(quickhullTriangleMesh.getPoints());
460462
quickhullLinesTriangleMesh.getTexCoords().addAll(quickhullTriangleMesh.getTexCoords());
461463
quickhullLinesTriangleMesh.getFaces().addAll(quickhullTriangleMesh.getFaces());
462464

463465
quickhullLinesMeshView = new MeshView(quickhullLinesTriangleMesh);
464-
PhongMaterial quickhullLinesMaterial = new PhongMaterial(Color.BLUE);
465-
quickhullLinesMaterial.setSpecularColor(Color.BLUE); //fix for aarch64 Mac Ventura
466+
PhongMaterial quickhullLinesMaterial = new PhongMaterial(Color.ALICEBLUE);
467+
quickhullLinesMaterial.setSpecularColor(Color.ALICEBLUE); //fix for aarch64 Mac Ventura
466468
quickhullLinesMeshView.setMaterial(quickhullLinesMaterial);
467469
quickhullLinesMeshView.setDrawMode(DrawMode.LINE);
468470
quickhullLinesMeshView.setCullFace(CullFace.NONE);
469471
quickhullLinesMeshView.setMouseTransparent(true);
470-
472+
quickhullLinesMeshView.setVisible(wasVisible);
471473
getChildren().add(quickhullLinesMeshView);
472474
}
473475

@@ -489,14 +491,15 @@ public void makeDebugPoints(QuickHull3D hull, float scale, boolean print) {
489491
sb.append(", ");
490492
}
491493

492-
Sphere sphere = new Sphere(2.5);
493-
PhongMaterial mat = new PhongMaterial(Color.BLUE);
494-
mat.setSpecularColor(Color.BLUE); // fix for aarch64 Mac Ventura
494+
Sphere sphere = new Sphere(1.5);
495+
PhongMaterial mat = new PhongMaterial(Color.ALICEBLUE);
496+
mat.setSpecularColor(Color.ALICEBLUE); // fix for aarch64 Mac Ventura
495497
sphere.setMaterial(mat);
496498
sphere.setTranslateX(point3D.x);
497499
sphere.setTranslateY(point3D.y);
498500
sphere.setTranslateZ(point3D.z);
499501
extrasGroup.getChildren().add(sphere);
502+
sphere.setVisible(false);
500503

501504
Label newLabel = new Label(String.valueOf(i));
502505
labelGroup.getChildren().addAll(newLabel);

src/main/java/edu/jhuapl/trinity/utils/clustering/ClusterUtils.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import org.slf4j.Logger;
44
import org.slf4j.LoggerFactory;
55

6+
import java.util.ArrayList;
7+
import java.util.List;
8+
69
import static java.lang.Math.abs;
710
import static java.lang.Math.sqrt;
811

@@ -56,6 +59,29 @@ public static double squaredDistanceWithMissingValues(double[] x, double[] y) {
5659
return dist;
5760
}
5861

62+
public static List<List<double[]>> extractGMMClusters(
63+
double[][] data,
64+
GaussianMixture gmm,
65+
double threshold) {
66+
List<List<double[]>> clusterPoints = new ArrayList<>();
67+
for (int i = 0; i < gmm.components.length; i++) {
68+
clusterPoints.add(new ArrayList<>());
69+
}
70+
71+
for (double[] x : data) {
72+
double[] post = gmm.posteriori(x);
73+
int k = ClusterUtils.whichMax(post);
74+
if (post[k] >= threshold) {
75+
clusterPoints.get(k).add(x);
76+
}
77+
}
78+
79+
// Filter out tiny/degenerate clusters if needed
80+
clusterPoints.removeIf(list -> list.size() < 4);
81+
82+
return clusterPoints;
83+
}
84+
5985
/**
6086
* Returns the sum of an array.
6187
*

src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianDistribution.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,20 @@ public double scatter() {
200200
return sigmaDet;
201201
}
202202

203-
public double logp(double[] x) {
204-
if (x.length != dim) {
205-
throw new IllegalArgumentException("Sample has different dimension.");
206-
}
207-
203+
public double mahalanobis2(double[] x) {
208204
double[] v = x.clone();
209205
ClusterUtils.sub(v, mu);
210-
// double result = sigmaInv.xAx(v) / -2.0;
211-
// double[] Ax = mv(x);
212-
double[] Ax = sigmaInv.operate(v);
213-
double result = ClusterUtils.dot(x, Ax) / -2.0;
214-
return result - pdfConstant;
206+
double[] Av = sigmaInv.operate(v);
207+
return ClusterUtils.dot(v, Av);
208+
}
209+
210+
public double logp(double[] x) {
211+
if (x.length != dim) throw new IllegalArgumentException("Sample has different dimension.");
212+
double[] v = x.clone();
213+
ClusterUtils.sub(v, mu); // v = x - μ
214+
double[] Av = sigmaInv.operate(v); // Σ⁻¹ v
215+
double quad = ClusterUtils.dot(v, Av); // vᵀ Σ⁻¹ v
216+
return -0.5 * quad - pdfConstant;
215217
}
216218

217219
public double p(double[] x) {
@@ -472,6 +474,10 @@ public double logLikelihood(double[][] x) {
472474
return L;
473475
}
474476

477+
public int dim() {
478+
return dim;
479+
}
480+
475481
@Override
476482
public String toString() {
477483
return String.format("Gaussian(mu = %s, sigma = %s)", Arrays.toString(mu), sigma);

src/main/java/edu/jhuapl/trinity/utils/clustering/GaussianMixture.java

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -240,30 +240,45 @@ public double[] mean() {
240240
}
241241

242242
public RealMatrix cov() {
243-
double w = components[0].priori();
244-
RealMatrix v = components[0].distribution().cov();
245-
246-
int m = v.getRowDimension();
247-
int n = v.getColumnDimension();
248-
RealMatrix cov = MatrixUtils.createRealMatrix(m, n);
243+
double[] mu = mean();
244+
int d = mu.length;
245+
RealMatrix C = MatrixUtils.createRealMatrix(d, d);
249246

250-
for (int i = 0; i < m; i++) {
251-
for (int j = 0; j < n; j++) {
252-
cov.setEntry(i, j, w * w * v.getEntry(i, j));
253-
}
247+
// within-component variance
248+
for (GaussianMixtureComponent c : components) {
249+
double w = c.priori();
250+
RealMatrix Sk = c.distribution().cov();
251+
C = C.add(Sk.scalarMultiply(w));
254252
}
255253

256-
for (int k = 1; k < components.length; k++) {
257-
w = components[k].priori();
258-
v = components[k].distribution().cov();
259-
for (int i = 0; i < m; i++) {
260-
for (int j = 0; j < n; j++) {
261-
cov.addToEntry(i, j, w * w * v.getEntry(i, j));
262-
}
263-
}
254+
// between-component variance
255+
for (GaussianMixtureComponent c : components) {
256+
double w = c.priori();
257+
double[] mk = c.distribution().mean();
258+
double[] diff = mk.clone();
259+
ClusterUtils.sub(diff, mu);
260+
RealMatrix outer = MatrixUtils.createColumnRealMatrix(diff)
261+
.multiply(MatrixUtils.createRowRealMatrix(diff));
262+
C = C.add(outer.scalarMultiply(w));
264263
}
264+
return C;
265+
}
266+
267+
public boolean inDistribution(double[] x, double q) {
268+
// choose most responsible component
269+
double[] r = posteriori(x);
270+
int idx = ClusterUtils.whichMax(r);
271+
GaussianMixtureComponent c = components[idx];
272+
double d2 = c.distribution().mahalanobis2(x);
273+
// chi-square threshold
274+
org.apache.commons.math3.distribution.ChiSquaredDistribution chi =
275+
new org.apache.commons.math3.distribution.ChiSquaredDistribution(c.distribution().dim());
276+
double thresh = chi.inverseCumulativeProbability(q);
277+
return d2 <= thresh;
278+
}
265279

266-
return cov;
280+
public boolean inDistributionByLogP(double[] x, double tau) {
281+
return Math.log(p(x)) >= tau;
267282
}
268283

269284
public Pair<Integer, Double> maxPostProb(double[] x) {

src/main/resources/edu/jhuapl/trinity/fxml/ManifoldControl.fxml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,9 @@
450450
</HBox>
451451
<HBox alignment="CENTER" spacing="10.0" GridPane.rowIndex="2">
452452
<children>
453-
<CheckBox fx:id="showWireframeCheckBox" mnemonicParsing="false" selected="true"
453+
<CheckBox fx:id="showWireframeCheckBox" mnemonicParsing="false" selected="false"
454454
text="Show Wire Frame"/>
455-
<CheckBox fx:id="showControlPointsCheckBox" mnemonicParsing="false" selected="true"
455+
<CheckBox fx:id="showControlPointsCheckBox" mnemonicParsing="false" selected="false"
456456
text="Show Control Points"/>
457457
</children>
458458
</HBox>

0 commit comments

Comments
 (0)