Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(
self.max_seq_length = pte_max_context_len

def run(self, prompt):
golden_logits, _ = INFERENCE_REGISTRY[True](
result = INFERENCE_REGISTRY[True](
get_example_inputs=self.get_example_inputs,
prompt=prompt,
module=self.source_model,
Expand All @@ -433,6 +433,7 @@ def run(self, prompt):
use_i64_token=self.args.embedding_quantize is not None,
collect_logits=True,
)
golden_logits = result.logits

input_file_name = f"{self.args.artifact}/input_tokens.raw"

Expand Down
81 changes: 53 additions & 28 deletions examples/qualcomm/oss_scripts/llama/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ class DecoderInputs:
embedding: Optional[torch.Tensor] = None


@dataclass
class DecoderOutputs:
logits: Optional[torch.Tensor] = None
token_list: Optional[List[int]] = None
input_samples: Optional[List] = None


class GraphModuleCalibrationWrapper(EagerEvalWrapper):
"""
A wrapper class for calibration
Expand All @@ -94,6 +101,7 @@ def __init__( # noqa: C901
get_example_inputs: Callable,
use_i64_token: bool,
seq_mse_candidates: int,
collect_input_samples: bool = False,
):
# n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call
assert max_seq_length is not None, "max_seq_length must be provided"
Expand All @@ -108,30 +116,32 @@ def __init__( # noqa: C901
self.use_i64_token = use_i64_token
self.seq_mse_candidates = seq_mse_candidates
self._input_samples = None
self.collect_input_samples = collect_input_samples

def get_input_samples(self):
return self._input_samples

def _model_call(self, inps):
all_logits = None
kwargs = {}
if self._use_kv_cache:
kwargs["ar_len"] = self.ar_len
kwargs["seq_mse_candidates"] = self.seq_mse_candidates

all_logits, self._input_samples = INFERENCE_REGISTRY[self._use_kv_cache](
result = INFERENCE_REGISTRY[self._use_kv_cache](
self.get_example_inputs,
inps,
self._model,
self._tokenizer,
max_seq_len=self.max_seq_length,
use_i64_token=self.use_i64_token,
collect_logits=True,
collect_input_samples=self.collect_input_samples,
**kwargs,
)
self._input_samples = result.input_samples
# one shot is enough for seq mse
self.seq_mse_candidates = 0
return all_logits
return result.logits


class LookaheadDecoder:
Expand Down Expand Up @@ -731,7 +741,8 @@ def kv_inference( # noqa: C901
collect_logits=False,
seq_mse_candidates=0,
lookahead_config=None,
):
collect_input_samples=False,
) -> DecoderOutputs:
input_samples = [] # Record input sample for quantization error analysis
is_multimodal = all(
[
Expand Down Expand Up @@ -818,6 +829,7 @@ def kv_inference( # noqa: C901

# record total input tokens and generated tokens
total_token_list = prompt_token_list
last_token_in_prompt = prompt_token_list[-1] if len(prompt_token_list) > 0 else None

# 3. prepare decoder inputs
inputs = DecoderInputs(
Expand Down Expand Up @@ -845,28 +857,33 @@ def kv_inference( # noqa: C901

# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
generate_input_sample = _generate(
inputs,
cur_pos,
module,
tokenizer,
tok_embedding,
ar_len,
max_seq_len,
k_caches,
v_caches,
total_token_list,
lookahead_config,
)
if generate_input_sample is not None:
input_samples.append(generate_input_sample)
else:
input_samples.append(prefill_input_sample)
generate_input_sample = None
if last_token_in_prompt != tokenizer.eos_id:
generate_input_sample = _generate(
inputs,
cur_pos,
module,
tokenizer,
tok_embedding,
ar_len,
max_seq_len,
k_caches,
v_caches,
total_token_list,
lookahead_config,
)

if collect_input_samples:
input_samples.append(generate_input_sample or prefill_input_sample)

logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
if collect_logits:
result_logits = torch.cat(result_logits, dim=1)
return result_logits, input_samples
return DecoderOutputs(
logits=result_logits if collect_logits else None,
token_list=total_token_list,
input_samples=input_samples if collect_input_samples else None,
)


@register_inference(use_kv_cache=False)
Expand All @@ -882,7 +899,8 @@ def prefill_inference(
max_seq_len=512,
use_i64_token=False,
collect_logits=False,
):
collect_input_samples=False,
) -> DecoderOutputs:
input_samples = None # Record input sample for quantization error analysis
is_multimodal = all(
[
Expand Down Expand Up @@ -950,7 +968,11 @@ def prefill_inference(
pos += 1
if isinstance(prompt, str):
logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}")
return result_logits, [input_samples]
return DecoderOutputs(
logits=result_logits if collect_logits else None,
token_list=token_list,
input_samples=[input_samples] if collect_input_samples else None,
)


def graph_module_inference(
Expand All @@ -972,7 +994,8 @@ def graph_module_inference(
event_name: Optional[str] = None,
seq_mse_candidates: int = 0,
lookahead_config: Optional[Tuple[int]] = None,
):
collect_input_samples: bool = False,
) -> DecoderOutputs:
"""
This function supports model execution from static nn.Module decoder model
all the way to edge program.
Expand All @@ -988,7 +1011,7 @@ def graph_module_inference(
kwargs["ar_len"] = ar_len
kwargs["lookahead_config"] = lookahead_config

_, input_samples = INFERENCE_REGISTRY[use_kv_cache](
result = INFERENCE_REGISTRY[use_kv_cache](
get_example_inputs,
prompt,
module,
Expand All @@ -1000,10 +1023,11 @@ def graph_module_inference(
max_seq_len=max_seq_len,
use_i64_token=use_i64_token,
collect_logits=False,
collect_input_samples=collect_input_samples,
**kwargs,
)
logging.info(f"Prompt summary for {event_name}")
return input_samples
return result
else:
calibration_wrapper = GraphModuleCalibrationWrapper(
model=module,
Expand All @@ -1014,6 +1038,7 @@ def graph_module_inference(
get_example_inputs=get_example_inputs,
use_i64_token=use_i64_token,
seq_mse_candidates=seq_mse_candidates,
collect_input_samples=collect_input_samples,
)
with torch.no_grad():
eval_results = simple_evaluate(
Expand All @@ -1026,4 +1051,4 @@ def graph_module_inference(
for task, res in eval_results["results"].items():
logging.info(f"{task}: {res}")

return calibration_wrapper.get_input_samples()
return DecoderOutputs(input_samples=calibration_wrapper.get_input_samples())
13 changes: 10 additions & 3 deletions examples/qualcomm/oss_scripts/llama/wrappers/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
class Mode(Enum):
PREFILL = 1
DECODE = 2
CALIBRATE = 3


def log_info(func):
Expand Down Expand Up @@ -83,7 +84,7 @@ def process_model_args(
model_args: ModelArgs object to be modified.
quant_recipe: Quantization recipe to be used.
config: LLMModelConfig object to be used.
mode: Mode of operation (PREFILL or DECODE).
mode: Mode of operation (PREFILL, DECODE, or CALIBRATE).
"""
# TODO: support batch inputs if necessary
if mode == Mode.DECODE:
Expand All @@ -95,13 +96,19 @@ def process_model_args(
if control_args.model_mode == "lookahead"
else 1
)
else:
elif mode == Mode.PREFILL:
ar_len = control_args.prefill_ar_len
elif mode == Mode.CALIBRATE:
ar_len = control_args.max_context_len
else:
raise ValueError(f"Unsupported mode: {mode}")

model_args.max_batch_size = 1
model_args.max_seq_len = control_args.max_seq_len
model_args.max_context_len = control_args.max_context_len
model_args.use_kv_cache = control_args.max_context_len != ar_len
model_args.use_kv_cache = (
control_args.max_context_len != ar_len or mode == Mode.CALIBRATE
)
model_args.enable_r3 = config.r3
model_args.ar_len = ar_len
model_args.kv_io_bit_width = quant_recipe.get_kv_io_bit_width()
Expand Down
Loading
Loading