Skip to content

Commit 79b40b5

Browse files
psiddhGithub Executorchclaude
authored
Add instrumentation tests for ByteBuffer prefill validation (#17835)
Tests cover all Java-side validation paths for prefillImages(ByteBuffer) and prefillNormalizedImage(ByteBuffer): non-direct buffer, insufficient remaining bytes, zero/negative dimensions, non-native byte order, and misaligned position. Valid-buffer tests confirm validation passes without requiring a multimodal model. Co-authored-by: Github Executorch <github_executorch@arm.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent 60f764f commit 79b40b5

1 file changed

Lines changed: 175 additions & 0 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
1111
import java.io.File
1212
import java.io.IOException
1313
import java.net.URISyntaxException
14+
import java.nio.ByteBuffer
15+
import java.nio.ByteOrder
1416
import org.apache.commons.io.FileUtils
1517
import org.json.JSONException
1618
import org.json.JSONObject
1719
import org.junit.Assert.assertEquals
20+
import org.junit.Assert.assertThrows
1821
import org.junit.Assert.assertTrue
1922
import org.junit.Before
2023
import org.junit.Test
@@ -98,6 +101,178 @@ class LlmModuleInstrumentationTest : LlmCallback {
98101
} catch (_: JSONException) {}
99102
}
100103

