22#include < pybind11/pybind11.h>
33
44#include < cstdint>
5- #include < stdexcept>
65
76#include " mandelbrot.h"
87
1110
1211namespace py = pybind11;
1312
14- namespace {
15-
16- using Image = py::array_t <std::int32_t , py::array::c_style>;
17-
18- // Shared wrapper: validate the arguments, allocate the output image, and run one
19- // of the compute functions with the GIL released. Returns a (height, width)
20- // NumPy array of escape counts.
21- Image render (void (*compute)(const MandelbrotParams &, std::int32_t *), int width, int height,
22- int max_iterations, double xmin, double xmax, double ymin, double ymax) {
23- if (width <= 0 || height <= 0 ) {
24- throw std::invalid_argument (" width and height must be positive" );
25- }
26- if (max_iterations <= 0 ) {
27- throw std::invalid_argument (" max_iterations must be positive" );
28- }
29-
30- const MandelbrotParams params{width, height, max_iterations, xmin, xmax, ymin, ymax};
31- Image image ({height, width});
32- std::int32_t *data = image.mutable_data ();
33-
34- {
35- py::gil_scoped_release release;
36- compute (params, data);
37- }
38- return image;
39- }
40-
41- } // namespace
42-
4313PYBIND11_MODULE (_core, m, py::mod_gil_not_used(), py::multiple_interpreters::per_interpreter_gil()) {
4414 m.doc () = R"pbdoc(
4515 Pybind11 + CUDA Mandelbrot example
@@ -55,35 +25,38 @@ PYBIND11_MODULE(_core, m, py::mod_gil_not_used(), py::multiple_interpreters::per
5525 cuda_available
5626 )pbdoc" ;
5727
58- const char *doc = R"pbdoc(
59- Render the Mandelbrot set.
60-
61- Returns a ``(height, width)`` int32 NumPy array; each value is the number
62- of iterations before the point escaped (``max_iterations`` if it never
63- did).
64- )pbdoc" ;
65-
66- m.def (" mandelbrot_cpu" ,
67- [](int width, int height, int max_iterations, double xmin, double xmax, double ymin,
68- double ymax) {
69- return render (&mandelbrot_cpu, width, height, max_iterations, xmin, xmax, ymin, ymax);
70- },
71- py::arg (" width" ) = 800 , py::arg (" height" ) = 600 , py::arg (" max_iterations" ) = 100 ,
72- py::arg (" xmin" ) = -2.0 , py::arg (" xmax" ) = 1.0 , py::arg (" ymin" ) = -1.25 ,
73- py::arg (" ymax" ) = 1.25 , doc);
74-
75- m.def (" mandelbrot_gpu" ,
76- [](int width, int height, int max_iterations, double xmin, double xmax, double ymin,
77- double ymax) {
78- return render (&mandelbrot_gpu, width, height, max_iterations, xmin, xmax, ymin, ymax);
79- },
80- py::arg (" width" ) = 800 , py::arg (" height" ) = 600 , py::arg (" max_iterations" ) = 100 ,
81- py::arg (" xmin" ) = -2.0 , py::arg (" xmax" ) = 1.0 , py::arg (" ymin" ) = -1.25 ,
82- py::arg (" ymax" ) = 1.25 , doc);
83-
84- m.def (" cuda_available" , &cuda_available, R"pbdoc(
85- Return True if a CUDA-capable device is available at runtime.
86- )pbdoc" );
28+ m.def (
29+ " mandelbrot_cpu" ,
30+ [](int width, int height, int max_iterations) {
31+ // Allocate the (height, width) output image and fill it on the CPU,
32+ // releasing the GIL while the C++ code runs.
33+ py::array_t <std::int32_t > image ({height, width});
34+ std::int32_t *data = image.mutable_data ();
35+ {
36+ py::gil_scoped_release release;
37+ mandelbrot_cpu (width, height, max_iterations, data);
38+ }
39+ return image;
40+ },
41+ py::arg (" width" ) = 800 , py::arg (" height" ) = 600 , py::arg (" max_iterations" ) = 100 ,
42+ " Render the Mandelbrot set on the CPU, returning a (height, width) int32 array." );
43+
44+ m.def (
45+ " mandelbrot_gpu" ,
46+ [](int width, int height, int max_iterations) {
47+ py::array_t <std::int32_t > image ({height, width});
48+ std::int32_t *data = image.mutable_data ();
49+ {
50+ py::gil_scoped_release release;
51+ mandelbrot_gpu (width, height, max_iterations, data);
52+ }
53+ return image;
54+ },
55+ py::arg (" width" ) = 800 , py::arg (" height" ) = 600 , py::arg (" max_iterations" ) = 100 ,
56+ " Render the Mandelbrot set on the GPU, returning a (height, width) int32 array." );
57+
58+ m.def (" cuda_available" , &cuda_available,
59+ " Return True if a CUDA-capable device is available at runtime." );
8760
8861#ifdef VERSION_INFO
8962 m.attr (" __version__" ) = MACRO_STRINGIFY (VERSION_INFO );
0 commit comments