1616 */
1717package org.tensorflow
1818
19- import org.junit.jupiter.api.Test
2019import org.tensorflow.ndarray.Shape
2120import org.tensorflow.ndarray.get
2221import org.tensorflow.op.kotlin.KotlinOps
2322import org.tensorflow.op.kotlin.tf
2423import org.tensorflow.op.kotlin.withSubScope
2524import org.tensorflow.types.TFloat32
25+ import kotlin.test.Test
2626
27- public fun KotlinOps.DenseLayer (
27+ private fun KotlinOps.DenseLayer (
2828 name : String ,
2929 x : Operand <TFloat32 >,
3030 n : Int ,
3131 activation : KotlinOps .(Operand <TFloat32 >) -> Operand <TFloat32 > = { tf.nn.relu(it) }
3232): Operand <TFloat32 > = tf.withSubScope(name) {
3333 val inputDims = x.shape()[1 ]
34- val W = tf.variable(tf.math.add (tf.zeros(tf. array(inputDims.toInt(), n), TFloat32 :: class .java), constant( 1f )))
35- val b = tf.variable(tf.math.add (tf.zeros(tf. array(n), TFloat32 :: class .java), constant( 1f )))
36- activation(tf.math.add(tf.linalg.matMul(x, W ), b) )
34+ val W = tf.variable(tf.ones< TFloat32 > (tf.array(inputDims.toInt(), n)))
35+ val b = tf.variable(tf.ones< TFloat32 > (tf.array(n)))
36+ activation((x matMul W ) + b )
3737}
3838
39- public class Example {
39+ public class ExampleTest {
4040 @Test
4141 public fun mnistExample () {
4242 Graph {
4343 val input = tf.placeholderWithDefault(
44- tf.math.add (tf.zeros(tf. array(1 , 28 , 28 , 3 )), tf.constant( 1f )),
44+ tf.ones< TFloat32 > (tf.array(1 , 28 , 28 , 3 )),
4545 Shape .of(- 1 , 28 , 28 , 3 )
4646 )
4747
@@ -53,10 +53,11 @@ public class Example {
5353 DenseLayer (" OutputLayer" , x, 10 ) { tf.math.sigmoid(x) }
5454 }
5555
56- // useSession {
57- // val outputValue = it.run(fetches = listOf(output))[output]
58- // println(outputValue.data())
59- // }
56+ useSession { session ->
57+
58+ val outputValue = session.runner().fetch(output).run ()[0 ] as TFloat32
59+ println (outputValue.getFloat(0 ))
60+ }
6061 }
6162 }
6263}
0 commit comments