Skip to content

Commit 43e7713

Browse files
Copilotssheorey
andcommitted
Add CPU fallback for NNS operations on SYCL devices
Co-authored-by: ssheorey <41028320+ssheorey@users.noreply.github.com>
1 parent 3569d35 commit 43e7713

4 files changed

Lines changed: 187 additions & 39 deletions

File tree

cpp/open3d/core/nns/NearestNeighborSearch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ class NearestNeighborSearch {
118118
const double radius,
119119
const int max_knn) const;
120120

121+
/// Get the device of the dataset points.
122+
core::Device GetDatasetDevice() const {
123+
return dataset_points_.GetDevice();
124+
}
125+
121126
private:
122127
bool SetIndex();
123128

cpp/open3d/t/geometry/PointCloud.cpp

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -429,17 +429,22 @@ std::tuple<PointCloud, core::Tensor> PointCloud::RemoveRadiusOutliers(
429429
"Illegal input parameters, number of points and radius must be "
430430
"positive");
431431
}
432-
core::nns::NearestNeighborSearch target_nns(GetPointPositions());
432+
// NNS does not support SYCL devices. Use CPU for the NNS index.
433+
const core::Device device = GetDevice();
434+
const core::Device cpu_device("CPU:0");
435+
const core::Device nns_device = device.IsSYCL() ? cpu_device : device;
436+
const core::Tensor positions_nns = GetPointPositions().To(nns_device);
437+
core::nns::NearestNeighborSearch target_nns(positions_nns);
433438

434439
const bool check = target_nns.FixedRadiusIndex(search_radius);
435440
if (!check) {
436441
utility::LogError("Fixed radius search index is not set.");
437442
}
438443

439444
core::Tensor indices, distance, row_splits;
440-
std::tie(indices, distance, row_splits) = target_nns.FixedRadiusSearch(
441-
GetPointPositions(), search_radius, false);
442-
row_splits = row_splits.To(GetDevice());
445+
std::tie(indices, distance, row_splits) =
446+
target_nns.FixedRadiusSearch(positions_nns, search_radius, false);
447+
row_splits = row_splits.To(device);
443448

444449
const int64_t size = row_splits.GetLength();
445450
const core::Tensor num_neighbors =
@@ -462,15 +467,22 @@ std::tuple<PointCloud, core::Tensor> PointCloud::RemoveStatisticalOutliers(
462467
core::Tensor({0}, core::Bool, GetDevice()));
463468
}
464469

465-
core::nns::NearestNeighborSearch nns(GetPointPositions().Contiguous());
470+
// NNS does not support SYCL devices. Use CPU for the NNS index.
471+
const core::Device device = GetDevice();
472+
const core::Device cpu_device("CPU:0");
473+
const core::Device nns_device = device.IsSYCL() ? cpu_device : device;
474+
const core::Tensor positions_nns =
475+
GetPointPositions().Contiguous().To(nns_device);
476+
core::nns::NearestNeighborSearch nns(positions_nns);
466477
const bool check = nns.KnnIndex();
467478
if (!check) {
468479
utility::LogError("Knn search index is not set.");
469480
}
470481

471482
core::Tensor indices, distance2;
472-
std::tie(indices, distance2) =
473-
nns.KnnSearch(GetPointPositions(), nb_neighbors);
483+
std::tie(indices, distance2) = nns.KnnSearch(positions_nns, nb_neighbors);
484+
// Move results back to original device for further tensor operations.
485+
distance2 = distance2.To(device);
474486

