Skip to content

Commit c3adb50

Browse files
thowellcopybara-github
authored andcommitted
_realloc_island
PiperOrigin-RevId: 941618362 Change-Id: Idef1deb0ff4976de42aa585993988ae3363cd1fc
1 parent 0e58c48 commit c3adb50

2 files changed

Lines changed: 112 additions & 0 deletions

File tree

python/mujoco/bindings_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,49 @@ def test_realloc_con_efc(self):
699699
self.assertEmpty(self.data.contact)
700700
self.assertEmpty(self.data.efc_id)
701701

702+
def test_realloc_island(self):
703+
# Test allocation on fresh data (on its own)
704+
nisland = 2
705+
nidof = 4
706+
mujoco._functions._realloc_island(self.data, nisland=nisland, nidof=nidof)
707+
self.assertEqual(self.data.nisland, nisland)
708+
self.assertEqual(self.data.nidof, nidof)
709+
self.assertEqual(self.data.island_nv.shape, (nisland,))
710+
self.assertEqual(self.data.ifrc_smooth.shape, (nidof,))
711+
712+
# Test allocation after _realloc_con_efc
713+
nefc = 10
714+
mujoco._functions._realloc_con_efc(self.data, ncon=0, nefc=nefc)
715+
716+
nisland = 3
717+
nidof = 5
718+
mujoco._functions._realloc_island(self.data, nisland=nisland, nidof=nidof)
719+
720+
self.assertEqual(self.data.nisland, nisland)
721+
self.assertEqual(self.data.nidof, nidof)
722+
self.assertEqual(self.data.island_nv.shape, (nisland,))
723+
self.assertEqual(self.data.ifrc_smooth.shape, (nidof,))
724+
725+
# Test re-allocation (calling it again with different sizes)
726+
nisland2 = 4
727+
nidof2 = 6
728+
mujoco._functions._realloc_island(self.data, nisland=nisland2, nidof=nidof2)
729+
730+
self.assertEqual(self.data.nisland, nisland2)
731+
self.assertEqual(self.data.nidof, nidof2)
732+
self.assertEqual(self.data.island_nv.shape, (nisland2,))
733+
self.assertEqual(self.data.ifrc_smooth.shape, (nidof2,))
734+
735+
# Test insufficient memory handling
736+
expected_error = (
737+
r'Insufficient arena memory, currently allocated memory=' +
738+
r'"[0-9]+[A-Z]?". Increase using <size memory="X"/>.'
739+
)
740+
with self.assertRaisesRegex(mujoco.FatalError, expected_error):
741+
mujoco._functions._realloc_island(self.data, 100000000, 100000000)
742+
self.assertEqual(self.data.nisland, 0)
743+
self.assertEqual(self.data.nidof, 0)
744+
702745
def test_mj_struct_list_equality(self):
703746
model2 = mujoco.MjModel.from_xml_string(TEST_XML)
704747
data2 = mujoco.MjData(model2)

python/mujoco/functions.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,7 @@ PYBIND11_MODULE(_functions, pymodule, pybind11::mod_gil_not_used()) {
16961696
#define X(type, name, nr, nc) data->name = nullptr;
16971697
MJDATA_ARENA_POINTERS_SOLVER
16981698
MJDATA_ARENA_POINTERS_DUAL
1699+
MJDATA_ARENA_POINTERS_ISLAND
16991700
#undef X
17001701
};
17011702

@@ -1744,6 +1745,74 @@ PYBIND11_MODULE(_functions, pymodule, pybind11::mod_gil_not_used()) {
17441745
},
17451746
py::arg("d"), py::arg("ncon"), py::arg("nefc"), py::arg("nJ") = -1,
17461747
py::call_guard<py::gil_scoped_release>());
1748+
1749+
pymodule.def(
1750+
"_realloc_island",
1751+
[](MjDataWrapper& d, int nisland, int nidof) {
1752+
raw::MjData* data = d.get();
1753+
1754+
size_t parena_start = data->parena;
1755+
// Find island block start in arena to reclaim memory on re-allocation.
1756+
char* min_ptr = nullptr;
1757+
#define X(type, name, nr, nc) \
1758+
if (data->name && \
1759+
(!min_ptr || reinterpret_cast<char*>(data->name) < min_ptr)) { \
1760+
min_ptr = reinterpret_cast<char*>(data->name); \
1761+
}
1762+
MJDATA_ARENA_POINTERS_ISLAND
1763+
#undef X
1764+
if (min_ptr && data->arena) {
1765+
parena_start = min_ptr - static_cast<char*>(data->arena);
1766+
}
1767+
1768+
auto cleanup = [](raw::MjData* data, size_t target_parena) {
1769+
#define X(type, name, nr, nc) data->name = nullptr;
1770+
MJDATA_ARENA_POINTERS_ISLAND
1771+
#undef X
1772+
data->nisland = 0;
1773+
data->nidof = 0;
1774+
data->parena = target_parena;
1775+
#ifdef ADDRESS_SANITIZER
1776+
ASAN_POISON_MEMORY_REGION(
1777+
static_cast<char*>(data->arena) + target_parena,
1778+
data->narena - data->pstack - target_parena);
1779+
#endif
1780+
};
1781+
1782+
cleanup(data, parena_start);
1783+
1784+
char error_msg[128];
1785+
error_msg[0] = '\0';
1786+
const char* error_msg_fmt =
1787+
"Insufficient arena memory, currently allocated memory=\"%s\". "
1788+
"Increase using <size memory=\"X\"/>.";
1789+
1790+
data->nisland = nisland;
1791+
data->nidof = nidof;
1792+
1793+
#undef MJ_M
1794+
#define MJ_M(x) d.model().get()->x
1795+
#undef MJ_D
1796+
#define MJ_D(x) data->x
1797+
#define X(type, name, nr, nc) \
1798+
data->name = static_cast<type*>(InterceptMjErrors(::mj_arenaAllocByte)( \
1799+
data, sizeof(type) * (nr) * (nc), alignof(type))); \
1800+
if (!data->name) { \
1801+
cleanup(data, parena_start); \
1802+
std::snprintf(error_msg, sizeof(error_msg), error_msg_fmt, \
1803+
mju_writeNumBytes(data->narena)); \
1804+
throw FatalError(error_msg); \
1805+
}
1806+
1807+
MJDATA_ARENA_POINTERS_ISLAND
1808+
#undef X
1809+
#undef MJ_D
1810+
#define MJ_D(x) x
1811+
#undef MJ_M
1812+
#define MJ_M(x) x
1813+
},
1814+
py::arg("d"), py::arg("nisland"), py::arg("nidof"),
1815+
py::call_guard<py::gil_scoped_release>());
17471816
} // PYBIND11_MODULE NOLINT(readability/fn_size)
17481817
} // namespace
17491818
} // namespace mujoco::python

0 commit comments

Comments
 (0)