diff --git a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte index 5903c5b5c32..942773b0633 100644 Binary files a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte and b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte differ diff --git a/examples/qualcomm/oss_scripts/llama/dataset/builders.py b/examples/qualcomm/oss_scripts/llama/dataset/builders.py index 1e839e765e1..c593a802c04 100644 --- a/examples/qualcomm/oss_scripts/llama/dataset/builders.py +++ b/examples/qualcomm/oss_scripts/llama/dataset/builders.py @@ -263,6 +263,7 @@ def build_calib_dataloaders(self) -> Dict[str, Optional[DataLoader]]: if len(decoder_datasets) > 1 else decoder_datasets[0] ) + self._log_dataset_stats(datasets, cfg.batch_size, phase="calibration") return dict.fromkeys(_ALL_MODALITY_KEYS) | { modality: DataLoader( @@ -315,3 +316,31 @@ def build_runtime_dataloader( ) for modality, dataset in datasets.items() } + + @staticmethod + def _log_dataset_stats( + datasets: Dict[str, Dataset], + batch_size: int, + phase: str = "calibration", + ) -> None: + """Log sample/batch counts per modality; raises if any dataset < batch_size.""" + for modality, ds in datasets.items(): + n = len(ds) + n_batches = n // batch_size + dropped = n - n_batches * batch_size + drop_str = f" ({dropped} dropped)" if batch_size > 1 and dropped else "" + logging.info( + "%s '%s': %d samples, batch_size=%d, %d batches%s", + phase, + modality, + n, + batch_size, + n_batches, + drop_str, + ) + if batch_size > 1 and n < batch_size: + raise ValueError( + f"{phase} '{modality}' has {n} samples but " + f"batch_size={batch_size}. " + "Increase the data limit or reduce the batch size." + ) diff --git a/examples/qualcomm/oss_scripts/llama/evaluator/lm_eval_adapter.py b/examples/qualcomm/oss_scripts/llama/evaluator/lm_eval_adapter.py index 0314a6e14a6..8d17fed55b1 100644 --- a/examples/qualcomm/oss_scripts/llama/evaluator/lm_eval_adapter.py +++ b/examples/qualcomm/oss_scripts/llama/evaluator/lm_eval_adapter.py @@ -8,6 +8,7 @@ from typing import Callable, Optional, Union import torch +import torch.nn.functional as F from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper from executorch.examples.qualcomm.oss_scripts.llama.inference import DecoderInference from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer @@ -34,6 +35,7 @@ def __init__( max_seq_length: int, get_example_inputs: Callable, use_i64_token: bool, + max_batch_size: int = 1, ): assert max_seq_length is not None, "max_seq_length must be provided" super().__init__( @@ -43,15 +45,23 @@ def __init__( self._runner = DecoderInference( get_example_inputs=get_example_inputs, max_context_len=max_seq_length, + max_batch_size=max_batch_size, use_i64_token=use_i64_token, ) + self._batch_size = max_batch_size + + @property + def batch_size(self): + return self._batch_size def _model_call(self, inps): - logits = self._runner.predict_step( + actual_bsz = inps.shape[0] + inps = F.pad(inps, (0, 0, 0, self._batch_size - actual_bsz)) + logits = self._runner.prediction_step( self._model, input_ids=inps, ) - return logits + return logits[:actual_bsz] def run_lm_eval( @@ -74,6 +84,7 @@ def run_lm_eval( max_seq_length=max_seq_length, get_example_inputs=get_example_inputs, use_i64_token=use_i64_token, + max_batch_size=max_batch_size, ) with torch.no_grad(): eval_results = simple_evaluate( diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index d3d4a475288..0611029ddb9 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -621,6 +621,15 @@ def _build_parser(): "Multiple files are merged.", ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for text decoder quantization. Larger values increase throughput " + "but require more host memory. Only affects the CALIBRATE graph; DECODE and " + "PREFILL graphs always use batch size 1.", + ) + parser.add_argument( "-F", "--use_fp16", diff --git a/examples/qualcomm/oss_scripts/llama/masking_utils.py b/examples/qualcomm/oss_scripts/llama/masking_utils.py index a09cdf1240f..5561f25ec89 100644 --- a/examples/qualcomm/oss_scripts/llama/masking_utils.py +++ b/examples/qualcomm/oss_scripts/llama/masking_utils.py @@ -35,7 +35,8 @@ def create_causal_attn_mask(max_batch_size: int, ar_len: int, max_context_len: i ], dim=-1, ) - mask = mask[None, :, :].expand(max_batch_size, ar_len, max_context_len) + # num_heads=1: the mask broadcasts across all heads. + mask = mask[None, None, :, :].expand(max_batch_size, 1, ar_len, max_context_len) return mask @@ -68,7 +69,8 @@ def create_sliding_window_attn_mask( ], dim=-1, ) - mask = mask[None, :, :].expand(max_batch_size, ar_len, max_context_len) + # num_heads=1: the mask broadcasts across all heads. + mask = mask[None, None, :, :].expand(max_batch_size, 1, ar_len, max_context_len) return mask @@ -123,8 +125,8 @@ def _mask_padding_positions( ) -> None: """Mask positions beyond each sequence's actual length.""" actual_lens = torch.tensor([len(seq) for seq in input_ids]) - pad_rows = torch.arange(max_seq_length).unsqueeze(0) >= actual_lens.unsqueeze(1) - self.mask.masked_fill_(pad_rows.unsqueeze(-1), PADDING_MASK_VALUE) + pad_rows = torch.arange(max_seq_length) >= actual_lens.unsqueeze(1) + self.mask.masked_fill_(pad_rows[:, None, :, None], PADDING_MASK_VALUE) class CausalAttentionMask(BaseAttentionMask): diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp index bc57eab5bde..7f73abdacb5 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp @@ -247,10 +247,10 @@ Error QNNMultimodalRunner::load() { int32_t token_generator_ar_len = 0; int32_t max_cache_len = 0; int32_t max_ar_len = 0; - // atten mask: [1, AR-N, CL] + // atten mask: [1, 1, AR-N, CL] auto atten_mask_meta_token = method_meta->input_tensor_meta(1); - token_generator_ar_len = atten_mask_meta_token->sizes()[1]; - context_len_ = atten_mask_meta_token->sizes()[2]; + token_generator_ar_len = atten_mask_meta_token->sizes()[2]; + context_len_ = atten_mask_meta_token->sizes()[3]; if (eval_mode_ == EvalMode::kKVCached) { prompt_processor_ar_len = token_generator_ar_len; } else if ( @@ -259,7 +259,7 @@ Error QNNMultimodalRunner::load() { auto atten_mask_meta_prompt = text_decoder_->method_meta(prompt_processor_method_name) ->input_tensor_meta(1); - prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[1]; + prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[2]; } if (prompt_processor_ar_len == context_len_) max_cache_len = context_len_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 611c4aaea35..0e1a27a751a 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -244,10 +244,10 @@ Error Runner::load() { int32_t token_generator_ar_len = 0; int32_t max_cache_len = 0; int32_t max_ar_len = 0; - // atten mask: [1, AR-N, CL] + // atten mask: [1, 1, AR-N, CL] auto atten_mask_meta_token = method_meta->input_tensor_meta(1); - token_generator_ar_len = atten_mask_meta_token->sizes()[1]; - context_len_ = atten_mask_meta_token->sizes()[2]; + token_generator_ar_len = atten_mask_meta_token->sizes()[2]; + context_len_ = atten_mask_meta_token->sizes()[3]; if (eval_mode_ == EvalMode::kKVCached) { prompt_processor_ar_len = token_generator_ar_len; } else if ( @@ -256,7 +256,7 @@ Error Runner::load() { auto atten_mask_meta_prompt = module_->method_meta(prompt_processor_method_name) ->input_tensor_meta(1); - prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[1]; + prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[2]; } if (prompt_processor_ar_len == context_len_) max_cache_len = context_len_; diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/base_component.py b/examples/qualcomm/oss_scripts/llama/wrappers/base_component.py index 149a376e918..76da9063ada 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/base_component.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/base_component.py @@ -90,7 +90,6 @@ def process_model_args( config: LLMModelConfig object to be used. mode: Mode of operation (PREFILL, DECODE, or CALIBRATE). """ - # TODO: support batch inputs if necessary if mode == Mode.DECODE: ar_len = ( # To get better performance, we round up to the nearest power of 2. @@ -107,8 +106,8 @@ def process_model_args( else: raise ValueError(f"Unsupported mode: {mode}") - # TODO: support multi_batch for CALIBRATION MODE - model_args.max_batch_size = 1 + # TODO: support batch inputs for runtime mode if necessary + model_args.max_batch_size = control_args.batch_size if mode == Mode.CALIBRATE else 1 model_args.max_seq_len = control_args.max_seq_len model_args.max_context_len = control_args.max_context_len model_args.use_kv_cache = ( diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 9bab682eac8..10d1abc648a 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -456,6 +456,7 @@ def _tag_ios(self, node, fixed_point_type): atten_mask_shape = { ( self.meta["get_max_batch_size"], + 1, # num_heads=1: the mask broadcasts across all heads. self.meta["get_ar_len"], self.meta["get_max_context_len"], ), @@ -609,6 +610,7 @@ def quantize(self, request: Request): # noqa: C901 use_i64_token=self.control_args.embedding_quantize is not None, num_fewshot=self.control_args.eval_num_fewshot, limit=self.control_args.eval_limit, + max_batch_size=self.meta["get_max_batch_size"], event_name="export_tasks", ) @@ -674,6 +676,7 @@ def quantize(self, request: Request): # noqa: C901 use_i64_token=self.control_args.embedding_quantize is not None, num_fewshot=self.control_args.eval_num_fewshot, limit=self.control_args.eval_limit, + max_batch_size=self.meta["get_max_batch_size"], event_name="convert_pt2e_tasks", ) @@ -1186,8 +1189,9 @@ def _calibrate(self, model, calibration_datasets): outputs.append(outputs_each_batch) return DataLoader( ModalityEncoderDataset(outputs), - batch_size=1, + batch_size=self.control_args.batch_size, shuffle=False, + drop_last=self.control_args.batch_size > 1, ) def quantize(self, request: Request):