Skip to content

Commit dc21310

Browse files
committed
Add unit test showcasing using slicing with image tensors.
1 parent 9ae815b commit dc21310

1 file changed

Lines changed: 237 additions & 0 deletions

File tree

  • skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
package sk.ainet.lang.tensor
2+
3+
import sk.ainet.lang.tensor.dsl.*
4+
import sk.ainet.lang.types.FP32
5+
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertNotNull
8+
9+
/**
10+
* Comprehensive test suite for pixel-by-pixel tensor access functionality.
11+
* Demonstrates how to access every pixel separately in 4D tensors (BCHW format).
12+
*/
13+
class PixelAccessTest {
14+
15+
@Test
16+
fun testPixelByPixelAccess() {
17+
println("[DEBUG_LOG] Testing pixel-by-pixel access for 4D tensor")
18+
19+
// Create a sample 4D tensor for computer vision: [Batch=2, Channels=3, Height=4, Width=4]
20+
val imageTensor = with<FP32, Float>(testFactory) {
21+
tensor(2, 3, 4, 4) { _ ->
22+
init { indices ->
23+
// Initialize with meaningful pattern: batch*1000 + channel*100 + height*10 + width
24+
(indices[0] * 1000 + indices[1] * 100 + indices[2] * 10 + indices[3]).toFloat()
25+
}
26+
}
27+
}
28+
29+
assertNotNull(imageTensor)
30+
assertEquals(Shape(2, 3, 4, 4), imageTensor.shape)
31+
32+
println("[DEBUG_LOG] Created image tensor with shape: ${imageTensor.shape}")
33+
34+
// Test accessing every pixel individually using data[indices] method
35+
for (batch in 0 until 2) {
36+
for (channel in 0 until 3) {
37+
for (height in 0 until 4) {
38+
for (width in 0 until 4) {
39+
val expectedValue = (batch * 1000 + channel * 100 + height * 10 + width).toFloat()
40+
val actualValue = imageTensor.data[batch, channel, height, width]
41+
42+
assertEquals(expectedValue, actualValue, 0.001f,
43+
"Pixel mismatch at [$batch,$channel,$height,$width]: expected $expectedValue, got $actualValue")
44+
}
45+
}
46+
}
47+
}
48+
49+
println("[DEBUG_LOG] Successfully accessed and verified all ${2*3*4*4} pixels")
50+
}
51+
52+
@Test
53+
fun testSpecificPixelPatterns() {
54+
println("[DEBUG_LOG] Testing specific pixel access patterns")
55+
56+
// Create a 4D tensor with a different pattern
57+
val tensor = with<FP32, Float>(testFactory) {
58+
tensor(2, 3, 4, 4) { _ ->
59+
init { indices ->
60+
// Pattern: sum of all indices
61+
(indices[0] + indices[1] + indices[2] + indices[3]).toFloat()
62+
}
63+
}
64+
}
65+
66+
// Test corner pixels
67+
assertEquals(0.0f, tensor.data[0, 0, 0, 0]) // Top-left corner of first batch/channel
68+
assertEquals(9.0f, tensor.data[1, 2, 3, 3]) // Bottom-right corner of second batch, third channel (1+2+3+3=9)
69+
70+
// Test center pixels
71+
assertEquals(4.0f, tensor.data[0, 1, 1, 2]) // Center-ish pixel
72+
assertEquals(6.0f, tensor.data[1, 1, 2, 2]) // Another center pixel (1+1+2+2=6)
73+
74+
// Test edge pixels
75+
assertEquals(3.0f, tensor.data[0, 0, 0, 3]) // Right edge, first row
76+
assertEquals(6.0f, tensor.data[0, 0, 3, 3]) // Bottom-right corner, first batch/channel
77+
78+
println("[DEBUG_LOG] All specific pixel patterns verified successfully")
79+
}
80+
81+
@Test
82+
fun testFirstBatchFirstChannelPixels() {
83+
println("[DEBUG_LOG] Testing comprehensive access to first batch, first channel")
84+
85+
val imageTensor = with<FP32, Float>(testFactory) {
86+
tensor(2, 3, 4, 4) { _ ->
87+
init { indices ->
88+
// Initialize with meaningful pattern: batch*1000 + channel*100 + height*10 + width
89+
(indices[0] * 1000 + indices[1] * 100 + indices[2] * 10 + indices[3]).toFloat()
90+
}
91+
}
92+
}
93+
94+
println("[DEBUG_LOG] Sample values from tensor[0,0,:,:] (first batch, first channel):")
95+
val pixelValues = mutableListOf<Float>()
96+
97+
for (h in 0 until 4) {
98+
val rowValues = mutableListOf<Float>()
99+
for (w in 0 until 4) {
100+
val value = imageTensor.data[0, 0, h, w]
101+
rowValues.add(value)
102+
pixelValues.add(value)
103+
print("${value.toInt()}\t")
104+
}
105+
println()
106+
107+
// Verify the row values match expected pattern
108+
for ((w, value) in rowValues.withIndex()) {
109+
val expectedValue = (h * 10 + w).toFloat()
110+
assertEquals(expectedValue, value, 0.001f,
111+
"Row $h, Col $w: expected $expectedValue, got $value")
112+
}
113+
}
114+
115+
// Verify we captured all 16 pixels for the first batch/channel
116+
assertEquals(16, pixelValues.size)
117+
118+
// Verify specific expected values
119+
assertEquals(0.0f, pixelValues[0]) // [0,0]
120+
assertEquals(33.0f, pixelValues[15]) // [3,3] = 3*10 + 3 = 33
121+
assertEquals(13.0f, pixelValues[7]) // [1,3] = 1*10 + 3 = 13 (index 7 is row 1, col 3)
122+
123+
println("[DEBUG_LOG] First batch, first channel pixels verified successfully")
124+
}
125+
126+
@Test
127+
fun testAllChannelsFromFirstBatch() {
128+
println("[DEBUG_LOG] Testing access to all channels from first batch")
129+
130+
val imageTensor = with<FP32, Float>(testFactory) {
131+
tensor(2, 3, 4, 4) { _ ->
132+
init { indices ->
133+
// Initialize with meaningful pattern: batch*1000 + channel*100 + height*10 + width
134+
(indices[0] * 1000 + indices[1] * 100 + indices[2] * 10 + indices[3]).toFloat()
135+
}
136+
}
137+
}
138+
139+
// Test accessing all channels for the first batch, first pixel [0,0]
140+
for (channel in 0 until 3) {
141+
val value = imageTensor.data[0, channel, 0, 0]
142+
val expectedValue = (channel * 100).toFloat() // batch=0, h=0, w=0
143+
assertEquals(expectedValue, value, 0.001f,
144+
"Channel $channel pixel [0,0]: expected $expectedValue, got $value")
145+
146+
println("[DEBUG_LOG] Channel $channel, pixel [0,0] = $value")
147+
}
148+
149+
// Test accessing all channels for a middle pixel [2,2]
150+
for (channel in 0 until 3) {
151+
val value = imageTensor.data[0, channel, 2, 2]
152+
val expectedValue = (channel * 100 + 2 * 10 + 2).toFloat() // batch=0, h=2, w=2
153+
assertEquals(expectedValue, value, 0.001f,
154+
"Channel $channel pixel [2,2]: expected $expectedValue, got $value")
155+
156+
println("[DEBUG_LOG] Channel $channel, pixel [2,2] = $value")
157+
}
158+
159+
println("[DEBUG_LOG] All channels from first batch verified successfully")
160+
}
161+
162+
@Test
163+
fun testBatchSeparation() {
164+
println("[DEBUG_LOG] Testing pixel access across different batches")
165+
166+
val imageTensor = with<FP32, Float>(testFactory) {
167+
tensor(2, 3, 4, 4) { _ ->
168+
init { indices ->
169+
// Initialize with meaningful pattern: batch*1000 + channel*100 + height*10 + width
170+
(indices[0] * 1000 + indices[1] * 100 + indices[2] * 10 + indices[3]).toFloat()
171+
}
172+
}
173+
}
174+
175+
// Compare same pixel location across different batches
176+
val pixel00Batch0 = imageTensor.data[0, 0, 0, 0] // Should be 0
177+
val pixel00Batch1 = imageTensor.data[1, 0, 0, 0] // Should be 1000
178+
179+
assertEquals(0.0f, pixel00Batch0, 0.001f)
180+
assertEquals(1000.0f, pixel00Batch1, 0.001f)
181+
182+
// Compare same pixel location, same channel, different batches
183+
val pixel22Ch1Batch0 = imageTensor.data[0, 1, 2, 2] // 0*1000 + 1*100 + 2*10 + 2 = 122
184+
val pixel22Ch1Batch1 = imageTensor.data[1, 1, 2, 2] // 1*1000 + 1*100 + 2*10 + 2 = 1122
185+
186+
assertEquals(122.0f, pixel22Ch1Batch0, 0.001f)
187+
assertEquals(1122.0f, pixel22Ch1Batch1, 0.001f)
188+
189+
println("[DEBUG_LOG] Batch separation verified: batch 0 pixel [1,2,2] = $pixel22Ch1Batch0, batch 1 pixel [1,2,2] = $pixel22Ch1Batch1")
190+
191+
// Verify the difference is exactly 1000 (the batch multiplier)
192+
val batchDifference = pixel22Ch1Batch1 - pixel22Ch1Batch0
193+
assertEquals(1000.0f, batchDifference, 0.001f)
194+
195+
println("[DEBUG_LOG] Batch separation test completed successfully")
196+
}
197+
198+
@Test
199+
fun testEdgeAndCornerPixels() {
200+
println("[DEBUG_LOG] Testing edge and corner pixel access")
201+
202+
val imageTensor = with<FP32, Float>(testFactory) {
203+
tensor(2, 3, 4, 4) { _ ->
204+
init { indices ->
205+
// Simple pattern: just the sum of indices
206+
(indices[0] + indices[1] + indices[2] + indices[3]).toFloat()
207+
}
208+
}
209+
}
210+
211+
// Test all four corners of first batch, first channel
212+
val topLeft = imageTensor.data[0, 0, 0, 0] // 0+0+0+0 = 0
213+
val topRight = imageTensor.data[0, 0, 0, 3] // 0+0+0+3 = 3
214+
val bottomLeft = imageTensor.data[0, 0, 3, 0] // 0+0+3+0 = 3
215+
val bottomRight = imageTensor.data[0, 0, 3, 3] // 0+0+3+3 = 6
216+
217+
assertEquals(0.0f, topLeft)
218+
assertEquals(3.0f, topRight)
219+
assertEquals(3.0f, bottomLeft)
220+
assertEquals(6.0f, bottomRight)
221+
222+
// Test edge pixels (middle of each edge)
223+
val topEdge = imageTensor.data[0, 0, 0, 2] // 0+0+0+2 = 2
224+
val bottomEdge = imageTensor.data[0, 0, 3, 2] // 0+0+3+2 = 5
225+
val leftEdge = imageTensor.data[0, 0, 2, 0] // 0+0+2+0 = 2
226+
val rightEdge = imageTensor.data[0, 0, 2, 3] // 0+0+2+3 = 5
227+
228+
assertEquals(2.0f, topEdge)
229+
assertEquals(5.0f, bottomEdge)
230+
assertEquals(2.0f, leftEdge)
231+
assertEquals(5.0f, rightEdge)
232+
233+
println("[DEBUG_LOG] All corner and edge pixels verified successfully")
234+
println("[DEBUG_LOG] Corners: TL=$topLeft, TR=$topRight, BL=$bottomLeft, BR=$bottomRight")
235+
println("[DEBUG_LOG] Edges: T=$topEdge, B=$bottomEdge, L=$leftEdge, R=$rightEdge")
236+
}
237+
}

0 commit comments

Comments
 (0)