Skip to content

Commit 5be05b1

Browse files
committed
Add save_npy function
1 parent d57ae5c commit 5be05b1

6 files changed

Lines changed: 315 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ target_sources(
134134
src/ddc/discrete_space.cpp
135135
src/ddc/discrete_element.cpp
136136
src/ddc/discrete_vector.cpp
137+
src/ddc/save_npy.cpp
137138
src/ddc/print.cpp
138139
INTERFACE
139140
FILE_SET HEADERS
@@ -171,6 +172,7 @@ target_sources(
171172
src/ddc/print.hpp
172173
src/ddc/real_type.hpp
173174
src/ddc/reducer.hpp
175+
src/ddc/save_npy.hpp
174176
src/ddc/scope_guard.hpp
175177
src/ddc/sparse_discrete_domain.hpp
176178
src/ddc/strided_discrete_domain.hpp

src/ddc/ddc.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@ namespace ddc {
8080

8181
// Output
8282
#include "print.hpp"
83+
#include "save_npy.hpp"

src/ddc/save_npy.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright (C) The DDC development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
#include <bit>
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <filesystem>
9+
#include <fstream>
10+
#include <numeric>
11+
#include <stdexcept>
12+
#include <string>
13+
#include <vector>
14+
15+
#include "save_npy.hpp"
16+
17+
namespace ddc::detail {
18+
19+
NpyByteOrder get_byte_order(std::size_t itemsize) noexcept
20+
{
21+
if (itemsize == 1) {
22+
return NpyByteOrder::not_applicable;
23+
}
24+
25+
if (std::endian::native == std::endian::little) {
26+
return NpyByteOrder::little_endian;
27+
}
28+
29+
return NpyByteOrder::big_endian;
30+
}
31+
32+
std::string NpyDtype::str() const
33+
{
34+
return std::string(1, static_cast<char>(byte_order)) + static_cast<char>(kind)
35+
+ std::to_string(itemsize);
36+
}
37+
38+
void save_npy(std::ostream& os, NpyArrayView const& view)
39+
{
40+
// Build shape string: (d0, d1, ..., dN,)
41+
std::string shape_str = "(";
42+
for (std::size_t ext : view.shape) {
43+
shape_str += std::to_string(ext);
44+
shape_str += ", ";
45+
}
46+
shape_str += ")";
47+
48+
std::string header_dict = std::string("{'descr': '") + view.dtype.str()
49+
+ "', 'fortran_order': " + (view.fortran_order ? "True" : "False")
50+
+ ", 'shape': " + shape_str + ", }";
51+
52+
// Pad header to a multiple of 64 bytes
53+
constexpr std::size_t prefix_size = 6 + 1 + 1 + 2; // magic + major + minor + hlen
54+
std::size_t const total_header = prefix_size + header_dict.size() + 1; // +1 for '\n'
55+
std::size_t const padded = ((total_header + 63) / 64) * 64;
56+
header_dict += std::string(padded - total_header, ' ');
57+
header_dict += '\n';
58+
59+
if (header_dict.size() > std::numeric_limits<std::uint16_t>::max()) {
60+
throw std::runtime_error("save_npy: header too large for npy v1.0.");
61+
}
62+
auto const hlen = static_cast<std::uint16_t>(header_dict.size());
63+
64+
// Magic + version
65+
os.write("\x93NUMPY", 6);
66+
std::uint8_t const major = 1;
67+
std::uint8_t const minor = 0;
68+
os.write(reinterpret_cast<char const*>(&major), 1);
69+
os.write(reinterpret_cast<char const*>(&minor), 1);
70+
71+
// Header length + content
72+
os.write(reinterpret_cast<char const*>(&hlen), sizeof(hlen));
73+
os.write(header_dict.data(), header_dict.size());
74+
75+
// Raw data
76+
std::size_t const n_elems
77+
= std::accumulate(view.shape.begin(), view.shape.end(), 1ULL, std::multiplies<> {});
78+
os.write(reinterpret_cast<char const*>(view.data), n_elems * view.dtype.itemsize);
79+
}
80+
81+
void save_npy(std::filesystem::path const& filename, NpyArrayView const& view)
82+
{
83+
std::ofstream file(filename, std::ios::binary);
84+
file.exceptions(std::ios::failbit | std::ios::badbit);
85+
86+
save_npy(file, view);
87+
}
88+
89+
} // namespace ddc::detail

src/ddc/save_npy.hpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Copyright (C) The DDC development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
#pragma once
6+
7+
#include <complex>
8+
#include <cstddef>
9+
#include <filesystem>
10+
#include <iosfwd>
11+
#include <string>
12+
#include <type_traits>
13+
#include <vector>
14+
15+
#include <Kokkos_Core.hpp>
16+
17+
namespace ddc::detail {
18+
19+
enum class NpyByteOrder : char { little_endian = '<', big_endian = '>', not_applicable = '|' };
20+
21+
NpyByteOrder get_byte_order(std::size_t itemsize) noexcept;
22+
23+
enum class NpyKind : char {
24+
boolean = 'b',
25+
signed_int = 'i',
26+
unsigned_int = 'u',
27+
floating_point = 'f',
28+
complex = 'c',
29+
other = 'V',
30+
};
31+
32+
struct NpyDtype
33+
{
34+
NpyByteOrder byte_order;
35+
NpyKind kind;
36+
std::size_t itemsize; // in bytes
37+
38+
std::string str() const;
39+
};
40+
41+
template <typename T>
42+
NpyDtype convert_to_npy_dtype()
43+
{
44+
std::size_t const itemsize = sizeof(T);
45+
NpyByteOrder const byte_order = get_byte_order(itemsize);
46+
NpyKind kind;
47+
48+
if constexpr (std::is_same_v<T, bool>) {
49+
// ── Single-byte / untyped ─────────────────────────────────────────
50+
kind = NpyKind::boolean;
51+
} else if constexpr (std::is_same_v<T, std::byte>) {
52+
// std::byte → raw byte buffer, no arithmetic meaning
53+
kind = NpyKind::other;
54+
} else if constexpr (std::is_same_v<T, char>) {
55+
// char is a distinct type; its signedness is implementation-defined
56+
kind = std::is_signed_v<char> ? NpyKind::signed_int : NpyKind::unsigned_int;
57+
} else if constexpr (
58+
std::is_same_v<T, std::complex<float>> || std::is_same_v<T, std::complex<double>>) {
59+
// ── Complex ───────────────────────────────────────────────────────
60+
// NumPy 'c' dtype stores interleaved real+imag, same layout as std::complex
61+
kind = NpyKind::complex;
62+
} else if constexpr (std::is_floating_point_v<T>) {
63+
// ── Floating-point ────────────────────────────────────────────────
64+
static_assert(
65+
!std::is_same_v<T, long double>,
66+
"long double is platform-specific (80/96/128-bit); cast to double first.");
67+
kind = NpyKind::floating_point;
68+
} else if constexpr (std::is_signed_v<T>) {
69+
// ── Integers ──────────────────────────────────────────────────────
70+
kind = NpyKind::signed_int;
71+
} else if constexpr (std::is_unsigned_v<T>) {
72+
kind = NpyKind::unsigned_int;
73+
} else {
74+
static_assert(sizeof(T) == 0, "Unsupported type for NpyDtype::of<T>()");
75+
}
76+
77+
return {byte_order, kind, itemsize};
78+
}
79+
80+
struct NpyArrayView
81+
{
82+
void const* data;
83+
NpyDtype dtype;
84+
std::vector<std::size_t> shape;
85+
bool fortran_order;
86+
};
87+
88+
void save_npy(std::ostream& os, NpyArrayView const& view);
89+
90+
void save_npy(std::filesystem::path const& filename, NpyArrayView const& view);
91+
92+
} // namespace ddc::detail
93+
94+
namespace ddc::experimental {
95+
96+
template <typename T, typename Extents, typename Layout, typename Accessor>
97+
void save_npy(std::ostream& os, Kokkos::mdspan<T, Extents, Layout, Accessor> const& mds)
98+
{
99+
static_assert(
100+
std::is_same_v<Layout, Kokkos::layout_left>
101+
|| std::is_same_v<Layout, Kokkos::layout_right>,
102+
"save_npy: only contiguous layouts supported.");
103+
static_assert(
104+
std::is_same_v<Accessor, Kokkos::default_accessor<T>>
105+
|| std::is_same_v<Accessor, Kokkos::default_accessor<T const>>,
106+
"save_npy: non-host-accessible accessor. Use create_mirror_view + deep_copy first.");
107+
108+
std::vector<std::size_t> shape(Extents::rank());
109+
for (std::size_t i = 0; i < Extents::rank(); ++i) {
110+
shape[i] = mds.extent(i);
111+
}
112+
113+
ddc::detail::save_npy(
114+
os,
115+
ddc::detail::NpyArrayView {
116+
mds.data_handle(),
117+
ddc::detail::convert_to_npy_dtype<std::remove_const_t<T>>(),
118+
std::move(shape),
119+
std::is_same_v<Layout, Kokkos::layout_left>,
120+
});
121+
}
122+
123+
template <typename T, typename Extents, typename Layout, typename Accessor>
124+
void save_npy(
125+
std::filesystem::path const& filename,
126+
Kokkos::mdspan<T, Extents, Layout, Accessor> const& mds)
127+
{
128+
static_assert(
129+
std::is_same_v<Layout, Kokkos::layout_left>
130+
|| std::is_same_v<Layout, Kokkos::layout_right>,
131+
"save_npy: only contiguous layouts supported.");
132+
static_assert(
133+
std::is_same_v<Accessor, Kokkos::default_accessor<T>>
134+
|| std::is_same_v<Accessor, Kokkos::default_accessor<T const>>,
135+
"save_npy: non-host-accessible accessor. Use create_mirror_view + deep_copy first.");
136+
137+
std::vector<std::size_t> shape(Extents::rank());
138+
for (std::size_t i = 0; i < Extents::rank(); ++i) {
139+
shape[i] = mds.extent(i);
140+
}
141+
142+
ddc::detail::save_npy(
143+
filename,
144+
ddc::detail::NpyArrayView {
145+
mds.data_handle(),
146+
ddc::detail::convert_to_npy_dtype<std::remove_const_t<T>>(),
147+
std::move(shape),
148+
std::is_same_v<Layout, Kokkos::layout_left>,
149+
});
150+
}
151+
152+
} // namespace ddc::experimental

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ add_executable(
4141
reducer.cpp
4242
relocatable_device_code.cpp
4343
relocatable_device_code_initialization.cpp
44+
save_npy.cpp
4445
sparse_discrete_domain.cpp
4546
strided_discrete_domain.cpp
4647
tagged_vector.cpp

tests/save_npy.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (C) The DDC development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
#include <array>
6+
#include <cmath>
7+
#include <complex>
8+
9+
#include <ddc/ddc.hpp>
10+
11+
#include <gtest/gtest.h>
12+
13+
#include <Kokkos_Core.hpp>
14+
15+
namespace {
16+
17+
std::array constexpr ns {2, 3, 4};
18+
int constexpr n = ns[0] * ns[1] * ns[2];
19+
20+
template <typename T>
21+
constexpr T make_value()
22+
{
23+
std::complex<double> base_value(2.3, 0.4);
24+
if constexpr (
25+
std::is_same_v<T, std::complex<float>> || std::is_same_v<T, std::complex<double>>) {
26+
return T(base_value);
27+
} else if constexpr (std::is_floating_point_v<T>) {
28+
return T(std::real(base_value));
29+
} else {
30+
return T(std::llround(std::real(base_value)));
31+
}
32+
}
33+
34+
} // namespace
35+
36+
template <typename T>
37+
struct SaveNpyTest : public ::testing::Test
38+
{
39+
using data_type = T;
40+
};
41+
42+
using SaveNpyTypes = ::testing::Types<
43+
float,
44+
double,
45+
std::complex<float>,
46+
std::complex<double>,
47+
char,
48+
signed char,
49+
signed short,
50+
signed int,
51+
signed long,
52+
signed long long,
53+
unsigned char,
54+
unsigned short,
55+
unsigned int,
56+
unsigned long,
57+
unsigned long long>;
58+
59+
TYPED_TEST_SUITE(SaveNpyTest, SaveNpyTypes);
60+
61+
TYPED_TEST(SaveNpyTest, SaveNpy)
62+
{
63+
using data_type = typename TestFixture::data_type;
64+
std::string const label(typeid(data_type).name());
65+
Kokkos::View<data_type*, Kokkos::HostSpace> alloc(label, n);
66+
Kokkos::deep_copy(alloc, make_value<data_type>());
67+
Kokkos::mdspan view(alloc.data(), ns);
68+
69+
ddc::experimental::save_npy("test.npy", view);
70+
}

0 commit comments

Comments
 (0)