Skip to content

Commit a684c41

Browse files
committed
Introduce explicit Weights scope for adding tensors factories.
Related-To: #94
1 parent a9a6691 commit a684c41

4 files changed

Lines changed: 47 additions & 237 deletions

File tree

skainet-core/skainet-tensors-api/src/commonMain/kotlin/sk/ainet/core/tensor/TensorFactories.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@ package sk.ainet.core.tensor
22

33
import kotlin.random.Random
44

5-
public interface TensorFactory<T: DType, V> {
5+
public interface TensorFactory<T : DType, V> {
66
public fun zeros(shape: Shape): Tensor<T, V>
77
public fun ones(shape: Shape): Tensor<T, V>
88
public fun random(shape: Shape): Tensor<T, V>
9-
9+
1010
// Advanced random methods with seed control
1111
public fun random(shape: Shape, seed: Long): Tensor<T, V>
1212
public fun random(shape: Shape, random: Random): Tensor<T, V>
13-
13+
1414
// Distribution-based random methods
1515
public fun randomNormal(shape: Shape, mean: Double = 0.0, std: Double = 1.0): Tensor<T, V>
1616
public fun randomNormal(shape: Shape, mean: Double = 0.0, std: Double = 1.0, seed: Long): Tensor<T, V>
1717
public fun randomNormal(shape: Shape, mean: Double = 0.0, std: Double = 1.0, random: Random): Tensor<T, V>
18-
18+
1919
public fun randomUniform(shape: Shape, min: Double = 0.0, max: Double = 1.0): Tensor<T, V>
2020
public fun randomUniform(shape: Shape, min: Double = 0.0, max: Double = 1.0, seed: Long): Tensor<T, V>
2121
public fun randomUniform(shape: Shape, min: Double = 0.0, max: Double = 1.0, random: Random): Tensor<T, V>
22+
23+
public fun fromArray(shape: Shape, data: FloatArray): Tensor<T, V>
24+
public fun fromArray(shape: Shape, data: IntArray): Tensor<T, V>
2225
}

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/dsl/NetworkBuilder.kt

Lines changed: 40 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import sk.ainet.core.tensor.Int8
1010
import sk.ainet.core.tensor.Int32
1111
import sk.ainet.core.tensor.TensorFactory
1212
import sk.ainet.core.tensor.DefaultTensorFactories
13-
import sk.ainet.core.tensor.backend.CpuBackend
13+
import sk.ainet.core.tensor.TensorData
1414
import sk.ainet.nn.Flatten
1515
import sk.ainet.nn.Input
1616
import sk.ainet.nn.Linear
@@ -59,65 +59,12 @@ public fun <T : DType, V> network(
5959
.apply(content)
6060
.create()
6161

62-
/**
63-
* Backward compatibility function - creates a network using FP32/Float precision.
64-
* This function maintains compatibility with existing code that doesn't specify generic types.
65-
* Uses CpuBackend as default factory.
66-
*
67-
* @param content The DSL content block that defines the network structure
68-
* @return A Module<FP32, Float> representing the complete neural network
69-
*/
70-
@NetworkDsl
71-
@JvmName("networkFP32Default")
72-
public fun network(content: NeuralNetworkDsl<FP32, Float>.() -> Unit): Module<FP32, Float> =
73-
network(CpuBackend(), content)
74-
75-
/**
76-
* Convenience function for creating FP32/Float precision networks.
77-
* Provides explicit type specification for better code readability.
78-
* Uses CpuBackend as default factory.
79-
*
80-
* @param content The DSL content block that defines the network structure
81-
* @return A Module<FP32, Float> representing the complete neural network
82-
*/
83-
@NetworkDsl
84-
public fun networkFP32(content: NeuralNetworkDsl<FP32, Float>.() -> Unit): Module<FP32, Float> =
85-
network(CpuBackend(), content)
86-
87-
/**
88-
* Generic network builder function with automatic factory resolution.
89-
* This function automatically selects the appropriate TensorFactory based on the generic types.
90-
*
91-
* Currently supports:
92-
* - FP32, Float → Uses CPU FP32 backend
93-
* - Int8, Byte → Uses CPU Int8 backend
94-
* - Int32, Int → Uses CPU Int32 backend
95-
*
96-
* @param T The data type (DType) - must extend DType (e.g., FP32, Int8, Int32)
97-
* @param V The value type - must match the DType's native type
98-
* @param content The DSL content block that defines the network structure
99-
* @return A Module<T, V> representing the complete neural network
100-
*
101-
* Example usage:
102-
* ```kotlin
103-
* val fpNetwork = network<FP32, Float> {
104-
* input(784)
105-
* dense(128)
106-
* dense(10)
107-
* }
108-
*
109-
* val intNetwork = network<Int8, Byte> {
110-
* input(28)
111-
* dense(16)
112-
* }
113-
* ```
114-
*/
11562
@NetworkDsl
11663
@JvmName("networkWithAutoFactory")
11764
public inline fun <reified T : DType, reified V> network(
11865
noinline content: NeuralNetworkDsl<T, V>.() -> Unit
11966
): Module<T, V> {
120-
CpuBackend()
67+
// CpuBackend()
12168
val factory = when {
12269
T::class == FP32::class && V::class == Float::class -> {
12370
@Suppress("UNCHECKED_CAST")
@@ -228,7 +175,9 @@ public interface NeuralNetworkDsl<T : DType, V> : NetworkDslItem {
228175
public interface DENSE<T : DType, V> : NetworkDslItem {
229176
public var activation: (Tensor<T, V>) -> Tensor<T, V>
230177
public var units: Int
231-
public fun weights(initBlock: (Shape) -> Tensor<T, V>)
178+
179+
//public fun weights(initBlock: (Shape) -> Tensor<T, V>)
180+
public fun weights(initBlock: WeightsScope<T, V>.(Shape) -> Tensor<T, V>)
232181
public fun bias(initBlock: (Shape) -> Tensor<T, V>)
233182

234183
// Factory-based convenience methods
@@ -310,22 +259,6 @@ public interface DENSE<T : DType, V> : NetworkDslItem {
310259
factory.randomUniform(shape, min, max, random)
311260
}
312261

313-
// Extension functions for convenient parameterless initialization
314-
/**
315-
* Extension function for weights initialization with implicit shape context.
316-
*/
317-
public fun <T : DType, V> DENSE<T, V>.weights(initBlock: WeightsScope<T, V>.() -> Tensor<T, V>) {
318-
val scope = WeightsScopeImpl(factory, weightsShape)
319-
weights { scope.initBlock() }
320-
}
321-
322-
/**
323-
* Extension function for bias initialization with implicit shape context.
324-
*/
325-
public fun <T : DType, V> DENSE<T, V>.bias(initBlock: BiasScope<T, V>.() -> Tensor<T, V>) {
326-
val scope = BiasScopeImpl(factory, biasShape)
327-
bias { scope.initBlock() }
328-
}
329262

330263
/**
331264
* Scope for weights initialization with implicit shape context.
@@ -335,6 +268,26 @@ public interface WeightsScope<T : DType, V> {
335268
public val factory: TensorFactory<T, V>
336269
public val shape: Shape
337270

271+
// fromXX Float
272+
public fun from(vararg data: Float): Tensor<T, V> = fromArray(data.toTypedArray().toFloatArray())
273+
public fun fromList(data: List<Float>): Tensor<T, V> = fromArray(data.toFloatArray())
274+
public fun fromArray(data: FloatArray): Tensor<T, V> {
275+
require(data.size == shape.volume) {
276+
"Data size ${data.size} doesn't match shape volume ${shape.volume}"
277+
}
278+
return factory.fromArray(shape, data)
279+
}
280+
281+
// fromXX Int
282+
public fun from(vararg data: Int): Tensor<T, V> = fromArray(data.toTypedArray().toIntArray())
283+
public fun fromList(data: List<Int>): Tensor<T, V> = fromArray(data.toIntArray())
284+
public fun fromArray(data: IntArray): Tensor<T, V> {
285+
require(data.size == shape.volume) {
286+
"Data size ${data.size} doesn't match shape volume ${shape.volume}"
287+
}
288+
return factory.fromArray(shape, data)
289+
}
290+
338291
public fun zeros(): Tensor<T, V> = factory.zeros(shape)
339292
public fun ones(): Tensor<T, V> = factory.ones(shape)
340293
public fun random(): Tensor<T, V> = factory.random(shape)
@@ -423,6 +376,11 @@ public interface FLATTEN<T : DType, V> : NetworkDslItem {
423376
public var endDim: Int
424377
}
425378

379+
@NetworkDsl
380+
public interface VALUES<T : DType, V> : NetworkDslItem {
381+
public var tensorValues: TensorData<T, V>
382+
}
383+
426384
private fun getDefaultName(id: String, s: String, size: Int): String {
427385
if (id.isNotEmpty()) return id
428386
return "$s-$size"
@@ -551,8 +509,16 @@ public class DenseImpl<T : DType, V>(
551509
_outputDimension = value
552510
}
553511

554-
override fun weights(initBlock: (Shape) -> Tensor<T, V>) {
555-
weightsValue = initBlock(weightsShape)
512+
private fun initWeights(tensor: Tensor<T, V>) {
513+
tensor
514+
515+
}
516+
517+
override fun weights(initBlock: WeightsScope<T, V>.(Shape) -> Tensor<T, V>) {
518+
val scope = WeightsScopeImpl(factory, weightsShape)
519+
//initWeights = scope.initBlock(weightsShape)
520+
521+
556522
}
557523

558524
override fun bias(initBlock: (Shape) -> Tensor<T, V>) {

skainet-nn/skainet-nn-api/src/commonMain/kotlin/sk/ainet/nn/dsl/extensions/TensorCreationExtensions.kt

Lines changed: 0 additions & 158 deletions
This file was deleted.

skainet-nn/skainet-nn-api/src/commonTest/kotlin/sk/ainet/nn/dsl/SinNetworkShapeTest.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import sk.ainet.nn.Linear
66
import sk.ainet.nn.topology.ModuleParameters
77
import sk.ainet.nn.topology.weights
88
import sk.ainet.nn.topology.bias
9-
import sk.ainet.nn.dsl.extensions.*
109
import sk.ainet.core.tensor.factory.fromBytes
1110
import sk.ainet.core.tensor.factory.ByteArrayConverter
1211
import kotlin.test.*

0 commit comments

Comments
 (0)