Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
29 changes: 29 additions & 0 deletions examples/qualcomm/oss_scripts/llama/dataset/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def build_calib_dataloaders(self) -> Dict[str, Optional[DataLoader]]:
if len(decoder_datasets) > 1
else decoder_datasets[0]
)
self._log_dataset_stats(datasets, cfg.batch_size, phase="calibration")

return dict.fromkeys(_ALL_MODALITY_KEYS) | {
modality: DataLoader(
Expand Down Expand Up @@ -315,3 +316,31 @@ def build_runtime_dataloader(
)
for modality, dataset in datasets.items()
}

@staticmethod
def _log_dataset_stats(
datasets: Dict[str, Dataset],
batch_size: int,
phase: str = "calibration",
) -> None:
"""Log sample/batch counts per modality; raises if any dataset < batch_size."""
for modality, ds in datasets.items():
n = len(ds)
n_batches = n // batch_size
dropped = n - n_batches * batch_size
drop_str = f" ({dropped} dropped)" if batch_size > 1 and dropped else ""
logging.info(
"%s '%s': %d samples, batch_size=%d, %d batches%s",
phase,
modality,
n,
batch_size,
n_batches,
drop_str,
)
if batch_size > 1 and n < batch_size:
raise ValueError(
f"{phase} '{modality}' has {n} samples but "
f"batch_size={batch_size}. "
"Increase the data limit or reduce the batch size."
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Optional, Union

import torch
import torch.nn.functional as F
from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper
from executorch.examples.qualcomm.oss_scripts.llama.inference import DecoderInference
from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer
Expand All @@ -34,6 +35,7 @@ def __init__(
max_seq_length: int,
get_example_inputs: Callable,
use_i64_token: bool,
max_batch_size: int = 1,
):
assert max_seq_length is not None, "max_seq_length must be provided"
super().__init__(
Expand All @@ -43,15 +45,23 @@ def __init__(
self._runner = DecoderInference(
get_example_inputs=get_example_inputs,
max_context_len=max_seq_length,
max_batch_size=max_batch_size,
use_i64_token=use_i64_token,
)
self._batch_size = max_batch_size

@property
def batch_size(self):
return self._batch_size

def _model_call(self, inps):
logits = self._runner.predict_step(
actual_bsz = inps.shape[0]
inps = F.pad(inps, (0, 0, 0, self._batch_size - actual_bsz))
logits = self._runner.prediction_step(
self._model,
input_ids=inps,
)
return logits
return logits[:actual_bsz]


def run_lm_eval(
Expand All @@ -74,6 +84,7 @@ def run_lm_eval(
max_seq_length=max_seq_length,
get_example_inputs=get_example_inputs,
use_i64_token=use_i64_token,
max_batch_size=max_batch_size,
)
with torch.no_grad():
eval_results = simple_evaluate(
Expand Down
9 changes: 9 additions & 0 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,15 @@ def _build_parser():
"Multiple files are merged.",
)

parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for text decoder quantization. Larger values increase throughput "
"but require more host memory. Only affects the CALIBRATE graph; DECODE and "
"PREFILL graphs always use batch size 1.",
)

parser.add_argument(
"-F",
"--use_fp16",
Expand Down
10 changes: 6 additions & 4 deletions examples/qualcomm/oss_scripts/llama/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def create_causal_attn_mask(max_batch_size: int, ar_len: int, max_context_len: i
],
dim=-1,
)
mask = mask[None, :, :].expand(max_batch_size, ar_len, max_context_len)
# num_heads=1: the mask broadcasts across all heads.
mask = mask[None, None, :, :].expand(max_batch_size, 1, ar_len, max_context_len)
return mask


Expand Down Expand Up @@ -68,7 +69,8 @@ def create_sliding_window_attn_mask(
],
dim=-1,
)
mask = mask[None, :, :].expand(max_batch_size, ar_len, max_context_len)
# num_heads=1: the mask broadcasts across all heads.
mask = mask[None, None, :, :].expand(max_batch_size, 1, ar_len, max_context_len)
return mask


Expand Down Expand Up @@ -123,8 +125,8 @@ def _mask_padding_positions(
) -> None:
"""Mask positions beyond each sequence's actual length."""
actual_lens = torch.tensor([len(seq) for seq in input_ids])
pad_rows = torch.arange(max_seq_length).unsqueeze(0) >= actual_lens.unsqueeze(1)
self.mask.masked_fill_(pad_rows.unsqueeze(-1), PADDING_MASK_VALUE)
pad_rows = torch.arange(max_seq_length) >= actual_lens.unsqueeze(1)
self.mask.masked_fill_(pad_rows[:, None, :, None], PADDING_MASK_VALUE)


class CausalAttentionMask(BaseAttentionMask):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ Error QNNMultimodalRunner::load() {
int32_t token_generator_ar_len = 0;
int32_t max_cache_len = 0;
int32_t max_ar_len = 0;
// atten mask: [1, AR-N, CL]
// atten mask: [1, 1, AR-N, CL]
auto atten_mask_meta_token = method_meta->input_tensor_meta(1);
token_generator_ar_len = atten_mask_meta_token->sizes()[1];
context_len_ = atten_mask_meta_token->sizes()[2];
token_generator_ar_len = atten_mask_meta_token->sizes()[2];
context_len_ = atten_mask_meta_token->sizes()[3];
if (eval_mode_ == EvalMode::kKVCached) {
prompt_processor_ar_len = token_generator_ar_len;
} else if (
Expand All @@ -259,7 +259,7 @@ Error QNNMultimodalRunner::load() {
auto atten_mask_meta_prompt =
text_decoder_->method_meta(prompt_processor_method_name)
->input_tensor_meta(1);
prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[1];
prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[2];
}
if (prompt_processor_ar_len == context_len_)
max_cache_len = context_len_;
Expand Down
8 changes: 4 additions & 4 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ Error Runner::load() {
int32_t token_generator_ar_len = 0;
int32_t max_cache_len = 0;
int32_t max_ar_len = 0;
// atten mask: [1, AR-N, CL]
// atten mask: [1, 1, AR-N, CL]
auto atten_mask_meta_token = method_meta->input_tensor_meta(1);
token_generator_ar_len = atten_mask_meta_token->sizes()[1];
context_len_ = atten_mask_meta_token->sizes()[2];
token_generator_ar_len = atten_mask_meta_token->sizes()[2];
context_len_ = atten_mask_meta_token->sizes()[3];
if (eval_mode_ == EvalMode::kKVCached) {
prompt_processor_ar_len = token_generator_ar_len;
} else if (
Expand All @@ -256,7 +256,7 @@ Error Runner::load() {
auto atten_mask_meta_prompt =
module_->method_meta(prompt_processor_method_name)
->input_tensor_meta(1);
prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[1];
prompt_processor_ar_len = atten_mask_meta_prompt->sizes()[2];
}
if (prompt_processor_ar_len == context_len_)
max_cache_len = context_len_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def process_model_args(
config: LLMModelConfig object to be used.
mode: Mode of operation (PREFILL, DECODE, or CALIBRATE).
"""
# TODO: support batch inputs if necessary
if mode == Mode.DECODE:
ar_len = (
# To get better performance, we round up to the nearest power of 2.
Expand All @@ -107,8 +106,8 @@ def process_model_args(
else:
raise ValueError(f"Unsupported mode: {mode}")

# TODO: support multi_batch for CALIBRATION MODE
model_args.max_batch_size = 1
# TODO: support batch inputs for runtime mode if necessary
model_args.max_batch_size = control_args.batch_size if mode == Mode.CALIBRATE else 1
model_args.max_seq_len = control_args.max_seq_len
model_args.max_context_len = control_args.max_context_len
model_args.use_kv_cache = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def _tag_ios(self, node, fixed_point_type):
atten_mask_shape = {
(
self.meta["get_max_batch_size"],
1, # num_heads=1: the mask broadcasts across all heads.
self.meta["get_ar_len"],
self.meta["get_max_context_len"],
),
Expand Down Expand Up @@ -609,6 +610,7 @@ def quantize(self, request: Request): # noqa: C901
use_i64_token=self.control_args.embedding_quantize is not None,
num_fewshot=self.control_args.eval_num_fewshot,
limit=self.control_args.eval_limit,
max_batch_size=self.meta["get_max_batch_size"],
event_name="export_tasks",
)

Expand Down Expand Up @@ -674,6 +676,7 @@ def quantize(self, request: Request): # noqa: C901
use_i64_token=self.control_args.embedding_quantize is not None,
num_fewshot=self.control_args.eval_num_fewshot,
limit=self.control_args.eval_limit,
max_batch_size=self.meta["get_max_batch_size"],
event_name="convert_pt2e_tasks",
)

Expand Down Expand Up @@ -1186,8 +1189,9 @@ def _calibrate(self, model, calibration_datasets):
outputs.append(outputs_each_batch)
return DataLoader(
ModalityEncoderDataset(outputs),
batch_size=1,
batch_size=self.control_args.batch_size,
shuffle=False,
drop_last=self.control_args.batch_size > 1,
)

def quantize(self, request: Request):
Expand Down
Loading