|
2 | 2 |
|
3 | 3 | #include <cuda/std/span> |
4 | 4 | #include <cuda/std/tuple> |
| 5 | +#include <cuda_runtime.h> |
5 | 6 | #include <stdexcept> |
| 7 | +#include <thrust/device_vector.h> |
| 8 | +#include <thrust/host_vector.h> |
| 9 | +#include <thrust/iterator/zip_iterator.h> |
| 10 | +#include <vector> |
6 | 11 |
|
7 | 12 | #include "geometry.cuh" |
8 | 13 | #include "utils.cuh" |
@@ -46,74 +51,49 @@ public: |
46 | 51 | template <MemoryLocation location> |
47 | 52 | class FMBScene { |
48 | 53 | private: |
49 | | - FMB* fmbs_; |
50 | | - float* log_weights_; |
| 54 | + // Host memory -> thrust::host_vector |
| 55 | + // Device memory -> thrust::device_vector |
| 56 | + template <typename T> |
| 57 | + using vector_t = std::conditional_t<location == MemoryLocation::HOST, thrust::host_vector<T>, |
| 58 | + thrust::device_vector<T>>; |
| 59 | + |
| 60 | + vector_t<FMB> fmbs_; |
| 61 | + vector_t<float> log_weights_; |
51 | 62 | size_t size_; |
52 | 63 |
|
53 | 64 | public: |
54 | | - __host__ FMBScene(size_t size); |
55 | | - |
56 | | - __host__ ~FMBScene(); |
57 | | - |
58 | | - CUDA_CALLABLE cuda::std::tuple<FMB&, float&> operator[](const uint32_t i) { |
59 | | - return cuda::std::tie(fmbs_[i], log_weights_[i]); |
| 65 | + __host__ FMBScene(size_t size) : size_{size}, fmbs_(size), log_weights_(size) {}; |
| 66 | + |
| 67 | + // Copy constructor from std::vector |
| 68 | + // This enables easy construction from Python side |
| 69 | + __host__ FMBScene<location>(const std::vector<FMB>& fmbs, const std::vector<float>& log_weights) |
| 70 | + : size_{fmbs.size()}, fmbs_(fmbs.begin(), fmbs.end()), |
| 71 | + log_weights_(log_weights.begin(), log_weights.end()) { |
| 72 | + if (fmbs.size() != log_weights.size()) { |
| 73 | + throw std::invalid_argument( |
| 74 | + "FMBScene constructor: fmbs and log_weights must have the same size"); |
| 75 | + } |
60 | 76 | } |
61 | 77 |
|
62 | | - CUDA_CALLABLE cuda::std::tuple<const FMB&, const float&> operator[](const uint32_t i) const { |
63 | | - return cuda::std::tie(fmbs_[i], log_weights_[i]); |
| 78 | + CUDA_CALLABLE auto operator[](const uint32_t i) { |
| 79 | + return cuda::std::make_tuple(fmbs_[i], log_weights_[i]); |
64 | 80 | } |
65 | 81 |
|
66 | | - class Iterator { |
67 | | - private: |
68 | | - FMB* fmb_ptr_; |
69 | | - float* log_weight_ptr_; |
70 | | - |
71 | | - public: |
72 | | - CUDA_CALLABLE Iterator(FMB* const fmb_ptr, float* const log_weight_ptr) |
73 | | - : fmb_ptr_{fmb_ptr}, log_weight_ptr_{log_weight_ptr} {} |
74 | | - CUDA_CALLABLE cuda::std::tuple<FMB&, float&> operator*() { |
75 | | - return cuda::std::tie(*fmb_ptr_, *log_weight_ptr_); |
76 | | - } |
77 | | - CUDA_CALLABLE bool operator!=(const Iterator& other) const { |
78 | | - return fmb_ptr_ != other.fmb_ptr_ || log_weight_ptr_ != other.log_weight_ptr_; |
79 | | - } |
80 | | - CUDA_CALLABLE Iterator& operator++() { |
81 | | - fmb_ptr_++, log_weight_ptr_++; |
82 | | - return *this; |
83 | | - } |
84 | | - }; |
85 | | - |
86 | | - class ConstIterator { |
87 | | - private: |
88 | | - const FMB* fmb_ptr_; |
89 | | - const float* log_weight_ptr_; |
90 | | - |
91 | | - public: |
92 | | - CUDA_CALLABLE ConstIterator(const FMB* const fmb_ptr, const float* const log_weight_ptr) |
93 | | - : fmb_ptr_{fmb_ptr}, log_weight_ptr_{log_weight_ptr} {} |
94 | | - CUDA_CALLABLE cuda::std::tuple<const FMB&, const float&> operator*() const { |
95 | | - return cuda::std::tie(*fmb_ptr_, *log_weight_ptr_); |
96 | | - } |
97 | | - CUDA_CALLABLE bool operator!=(const ConstIterator& other) const { |
98 | | - return fmb_ptr_ != other.fmb_ptr_ || log_weight_ptr_ != other.log_weight_ptr_; |
99 | | - } |
100 | | - CUDA_CALLABLE ConstIterator& operator++() { |
101 | | - fmb_ptr_++, log_weight_ptr_++; |
102 | | - return *this; |
103 | | - } |
104 | | - }; |
| 82 | + CUDA_CALLABLE auto operator[](const uint32_t i) const { |
| 83 | + return cuda::std::make_tuple(fmbs_[i], log_weights_[i]); |
| 84 | + } |
105 | 85 |
|
106 | | - CUDA_CALLABLE Iterator begin() { |
107 | | - return Iterator(fmbs_, log_weights_); |
| 86 | + CUDA_CALLABLE auto begin() { |
| 87 | + return thrust::make_zip_iterator(fmbs_.begin(), log_weights_.begin()); |
108 | 88 | } |
109 | | - CUDA_CALLABLE Iterator end() { |
110 | | - return Iterator(fmbs_ + size_, log_weights_ + size_); |
| 89 | + CUDA_CALLABLE auto end() { |
| 90 | + return thrust::make_zip_iterator(fmbs_.end(), log_weights_.end()); |
111 | 91 | } |
112 | | - CUDA_CALLABLE ConstIterator begin() const { |
113 | | - return ConstIterator(fmbs_, log_weights_); |
| 92 | + CUDA_CALLABLE auto begin() const { |
| 93 | + return thrust::make_zip_iterator(fmbs_.begin(), log_weights_.begin()); |
114 | 94 | } |
115 | | - CUDA_CALLABLE ConstIterator end() const { |
116 | | - return ConstIterator(fmbs_ + size_, log_weights_ + size_); |
| 95 | + CUDA_CALLABLE auto end() const { |
| 96 | + return thrust::make_zip_iterator(fmbs_.end(), log_weights_.end()); |
117 | 97 | } |
118 | 98 | CUDA_CALLABLE const FMB& get_fmb(uint32_t idx) const { |
119 | 99 | return fmbs_[idx]; |
|
0 commit comments