@@ -10,7 +10,7 @@ import sk.ainet.core.tensor.Int8
1010import sk.ainet.core.tensor.Int32
1111import sk.ainet.core.tensor.TensorFactory
1212import sk.ainet.core.tensor.DefaultTensorFactories
13- import sk.ainet.core.tensor.backend.CpuBackend
13+ import sk.ainet.core.tensor.TensorData
1414import sk.ainet.nn.Flatten
1515import sk.ainet.nn.Input
1616import sk.ainet.nn.Linear
@@ -59,65 +59,12 @@ public fun <T : DType, V> network(
5959 .apply (content)
6060 .create()
6161
62- /* *
63- * Backward compatibility function - creates a network using FP32/Float precision.
64- * This function maintains compatibility with existing code that doesn't specify generic types.
65- * Uses CpuBackend as default factory.
66- *
67- * @param content The DSL content block that defines the network structure
68- * @return A Module<FP32, Float> representing the complete neural network
69- */
70- @NetworkDsl
71- @JvmName(" networkFP32Default" )
72- public fun network (content : NeuralNetworkDsl <FP32 , Float >.() -> Unit ): Module <FP32 , Float > =
73- network(CpuBackend (), content)
74-
75- /* *
76- * Convenience function for creating FP32/Float precision networks.
77- * Provides explicit type specification for better code readability.
78- * Uses CpuBackend as default factory.
79- *
80- * @param content The DSL content block that defines the network structure
81- * @return A Module<FP32, Float> representing the complete neural network
82- */
83- @NetworkDsl
84- public fun networkFP32 (content : NeuralNetworkDsl <FP32 , Float >.() -> Unit ): Module <FP32 , Float > =
85- network(CpuBackend (), content)
86-
87- /* *
88- * Generic network builder function with automatic factory resolution.
89- * This function automatically selects the appropriate TensorFactory based on the generic types.
90- *
91- * Currently supports:
92- * - FP32, Float → Uses CPU FP32 backend
93- * - Int8, Byte → Uses CPU Int8 backend
94- * - Int32, Int → Uses CPU Int32 backend
95- *
96- * @param T The data type (DType) - must extend DType (e.g., FP32, Int8, Int32)
97- * @param V The value type - must match the DType's native type
98- * @param content The DSL content block that defines the network structure
99- * @return A Module<T, V> representing the complete neural network
100- *
101- * Example usage:
102- * ```kotlin
103- * val fpNetwork = network<FP32, Float> {
104- * input(784)
105- * dense(128)
106- * dense(10)
107- * }
108- *
109- * val intNetwork = network<Int8, Byte> {
110- * input(28)
111- * dense(16)
112- * }
113- * ```
114- */
11562@NetworkDsl
11663@JvmName(" networkWithAutoFactory" )
11764public inline fun <reified T : DType , reified V > network (
11865 noinline content : NeuralNetworkDsl <T , V >.() -> Unit
11966): Module <T , V > {
120- CpuBackend ()
67+ // CpuBackend()
12168 val factory = when {
12269 T ::class == FP32 ::class && V ::class == Float ::class -> {
12370 @Suppress(" UNCHECKED_CAST" )
@@ -228,7 +175,9 @@ public interface NeuralNetworkDsl<T : DType, V> : NetworkDslItem {
228175public interface DENSE <T : DType , V > : NetworkDslItem {
229176 public var activation: (Tensor <T , V >) -> Tensor <T , V >
230177 public var units: Int
231- public fun weights (initBlock : (Shape ) -> Tensor <T , V >)
178+
179+ // public fun weights(initBlock: (Shape) -> Tensor<T, V>)
180+ public fun weights (initBlock : WeightsScope <T , V >.(Shape ) -> Tensor <T , V >)
232181 public fun bias (initBlock : (Shape ) -> Tensor <T , V >)
233182
234183 // Factory-based convenience methods
@@ -310,22 +259,6 @@ public interface DENSE<T : DType, V> : NetworkDslItem {
310259 factory.randomUniform(shape, min, max, random)
311260}
312261
313- // Extension functions for convenient parameterless initialization
314- /* *
315- * Extension function for weights initialization with implicit shape context.
316- */
317- public fun <T : DType , V > DENSE <T , V >.weights (initBlock : WeightsScope <T , V >.() -> Tensor <T , V >) {
318- val scope = WeightsScopeImpl (factory, weightsShape)
319- weights { scope.initBlock() }
320- }
321-
322- /* *
323- * Extension function for bias initialization with implicit shape context.
324- */
325- public fun <T : DType , V > DENSE <T , V >.bias (initBlock : BiasScope <T , V >.() -> Tensor <T , V >) {
326- val scope = BiasScopeImpl (factory, biasShape)
327- bias { scope.initBlock() }
328- }
329262
330263/* *
331264 * Scope for weights initialization with implicit shape context.
@@ -335,6 +268,26 @@ public interface WeightsScope<T : DType, V> {
335268 public val factory: TensorFactory <T , V >
336269 public val shape: Shape
337270
271+ // fromXX Float
272+ public fun from (vararg data : Float ): Tensor <T , V > = fromArray(data.toTypedArray().toFloatArray())
273+ public fun fromList (data : List <Float >): Tensor <T , V > = fromArray(data.toFloatArray())
274+ public fun fromArray (data : FloatArray ): Tensor <T , V > {
275+ require(data.size == shape.volume) {
276+ " Data size ${data.size} doesn't match shape volume ${shape.volume} "
277+ }
278+ return factory.fromArray(shape, data)
279+ }
280+
281+ // fromXX Int
282+ public fun from (vararg data : Int ): Tensor <T , V > = fromArray(data.toTypedArray().toIntArray())
283+ public fun fromList (data : List <Int >): Tensor <T , V > = fromArray(data.toIntArray())
284+ public fun fromArray (data : IntArray ): Tensor <T , V > {
285+ require(data.size == shape.volume) {
286+ " Data size ${data.size} doesn't match shape volume ${shape.volume} "
287+ }
288+ return factory.fromArray(shape, data)
289+ }
290+
338291 public fun zeros (): Tensor <T , V > = factory.zeros(shape)
339292 public fun ones (): Tensor <T , V > = factory.ones(shape)
340293 public fun random (): Tensor <T , V > = factory.random(shape)
@@ -423,6 +376,11 @@ public interface FLATTEN<T : DType, V> : NetworkDslItem {
423376 public var endDim: Int
424377}
425378
379+ @NetworkDsl
380+ public interface VALUES <T : DType , V > : NetworkDslItem {
381+ public var tensorValues: TensorData <T , V >
382+ }
383+
426384private fun getDefaultName (id : String , s : String , size : Int ): String {
427385 if (id.isNotEmpty()) return id
428386 return " $s -$size "
@@ -551,8 +509,16 @@ public class DenseImpl<T : DType, V>(
551509 _outputDimension = value
552510 }
553511
554- override fun weights (initBlock : (Shape ) -> Tensor <T , V >) {
555- weightsValue = initBlock(weightsShape)
512+ private fun initWeights (tensor : Tensor <T , V >) {
513+ tensor
514+
515+ }
516+
517+ override fun weights (initBlock : WeightsScope <T , V >.(Shape ) -> Tensor <T , V >) {
518+ val scope = WeightsScopeImpl (factory, weightsShape)
519+ // initWeights = scope.initBlock(weightsShape)
520+
521+
556522 }
557523
558524 override fun bias (initBlock : (Shape ) -> Tensor <T , V >) {
0 commit comments