Skip to content

Commit fc8ea3f

Browse files
committed
refactor: use SkipTransformer with pointer assignment for Reshape, avoiding unnecessary DMA and memcpy
1 parent 4865516 commit fc8ea3f

2 files changed

Lines changed: 18 additions & 21 deletions

File tree

Deeploy/Targets/Snitch/Bindings.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from Deeploy.AbstractDataTypes import PointerClass
88
from Deeploy.CommonExtensions.CodeTransformationPasses.Closure import ClosureGeneration, MemoryAwareClosureGeneration
99
from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration, \
10-
MemoryManagementGeneration
10+
MemoryManagementGeneration, MemoryPassthroughGeneration
1111
from Deeploy.CommonExtensions.DataTypes import float32_t, int8_t, int32_t, uint8_t
1212
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
1313
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
@@ -43,6 +43,13 @@
4343
MemoryManagementGeneration(),
4444
FutureGeneration()])
4545

46+
SkipTransformer = CodeTransformation(
47+
[SnitchSynchCoresPass(),
48+
ArgumentStructGeneration(),
49+
MemoryPassthroughGeneration("L.*"),
50+
MemoryPassthroughGeneration(),
51+
FutureGeneration()])
52+
4653
TiledTransformer = CodeTransformation([
4754
SnitchCoreFilterPass("compute"),
4855
TilingVariableReplacement("L1"),
@@ -184,10 +191,10 @@
184191
TransposeTemplate.referenceTemplate, BasicTransformer)
185192
]
186193

187-
# Reshape Bindings (Tiled)
194+
# Reshape Bindings (pointer passthrough, no DMA needed)
188195
SnitchReshapeBindings = [
189196
NodeBinding(ReshapeChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), ReshapeTemplate.referenceTemplate,
190-
TiledTransformer)
197+
SkipTransformer)
191198
]
192199

193200
# Gather Bindings (Tiled)

Deeploy/Targets/Snitch/Templates/ReshapeTemplate.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
from typing import Dict, List, Tuple
66

7-
import numpy as np
8-
97
from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation, VariableBuffer
108
from Deeploy.Targets.Generic.Templates.ReshapeTemplate import _ReshapeTemplate
119

@@ -17,28 +15,20 @@ def alignToContext(self, ctxt: NetworkContext,
1715

1816
ctxt, operatorRepresentation, _ = super().alignToContext(ctxt, operatorRepresentation)
1917

20-
# Calculate size for multi-core parallel copy
2118
bufferIn = ctxt.lookup(operatorRepresentation['data_in'])
2219
assert isinstance(bufferIn, VariableBuffer)
23-
operatorRepresentation['size'] = int(np.prod(bufferIn.shape))
20+
bufferOut = ctxt.lookup(operatorRepresentation['data_out'])
21+
assert isinstance(bufferOut, VariableBuffer)
22+
23+
# Set alias so input and output share the same memory
24+
bufferOut._alias = bufferIn.name
2425

2526
return ctxt, operatorRepresentation, []
2627

2728

28-
# Reshape uses multi-core parallel copy
29-
# When aliases work (internal nodes), this copies between same memory (no-op effect)
30-
# When aliases don't work (global I/O), this copies data correctly
29+
# Reshape only reinterprets tensor shape without modifying data.
30+
# Uses SkipTransformer (no DMA), consistent with PULPOpen.
3131
referenceTemplate = _SnitchReshapeTemplate("""
3232
// Reshape (Name: ${nodeName}, Op: ${nodeOp})
33-
{
34-
uint32_t core_id = snrt_cluster_core_idx();
35-
uint32_t num_cores = snrt_cluster_compute_core_num();
36-
uint32_t total = ${size};
37-
uint32_t chunk = total / num_cores;
38-
uint32_t start = core_id * chunk;
39-
uint32_t end = (core_id == num_cores - 1) ? total : start + chunk;
40-
for (uint32_t i = start; i < end; i++) {
41-
${data_out}[i] = ${data_in}[i];
42-
}
43-
}
33+
${data_out} = ${data_in};
4434
""")

0 commit comments

Comments
 (0)