Skip to content

Commit 8e78519

Browse files
committed
Add trainingscheduler to keep input forever alive
1 parent 8f3df5a commit 8e78519

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

DeeployTest/testMVPTraining.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from testUtils.codeGenerate import generateTrainingTestNetwork
1515
from testUtils.platformMapping import mapDeployer, mapPlatform, setupMemoryPlatform
1616
from testUtils.testRunner import TestGeneratorArgumentParser
17-
from testUtils.tilingUtils import SBTiler
17+
from testUtils.tilingUtils import TrainingSBTiler
1818
from testUtils.typeMapping import inferTypeAndOffset
1919

2020
from Deeploy.AbstractDataTypes import PointerClass
@@ -207,11 +207,11 @@ def generateTiledTrainingNetwork(args) -> None:
207207
AnnotateDefaultMemoryLevel(memoryHierarchy),
208208
])
209209

210-
# 9. Wrap with tiler (SBTiler only — DBTiler conflicts with InPlaceAccumulatorV2 alias).
210+
# 9. Wrap with tiler (TrainingSBTiler: SB strategy + extended input lifetimes for backward pass).
211211
unique_params = f"{args.dumpdir}_L1{args.l1}_L2{args.l2}_{args.defaultMemLevel}"
212212
testIdentifier = hashlib.md5(unique_params.encode()).hexdigest()[:16]
213213

214-
deployer = TilerDeployerWrapper(deployer, SBTiler, testName=testIdentifier, workDir=args.dumpdir)
214+
deployer = TilerDeployerWrapper(deployer, TrainingSBTiler, testName=testIdentifier, workDir=args.dumpdir)
215215
deployer.tiler.visualizeMemoryAlloc = args.plotMemAlloc
216216
deployer.tiler.memoryAllocStrategy = args.memAllocStrategy
217217
deployer.tiler.searchStrategy = args.searchStrategy

DeeployTest/testUtils/tilingUtils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import List, Union
5+
from typing import Dict, List, Optional, Tuple, Union
66

77
from ortools.constraint_solver.pywrapcp import IntVar
88

99
from Deeploy.DeeployTypes import NetworkContext, SubGraph, TransientBuffer
10+
from Deeploy.TilingExtension.MemoryConstraints import PatternMemoryConstraints
11+
from Deeploy.TilingExtension.MemoryScheduler import MemoryScheduler
1012
from Deeploy.TilingExtension.TilerExtension import Tiler
1113
from Deeploy.TilingExtension.TilerModel import TilerModel
1214

@@ -43,3 +45,28 @@ class SBTiler(Tiler):
4345
def multiBufferStrategy(self, tilerModel: TilerModel, ctxt: NetworkContext, pattern: SubGraph, path: List[str],
4446
hop: str, tensorName: str) -> Union[int, IntVar]:
4547
return 1
48+
49+
50+
class TrainingMemoryScheduler(MemoryScheduler):
51+
"""MemoryScheduler variant for training networks.
52+
53+
Extends input tensor lifetimes to the end of the full tiling schedule so
54+
that forward-pass inputs remain live during the backward pass.
55+
"""
56+
57+
def _calculateLifetimes(
58+
self, ctxt: NetworkContext, patternMemoryConstraint: PatternMemoryConstraints,
59+
memoryLevel: str) -> Tuple[Dict[str, Tuple[int, int]], Dict]:
60+
tensorLifetimeMap, tensorMap = super()._calculateLifetimes(ctxt, patternMemoryConstraint, memoryLevel)
61+
62+
maxStepIdx = len(patternMemoryConstraint.nodeConstraints)
63+
for tensorName, lifetime in tensorLifetimeMap.items():
64+
buffer = ctxt.lookup(tensorName)
65+
if buffer.is_input:
66+
tensorLifetimeMap[tensorName] = (0, maxStepIdx)
67+
68+
return tensorLifetimeMap, tensorMap
69+
70+
71+
class TrainingSBTiler(SBTiler):
72+
memorySchedulerClass = TrainingMemoryScheduler

0 commit comments

Comments
 (0)