Skip to content

Commit a81ef44

Browse files
authored
Unify Android JNI to single IRunner, wire prefill to runner (pytorch#17756)
Replace the dual-runner pattern (runner_ + multi_modal_runner_) with a single IRunner* that holds either TextLLMRunner or MultimodalRunner, leveraging MultimodalRunner's new IRunner inheritance from pytorch#17741. Each prefill method (text, images, audio) now immediately calls IRunner::prefill(vector<MultimodalInput>) instead of buffering inputs for later consumption by generate(). A needs_bos_ flag tracks whether the next prefill should apply BOS tokens — MultimodalRunner also guards this via pos_==0 internally, but TextLLMRunner trusts the caller. generate(), stop(), load(), and reset() no longer branch on model_type_category_; all dispatch through the unified runner_. Rename all JNI native methods from append* to prefill* to match the existing Java public API naming.
1 parent d0820e1 commit a81ef44

2 files changed

Lines changed: 216 additions & 167 deletions

File tree

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -373,19 +373,18 @@ public int generate(
373373
}
374374

375375
/**
376-
* Prefill a multimodal Module with the given images input.
376+
* Prefill the KV cache with the given image input.
377377
*
378378
* @param image Input image as a byte array
379379
* @param width Input image width
380380
* @param height Input image height
381381
* @param channels Input image number of channels
382-
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
383-
* exposed to user.
382+
* @return 0 on success
384383
* @throws RuntimeException if the prefill failed
385384
*/
386385
@Experimental
387386
public long prefillImages(int[] image, int width, int height, int channels) {
388-
int nativeResult = appendImagesInput(image, width, height, channels);
387+
int nativeResult = prefillImagesInput(image, width, height, channels);
389388
if (nativeResult != 0) {
390389
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
391390
}
@@ -434,7 +433,7 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels)
434433
}
435434
// slice() so that getDirectBufferAddress on the native side returns a pointer
436435
// starting at the current position, not the base address.
437-
int nativeResult = appendImagesInputBuffer(image.slice(), width, height, channels);
436+
int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels);
438437
if (nativeResult != 0) {
439438
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
440439
}
@@ -500,128 +499,125 @@ public void prefillNormalizedImage(ByteBuffer image, int width, int height, int
500499
}
501500
// slice() so that getDirectBufferAddress on the native side returns a pointer
502501
// starting at the current position, not the base address.
503-
int nativeResult = appendNormalizedImagesInputBuffer(image.slice(), width, height, channels);
502+
int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels);
504503
if (nativeResult != 0) {
505504
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
506505
}
507506
}
508507

509-
private native int appendImagesInput(int[] image, int width, int height, int channels);
508+
private native int prefillImagesInput(int[] image, int width, int height, int channels);
510509

511-
private native int appendImagesInputBuffer(ByteBuffer image, int width, int height, int channels);
510+
private native int prefillImagesInputBuffer(
511+
ByteBuffer image, int width, int height, int channels);
512512

513-
private native int appendNormalizedImagesInputBuffer(
513+
private native int prefillNormalizedImagesInputBuffer(
514514
ByteBuffer image, int width, int height, int channels);
515515

516516
/**
517-
* Prefill a multimodal Module with the given images input.
517+
* Prefill the KV cache with the given normalized image input.
518518
*
519519
* @param image Input normalized image as a float array
520520
* @param width Input image width
521521
* @param height Input image height
522522
* @param channels Input image number of channels
523-
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
524-
* exposed to user.
523+
* @return 0 on success
525524
* @throws RuntimeException if the prefill failed
526525
*/
527526
@Experimental
528527
public long prefillImages(float[] image, int width, int height, int channels) {
529-
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
528+
int nativeResult = prefillNormalizedImagesInput(image, width, height, channels);
530529
if (nativeResult != 0) {
531530
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
532531
}
533532
return 0;
534533
}
535534

536-
private native int appendNormalizedImagesInput(
535+
private native int prefillNormalizedImagesInput(
537536
float[] image, int width, int height, int channels);
538537

539538
/**
540-
* Prefill a multimodal Module with the given audio input.
539+
* Prefill the KV cache with the given preprocessed audio input.
541540
*
542541
* @param audio Input preprocessed audio as a byte array
543542
* @param batch_size Input batch size
544543
* @param n_bins Input number of bins
545544
* @param n_frames Input number of frames
546-
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
547-
* exposed to user.
545+
* @return 0 on success
548546
* @throws RuntimeException if the prefill failed
549547
*/
550548
@Experimental
551549
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
552-
int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
550+
int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames);
553551
if (nativeResult != 0) {
554552
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
555553
}
556554
return 0;
557555
}
558556

559-
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
557+
private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
560558

561559
/**
562-
* Prefill a multimodal Module with the given audio input.
560+
* Prefill the KV cache with the given preprocessed audio input.
563561
*
564562
* @param audio Input preprocessed audio as a float array
565563
* @param batch_size Input batch size
566564
* @param n_bins Input number of bins
567565
* @param n_frames Input number of frames
568-
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
569-
* exposed to user.
566+
* @return 0 on success
570567
* @throws RuntimeException if the prefill failed
571568
*/
572569
@Experimental
573570
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
574-
int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
571+
int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames);
575572
if (nativeResult != 0) {
576573
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
577574
}
578575
return 0;
579576
}
580577

581-
private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);
578+
private native int prefillAudioInputFloat(
579+
float[] audio, int batch_size, int n_bins, int n_frames);
582580

583581
/**
584-
* Prefill a multimodal Module with the given raw audio input.
582+
* Prefill the KV cache with the given raw audio input.
585583
*
586584
* @param audio Input raw audio as a byte array
587585
* @param batch_size Input batch size
588586
* @param n_channels Input number of channels
589587
* @param n_samples Input number of samples
590-
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
591-
* exposed to user.
588+
* @return 0 on success
592589
* @throws RuntimeException if the prefill failed
593590
*/
594591
@Experimental
595592
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
596-
int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
593+
int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples);
597594
if (nativeResult != 0) {
598595
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
599596
}
600597
return 0;
601598
}
602599

603-
private native int appendRawAudioInput(
600+
private native int prefillRawAudioInput(
604601
byte[] audio, int batch_size, int n_channels, int n_samples);
605602

606603
/**
607-
* Prefill a multimodal Module with the given text input.
604+
* Prefill the KV cache with the given text prompt.
608605
*
609606
* @param prompt The text prompt to prefill.
610-
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
611-
* exposed to user.
607+
* @return 0 on success
612608
* @throws RuntimeException if the prefill failed
613609
*/
614610
@Experimental
615611
public long prefillPrompt(String prompt) {
616-
int nativeResult = appendTextInput(prompt);
612+
int nativeResult = prefillTextInput(prompt);
617613
if (nativeResult != 0) {
618614
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
619615
}
620616
return 0;
621617
}
622618

623619
// returns status
624-
private native int appendTextInput(String prompt);
620+
private native int prefillTextInput(String prompt);
625621

626622
/**
627623
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.

0 commit comments

Comments
 (0)