Skip to content

Commit 36ee3c3

Browse files
committed
Add MST filter to full data pipeline
1 parent cad715b commit 36ee3c3

2 files changed

Lines changed: 142 additions & 27 deletions

File tree

DVRViewPlugin/src/VolumeRenderer.cpp

Lines changed: 124 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ void VolumeRenderer::batchSearch(
859859
resultQueue.pop();
860860
}
861861

862-
QVector2D meanPos = ComputeMeanOfNN(answers, k, positionData, useWeightedMean);
862+
QVector2D meanPos = ComputeMeanOfNN(answers, k, positionData);
863863
meanPositionData[i * 2] = meanPos.x();
864864
meanPositionData[i * 2 + 1] = meanPos.y();
865865

@@ -962,7 +962,7 @@ void VolumeRenderer::getGPUFullDataModeBatches(std::vector<float>& frontfacesDat
962962
batches[batchIndex].push_back(idx);
963963

964964
// Compute the number of samples along this ray.
965-
int sampleCount = static_cast<int>(rayLength / _stepSize);
965+
int sampleCount = std::ceil(rayLength / _stepSize);
966966

967967
batchRaySampleAmount[batchIndex].push_back(sampleCount);
968968
// Update the batch total (in bytes) for partitioning.
@@ -979,7 +979,7 @@ void VolumeRenderer::getGPUFullDataModeBatches(std::vector<float>& frontfacesDat
979979
// Combine as many of the small batches as can possibly fit in the indicated GPU memory ---
980980

981981
// Calculate available GPU memory for the batch transfer
982-
size_t availableMemoryInBytes = std::min(int(_fullGPUMemorySize - _fullDataMemorySize - 100000), 2000000); // ~100MB reserved for other data
982+
size_t availableMemoryInBytes = std::min(size_t(_fullGPUMemorySize - _fullDataMemorySize - 100000), (size_t(2 * 1024 * 1024) * 1024)); // ~100MB reserved for other data
983983
if (availableMemoryInBytes < 0 || availableMemoryInBytes < maxBatchMemory)
984984
throw std::runtime_error("Not enough GPU memory available for the GPU-CPU batch transfer.");
985985

@@ -1153,27 +1153,26 @@ void VolumeRenderer::retrieveBatchFullData(std::vector<float>& cpuOutput, int ba
11531153

11541154
// TODO : This function should be moved to a more appropriate location, as it is not specific to the VolumeRenderer class.
11551155
// Compute the unweighted mean of a std::vector<QVector2D>
1156-
QVector2D computeMean(const std::vector<QVector2D>& points) {
1157-
if (points.empty())
1158-
return QVector2D(0, 0);
1159-
1160-
QVector2D sum = std::accumulate(points.begin(), points.end(), QVector2D(0, 0));
1161-
return sum / static_cast<float>(points.size());
1156+
QVector2D computeMean(const std::vector<QVector2D>& points,
1157+
const std::vector<int>& indices) {
1158+
QVector2D sum(0, 0);
1159+
for (int idx : indices) sum += points[idx];
1160+
return sum / float(indices.size());
11621161
}
11631162

1163+
11641164
// TODO : This function should be moved to a more appropriate location, as it is not specific to the VolumeRenderer class.
11651165
// Compute the weighted mean of a std::vector<QVector2D> given corresponding weight values.
1166-
QVector2D computeWeightedMean(const std::vector<QVector2D>& points, const std::vector<float>& weights) {
1167-
if (points.empty() || points.size() != weights.size())
1168-
return QVector2D(0, 0);
1169-
1170-
QVector2D weightedSum(0, 0);
1171-
float totalWeight = 0.0f;
1172-
for (size_t i = 0; i < points.size(); ++i) {
1173-
weightedSum += points[i] * weights[i];
1174-
totalWeight += weights[i];
1166+
QVector2D computeWeightedMean(const std::vector<QVector2D>& points,
1167+
const std::vector<float>& weights,
1168+
const std::vector<int>& indices) {
1169+
QVector2D sum(0, 0);
1170+
float weightSum = 0.0f;
1171+
for (int idx : indices) {
1172+
sum += points[idx] * weights[idx];
1173+
weightSum += weights[idx];
11751174
}
1176-
return (totalWeight > 0.0f) ? (weightedSum / totalWeight) : QVector2D(0, 0);
1175+
return sum / (weightSum);
11771176
}
11781177

11791178
// This function renders the full data to the screen using the composite shader.
@@ -1312,9 +1311,10 @@ void VolumeRenderer::renderBatchToScreen(int batchIndex, uint32_t sampleDim, std
13121311
// @param neighbours: A vector of pairs containing the distance and label of each neighbour.
13131312
// @param k: The number of neighbours to consider.
13141313
// @param positionData: A vector containing the 2D positions of the neighbours.
1315-
// @param useWeightedMean: A boolean indicating whether to use a weighted mean or not.
1316-
QVector2D VolumeRenderer::ComputeMeanOfNN(const std::vector<std::pair<float, hnswlib::labeltype>>& neighbors, int k, const std::vector<float>& positionData, bool useWeightedMean) {
1314+
QVector2D VolumeRenderer::ComputeMeanOfNN(const std::vector<std::pair<float, hnswlib::labeltype>>& neighbors, int k, const std::vector<float>& positionData) {
13171315
float epsilon = 1.0f; // To avoid division by zero and limit the impact of very close neighbours.
1316+
float clusterSlack = 0.1f; // Slack for cluster thresholding, can be adjusted based on the dataset.
1317+
13181318
std::vector<float> weights(k, 0.0f);
13191319
std::vector<QVector2D> candidatePositions(k, QVector2D(0.0f, 0.0f));
13201320
int j = 0;
@@ -1326,14 +1326,111 @@ QVector2D VolumeRenderer::ComputeMeanOfNN(const std::vector<std::pair<float, hns
13261326
candidatePositions[j] = QVector2D(posX, posY);
13271327
j++;
13281328
}
1329+
1330+
// 2) Optionally extract largest cluster via MST + relative-threshold
1331+
std::vector<int> chosenIndices;
1332+
chosenIndices.reserve(k);
1333+
1334+
if (useLargestCluster && k > 1) {
1335+
// 2a) Build full distance matrix
1336+
std::vector<float> distanceMatrix(k * k);
1337+
for (int a = 0; a < k; ++a) {
1338+
distanceMatrix[a * k + a] = 0.0f;
1339+
for (int b = a + 1; b < k; ++b) {
1340+
float dx = candidatePositions[a].x() - candidatePositions[b].x();
1341+
float dy = candidatePositions[a].y() - candidatePositions[b].y();
1342+
float d = std::sqrt(dx * dx + dy * dy);
1343+
distanceMatrix[a * k + b] = d;
1344+
distanceMatrix[b * k + a] = d;
1345+
}
1346+
}
1347+
1348+
// 2b) Prim’s MST: collect k-1 smallest edges
1349+
std::vector<bool> inTree(k, false);
1350+
std::vector<float> minEdgeToTree(k, FLT_MAX);
1351+
std::vector<int> mstParent(k, -1);
1352+
struct MstEdge { float weight; int u, v; };
1353+
std::vector<MstEdge> mstEdges;
1354+
mstEdges.reserve(k - 1);
1355+
1356+
inTree[0] = true;
1357+
for (int v = 1; v < k; ++v) {
1358+
minEdgeToTree[v] = distanceMatrix[v];
1359+
mstParent[v] = 0;
1360+
}
1361+
1362+
for (int e = 0; e < k - 1; ++e) {
1363+
// pick frontier vertex with smallest connecting edge
1364+
int bestV = -1;
1365+
float bestW = FLT_MAX;
1366+
for (int v = 0; v < k; ++v) {
1367+
if (!inTree[v] && minEdgeToTree[v] < bestW) {
1368+
bestW = minEdgeToTree[v];
1369+
bestV = v;
1370+
}
1371+
}
1372+
mstEdges.push_back({ bestW, mstParent[bestV], bestV });
1373+
inTree[bestV] = true;
1374+
1375+
// update frontier
1376+
for (int v = 0; v < k; ++v) {
1377+
float w = distanceMatrix[bestV * k + v];
1378+
if (!inTree[v] && w < minEdgeToTree[v]) {
1379+
minEdgeToTree[v] = w;
1380+
mstParent[v] = bestV;
1381+
}
1382+
}
1383+
}
1384+
1385+
// 2c) Determine threshold: smallest MST edge + slack
1386+
float minMstWeight = FLT_MAX;
1387+
for (auto& e : mstEdges) {
1388+
minMstWeight = std::min(minMstWeight, e.weight);
1389+
}
1390+
float threshold = minMstWeight + clusterSlack;
1391+
1392+
// 2d) Cut edges > threshold and form clusters via union-find
1393+
UnionFind uf(k);
1394+
for (auto& e : mstEdges) {
1395+
if (e.weight <= threshold) {
1396+
uf.unify(e.u, e.v);
1397+
}
1398+
}
1399+
1400+
// 2e) Group vertices by root to get clusters
1401+
std::unordered_map<int, std::vector<int>> clusters;
1402+
clusters.reserve(k);
1403+
for (int v = 0; v < k; ++v) {
1404+
clusters[uf.findRoot(v)].push_back(v);
1405+
}
1406+
1407+
// 2f) Pick the largest cluster
1408+
int maxSize = 0;
1409+
for (auto& kv : clusters) {
1410+
int sz = int(kv.second.size());
1411+
if (sz > maxSize) {
1412+
maxSize = sz;
1413+
chosenIndices = kv.second;
1414+
}
1415+
}
1416+
}
1417+
else {
1418+
// use all neighbors if clustering disabled
1419+
chosenIndices.resize(k);
1420+
std::iota(chosenIndices.begin(), chosenIndices.end(), 0);
1421+
}
1422+
1423+
// 3) Compute mean position (weighted or unweighted)
13291424
QVector2D meanPos;
13301425
if (useWeightedMean)
1331-
meanPos = computeWeightedMean(candidatePositions, weights);
1426+
meanPos = computeWeightedMean(candidatePositions, weights, chosenIndices);
13321427
else
1333-
meanPos = computeMean(candidatePositions);
1428+
meanPos = computeMean(candidatePositions, chosenIndices);
13341429
return meanPos;
13351430
}
13361431

1432+
1433+
13371434
void VolumeRenderer::updateRenderModeParameters()
13381435
{
13391436
// Get the screen dimensions and allocate arrays to read the front and back face textures.
@@ -1406,14 +1503,15 @@ void VolumeRenderer::renderFullData()
14061503
_reducedPosDataset->populateDataForDimensions(positionData, std::vector<int>{0, 1});
14071504
normalizePositionData(positionData);
14081505

1409-
// Perform the ANN search for the batch using the CPU output data and use them to retrieve the estimated position in the 2D space ---
1410-
14111506
// Run approximate nearest-neighbour search on the retrieved CPU data.
14121507
uint32_t sampleDim = _volumeDataset->getComponentsPerVoxel();
14131508
int64_t numQueries = static_cast<int64_t>(cpuOutput.size() / sampleDim);
14141509
std::vector<float> meanPositions(numQueries * 2);
14151510

1416-
int k = 3;
1511+
int k = 1; // Number of nearest neighbours to consider for the mean position computation.
1512+
if (_useShading) { // I just use the same button since it is not used anyway
1513+
k = 9;
1514+
}
14171515
bool useWeightedMean = true; // change to "true" if you need weighting.
14181516
batchSearch(cpuOutput, positionData, sampleDim, k, useWeightedMean, meanPositions);
14191517

DVRViewPlugin/src/VolumeRenderer.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class VolumeRenderer : protected QOpenGLFunctions_4_3_Core
123123
void getGPUFullDataModeBatches(std::vector<float>& frontfacesData, std::vector<float>& backfacesData);
124124
void retrieveBatchFullData(std::vector<float>& cpuOutput, int batchIndex, bool deleteBuffers);
125125
void renderBatchToScreen(int batchIndex, uint32_t sampleDim, std::vector<float>& meanPositions);
126-
QVector2D ComputeMeanOfNN(const std::vector<std::pair<float, hnswlib::labeltype>>& neighbors, int k, const std::vector<float>& positionData, bool useWeightedMean);
126+
QVector2D ComputeMeanOfNN(const std::vector<std::pair<float, hnswlib::labeltype>>& neighbors, int k, const std::vector<float>& positionData);
127127
void updateRenderModeParameters();
128128

129129
void renderFullData();
@@ -252,5 +252,22 @@ class VolumeRenderer : protected QOpenGLFunctions_4_3_Core
252252
// Calculate the size of the data in bytes
253253
size_t edgeTableSize = sizeof(MarchingCubes::edgeTable);
254254
size_t triTableSize = sizeof(MarchingCubes::triTable);
255+
256+
bool useWeightedMean = true; // change to "true" if you need weighting.
257+
bool useLargestCluster = true; // change to "false" if you do not want to use the largest cluster.
258+
259+
// Union-Find structure for connected components (used full data render pipeline to remove outliers)
260+
struct UnionFind {
261+
std::vector<int> parent;
262+
UnionFind(int n) : parent(n) { std::iota(parent.begin(), parent.end(), 0); }
263+
int findRoot(int x) {
264+
return parent[x] == x ? x : (parent[x] = findRoot(parent[x]));
265+
}
266+
void unify(int a, int b) {
267+
int ra = findRoot(a), rb = findRoot(b);
268+
if (ra != rb) parent[rb] = ra;
269+
}
270+
};
271+
255272
};
256273

0 commit comments

Comments
 (0)