Skip to content

Commit 942cdec

Browse files
Gradient descent optimizer javadocs
1 parent aed031f commit 942cdec

1 file changed

Lines changed: 49 additions & 0 deletions

File tree

src/GradientDescentOptimizer.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,42 @@
77
import javax.swing.JPanel;
88
import javax.swing.JTextField;
99

10+
/**
11+
* The GradientDescentOptimizer class provides functionality to optimize the coefficients
12+
* for a linear combination of features using the gradient descent algorithm. This optimization
13+
* aims to maximize the separability between classes in a dataset.
14+
*/
1015
public class GradientDescentOptimizer {
1116

1217
private final CsvViewer csvViewer;
1318
private final double learningRate;
1419
private final int maxIterations;
1520
private final double tolerance;
1621

22+
/**
23+
* Constructs a GradientDescentOptimizer with the specified parameters.
24+
*
25+
* @param csvViewer the CsvViewer instance that manages the data and UI.
26+
* @param learningRate the learning rate for the gradient descent algorithm.
27+
* @param maxIterations the maximum number of iterations for the optimization process.
28+
* @param tolerance the tolerance for convergence in the optimization process.
29+
*/
1730
public GradientDescentOptimizer(CsvViewer csvViewer, double learningRate, int maxIterations, double tolerance) {
1831
this.csvViewer = csvViewer;
1932
this.learningRate = learningRate;
2033
this.maxIterations = maxIterations;
2134
this.tolerance = tolerance;
2235
}
2336

37+
/**
38+
* Optimizes the coefficients for the linear combination using gradient descent.
39+
* The optimized coefficients are then updated in the provided JPanel.
40+
*
41+
* @param originalColumnIndices the list of indices corresponding to the original columns in the dataset.
42+
* @param coefficients the list of coefficients to be optimized.
43+
* @param panel the JPanel containing the UI components for coefficient inputs.
44+
* @param trigFunction the trigonometric function to apply to the linear combination.
45+
*/
2446
public void optimizeCoefficientsUsingGradientDescent(List<Integer> originalColumnIndices, List<Double> coefficients, JPanel panel, String trigFunction) {
2547
initializeCoefficients(coefficients);
2648

@@ -54,6 +76,11 @@ public void optimizeCoefficientsUsingGradientDescent(List<Integer> originalColum
5476
updatePanelFields(coefficients, panel);
5577
}
5678

79+
/**
80+
* Initializes the coefficients to a default value of 1.0 if they are not already set.
81+
*
82+
* @param coefficients the list of coefficients to initialize.
83+
*/
5784
private void initializeCoefficients(List<Double> coefficients) {
5885
for (int i = 0; i < coefficients.size(); i++) {
5986
if (coefficients.get(i) == null) {
@@ -62,6 +89,15 @@ private void initializeCoefficients(List<Double> coefficients) {
6289
}
6390
}
6491

92+
/**
93+
* Evaluates the class separability using the specified coefficients and trigonometric function.
94+
* The separability is measured as the ratio of between-class variance to within-class variance.
95+
*
96+
* @param originalColumnIndices the list of indices corresponding to the original columns in the dataset.
97+
* @param coefficients the array of coefficients for the linear combination.
98+
* @param trigFunction the trigonometric function to apply to the linear combination.
99+
* @return the class separability score.
100+
*/
65101
private double evaluateClassSeparation(List<Integer> originalColumnIndices, double[] coefficients, String trigFunction) {
66102
Map<String, List<Double>> classSums = new HashMap<>();
67103
int classColumnIndex = csvViewer.getClassColumnIndex();
@@ -106,6 +142,13 @@ private double evaluateClassSeparation(List<Integer> originalColumnIndices, doub
106142
return betweenClassVariance / withinClassVariance;
107143
}
108144

145+
/**
146+
* Applies the specified trigonometric function to a given value.
147+
*
148+
* @param value the value to which the trigonometric function is applied.
149+
* @param trigFunction the trigonometric function to apply.
150+
* @return the result of applying the trigonometric function to the value.
151+
*/
109152
private double applyTrigFunction(double value, String trigFunction) {
110153
switch (trigFunction) {
111154
case "cos":
@@ -126,6 +169,12 @@ private double applyTrigFunction(double value, String trigFunction) {
126169
}
127170
}
128171

172+
/**
173+
* Updates the text fields in the provided JPanel with the optimized coefficients.
174+
*
175+
* @param coefficients the list of optimized coefficients.
176+
* @param panel the JPanel containing the text fields to update.
177+
*/
129178
private void updatePanelFields(List<Double> coefficients, JPanel panel) {
130179
for (int i = 0; i < coefficients.size(); i++) {
131180
JTextField coefficientField = (JTextField) panel.getComponent(2 * i + 1);

0 commit comments

Comments
 (0)