Skip to content

Commit 1ce7118

Browse files
authored
Fix ref leak in mx.save/load with file like object (#3187)
1 parent 72e04f7 commit 1ce7118

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

python/src/load.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,12 @@ class PyFileReader : public mx::io::Reader {
162162

163163
private:
164164
void _read(char* data, size_t n) {
165-
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
166-
nb::object bytes_read = readinto_func_(nb::handle(memview));
165+
nb::object memview =
166+
nb::steal<nb::object>(PyMemoryView_FromMemory(data, n, PyBUF_WRITE));
167+
if (!memview.is_valid()) {
168+
throw std::runtime_error("[load] Failed to create memoryview for read");
169+
}
170+
nb::object bytes_read = readinto_func_(memview);
167171

168172
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
169173
throw std::runtime_error("[load] Failed to read from python stream");
@@ -374,9 +378,12 @@ class PyFileWriter : public mx::io::Writer {
374378
void write(const char* data, size_t n) override {
375379
nb::gil_scoped_acquire gil;
376380

377-
auto memview =
378-
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
379-
nb::object bytes_written = write_func_(nb::handle(memview));
381+
nb::object memview = nb::steal<nb::object>(
382+
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ));
383+
if (!memview.is_valid()) {
384+
throw std::runtime_error("[load] Failed to create memoryview for write");
385+
}
386+
nb::object bytes_written = write_func_(memview);
380387

381388
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
382389
throw std::runtime_error("[load] Failed to write to python stream");

0 commit comments

Comments
 (0)