Skip to content

Commit 75b82e2

Browse files
Merge pull request #588 from SKaiNET-developers/feature/dsl-lazy-zero-init
feat(dsl): lazy zero-init for parameter placeholders
2 parents 62cf5ce + 6eda5d2 commit 75b82e2

6 files changed

Lines changed: 266 additions & 10 deletions

File tree

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/context/ExecutionContext.kt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ public interface ExecutionContext {
6464
return fromData(data, dtype)
6565
}
6666

67+
/**
68+
* Lazy-initialized zero tensor — see [TensorDataFactory.placeholder].
69+
* The underlying primitive array allocates on first read; if the parameter
70+
* is replaced before any read (the common case for DSL modules whose weights
71+
* are loaded from disk), the allocation is skipped entirely.
72+
*/
73+
public fun <T : DType, V> placeholder(
74+
shape: Shape,
75+
dtype: KClass<T>
76+
): Tensor<T, V> {
77+
val data = tensorDataFactory.placeholder<T, V>(shape, dtype)
78+
return fromData(data, dtype)
79+
}
80+
6781
public fun <T : DType, V> ones(
6882
shape: Shape,
6983
dtype: KClass<T>

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ private fun <T : DType, V> createLinear(
621621

622622
myInitWeights == null && myInitBias != null -> {
623623

624-
val safeWeights = executionContext.tensorDataFactory.zeros<T, V>(Shape(outFeatures, inFeatures), kClass)
624+
val safeWeights = executionContext.tensorDataFactory.placeholder<T, V>(Shape(outFeatures, inFeatures), kClass)
625625
val initW = executionContext.fromData(safeWeights, kClass)
626626

627627
Linear(
@@ -635,7 +635,7 @@ private fun <T : DType, V> createLinear(
635635
}
636636

637637
myInitWeights != null && myInitBias == null -> {
638-
val safeBias = executionContext.tensorDataFactory.zeros<T, V>(Shape(outFeatures), kClass)
638+
val safeBias = executionContext.tensorDataFactory.placeholder<T, V>(Shape(outFeatures), kClass)
639639
val initB = executionContext.fromData(safeBias, kClass)
640640

641641
Linear(
@@ -649,8 +649,8 @@ private fun <T : DType, V> createLinear(
649649
}
650650

651651
else -> {
652-
val safeWeights = executionContext.tensorDataFactory.zeros<T, V>(Shape(outFeatures, inFeatures), kClass)
653-
val safeBias = executionContext.tensorDataFactory.zeros<T, V>(Shape(outFeatures), kClass)
652+
val safeWeights = executionContext.tensorDataFactory.placeholder<T, V>(Shape(outFeatures, inFeatures), kClass)
653+
val safeBias = executionContext.tensorDataFactory.placeholder<T, V>(Shape(outFeatures), kClass)
654654
val initW = executionContext.fromData(safeWeights, kClass)
655655
val initB = executionContext.fromData(safeBias, kClass)
656656

@@ -792,10 +792,10 @@ public class Conv2dImpl<T : DType, V>(
792792
require(inChannels > 0) { "Conv2d inChannels must be > 0 (set explicitly if not inferred)." }
793793

794794
// Create default tensors if not provided
795-
val weights = weightsValue ?: executionContext.zeros(weightsShape, kClass)
795+
val weights = weightsValue ?: executionContext.placeholder(weightsShape, kClass)
796796

797797
val biasParam = if (bias) {
798-
biasValue ?: executionContext.zeros(biasShape, kClass)
798+
biasValue ?: executionContext.placeholder(biasShape, kClass)
799799
} else null
800800

801801
return Conv2d(
@@ -921,8 +921,8 @@ public class Conv1dImpl<T : DType, V>(
921921
require(kernelSize > 0) { "Conv1d kernelSize must be > 0." }
922922
require(inChannels > 0) { "Conv1d inChannels must be > 0." }
923923

924-
val weights = weightsValue ?: executionContext.zeros(weightsShape, kClass)
925-
val biasParam = if (bias) biasValue ?: executionContext.zeros(biasShape, kClass) else null
924+
val weights = weightsValue ?: executionContext.placeholder(weightsShape, kClass)
925+
val biasParam = if (bias) biasValue ?: executionContext.placeholder(biasShape, kClass) else null
926926

927927
return Conv1d(
928928
inChannels = inChannels,
@@ -993,8 +993,8 @@ public class Conv3dImpl<T : DType, V>(
993993
require(kernelSize.first > 0 && kernelSize.second > 0 && kernelSize.third > 0) { "Conv3d kernelSize must be > 0." }
994994
require(inChannels > 0) { "Conv3d inChannels must be > 0." }
995995

996-
val weights = weightsValue ?: executionContext.zeros(weightsShape, kClass)
997-
val biasParam = if (bias) biasValue ?: executionContext.zeros(biasShape, kClass) else null
996+
val weights = weightsValue ?: executionContext.placeholder(weightsShape, kClass)
997+
val biasParam = if (bias) biasValue ?: executionContext.placeholder(biasShape, kClass) else null
998998

999999
return Conv3d(
10001000
inChannels = inChannels,

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/DenseTensorDataFactory.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,22 @@ public class DenseTensorDataFactory: TensorDataFactory {
358358
}
359359
}
360360

361+
/**
362+
* Returns a [LazyZeroFloatArrayTensorData] / [LazyZeroIntArrayTensorData] for FP32 /
363+
* FP16 / Int32. The underlying primitive array materializes only on the first
364+
* `get`/`set`/`buffer` access. For Int8 (byte-backed) we currently fall back to
365+
* [zeros]; the eager byte allocation is rarely the dominant cost on real models.
366+
*/
367+
override fun <T : DType, V> placeholder(shape: Shape, dtype: KClass<T>): TensorData<T, V> {
368+
@Suppress("UNCHECKED_CAST")
369+
return when (dtype) {
370+
FP32::class -> LazyZeroFloatArrayTensorData<T>(shape) as TensorData<T, V>
371+
FP16::class -> LazyZeroFloatArrayTensorData<T>(shape) as TensorData<T, V>
372+
Int32::class -> LazyZeroIntArrayTensorData<T>(shape) as TensorData<T, V>
373+
else -> zeros(shape, dtype)
374+
}
375+
}
376+
361377
override fun <T : DType, V> ones(shape: Shape, dtype: KClass<T>): TensorData<T, V> {
362378
@Suppress("UNCHECKED_CAST")
363379
return when (dtype) {
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package sk.ainet.lang.tensor.data
2+
3+
import sk.ainet.lang.tensor.Shape
4+
import sk.ainet.lang.tensor.storage.ActiveMemoryTracker
5+
import sk.ainet.lang.types.DType
6+
7+
/**
8+
* Zero-allocation [FloatArrayTensorData] whose underlying [FloatArray] materializes
9+
* lazily on first read.
10+
*
11+
* Use when a parameter tensor is going to be replaced before any forward / backward
12+
* pass — e.g. immediately after the DSL builds a `Linear`/`Embedding`/`Conv` module
13+
* the loader's `WeightMapper.applyWeights` substitutes the entire `Tensor` via
14+
* `parameter.value = loadedTensor`. The placeholder is then GC'd before its lazy
15+
* fires, eliminating the eager `FloatArray(shape.volume)` cost.
16+
*
17+
* Behavior is identical to [DenseFloatArrayTensorData] backed by a zero-filled
18+
* `FloatArray` for any consumer that doesn't substitute first — the lazy
19+
* materializes to zeros on the first `get`/`set`/`buffer` access and is then
20+
* cached, so repeated reads return the same values that an eager zero allocation
21+
* would have produced.
22+
*/
23+
public class LazyZeroFloatArrayTensorData<T : DType>(
24+
initialShape: Shape
25+
) : FloatArrayTensorData<T> {
26+
override val shape: Shape = Shape(initialShape.dimensions.copyOf())
27+
private val strides: IntArray = this.shape.computeStrides()
28+
29+
private val backing: FloatArray by lazy {
30+
ActiveMemoryTracker.recordCopy(
31+
"LazyZeroFloatArrayTensorData.materialize",
32+
shape.volume.toLong() * 4
33+
)
34+
FloatArray(shape.volume)
35+
}
36+
37+
override val buffer: FloatArray
38+
get() = backing
39+
40+
override fun get(vararg indices: Int): Float =
41+
backing[calcFlatIndex(shape, strides, indices)]
42+
43+
override fun set(vararg indices: Int, value: Float) {
44+
backing[calcFlatIndex(shape, strides, indices)] = value
45+
}
46+
}
47+
48+
/**
49+
* Zero-allocation [IntArrayTensorData] whose backing [IntArray] materializes
50+
* lazily on first read. See [LazyZeroFloatArrayTensorData].
51+
*/
52+
public class LazyZeroIntArrayTensorData<T : DType>(
53+
initialShape: Shape
54+
) : IntArrayTensorData<T> {
55+
override val shape: Shape = Shape(initialShape.dimensions.copyOf())
56+
private val strides: IntArray = this.shape.computeStrides()
57+
58+
private val backing: IntArray by lazy {
59+
ActiveMemoryTracker.recordCopy(
60+
"LazyZeroIntArrayTensorData.materialize",
61+
shape.volume.toLong() * 4
62+
)
63+
IntArray(shape.volume)
64+
}
65+
66+
override val buffer: IntArray
67+
get() = backing
68+
69+
override fun get(vararg indices: Int): Int =
70+
backing[calcFlatIndex(shape, strides, indices)]
71+
72+
override fun set(vararg indices: Int, value: Int) {
73+
backing[calcFlatIndex(shape, strides, indices)] = value
74+
}
75+
}
76+
77+
private fun calcFlatIndex(shape: Shape, strides: IntArray, indices: IntArray): Int {
78+
require(indices.size == shape.dimensions.size) {
79+
"Number of indices (${indices.size}) must match tensor dimensions (${shape.dimensions.size})"
80+
}
81+
var flat = 0
82+
for (i in indices.indices) {
83+
val idx = indices[i]
84+
require(idx >= 0 && idx < shape.dimensions[i]) {
85+
"Index $idx out of bounds for dimension $i with size ${shape.dimensions[i]}"
86+
}
87+
flat += idx * strides[i]
88+
}
89+
return flat
90+
}

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/TensorDataFactory.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,27 @@ import kotlin.reflect.KClass
1111
*/
1212
public interface TensorDataFactory {
1313
public fun <T : DType, V> zeros(shape: Shape, dtype: KClass<T>): TensorData<T, V>
14+
15+
/**
16+
* Allocates a zero-filled tensor whose underlying storage materializes lazily
17+
* on first read.
18+
*
19+
* Behavior is identical to [zeros] for any caller that reads the tensor — a
20+
* fresh zero buffer is produced on first access and cached for subsequent
21+
* reads. The benefit is for callers that **never** read the tensor before
22+
* replacing it, which is the common case in DSL-built modules whose
23+
* parameters get substituted by a downstream weight loader (e.g.
24+
* `WeightMapper.applyWeights` sets `parameter.value = loadedTensor`). For
25+
* those callers, the `FloatArray(shape.volume)` allocation never happens.
26+
*
27+
* The default implementation falls back to [zeros], preserving existing
28+
* behavior for any custom factory that does not opt in. Implementations
29+
* that have a meaningful lazy form (e.g. [DenseTensorDataFactory]) should
30+
* override.
31+
*/
32+
public fun <T : DType, V> placeholder(shape: Shape, dtype: KClass<T>): TensorData<T, V> =
33+
zeros(shape, dtype)
34+
1435
public fun <T : DType, V> ones(shape: Shape, dtype: KClass<T>): TensorData<T, V>
1536
public fun <T : DType, V> full(shape: Shape, dtype: KClass<T>, value: Number): TensorData<T, V>
1637
public fun <T : DType, V> randn(
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package sk.ainet.lang.tensor.data
2+
3+
import sk.ainet.lang.tensor.Shape
4+
import sk.ainet.lang.types.FP32
5+
import sk.ainet.lang.types.Int32
6+
import sk.ainet.lang.types.Int8
7+
import kotlin.test.Test
8+
import kotlin.test.assertEquals
9+
import kotlin.test.assertNotSame
10+
import kotlin.test.assertSame
11+
12+
/**
13+
* Pins the contract for [TensorDataFactory.placeholder]:
14+
*
15+
* 1. Reports the requested shape without touching the underlying buffer.
16+
* 2. Materializes to zeros on the first read — value parity with [zeros].
17+
* 3. Caches the materialized buffer (no re-allocation across reads).
18+
*
19+
* The benefit (deferred allocation) doesn't show up directly in unit tests, but
20+
* the parity guarantee means any caller that *does* read the tensor sees the
21+
* same values an eager [zeros] call would have produced — so dropping in
22+
* `placeholder` for `zeros` in DSL parameter init is a strict improvement.
23+
*/
24+
class PlaceholderTensorDataTest {
25+
26+
private val factory = DenseTensorDataFactory()
27+
28+
@Test
29+
fun placeholder_reports_shape_without_materializing() {
30+
val shape = Shape(64, 64)
31+
val td = factory.placeholder<FP32, Float>(shape, FP32::class)
32+
33+
// Reading shape must not require allocating the underlying buffer.
34+
assertEquals(shape, td.shape)
35+
// Returned shape is a defensive copy — mutating one shouldn't affect the
36+
// factory-issued tensor's view.
37+
assertEquals(64, td.shape.dimensions[0])
38+
assertEquals(64, td.shape.dimensions[1])
39+
}
40+
41+
@Test
42+
fun placeholder_materializes_to_zeros_on_first_read_fp32() {
43+
val td = factory.placeholder<FP32, Float>(Shape(2, 3), FP32::class)
44+
45+
// Every position reads as 0.0f — same as zeros().
46+
for (i in 0 until 2) for (j in 0 until 3) {
47+
assertEquals(0.0f, td[i, j], "[$i,$j] must be 0.0f on first read")
48+
}
49+
}
50+
51+
@Test
52+
fun placeholder_supports_writes_and_reads_back_fp32() {
53+
val td = factory.placeholder<FP32, Float>(Shape(4), FP32::class)
54+
55+
td[2] = 7.5f
56+
assertEquals(7.5f, td[2])
57+
assertEquals(0.0f, td[0])
58+
assertEquals(0.0f, td[3])
59+
}
60+
61+
@Test
62+
fun placeholder_buffer_is_stable_across_reads() {
63+
val td = factory.placeholder<FP32, Float>(Shape(8), FP32::class)
64+
as FloatArrayTensorData<FP32>
65+
66+
val first = td.buffer
67+
val second = td.buffer
68+
// Same backing FloatArray on every access — the lazy fires once.
69+
assertSame(first, second, "buffer must be cached after first materialization")
70+
}
71+
72+
@Test
73+
fun placeholder_value_parity_with_zeros_fp32() {
74+
val shape = Shape(5, 7)
75+
val placeholder = factory.placeholder<FP32, Float>(shape, FP32::class)
76+
val zeros = factory.zeros<FP32, Float>(shape, FP32::class)
77+
78+
for (i in 0 until 5) for (j in 0 until 7) {
79+
assertEquals(zeros[i, j], placeholder[i, j],
80+
"placeholder must match zeros at [$i,$j]")
81+
}
82+
}
83+
84+
@Test
85+
fun placeholder_int32_materializes_to_zeros() {
86+
val td = factory.placeholder<Int32, Int>(Shape(3), Int32::class)
87+
assertEquals(0, td[0])
88+
assertEquals(0, td[1])
89+
assertEquals(0, td[2])
90+
}
91+
92+
@Test
93+
fun placeholder_int8_falls_back_to_zeros() {
94+
// Int8 has no lazy variant — falls back to eager zeros. The test pins
95+
// the value contract; it shouldn't throw and reads must be 0.
96+
val td = factory.placeholder<Int8, Byte>(Shape(4), Int8::class)
97+
for (i in 0 until 4) {
98+
assertEquals(0.toByte(), td[i])
99+
}
100+
}
101+
102+
@Test
103+
fun placeholder_returns_distinct_instances() {
104+
// Two placeholder calls must not share underlying state — separate Linear
105+
// layers must not see each other's writes.
106+
val a = factory.placeholder<FP32, Float>(Shape(4), FP32::class)
107+
as FloatArrayTensorData<FP32>
108+
val b = factory.placeholder<FP32, Float>(Shape(4), FP32::class)
109+
as FloatArrayTensorData<FP32>
110+
111+
assertNotSame(a.buffer, b.buffer)
112+
a[0] = 99.0f
113+
assertEquals(0.0f, b[0], "placeholder b must not see writes to placeholder a")
114+
}
115+
}

0 commit comments

Comments
 (0)