@@ -644,6 +644,9 @@ def _prepare_arguments(self, autotune=None, **kwargs):
644644 for i in discretizations :
645645 args .update (i ._arg_values (** kwargs ))
646646
647+ # TODO: Want to be able to simply stop at this stage and get
648+ # the ArgumentsMap for processing
649+
647650 # An ArgumentsMap carries additional metadata that may be used by
648651 # the subsequent phases of the arguments processing
649652 args = kwargs ['args' ] = ArgumentsMap (args , grid , self )
@@ -1264,6 +1267,27 @@ def nbytes_avail_mapper(self):
12641267 nproc = 1
12651268 mapper [host_layer ] = int (ANYCPU .memavail () / nproc )
12661269
1270+ for layer , consumed in zip ((host_layer , device_layer ), self .nbytes_consumed ):
1271+ mapper [layer ] -= consumed
1272+
1273+ mapper = {k : int (v ) for k , v in mapper .items ()}
1274+
1275+ return mapper
1276+
1277+ # TODO: This will want some suitable tests in due course
1278+ @cached_property
1279+ def nbytes_consumed (self ):
1280+ consumed_host , consumed_device = self .nbytes_consumed_heap
1281+ return consumed_host , consumed_device + self .nbytes_consumed_memmapped
1282+
1283+ @cached_property
1284+ def nbytes_consumed_heap (self ):
1285+ """
1286+ Memory consumed on both device and host by the corresponding operator.
1287+ """
1288+ host_layer = 0
1289+ device_layer = 0
1290+
12671291 # Temporaries such as Arrays are allocated and deallocated on-the-fly
12681292 # while in C land, so they need to be accounted for as well
12691293 for i in FindSymbols ().visit (self .op ):
@@ -1280,17 +1304,26 @@ def nbytes_avail_mapper(self):
12801304 continue
12811305
12821306 if i ._mem_host :
1283- mapper [ host_layer ] - = v
1307+ host_layer + = v
12841308 elif i ._mem_local :
12851309 if isinstance (self .platform , Device ):
1286- mapper [ device_layer ] - = v
1310+ device_layer + = v
12871311 else :
1288- mapper [ host_layer ] - = v
1312+ host_layer + = v
12891313 elif i ._mem_mapped :
12901314 if isinstance (self .platform , Device ):
1291- mapper [device_layer ] -= v
1292- mapper [host_layer ] -= v
1315+ device_layer += v
1316+ host_layer += v
1317+
1318+ return host_layer , device_layer
12931319
1320+ @cached_property
1321+ def nbytes_consumed_memmapped (self ):
1322+ """
1323+ Memory also consumed on device by data which is to be memcpy-d
1324+ from host to device at the start of computation.
1325+ """
1326+ device_layer = 0
12941327 # All input Functions are yet to be memcpy-ed to the device
12951328 # TODO: this may not be true depending on `devicerm`, which is however
12961329 # virtually never used
@@ -1304,13 +1337,11 @@ def nbytes_avail_mapper(self):
13041337 v = self [i .name ]._obj .nbytes
13051338 except AttributeError :
13061339 v = i .nbytes
1307- mapper [ device_layer ] - = v
1340+ device_layer + = v
13081341 except AttributeError :
13091342 pass
13101343
1311- mapper = {k : int (v ) for k , v in mapper .items ()}
1312-
1313- return mapper
1344+ return device_layer
13141345
13151346
13161347def parse_kwargs (** kwargs ):
0 commit comments