Skip to content

Commit 0c3c2e2

Browse files
committed
Begin modularizing
1 parent e9c1771 commit 0c3c2e2

1 file changed

Lines changed: 173 additions & 100 deletions

File tree

mdio/intersection.h

Lines changed: 173 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -31,101 +31,107 @@ class IndexSelection {
3131

3232
template <typename T>
3333
mdio::Future<void> add_selection(const ValueDescriptor<T>& descriptor) {
34-
using Interval = typename Variable<T>::Interval;
35-
36-
MDIO_ASSIGN_OR_RETURN(auto var, dataset_.variables.get<T>(std::string(descriptor.label.label())));
37-
auto fut = var.Read();
38-
MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
39-
if (!fut.status().ok()) return fut.status();
40-
41-
auto data = fut.value();
42-
const T* data_ptr = data.get_data_accessor().data();
43-
Index offset = data.get_flattened_offset();
44-
Index n_samples = data.num_samples();
45-
46-
auto current_pos = intervals;
47-
bool isInRun = false;
48-
std::vector<std::vector<Interval>> local_runs;
34+
// using Interval = typename Variable<T>::Interval;
35+
36+
// MDIO_ASSIGN_OR_RETURN(auto var, dataset_.variables.get<T>(std::string(descriptor.label.label())));
37+
// auto fut = var.Read();
38+
// MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
39+
// if (!fut.status().ok()) return fut.status();
40+
41+
// auto data = fut.value();
42+
// const T* data_ptr = data.get_data_accessor().data();
43+
// Index offset = data.get_flattened_offset();
44+
// Index n_samples = data.num_samples();
45+
46+
// auto current_pos = intervals;
47+
// bool isInRun = false;
48+
// std::vector<std::vector<Interval>> local_runs;
49+
50+
// for (mdio::Index idx = offset; idx < offset + n_samples; ++idx) {
51+
// if (data_ptr[idx] == descriptor.value) {
52+
// if (!isInRun) {
53+
// isInRun = true;
54+
// std::vector<Interval> run = current_pos;
55+
// for (auto& pos : run) {
56+
// pos.exclusive_max = pos.inclusive_min + 1;
57+
// }
58+
// local_runs.push_back(std::move(run));
59+
// } else {
60+
// auto& run = local_runs.back();
61+
// for (auto i=0; i<current_pos.size(); ++i) {
62+
// run[i].exclusive_max = current_pos[i].inclusive_min + 1;
63+
// }
64+
// }
65+
// } else {
66+
// isInRun = false;
67+
// }
68+
// _current_position_increment<T>(current_pos, intervals);
69+
// }
70+
71+
// if (local_runs.empty()) {
72+
// std::stringstream ss;
73+
// ss << "No matches for coordinate '" << descriptor.label.label() << "'";
74+
// return absl::NotFoundError(ss.str());
75+
// }
76+
77+
// auto new_runs = _from_intervals<T>(local_runs);
78+
79+
// // First time calling add_selection_2
80+
// if (kept_runs_.empty()) {
81+
// kept_runs_ = std::move(new_runs);
82+
// } else {
83+
// // now intersect each kept_run with each new local run
84+
// std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> new_kept;
85+
// new_kept.reserve(kept_runs_.size() * local_runs.size());
86+
87+
// for (const auto& kept : kept_runs_) {
88+
// for (const auto& run : new_runs) {
89+
// // start from the old run
90+
// auto intersection = kept;
91+
// bool empty = false;
92+
93+
// // for each descriptor in the new run...
94+
// for (const auto& d_new : run) {
95+
// // try to find the same label in the kept run
96+
// auto it = std::find_if(
97+
// intersection.begin(), intersection.end(),
98+
// [&](auto const& d_old) {
99+
// return d_old.label.label() == d_new.label.label();
100+
// });
101+
102+
// if (it != intersection.end()) {
103+
// // intersect intervals
104+
// auto& d_old = *it;
105+
// auto new_min = std::max(d_old.start, d_new.start);
106+
// auto new_max = std::min(d_old.stop, d_new.stop);
107+
// if (new_min >= new_max) {
108+
// empty = true;
109+
// break;
110+
// }
111+
// d_old.start = new_min;
112+
// d_old.stop = new_max;
113+
// } else {
114+
// // brand-new dimension: append it
115+
// intersection.push_back(d_new);
116+
// }
117+
// }
118+
119+
// if (!empty) {
120+
// new_kept.push_back(std::move(intersection));
121+
// }
122+
// }
123+
// }
124+
125+
// kept_runs_ = std::move(new_kept);
126+
// }
127+
128+
// return absl::OkStatus();
49129

50-
for (mdio::Index idx = offset; idx < offset + n_samples; ++idx) {
51-
if (data_ptr[idx] == descriptor.value) {
52-
if (!isInRun) {
53-
isInRun = true;
54-
std::vector<Interval> run = current_pos;
55-
for (auto& pos : run) {
56-
pos.exclusive_max = pos.inclusive_min + 1;
57-
}
58-
local_runs.push_back(std::move(run));
59-
} else {
60-
auto& run = local_runs.back();
61-
for (auto i=0; i<current_pos.size(); ++i) {
62-
run[i].exclusive_max = current_pos[i].inclusive_min + 1;
63-
}
64-
}
65-
} else {
66-
isInRun = false;
67-
}
68-
_current_position_increment<T>(current_pos, intervals);
69-
}
70-
71-
if (local_runs.empty()) {
72-
std::stringstream ss;
73-
ss << "No matches for coordinate '" << descriptor.label.label() << "'";
74-
return absl::NotFoundError(ss.str());
75-
}
76-
77-
auto new_runs = _from_intervals<T>(local_runs);
78-
79-
// First time calling add_selection_2
80130
if (kept_runs_.empty()) {
81-
kept_runs_ = std::move(new_runs);
131+
return _init_runs(descriptor);
82132
} else {
83-
// now intersect each kept_run with each new local run
84-
std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> new_kept;
85-
new_kept.reserve(kept_runs_.size() * local_runs.size());
86-
87-
for (const auto& kept : kept_runs_) {
88-
for (const auto& run : new_runs) {
89-
// start from the old run
90-
auto intersection = kept;
91-
bool empty = false;
92-
93-
// for each descriptor in the new run...
94-
for (const auto& d_new : run) {
95-
// try to find the same label in the kept run
96-
auto it = std::find_if(
97-
intersection.begin(), intersection.end(),
98-
[&](auto const& d_old) {
99-
return d_old.label.label() == d_new.label.label();
100-
});
101-
102-
if (it != intersection.end()) {
103-
// intersect intervals
104-
auto& d_old = *it;
105-
auto new_min = std::max(d_old.start, d_new.start);
106-
auto new_max = std::min(d_old.stop, d_new.stop);
107-
if (new_min >= new_max) {
108-
empty = true;
109-
break;
110-
}
111-
d_old.start = new_min;
112-
d_old.stop = new_max;
113-
} else {
114-
// brand-new dimension: append it
115-
intersection.push_back(d_new);
116-
}
117-
}
118-
119-
if (!empty) {
120-
new_kept.push_back(std::move(intersection));
121-
}
122-
}
123-
}
124-
125-
kept_runs_ = std::move(new_kept);
133+
return _add_new_run(descriptor);
126134
}
127-
128-
return absl::OkStatus();
129135
}
130136

131137
/// \brief Emit a RangeDescriptor per surviving tuple coordinate, without coalescing.
@@ -157,10 +163,15 @@ class IndexSelection {
157163
std::vector<T> keys;
158164
keys.reserve(n);
159165
for (auto &f : reads) {
160-
if (!f.status().ok()) return f.status();
161-
auto data = f.value();
162-
keys.push_back(data.get_data_accessor().data()[data.get_flattened_offset()]
163-
);
166+
// if (!f.status().ok()) return f.status();
167+
// auto data = f.value();
168+
// keys.push_back(data.get_data_accessor().data()[data.get_flattened_offset()]);
169+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(f));
170+
auto data = std::get<0>(resolution);
171+
auto data_ptr = std::get<1>(resolution);
172+
auto offset = std::get<2>(resolution);
173+
// auto n = std::get<3>(resolution); // Not required
174+
keys.push_back(data_ptr[offset]);
164175
}
165176

166177
// 2) Build and stable-sort an index array [0…n-1] by key
@@ -193,11 +204,11 @@ class IndexSelection {
193204
MDIO_ASSIGN_OR_RETURN(auto ds, non_const_ds.isel(desc));
194205
MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.get<T>(output_variable));
195206
auto fut = var.Read();
196-
if (!fut.status().ok()) return fut.status();
197-
auto data = fut.value();
198-
T* data_ptr = data.get_data_accessor().data();
199-
Index n = data.num_samples();
200-
Index offset = data.get_flattened_offset();
207+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(fut));
208+
auto data = std::get<0>(resolution);
209+
auto data_ptr = std::get<1>(resolution);
210+
auto offset = std::get<2>(resolution);
211+
auto n = std::get<3>(resolution);
201212
std::vector<T> buffer(n);
202213
std::memcpy(buffer.data(), data_ptr + offset, n * sizeof(T));
203214
ret.insert(ret.end(), buffer.begin(), buffer.end());
@@ -213,6 +224,58 @@ class IndexSelection {
213224
tensorstore::IndexDomain<> base_domain_;
214225
std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> kept_runs_;
215226

227+
template <typename T>
228+
Future<void> _init_runs(const ValueDescriptor<T>& descriptor) {
229+
using Interval = typename Variable<T>::Interval;
230+
MDIO_ASSIGN_OR_RETURN(auto var, dataset_.variables.get<T>(std::string(descriptor.label.label())));
231+
auto fut = var.Read();
232+
MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
233+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(fut));
234+
auto data = std::get<0>(resolution);
235+
auto data_ptr = std::get<1>(resolution);
236+
auto offset = std::get<2>(resolution);
237+
auto n_samples = std::get<3>(resolution);
238+
239+
auto current_pos = intervals;
240+
bool isInRun = false;
241+
std::vector<std::vector<Interval>> local_runs;
242+
243+
for (mdio::Index idx = offset; idx < offset + n_samples; ++idx) {
244+
if (data_ptr[idx] == descriptor.value) {
245+
if (!isInRun) {
246+
isInRun = true;
247+
std::vector<Interval> run = current_pos;
248+
for (auto& pos : run) {
249+
pos.exclusive_max = pos.inclusive_min + 1;
250+
}
251+
local_runs.push_back(std::move(run));
252+
} else {
253+
auto& run = local_runs.back();
254+
for (auto i=0; i<current_pos.size(); ++i) {
255+
run[i].exclusive_max = current_pos[i].inclusive_min + 1;
256+
}
257+
}
258+
} else {
259+
isInRun = false;
260+
}
261+
_current_position_increment<T>(current_pos, intervals);
262+
}
263+
264+
if (local_runs.empty()) {
265+
std::stringstream ss;
266+
ss << "No matches for coordinate '" << descriptor.label.label() << "'";
267+
return absl::NotFoundError(ss.str());
268+
}
269+
270+
kept_runs_ = _from_intervals<T>(local_runs);
271+
return absl::OkStatus();
272+
}
273+
274+
template <typename T>
275+
Future<void> _add_new_run(const ValueDescriptor<T>& descriptor) {
276+
return absl::UnimplementedError("Adding selection to an existing IndexSelection is not yet implemented");
277+
}
278+
216279
template <typename T>
217280
std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> _from_intervals(std::vector<std::vector<typename mdio::Variable<T>::Interval>>& intervals) {
218281
std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> ret;
@@ -241,6 +304,16 @@ class IndexSelection {
241304
position[d].inclusive_min = interval[d].inclusive_min;
242305
}
243306
}
307+
308+
template <typename T>
309+
Result<std::tuple<VariableData<T>, const T*, Index, Index>> _resolve_future(Future<VariableData<T>>& fut) {
310+
if (!fut.status().ok()) return fut.status();
311+
auto data = fut.value();
312+
const T* data_ptr = data.get_data_accessor().data();
313+
Index offset = data.get_flattened_offset();
314+
Index n_samples = data.num_samples();
315+
return std::make_tuple(std::move(data), data_ptr, offset, n_samples);
316+
}
244317
};
245318

246319
} // namespace mdio

0 commit comments

Comments
 (0)