Skip to content

Commit 4d951cc

Browse files
committed
compiler: Refactor nbytes_avail_mapper to enable future code reuse
1 parent 98a14a4 commit 4d951cc

1 file changed

Lines changed: 40 additions & 9 deletions

File tree

devito/operator/operator.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,9 @@ def _prepare_arguments(self, autotune=None, **kwargs):
644644
for i in discretizations:
645645
args.update(i._arg_values(**kwargs))
646646

647+
# TODO: Want to be able to simply stop at this stage and get
648+
# the ArgumentsMap for processing
649+
647650
# An ArgumentsMap carries additional metadata that may be used by
648651
# the subsequent phases of the arguments processing
649652
args = kwargs['args'] = ArgumentsMap(args, grid, self)
@@ -1264,6 +1267,27 @@ def nbytes_avail_mapper(self):
12641267
nproc = 1
12651268
mapper[host_layer] = int(ANYCPU.memavail() / nproc)
12661269

1270+
for layer, consumed in zip((host_layer, device_layer), self.nbytes_consumed):
1271+
mapper[layer] -= consumed
1272+
1273+
mapper = {k: int(v) for k, v in mapper.items()}
1274+
1275+
return mapper
1276+
1277+
# TODO: This will want some suitable tests in due course
1278+
@cached_property
1279+
def nbytes_consumed(self):
1280+
consumed_host, consumed_device = self.nbytes_consumed_heap
1281+
return consumed_host, consumed_device + self.nbytes_consumed_memmapped
1282+
1283+
@cached_property
1284+
def nbytes_consumed_heap(self):
1285+
"""
1286+
Memory consumed on both device and host by the corresponding operator.
1287+
"""
1288+
host_layer = 0
1289+
device_layer = 0
1290+
12671291
# Temporaries such as Arrays are allocated and deallocated on-the-fly
12681292
# while in C land, so they need to be accounted for as well
12691293
for i in FindSymbols().visit(self.op):
@@ -1280,17 +1304,26 @@ def nbytes_avail_mapper(self):
12801304
continue
12811305

12821306
if i._mem_host:
1283-
mapper[host_layer] -= v
1307+
host_layer += v
12841308
elif i._mem_local:
12851309
if isinstance(self.platform, Device):
1286-
mapper[device_layer] -= v
1310+
device_layer += v
12871311
else:
1288-
mapper[host_layer] -= v
1312+
host_layer += v
12891313
elif i._mem_mapped:
12901314
if isinstance(self.platform, Device):
1291-
mapper[device_layer] -= v
1292-
mapper[host_layer] -= v
1315+
device_layer += v
1316+
host_layer += v
1317+
1318+
return host_layer, device_layer
12931319

1320+
@cached_property
1321+
def nbytes_consumed_memmapped(self):
1322+
"""
1323+
Memory also consumed on device by data which is to be memcpy-d
1324+
from host to device at the start of computation.
1325+
"""
1326+
device_layer = 0
12941327
# All input Functions are yet to be memcpy-ed to the device
12951328
# TODO: this may not be true depending on `devicerm`, which is however
12961329
# virtually never used
@@ -1304,13 +1337,11 @@ def nbytes_avail_mapper(self):
13041337
v = self[i.name]._obj.nbytes
13051338
except AttributeError:
13061339
v = i.nbytes
1307-
mapper[device_layer] -= v
1340+
device_layer += v
13081341
except AttributeError:
13091342
pass
13101343

1311-
mapper = {k: int(v) for k, v in mapper.items()}
1312-
1313-
return mapper
1344+
return device_layer
13141345

13151346

13161347
def parse_kwargs(**kwargs):

0 commit comments

Comments
 (0)