Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions python_bindings/src/halide/halide_/PyBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ Halide::Runtime::Buffer<T, Dims, InClassDimStorage> pybufferinfo_to_halidebuffer
return Halide::Runtime::Buffer<T, Dims, InClassDimStorage>(t, info.ptr, (int)info.ndim, dims);
}

template<typename T = void,
Comment thread
alexreinking marked this conversation as resolved.
int Dims = AnyDims,
int InClassDimStorage = (Dims == AnyDims ? 4 : std::max(Dims, 1))>
Halide::Runtime::Buffer<T, Dims, InClassDimStorage> pybuffer_to_halidebuffer(const py::buffer &pyb, bool writable, bool reverse_axes) {
return pybufferinfo_to_halidebuffer(pyb.request(writable), reverse_axes);
}

} // namespace PythonBindings
} // namespace Halide

Expand Down
39 changes: 36 additions & 3 deletions python_bindings/src/halide/halide_/PyCallable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,28 @@ T cast_to(const py::handle &h) {
}
}

std::pair<bool, bool> is_any_contiguous(const py::buffer_info &info) {
py::ssize_t c_stride = info.itemsize;
py::ssize_t f_stride = info.itemsize;
bool c_contig = true;
bool f_contig = true;

for (size_t i = 0; i < info.ndim; ++i) {
size_t c_idx = info.ndim - 1 - i;
if (info.strides[c_idx] != c_stride) {
c_contig = false;
}
c_stride *= info.shape[c_idx];

if (info.strides[i] != f_stride) {
f_contig = false;
}
f_stride *= info.shape[i];
}

return {c_contig, f_contig};
}

} // namespace

class PyCallable {
Expand Down Expand Up @@ -92,11 +114,22 @@ class PyCallable {
auto b = cast_to<Halide::Buffer<>>(value);
raw_buffer = b.raw_buffer();
} else {
py::buffer py_buffer_value = cast_to<py::buffer>(value);
const bool writable = c_arg.is_output();
const bool reverse_axes = true;

const py::buffer_info value_buffer_info = py_buffer_value.request(writable);
auto [c_contig, f_contig] = is_any_contiguous(value_buffer_info);

if (!c_contig && !f_contig) {
throw Halide::Error("Invalid buffer: only C or F contiguous buffers are supported");
}

// It is possible for a buffer to be both C and F contiguous
// (e.g., a scalar or a 1D buffer).
const bool reverse_axes = c_contig && !f_contig;
Comment thread
jiawen marked this conversation as resolved.
buffers.buffers[slot] =
pybuffer_to_halidebuffer<void, AnyDims, MaxFastDimensions>(
cast_to<py::buffer>(value), writable, reverse_axes);
pybufferinfo_to_halidebuffer<void, AnyDims, MaxFastDimensions>(
value_buffer_info, reverse_axes);
raw_buffer = buffers.buffers[slot].raw_buffer();
}
// Mark all input buffers as having a dirty host, so that the Halide call will
Expand Down
68 changes: 67 additions & 1 deletion python_bindings/test/correctness/callable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import halide as hl
import numpy as np
Comment thread
alexreinking marked this conversation as resolved.

from simplepy_generator import SimplePy
import simplecpp_pystub # noqa: F401 - needed for create_callable_from_generator("simplecpp") to work
Expand Down Expand Up @@ -169,11 +170,74 @@ def _check(offset=0):
assert False, "Did not see expected exception!"


def test_callable_buffer_conventions():
# Make a generator that echoes the extents and strides of its input buffer
# to two output buffers.
@hl.generator(name="echo_dims")
class EchoDims:
input = hl.InputBuffer(hl.Int(32), 3)
output_extents = hl.OutputBuffer(hl.Int(32), 1)
output_strides = hl.OutputBuffer(hl.Int(32), 1)

def generate(self):
g = self
d = hl.Var("d")
g.output_extents[d] = hl.mux(d, [
g.input.dim(0).extent(),
g.input.dim(1).extent(),
g.input.dim(2).extent()
])
g.output_strides[d] = hl.mux(d, [
g.input.dim(0).stride(),
g.input.dim(1).stride(),
g.input.dim(2).stride()
])

with hl.GeneratorContext(hl.Target("host-debug")):
echo_dims = EchoDims()
echo_dims_callable = echo_dims.compile_to_callable()

output_extents = hl.Buffer(hl.Int(32), [3])
output_strides = hl.Buffer(hl.Int(32), [3])

# C-contiguous input reverses dimensions.
# Note that numpy defaults to `order='C'`.
input_c = np.zeros((16, 12, 3), dtype=np.int32, order='C')
echo_dims_callable(input_c, output_extents, output_strides)
assert output_extents[0] == 3
assert output_extents[1] == 12
assert output_extents[2] == 16
assert output_strides[0] == 1
assert output_strides[1] == 3
assert output_strides[2] == 36

# F-contiguous input preserves dimensions.
input_f = np.zeros((16, 12, 3), dtype=np.int32, order='F')
echo_dims_callable(input_f, output_extents, output_strides)
assert output_extents[0] == 16
assert output_extents[1] == 12
assert output_extents[2] == 3
assert output_strides[0] == 1
assert output_strides[1] == 16
assert output_strides[2] == 192

# Non-contiguous inputs are rejected.
input_noncontig = np.zeros((16, 12, 3), dtype=np.int32)
input_noncontig = np.transpose(input_noncontig, (1, 0, 2))
try:
echo_dims_callable(input_noncontig, output_extents, output_strides)
except hl.HalideError as e:
assert "Invalid buffer: only C or F contiguous buffers are supported" in str(e)
else:
assert False, "Did not see expected exception!"


if __name__ == "__main__":
# test_callable()

def via_simplecpp_pystub(target, generator_params):
return hl.create_callable_from_generator(target, "simplecpp", generator_params)
return hl.create_callable_from_generator(target, "simplecpp",
generator_params)

def via_simplepy(target, generator_params):
with hl.GeneratorContext(target):
Expand All @@ -182,3 +246,5 @@ def via_simplepy(target, generator_params):

test_simple(via_simplecpp_pystub)
test_simple(via_simplepy)

test_callable_buffer_conventions()
Loading