475487
core::Tensor avg_distances = distance2.Sqrt().Mean({1});
476488
const double cloud_mean =
@@ -587,15 +599,22 @@ std::tuple<PointCloud, core::Tensor> PointCloud::ComputeBoundaryPoints(
587599
const core::Tensor normals_d = GetPointNormals().Contiguous();
588600

589601
// Compute nearest neighbors.
602+
// NNS does not support SYCL devices. Use CPU for the NNS index.
603+
const core::Device cpu_device("CPU:0");
604+
const core::Device nns_device = device.IsSYCL() ? cpu_device : device;
605+
const core::Tensor points_nns = points_d.To(nns_device);
590606
core::Tensor indices, distance2, counts;
591-
core::nns::NearestNeighborSearch tree(points_d, core::Int32);
607+
core::nns::NearestNeighborSearch tree(points_nns, core::Int32);
592608

593609
bool check = tree.HybridIndex(radius);
594610
if (!check) {
595611
utility::LogError("Building HybridIndex failed.");
596612
}
597613
std::tie(indices, distance2, counts) =
598-
tree.HybridSearch(points_d, radius, max_nn);
614+
tree.HybridSearch(points_nns, radius, max_nn);
615+
// Move NNS results back to the original device.
616+
indices = indices.To(device);
617+
counts = counts.To(device);
599618
utility::LogDebug(
600619
"Use HybridSearch [max_nn: {} | radius {}] for computing "
601620
"boundary points.",
@@ -608,6 +627,18 @@ std::tuple<PointCloud, core::Tensor> PointCloud::ComputeBoundaryPoints(
608627
} else if (IsCUDA()) {
609628
CUDA_CALL(kernel::pointcloud::ComputeBoundaryPointsCUDA, points_d,
610629
normals_d, indices, counts, mask, angle_threshold);
630+
} else if (IsSYCL()) {
631+
// No SYCL kernel; use the CPU kernel with CPU tensors.
632+
const core::Tensor points_cpu = points_d.To(cpu_device);
633+
const core::Tensor normals_cpu = normals_d.To(cpu_device);
634+
const core::Tensor indices_cpu = indices.To(cpu_device);
635+
const core::Tensor counts_cpu = counts.To(cpu_device);
636+
core::Tensor mask_cpu =
637+
core::Tensor::Zeros({num_points}, core::Bool, cpu_device);
638+
kernel::pointcloud::ComputeBoundaryPointsCPU(points_cpu, normals_cpu,
639+
indices_cpu, counts_cpu,
640+
mask_cpu, angle_threshold);
641+
mask = mask_cpu.To(device);
611642
} else {
612643
utility::LogError("Unimplemented device");
613644
}
@@ -623,6 +654,8 @@ void PointCloud::EstimateNormals(
623654

624655
const core::Dtype dtype = this->GetPointPositions().GetDtype();
625656
const core::Device device = GetDevice();
657+
// CPU device for SYCL fallback (NNS is not supported on SYCL devices).
658+
const core::Device cpu_device("CPU:0");
626659

627660
const bool has_normals = HasPointNormals();
628661

@@ -655,6 +688,16 @@ void PointCloud::EstimateNormals(
655688
this->GetPointPositions().Contiguous(),
656689
this->GetPointAttr("covariances"), radius.value(),
657690
max_knn.value());
691+
} else if (IsSYCL()) {
692+
// NNS is not supported on SYCL; use the CPU kernel with CPU data.
693+
core::Tensor points_cpu =
694+
this->GetPointPositions().Contiguous().To(cpu_device);
695+
core::Tensor covariances_cpu =
696+
this->GetPointAttr("covariances").To(cpu_device);
697+
kernel::pointcloud::EstimateCovariancesUsingHybridSearchCPU(
698+
points_cpu, covariances_cpu, radius.value(),
699+
max_knn.value());
700+
this->SetPointAttr("covariances", covariances_cpu.To(device));
658701
} else {
659702
utility::LogError("Unimplemented device");
660703
}
@@ -669,12 +712,20 @@ void PointCloud::EstimateNormals(
669712
CUDA_CALL(kernel::pointcloud::EstimateCovariancesUsingKNNSearchCUDA,
670713
this->GetPointPositions().Contiguous(),
671714
this->GetPointAttr("covariances"), max_knn.value());
715+
} else if (IsSYCL()) {
716+
core::Tensor points_cpu =
717+
this->GetPointPositions().Contiguous().To(cpu_device);
718+
core::Tensor covariances_cpu =
719+
this->GetPointAttr("covariances").To(cpu_device);
720+
kernel::pointcloud::EstimateCovariancesUsingKNNSearchCPU(
721+
points_cpu, covariances_cpu, max_knn.value());
722+
this->SetPointAttr("covariances", covariances_cpu.To(device));
672723
} else {
673724
utility::LogError("Unimplemented device");
674725
}
675726
} else if (!max_knn.has_value() && radius.has_value()) {
676727
utility::LogDebug("Using Radius Search for computing covariances");
677-
// Computes and sets `covariances` attribute using KNN Search method.
728+
// Computes and sets `covariances` attribute using Radius Search method.
678729
if (IsCPU()) {
679730
kernel::pointcloud::EstimateCovariancesUsingRadiusSearchCPU(
680731
this->GetPointPositions().Contiguous(),
@@ -684,6 +735,14 @@ void PointCloud::EstimateNormals(
684735
EstimateCovariancesUsingRadiusSearchCUDA,
685736
this->GetPointPositions().Contiguous(),
686737
this->GetPointAttr("covariances"), radius.value());
738+
} else if (IsSYCL()) {
739+
core::Tensor points_cpu =
740+
this->GetPointPositions().Contiguous().To(cpu_device);
741+
core::Tensor covariances_cpu =
742+
this->GetPointAttr("covariances").To(cpu_device);
743+
kernel::pointcloud::EstimateCovariancesUsingRadiusSearchCPU(
744+
points_cpu, covariances_cpu, radius.value());
745+
this->SetPointAttr("covariances", covariances_cpu.To(device));
687746
} else {
688747
utility::LogError("Unimplemented device");
689748
}
@@ -700,6 +759,14 @@ void PointCloud::EstimateNormals(
700759
CUDA_CALL(kernel::pointcloud::EstimateNormalsFromCovariancesCUDA,
701760
this->GetPointAttr("covariances"), this->GetPointNormals(),
702761
has_normals);
762+
} else if (IsSYCL()) {
763+
// NNS is not supported on SYCL; use the CPU kernel with CPU data.
764+
core::Tensor covariances_cpu =
765+
this->GetPointAttr("covariances").To(cpu_device);
766+
core::Tensor normals_cpu = this->GetPointNormals().To(cpu_device);
767+
kernel::pointcloud::EstimateNormalsFromCovariancesCPU(
768+
covariances_cpu, normals_cpu, has_normals);
769+
this->SetPointNormals(normals_cpu.To(device));
703770
} else {
704771
utility::LogError("Unimplemented device");
705772
}
@@ -789,9 +856,10 @@ void PointCloud::EstimateColorGradients(
789856

790857
const core::Dtype dtype = this->GetPointColors().GetDtype();
791858
const core::Device device = GetDevice();
859+
// CPU device for SYCL fallback (NNS is not supported on SYCL devices).
860+
const core::Device cpu_device("CPU:0");
792861

793862
if (!this->HasPointAttr("color_gradients")) {
794-
this->SetPointAttr(
795863
"color_gradients",
796864
core::Tensor::Empty({GetPointPositions().GetLength(), 3}, dtype,
797865
device));
@@ -823,6 +891,16 @@ void PointCloud::EstimateColorGradients(
823891
this->GetPointColors().Contiguous(),
824892
this->GetPointAttr("color_gradients"), radius.value(),
825893
max_knn.value());
894+
} else if (IsSYCL()) {
895+
// NNS is not supported on SYCL; use the CPU kernel with CPU data.
896+
core::Tensor grad_cpu =
897+
this->GetPointAttr("color_gradients").To(cpu_device);
898+
kernel::pointcloud::EstimateColorGradientsUsingHybridSearchCPU(
899+
this->GetPointPositions().Contiguous().To(cpu_device),
900+
this->GetPointNormals().Contiguous().To(cpu_device),
901+
this->GetPointColors().Contiguous().To(cpu_device), grad_cpu,
902+
radius.value(), max_knn.value());
903+
this->SetPointAttr("color_gradients", grad_cpu.To(device));
826904
} else {
827905
utility::LogError("Unimplemented device");
828906
}
@@ -841,6 +919,15 @@ void PointCloud::EstimateColorGradients(
841919
this->GetPointNormals().Contiguous(),
842920
this->GetPointColors().Contiguous(),
843921
this->GetPointAttr("color_gradients"), max_knn.value());
922+
} else if (IsSYCL()) {
923+
core::Tensor grad_cpu =
924+
this->GetPointAttr("color_gradients").To(cpu_device);
925+
kernel::pointcloud::EstimateColorGradientsUsingKNNSearchCPU(
926+
this->GetPointPositions().Contiguous().To(cpu_device),
927+
this->GetPointNormals().Contiguous().To(cpu_device),
928+
this->GetPointColors().Contiguous().To(cpu_device), grad_cpu,
929+
max_knn.value());
930+
this->SetPointAttr("color_gradients", grad_cpu.To(device));
844931
} else {
845932
utility::LogError("Unimplemented device");
846933
}
@@ -859,6 +946,15 @@ void PointCloud::EstimateColorGradients(
859946
this->GetPointNormals().Contiguous(),
860947
this->GetPointColors().Contiguous(),
861948
this->GetPointAttr("color_gradients"), radius.value());
949+
} else if (IsSYCL()) {
950+
core::Tensor grad_cpu =
951+
this->GetPointAttr("color_gradients").To(cpu_device);
952+
kernel::pointcloud::EstimateColorGradientsUsingRadiusSearchCPU(
953+
this->GetPointPositions().Contiguous().To(cpu_device),
954+
this->GetPointNormals().Contiguous().To(cpu_device),
955+
this->GetPointColors().Contiguous().To(cpu_device), grad_cpu,
956+
radius.value());
957+
this->SetPointAttr("color_gradients", grad_cpu.To(device));
862958
} else {
863959
utility::LogError("Unimplemented device");
864960
}

0 commit comments

Comments
 (0)