Skip to content

Commit 68cc3c6

Browse files
enchancement: prevent Data Leakage, not sending testdata for bestK
1 parent 4e330dd commit 68cc3c6

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

ConsoleApp2/Analyser.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public static IEstimator<ITransformer> FeaturizeText(MLContext mlContext)
6969
return mlContext.Transforms.Text.FeaturizeText(FeaturesColumnName, nameof(CommitMLData.CommitName));
7070
}
7171

72-
public static int GetOrFindBestK(MLContext mlContext, IDataView trainData, IDataView testData, IEstimator<ITransformer> featurizer, string kFilePath)
72+
public static int GetOrFindBestK(MLContext mlContext, IDataView trainData, IEstimator<ITransformer> featurizer, string kFilePath)
7373
{
7474
if (File.Exists(kFilePath))
7575
{
@@ -80,16 +80,20 @@ public static int GetOrFindBestK(MLContext mlContext, IDataView trainData, IData
8080
}
8181
}
8282

83-
Console.WriteLine("Finding best K via Grid Search...");
83+
Console.WriteLine("Finding best K via Grid Search using validation split...");
84+
var split = mlContext.Data.TrainTestSplit(trainData, testFraction: 0.2);
85+
var subTrainData = split.TrainSet;
86+
var validationData = split.TestSet;
87+
8488
int bestK = 2;
8589
double bestMetric = double.MaxValue; // Lower Davies-Bouldin is better for measuring clustering quality
8690

8791
for (int k = 2; k <= 10; k++)
8892
{
8993
var pipeline = featurizer.Append(mlContext.Clustering.Trainers.KMeans(featureColumnName: FeaturesColumnName, numberOfClusters: k));
90-
var model = pipeline.Fit(trainData);
91-
92-
var predictions = model.Transform(testData);
94+
var model = pipeline.Fit(subTrainData);
95+
96+
var predictions = model.Transform(validationData);
9397
var metrics = mlContext.Clustering.Evaluate(predictions, labelColumnName: null, scoreColumnName: "Score", featureColumnName: FeaturesColumnName);
9498

9599
Console.WriteLine($"K = {k} | Davies-Bouldin: {metrics.DaviesBouldinIndex:F4} | Avg Distance: {metrics.AverageDistance:F4}");

ConsoleApp2/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ static async Task Main(string[] args)
3535
var featurizer = Analyser.FeaturizeText(mlContext);
3636

3737
// 4. Find Best K using Grid Search or load from file
38-
int bestK = Analyser.GetOrFindBestK(mlContext, split.TrainSet, split.TestSet, featurizer, kFilePath);
38+
int bestK = Analyser.GetOrFindBestK(mlContext, split.TrainSet, featurizer, kFilePath);
3939

4040
// 5. Train KMeans Clusterer
4141
var model = Analyser.TrainKMeansClusterer(mlContext, split.TrainSet, featurizer, bestK);

0 commit comments

Comments
 (0)