104+
// --- prefillImages(ByteBuffer) validation tests ---
105+
106+
@Test
107+
fun testPrefillImagesByteBuffer_nonDirectThrows() {
108+
val heapBuffer = ByteBuffer.allocate(2 * 2 * 3)
109+
assertThrows(IllegalArgumentException::class.java) {
110+
llmModule.prefillImages(heapBuffer, 2, 2, 3)
111+
}
112+
}
113+
114+
@Test
115+
fun testPrefillImagesByteBuffer_insufficientRemainingThrows() {
116+
val buffer = ByteBuffer.allocateDirect(10)
117+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, 2, 2, 3) }
118+
}
119+
120+
@Test
121+
fun testPrefillImagesByteBuffer_zeroWidthThrows() {
122+
val buffer = ByteBuffer.allocateDirect(12)
123+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, 0, 2, 3) }
124+
}
125+
126+
@Test
127+
fun testPrefillImagesByteBuffer_zeroHeightThrows() {
128+
val buffer = ByteBuffer.allocateDirect(12)
129+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, 2, 0, 3) }
130+
}
131+
132+
@Test
133+
fun testPrefillImagesByteBuffer_zeroChannelsThrows() {
134+
val buffer = ByteBuffer.allocateDirect(12)
135+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, 2, 2, 0) }
136+
}
137+
138+
@Test
139+
fun testPrefillImagesByteBuffer_negativeWidthThrows() {
140+
val buffer = ByteBuffer.allocateDirect(12)
141+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, -1, 2, 3) }
142+
}
143+
144+
@Test
145+
fun testPrefillImagesByteBuffer_negativeHeightThrows() {
146+
val buffer = ByteBuffer.allocateDirect(12)
147+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, 2, -1, 3) }
148+
}
149+
150+
@Test
151+
fun testPrefillImagesByteBuffer_negativeChannelsThrows() {
152+
val buffer = ByteBuffer.allocateDirect(12)
153+
assertThrows(IllegalArgumentException::class.java) { llmModule.prefillImages(buffer, 2, 2, -1) }
154+
}
155+
156+
@Test
157+
fun testPrefillImagesByteBuffer_validBufferPassesValidation() {
158+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3)
159+
try {
160+
llmModule.prefillImages(buffer, 2, 2, 3)
161+
} catch (e: IllegalArgumentException) {
162+
throw AssertionError("Validation should not reject a correctly sized direct buffer", e)
163+
} catch (_: RuntimeException) {
164+
// Expected: native call may fail since this is a text-only model
165+
}
166+
}
167+
168+
// --- prefillNormalizedImage(ByteBuffer) validation tests ---
169+
170+
@Test
171+
fun testPrefillNormalizedImage_nonDirectThrows() {
172+
val heapBuffer = ByteBuffer.allocate(2 * 2 * 3 * 4)
173+
assertThrows(IllegalArgumentException::class.java) {
174+
llmModule.prefillNormalizedImage(heapBuffer, 2, 2, 3)
175+
}
176+
}
177+
178+
@Test
179+
fun testPrefillNormalizedImage_insufficientRemainingThrows() {
180+
val buffer = ByteBuffer.allocateDirect(10)
181+
buffer.order(ByteOrder.nativeOrder())
182+
assertThrows(IllegalArgumentException::class.java) {
183+
llmModule.prefillNormalizedImage(buffer, 2, 2, 3)
184+
}
185+
}
186+
187+
@Test
188+
fun testPrefillNormalizedImage_zeroWidthThrows() {
189+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
190+
buffer.order(ByteOrder.nativeOrder())
191+
assertThrows(IllegalArgumentException::class.java) {
192+
llmModule.prefillNormalizedImage(buffer, 0, 2, 3)
193+
}
194+
}
195+
196+
@Test
197+
fun testPrefillNormalizedImage_zeroHeightThrows() {
198+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
199+
buffer.order(ByteOrder.nativeOrder())
200+
assertThrows(IllegalArgumentException::class.java) {
201+
llmModule.prefillNormalizedImage(buffer, 2, 0, 3)
202+
}
203+
}
204+
205+
@Test
206+
fun testPrefillNormalizedImage_zeroChannelsThrows() {
207+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
208+
buffer.order(ByteOrder.nativeOrder())
209+
assertThrows(IllegalArgumentException::class.java) {
210+
llmModule.prefillNormalizedImage(buffer, 2, 2, 0)
211+
}
212+
}
213+
214+
@Test
215+
fun testPrefillNormalizedImage_negativeWidthThrows() {
216+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
217+
buffer.order(ByteOrder.nativeOrder())
218+
assertThrows(IllegalArgumentException::class.java) {
219+
llmModule.prefillNormalizedImage(buffer, -1, 2, 3)
220+
}
221+
}
222+
223+
@Test
224+
fun testPrefillNormalizedImage_negativeHeightThrows() {
225+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
226+
buffer.order(ByteOrder.nativeOrder())
227+
assertThrows(IllegalArgumentException::class.java) {
228+
llmModule.prefillNormalizedImage(buffer, 2, -1, 3)
229+
}
230+
}
231+
232+
@Test
233+
fun testPrefillNormalizedImage_negativeChannelsThrows() {
234+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
235+
buffer.order(ByteOrder.nativeOrder())
236+
assertThrows(IllegalArgumentException::class.java) {
237+
llmModule.prefillNormalizedImage(buffer, 2, 2, -1)
238+
}
239+
}
240+
241+
@Test
242+
fun testPrefillNormalizedImage_nonNativeByteOrderThrows() {
243+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
244+
val nonNativeOrder =
245+
if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) ByteOrder.BIG_ENDIAN
246+
else ByteOrder.LITTLE_ENDIAN
247+
buffer.order(nonNativeOrder)
248+
assertThrows(IllegalArgumentException::class.java) {
249+
llmModule.prefillNormalizedImage(buffer, 2, 2, 3)
250+
}
251+
}
252+
253+
@Test
254+
fun testPrefillNormalizedImage_misalignedPositionThrows() {
255+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4 + 1)
256+
buffer.order(ByteOrder.nativeOrder())
257+
buffer.position(1)
258+
assertThrows(IllegalArgumentException::class.java) {
259+
llmModule.prefillNormalizedImage(buffer, 2, 2, 3)
260+
}
261+
}
262+
263+
@Test
264+
fun testPrefillNormalizedImage_validBufferPassesValidation() {
265+
val buffer = ByteBuffer.allocateDirect(2 * 2 * 3 * 4)
266+
buffer.order(ByteOrder.nativeOrder())
267+
try {
268+
llmModule.prefillNormalizedImage(buffer, 2, 2, 3)
269+
} catch (e: IllegalArgumentException) {
270+
throw AssertionError("Validation should not reject a correctly sized direct buffer", e)
271+
} catch (_: RuntimeException) {
272+
// Expected: native call may fail since this is a text-only model
273+
}
274+
}
275+
101276
companion object {
102277
private const val TEST_FILE_NAME = "/stories.pte"
103278
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"

0 commit comments

Comments
 (0)