Skip to content

Commit cb6a4cb

Browse files
committed
Add initial support braodcasting
Relate-To: #94, #119
1 parent e56513b commit cb6a4cb

9 files changed

Lines changed: 1012 additions & 130 deletions

File tree

skainet-core/skainet-tensors-api/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ kotlin {
3535
sourceSets {
3636
commonTest.dependencies {
3737
implementation(libs.kotlin.test)
38+
implementation(project(":skainet-core:skainet-tensors"))
3839
}
3940
}
4041
}

skainet-core/skainet-tensors-api/src/commonMain/kotlin/sk/ainet/core/tensor/backend/ComputeBackend.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sk.ainet.core.tensor.backend
22

33
import sk.ainet.core.tensor.DType
44
import sk.ainet.core.tensor.Tensor
5+
import sk.ainet.core.tensor.TensorFactory
56
import sk.ainet.core.tensor.TensorOps
67

78

@@ -12,7 +13,7 @@ import sk.ainet.core.tensor.TensorOps
1213
* hardware platform (CPU, GPU, etc.). Different backends can provide different
1314
* implementations of the same operations, optimized for their target platform.
1415
*/
15-
public interface ComputeBackend<D : DType, V> : TensorOps<D, V, Tensor<D, V>> {
16+
public interface ComputeBackend<D : DType, V> : TensorOps<D, V, Tensor<D, V>>, TensorFactory<D, V> {
1617
/**
1718
* The name of the backend.
1819
*/
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package sk.ainet.core.tensor
2+
3+
import sk.ainet.core.tensor.backend.CpuBackend
4+
import sk.ainet.core.tensor.backend.CpuBackendInt8
5+
import sk.ainet.core.tensor.backend.CpuBackendInt32
6+
import kotlin.test.*
7+
8+
/**
9+
* Tests for DefaultTensorFactories to verify factory initialization
10+
* and identify any IllegalStateException issues
11+
*/
12+
class DefaultTensorFactoriesTest {
13+
14+
@Test
15+
fun testFP32FactoryInitialization() {
16+
try {
17+
DefaultTensorFactories.setFP32Factory(CpuBackend())
18+
val factory = DefaultTensorFactories.getFP32Factory()
19+
val tensor = factory.zeros(Shape(2, 3))
20+
assertNotNull(tensor)
21+
assertEquals(Shape(2, 3), tensor.shape)
22+
println("[DEBUG_LOG] FP32 factory initialized successfully")
23+
} catch (e: IllegalStateException) {
24+
println("[DEBUG_LOG] FP32 factory not initialized: ${e.message}")
25+
fail("FP32 factory not initialized: ${e.message}")
26+
}
27+
}
28+
29+
@Test
30+
fun testInt8FactoryInitialization() {
31+
try {
32+
val factory = DefaultTensorFactories.getInt8Factory()
33+
val tensor = factory.zeros(Shape(2, 3))
34+
assertNotNull(tensor)
35+
assertEquals(Shape(2, 3), tensor.shape)
36+
println("[DEBUG_LOG] Int8 factory initialized successfully")
37+
} catch (e: IllegalStateException) {
38+
println("[DEBUG_LOG] Int8 factory not initialized: ${e.message}")
39+
fail("Int8 factory not initialized: ${e.message}")
40+
}
41+
}
42+
43+
@Test
44+
fun testInt32FactoryInitialization() {
45+
try {
46+
val factory = DefaultTensorFactories.getInt32Factory()
47+
val tensor = factory.zeros(Shape(2, 3))
48+
assertNotNull(tensor)
49+
assertEquals(Shape(2, 3), tensor.shape)
50+
println("[DEBUG_LOG] Int32 factory initialized successfully")
51+
} catch (e: IllegalStateException) {
52+
println("[DEBUG_LOG] Int32 factory not initialized: ${e.message}")
53+
fail("Int32 factory not initialized: ${e.message}")
54+
}
55+
}
56+
57+
@Test
58+
fun testAllFactoryTypes() {
59+
println("[DEBUG_LOG] Testing all factory types...")
60+
61+
// Force initialization by creating backend instances
62+
try {
63+
CpuBackend()
64+
println("[DEBUG_LOG] FP32 backend instantiated")
65+
} catch (e: Exception) {
66+
println("[DEBUG_LOG] Failed to instantiate FP32 backend: ${e.message}")
67+
}
68+
69+
try {
70+
CpuBackendInt8()
71+
println("[DEBUG_LOG] Int8 backend instantiated")
72+
} catch (e: Exception) {
73+
println("[DEBUG_LOG] Failed to instantiate Int8 backend: ${e.message}")
74+
}
75+
76+
try {
77+
CpuBackendInt32()
78+
println("[DEBUG_LOG] Int32 backend instantiated")
79+
} catch (e: Exception) {
80+
println("[DEBUG_LOG] Failed to instantiate Int32 backend: ${e.message}")
81+
}
82+
83+
// This test will show us which factories are missing
84+
val results = mutableListOf<String>()
85+
86+
try {
87+
DefaultTensorFactories.getFP32Factory()
88+
results.add("FP32: OK")
89+
} catch (e: IllegalStateException) {
90+
results.add("FP32: MISSING - ${e.message}")
91+
}
92+
93+
try {
94+
DefaultTensorFactories.getInt8Factory()
95+
results.add("Int8: OK")
96+
} catch (e: IllegalStateException) {
97+
results.add("Int8: MISSING - ${e.message}")
98+
}
99+
100+
try {
101+
DefaultTensorFactories.getInt32Factory()
102+
results.add("Int32: OK")
103+
} catch (e: IllegalStateException) {
104+
results.add("Int32: MISSING - ${e.message}")
105+
}
106+
107+
results.forEach { println("[DEBUG_LOG] $it") }
108+
109+
// We'll fail only if all factories are missing
110+
val missingCount = results.count { it.contains("MISSING") }
111+
if (missingCount == results.size) {
112+
fail("All factories are missing!")
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)