77from tempfile import gettempdir
88
99from sympy import sympify
10+ from sympy import Basic as SympyBasic
1011import numpy as np
1112
1213from devito .arch import ANYCPU , Device , compiler_registry , platform_registry
4041__all__ = ['Operator' ]
4142
4243
44+ _layers = (disk_layer , host_layer , device_layer )
45+
46+
4347class Operator (Callable ):
4448
4549 """
@@ -597,6 +601,8 @@ def _prepare_arguments(self, autotune=None, **kwargs):
597601
598602 # Prepare to process data-carriers
599603 args = kwargs ['args' ] = ReducerMap ()
604+
605+ # TODO: Add 'estimate-memory' here
600606 kwargs ['metadata' ] = {'language' : self ._language ,
601607 'platform' : self ._platform ,
602608 'transients' : self .transients ,
@@ -644,9 +650,6 @@ def _prepare_arguments(self, autotune=None, **kwargs):
644650 for i in discretizations :
645651 args .update (i ._arg_values (** kwargs ))
646652
647- # TODO: Want to be able to simply stop at this stage and get
648- # the ArgumentsMap for processing
649-
650653 # An ArgumentsMap carries additional metadata that may be used by
651654 # the subsequent phases of the arguments processing
652655 args = kwargs ['args' ] = ArgumentsMap (args , grid , self )
@@ -860,6 +863,38 @@ def cinterface(self, force=False):
860863 def __call__ (self , ** kwargs ):
861864 return self .apply (** kwargs )
862865
866+ def estimate_memory (self , human_readable = True , ** kwargs ):
867+ """
868+ Estimate the memory consumed by the Operator.
869+
870+ TODO: Finish this docstring
871+ """
872+ # Build the arguments list for which to get the memory consumption
873+ # This is so that the estimate will factor in overrides
874+ args = self .arguments (** kwargs )
875+ mem = args .nbytes_avail_mapper
876+
877+ if human_readable :
878+ # TODO: Fill real values in here
879+ # TODO: Format these values to have 3 digits and suitable units
880+
881+ headline = f"Memory consumption for operator `{ self .name } `:"
882+ # Table is 28 characters wide
883+ lpad = " " * ((len (headline ) - 28 ) // 2 )
884+ info (
885+ "\n "
886+ f"{ headline } \n "
887+ f"{ lpad } ┌────────┬────────┬────────┐\n "
888+ f"{ lpad } │ Disk │ Host │ Device │\n "
889+ f"{ lpad } ├────────┼────────┼────────┤\n "
890+ f"{ lpad } │ 1 │ 2 │ 3 │\n "
891+ f"{ lpad } └────────┴────────┴────────┘\n "
892+ )
893+ else :
894+ info (f"{ self .name } { mem [disk_layer ]} { mem [host_layer ]} { mem [device_layer ]} " )
895+
896+ from IPython import embed ; embed ()
897+
863898 def apply (self , ** kwargs ):
864899 """
865900 Execute the Operator.
@@ -1251,6 +1286,7 @@ def nbytes_avail_mapper(self):
12511286 """
12521287 mapper = {}
12531288
1289+ # TODO: This doesn't account for the size of the snapshots?
12541290 # The amount of space available on the disk
12551291 usage = shutil .disk_usage (gettempdir ())
12561292 mapper [disk_layer ] = usage .free
@@ -1267,26 +1303,95 @@ def nbytes_avail_mapper(self):
12671303 nproc = 1
12681304 mapper [host_layer ] = int (ANYCPU .memavail () / nproc )
12691305
1270- for layer , consumed in zip ((host_layer , device_layer ), self .nbytes_consumed ):
1271- mapper [layer ] -= consumed
1306+ for layer in (host_layer , device_layer ):
1307+ try :
1308+ mapper [layer ] -= self .nbytes_consumed_operator [layer ]
1309+ except KeyError :
1310+ continue
12721311
12731312 mapper = {k : int (v ) for k , v in mapper .items ()}
12741313
12751314 return mapper
12761315
12771316 # TODO: This will want some suitable tests in due course
1317+ # TODO: Might want to also check the spillover onto disk
12781318 @cached_property
12791319 def nbytes_consumed (self ):
1280- consumed_host , consumed_device = self .nbytes_consumed_heap
1281- return consumed_host , consumed_device + self .nbytes_consumed_memmapped
1320+ """Memory consumed by all objects in the operator"""
1321+ mem_locations = (
1322+ self .nbytes_consumed_function ,
1323+ self .nbytes_consumed_array ,
1324+ self .nbytes_consumed_memmapped
1325+ )
1326+ return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
1327+
1328+ @cached_property
1329+ def nbytes_consumed_operator (self ):
1330+ """Memory consumed by objects allocated within the operator"""
1331+ mem_locations = (
1332+ self .nbytes_consumed_array ,
1333+ self .nbytes_consumed_memmapped
1334+ )
1335+ return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
1336+
1337+ @cached_property
1338+ def nbytes_consumed_function (self ):
1339+ """
1340+ Memory consumed on both device and host by Functions in the
1341+ corresponding operator.
1342+ """
1343+ def get_nbytes (obj ):
1344+ if obj .is_regular :
1345+ nbytes = obj .nbytes
1346+ else :
1347+ nbytes = obj .nbytes_max
1348+
1349+ # Could nominally have symbolic nbytes at this point
1350+ if isinstance (nbytes , SympyBasic ):
1351+ return subs_op_args (nbytes , self )
1352+ else :
1353+ return nbytes
1354+
1355+ host = 0
1356+ device = 0
1357+
1358+ # Symbols in the operator which may or may not carry data
1359+ op_symbols = FindSymbols ().visit (self .op )
1360+
1361+ # Filter out arrays, aliases and non-AbstractFunction objects
1362+ op_symbols = [i for i in op_symbols if i .is_AbstractFunction
1363+ and not i .is_Array and not i .alias ]
1364+
1365+ for i in op_symbols :
1366+ # FIXME: Probably wrong for streamed functions
1367+ # Will overreport memory usage currently
1368+ try :
1369+ v = get_nbytes (self [i .name ]._obj )
1370+ except AttributeError :
1371+ v = get_nbytes (i )
1372+
1373+ if i ._mem_host :
1374+ host += v
1375+ elif i ._mem_local :
1376+ if isinstance (self .platform , Device ):
1377+ device += v
1378+ else :
1379+ host += v
1380+ elif i ._mem_mapped :
1381+ if isinstance (self .platform , Device ):
1382+ device += v
1383+ host += v
1384+
1385+ return {disk_layer : 0 , host_layer : host , device_layer : device }
12821386
12831387 @cached_property
1284- def nbytes_consumed_heap (self ):
1388+ def nbytes_consumed_array (self ):
12851389 """
1286- Memory consumed on both device and host by the corresponding operator.
1390+ Memory consumed on both device and host by C-land Arrays
1391+ in the corresponding operator.
12871392 """
1288- host_layer = 0
1289- device_layer = 0
1393+ host = 0
1394+ device = 0
12901395
12911396 # Temporaries such as Arrays are allocated and deallocated on-the-fly
12921397 # while in C land, so they need to be accounted for as well
@@ -1304,26 +1409,26 @@ def nbytes_consumed_heap(self):
13041409 continue
13051410
13061411 if i ._mem_host :
1307- host_layer += v
1412+ host += v
13081413 elif i ._mem_local :
13091414 if isinstance (self .platform , Device ):
1310- device_layer += v
1415+ device += v
13111416 else :
1312- host_layer += v
1417+ host += v
13131418 elif i ._mem_mapped :
13141419 if isinstance (self .platform , Device ):
1315- device_layer += v
1316- host_layer += v
1420+ device += v
1421+ host += v
13171422
1318- return host_layer , device_layer
1423+ return { disk_layer : 0 , host_layer : host , device_layer : device }
13191424
13201425 @cached_property
13211426 def nbytes_consumed_memmapped (self ):
13221427 """
13231428 Memory also consumed on device by data which is to be memcpy-d
13241429 from host to device at the start of computation.
13251430 """
1326- device_layer = 0
1431+ device = 0
13271432 # All input Functions are yet to be memcpy-ed to the device
13281433 # TODO: this may not be true depending on `devicerm`, which is however
13291434 # virtually never used
@@ -1337,11 +1442,11 @@ def nbytes_consumed_memmapped(self):
13371442 v = self [i .name ]._obj .nbytes
13381443 except AttributeError :
13391444 v = i .nbytes
1340- device_layer += v
1445+ device += v
13411446 except AttributeError :
13421447 pass
13431448
1344- return device_layer
1449+ return { disk_layer : 0 , host_layer : 0 , device_layer : device }
13451450
13461451
13471452def parse_kwargs (** kwargs ):
0 commit comments