Skip to content

Commit e97d597

Browse files
committed
compiler: Start adding estimate_memory utility
1 parent 4d951cc commit e97d597

1 file changed

Lines changed: 125 additions & 20 deletions

File tree

devito/operator/operator.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tempfile import gettempdir
88

99
from sympy import sympify
10+
from sympy import Basic as SympyBasic
1011
import numpy as np
1112

1213
from devito.arch import ANYCPU, Device, compiler_registry, platform_registry
@@ -40,6 +41,9 @@
4041
__all__ = ['Operator']
4142

4243

44+
_layers = (disk_layer, host_layer, device_layer)
45+
46+
4347
class Operator(Callable):
4448

4549
"""
@@ -597,6 +601,8 @@ def _prepare_arguments(self, autotune=None, **kwargs):
597601

598602
# Prepare to process data-carriers
599603
args = kwargs['args'] = ReducerMap()
604+
605+
# TODO: Add 'estimate-memory' here
600606
kwargs['metadata'] = {'language': self._language,
601607
'platform': self._platform,
602608
'transients': self.transients,
@@ -644,9 +650,6 @@ def _prepare_arguments(self, autotune=None, **kwargs):
644650
for i in discretizations:
645651
args.update(i._arg_values(**kwargs))
646652

647-
# TODO: Want to be able to simply stop at this stage and get
648-
# the ArgumentsMap for processing
649-
650653
# An ArgumentsMap carries additional metadata that may be used by
651654
# the subsequent phases of the arguments processing
652655
args = kwargs['args'] = ArgumentsMap(args, grid, self)
@@ -860,6 +863,38 @@ def cinterface(self, force=False):
860863
def __call__(self, **kwargs):
861864
return self.apply(**kwargs)
862865

866+
def estimate_memory(self, human_readable=True, **kwargs):
867+
"""
868+
Estimate the memory consumed by the Operator.
869+
870+
TODO: Finish this docstring
871+
"""
872+
# Build the arguments list for which to get the memory consumption
873+
# This is so that the estimate will factor in overrides
874+
args = self.arguments(**kwargs)
875+
mem = args.nbytes_avail_mapper
876+
877+
if human_readable:
878+
# TODO: Fill real values in here
879+
# TODO: Format these values to have 3 digits and suitable units
880+
881+
headline = f"Memory consumption for operator `{self.name}`:"
882+
# Table is 28 characters wide
883+
lpad = " "*((len(headline) - 28) // 2)
884+
info(
885+
"\n"
886+
f"{headline}\n"
887+
f"{lpad}┌────────┬────────┬────────┐\n"
888+
f"{lpad}│ Disk │ Host │ Device │\n"
889+
f"{lpad}├────────┼────────┼────────┤\n"
890+
f"{lpad}│ 1 │ 2 │ 3 │\n"
891+
f"{lpad}└────────┴────────┴────────┘\n"
892+
)
893+
else:
894+
info(f"{self.name} {mem[disk_layer]} {mem[host_layer]} {mem[device_layer]}")
895+
896+
from IPython import embed; embed()
897+
863898
def apply(self, **kwargs):
864899
"""
865900
Execute the Operator.
@@ -1251,6 +1286,7 @@ def nbytes_avail_mapper(self):
12511286
"""
12521287
mapper = {}
12531288

1289+
# TODO: This doesn't account for the size of the snapshots?
12541290
# The amount of space available on the disk
12551291
usage = shutil.disk_usage(gettempdir())
12561292
mapper[disk_layer] = usage.free
@@ -1267,26 +1303,95 @@ def nbytes_avail_mapper(self):
12671303
nproc = 1
12681304
mapper[host_layer] = int(ANYCPU.memavail() / nproc)
12691305

1270-
for layer, consumed in zip((host_layer, device_layer), self.nbytes_consumed):
1271-
mapper[layer] -= consumed
1306+
for layer in (host_layer, device_layer):
1307+
try:
1308+
mapper[layer] -= self.nbytes_consumed_operator[layer]
1309+
except KeyError:
1310+
continue
12721311

12731312
mapper = {k: int(v) for k, v in mapper.items()}
12741313

12751314
return mapper
12761315

12771316
# TODO: This will want some suitable tests in due course
1317+
# TODO: Might want to also check the spillover onto disk
12781318
@cached_property
12791319
def nbytes_consumed(self):
1280-
consumed_host, consumed_device = self.nbytes_consumed_heap
1281-
return consumed_host, consumed_device + self.nbytes_consumed_memmapped
1320+
"""Memory consumed by all objects in the operator"""
1321+
mem_locations = (
1322+
self.nbytes_consumed_function,
1323+
self.nbytes_consumed_array,
1324+
self.nbytes_consumed_memmapped
1325+
)
1326+
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
1327+
1328+
@cached_property
1329+
def nbytes_consumed_operator(self):
1330+
"""Memory consumed by objects allocated within the operator"""
1331+
mem_locations = (
1332+
self.nbytes_consumed_array,
1333+
self.nbytes_consumed_memmapped
1334+
)
1335+
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
1336+
1337+
@cached_property
1338+
def nbytes_consumed_function(self):
1339+
"""
1340+
Memory consumed on both device and host by Functions in the
1341+
corresponding operator.
1342+
"""
1343+
def get_nbytes(obj):
1344+
if obj.is_regular:
1345+
nbytes = obj.nbytes
1346+
else:
1347+
nbytes = obj.nbytes_max
1348+
1349+
# Could nominally have symbolic nbytes at this point
1350+
if isinstance(nbytes, SympyBasic):
1351+
return subs_op_args(nbytes, self)
1352+
else:
1353+
return nbytes
1354+
1355+
host = 0
1356+
device = 0
1357+
1358+
# Symbols in the operator which may or may not carry data
1359+
op_symbols = FindSymbols().visit(self.op)
1360+
1361+
# Filter out arrays, aliases and non-AbstractFunction objects
1362+
op_symbols = [i for i in op_symbols if i.is_AbstractFunction
1363+
and not i.is_Array and not i.alias]
1364+
1365+
for i in op_symbols:
1366+
# FIXME: Probably wrong for streamed functions
1367+
# Will overreport memory usage currently
1368+
try:
1369+
v = get_nbytes(self[i.name]._obj)
1370+
except AttributeError:
1371+
v = get_nbytes(i)
1372+
1373+
if i._mem_host:
1374+
host += v
1375+
elif i._mem_local:
1376+
if isinstance(self.platform, Device):
1377+
device += v
1378+
else:
1379+
host += v
1380+
elif i._mem_mapped:
1381+
if isinstance(self.platform, Device):
1382+
device += v
1383+
host += v
1384+
1385+
return {disk_layer: 0, host_layer: host, device_layer: device}
12821386

12831387
@cached_property
1284-
def nbytes_consumed_heap(self):
1388+
def nbytes_consumed_array(self):
12851389
"""
1286-
Memory consumed on both device and host by the corresponding operator.
1390+
Memory consumed on both device and host by C-land Arrays
1391+
in the corresponding operator.
12871392
"""
1288-
host_layer = 0
1289-
device_layer = 0
1393+
host = 0
1394+
device = 0
12901395

12911396
# Temporaries such as Arrays are allocated and deallocated on-the-fly
12921397
# while in C land, so they need to be accounted for as well
@@ -1304,26 +1409,26 @@ def nbytes_consumed_heap(self):
13041409
continue
13051410

13061411
if i._mem_host:
1307-
host_layer += v
1412+
host += v
13081413
elif i._mem_local:
13091414
if isinstance(self.platform, Device):
1310-
device_layer += v
1415+
device += v
13111416
else:
1312-
host_layer += v
1417+
host += v
13131418
elif i._mem_mapped:
13141419
if isinstance(self.platform, Device):
1315-
device_layer += v
1316-
host_layer += v
1420+
device += v
1421+
host += v
13171422

1318-
return host_layer, device_layer
1423+
return {disk_layer: 0, host_layer: host, device_layer: device}
13191424

13201425
@cached_property
13211426
def nbytes_consumed_memmapped(self):
13221427
"""
13231428
Memory also consumed on device by data which is to be memcpy-d
13241429
from host to device at the start of computation.
13251430
"""
1326-
device_layer = 0
1431+
device = 0
13271432
# All input Functions are yet to be memcpy-ed to the device
13281433
# TODO: this may not be true depending on `devicerm`, which is however
13291434
# virtually never used
@@ -1337,11 +1442,11 @@ def nbytes_consumed_memmapped(self):
13371442
v = self[i.name]._obj.nbytes
13381443
except AttributeError:
13391444
v = i.nbytes
1340-
device_layer += v
1445+
device += v
13411446
except AttributeError:
13421447
pass
13431448

1344-
return device_layer
1449+
return {disk_layer: 0, host_layer: 0, device_layer: device}
13451450

13461451

13471452
def parse_kwargs(**kwargs):

0 commit comments

Comments
 (0)