Skip to content

Commit 3b3c9d4

Browse files
authored
Qualcomm AI Engine Direct - Support attention sink for long context usecase (pytorch#16574)
### Summary - Support narrow operation - Support attention sink for static llama - Include the `--max_context_len` option to set the maximum length for the model's memory, and use `max_seq_len` to define the maximum sequence length for evaluation. - Specified `--use_attention_sink <sink_size>,<eviction_batch_size>` to enable attention sink feature in llama.py - Behavior matrix in `llama.py` - Given that `--compile_only`: - Specify `--use_attention_sink` -> Compile the LLM and the attention sink model - Otherwise, -> Compile the LLM only - Given that `--pre_get_pte`: - Specify `--use_attention_sink` -> If the criteria below are not met, compile the attention sink model before running inference. And then inference with attention sink - Check if the attention sink model exists - Verify sink_size and eviction batch size are identical - Otherwise, -> Inference LLM without attention sink - Neither `--compile_only` nor `--pre_get_pte`: - Specify `--use_attention_sink` -> Compile the LLM and the attention sink model and inference with attention sink - Otherwise, -> Compile the LLM and inference LLM without attention sink ### Test plan - Test for narrow op: - `python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_narrow --model SM8750 --device $DEVICE --build_folder build-android` - `python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_narrow --model SM8750 --device $DEVICE --build_folder build-android` - Test for attention sink in `llama.py` - `python backends/qualcomm/tests/test_qnn_delegate.py TestExampleLLMScript.test_attention_sink --model SM8750 --device $DEVICE -b build-android -a unit_test` ### Results - Test multi-conversation using attention sink in `llama.py` - `python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8750 --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode kv --max_seq_len 4096 --max_context_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" "Could you give a more difficult example in python?" "Could you add a GUI for this game?" "Could you tell me more about tkinter?" "Is it possible to deploy on a website?" --tasks wikitext --limit 1 --use_attention_sink 4,32` <details> <summary> Set llama 3.2 1b instruct to a max context length of 1024, and activate the attention sink feature with a sink_size of 4 and eviction_batch_size of 32. Then, run a multi-turn conversation with a sequence length of 4096 using five prompts: <br> "I would like to learn python, could you teach me with a simple example?", <br> "Could you give a more difficult example in python?", <br> "Could you add a GUI for this game?", <br> "Could you tell me more about tkinter?", <br> "Is it possible to deploy on a website?" </summary> ~~~ <|start_header_id|>user<|end_header_id|> I would like to learn python, could you teach me with a simple example?<|eot_id|><|start_header_id|>assistant<|end_header_id|> I'd be happy to help you learn Python. Here's a simple example to get you started: **Example: Guessing Game** In this example, we'll create a simple game where the computer thinks of a number between 1 and 100, and you try to guess it. After each guess, the computer will tell you if your guess is higher or lower than the number it's thinking of. **Code:** ```python import random # The number the computer is thinking of number = random.randint(1, 100) # Keep track of your guesses guesses = 0 # Keep asking for guesses until you guess the number while True: # Ask for a guess guess = input("Guess a number: ") # Check if the guess is valid try: # Convert the guess to an integer guess = int(guess) # Check if the guess is within the range if guess < 1 or guess > 100: print("Please enter a number between 1 and 100.") else: # Increment the number of guesses guesses += 1 # Check if you've guessed the number if guess == number: print(f"Congratulations! You guessed the number in {guesses} guesses.") break else: # Tell the computer how close your guess is print(f"Your guess is {guess} away from the number.") except ValueError: print("Invalid input. Please enter a number.") ``` **How to Run:** 1. Save this code in a file with a `.py` extension (e.g., `guessing_game.py`). 2. Open a terminal or command prompt and navigate to the directory where you saved the file. 3. Type `python guessing_game.py` to run the game. **What to Do:** 1. Run the game by typing `python guessing_game.py`. 2. Follow the prompts to guess a number. 3. Keep guessing until you guess the number correctly. I hope this example helps you get started with Python! Let me know if you have any questions or need further guidance.<|eot_id|><|start_header_id|>user<|end_header_id|> Could you give a more difficult example in python?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Here's a more challenging example in Python: **Example: Rock, Paper, Scissors** In this example, we'll create a game of Rock, Paper, Scissors where the computer thinks of a random move and you try to guess it. After each guess, the computer will tell you if your move is Rock, Paper, or Scissors. **Code:** ```python import random # Define the possible moves moves = ["rock", "paper", "scissors"] # The computer's move computer_move = random.choice(moves) # Keep track of your guesses guesses = 0 # Keep asking for guesses until you guess the computer's move while True: # Ask for a guess guess = input("Guess a move (rock, paper, or scissors): ") # Check if the guess is valid if guess.lower() in moves: # Convert the guess to a move move = guess.lower() if move == computer_move: print(f"Congratulations! You guessed the computer's move correctly!") break else: # Increment the number of guesses guesses += 1 # Tell the computer how close your guess is print(f"Your guess is {guess} away from the computer's move.") else: print("Invalid input. Please enter a valid move (rock, paper, or scissors).") ``` **How to Run:** 1. Save this code in a file with a `.py` extension (e.g., `rock_paper_scissors.py`). 2. Open a terminal or command prompt and navigate to the directory where you saved the file. 3. Type `python rock_paper_scissors.py` to run the game. **What to Do:** 1. Run the game by typing `python rock_paper_scissors.py`. 2. Follow the prompts to guess a move. 3. Keep guessing until you guess the computer's move correctly. This example is more challenging because it: * Requires you to keep track of your guesses and the computer's move. * Requires you to convert your guess to a specific move (rock, paper, or scissors). * Requires you to tell the computer how close your guess is to the computer's move. I hope this example helps you get started with more advanced Python concepts! Let me know if you have any questions or need further guidance.<|eot_id|><|start_header_id|>user<|end_header_id|> Could you add a GUI for this game?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Here's an example of how you can create a GUI for the Rock, Paper, Scissors game using Python's Tkinter library: **Code:** ```python import tkinter as tk from random import randint class RockPaperScissors: def __init__(self): self.window = tk.Tk() self.window.title("Rock, Paper, Scissors") self.window.geometry("300x200") self.move = tk.StringVar() self.move.set("rock") self.rock_button = tk.Button(self.window, text="Rock", command=self.rock) self.rock_button.pack() self.paper_button = tk.Button(self.window, text="Paper", command=self.paper) self.paper_button.pack() self.scissors_button = tk.Button(self.window, text="Scissors", command=self.scissors) self.scissors_button.pack() self.result_label = tk.Label(self.window, text="", font=('Arial', 24)) self.result_label.pack() def rock(self): computer_move = randint(0, 2) if computer_move == 0: self.move.set("rock") self.result_label.config(text=f"Computer chose {self.move.get()}.") elif computer_move == 1: self.move.set("paper") self.result_label.config(text=f"Computer chose {self.move.get()}.") else: self.move.set("scissors") self.result_label.config(text=f"Computer chose {self.move.get()}.") def paper(self): computer_move = randint(0, 2) if computer_move == 0: self.move.set("paper") self.result_label.config(text=f"Computer chose {self.move.get()}.") elif computer_move == 1: self.move.set("rock") self.result_label.config(text=f"Computer chose {self.move.get()}.") else: self.move.set("scissors") self.result_label.config(text=f"Computer chose {self.move.get()}.") def scissors(self): computer_move = randint(0, 2) if computer_move == 0: self.move.set("scissors") self.result_label.config(text=f"Computer chose {self.move.get()}.") elif computer_move == 1: self.move.set("paper") self.result_label.config(text=f"Computer chose {self.move.get()}.") else: self.move.set("rock") self.result_label.config(text=f"Computer chose {self.move.get()}.") def run(self): self.window.mainloop() if __name__ == "__main__": game = RockPaperScissors() game.run() ``` **How to Run:** 1. Save this code in a file with a `.py` extension (e.g., `rock_paper_scissors.py`). 2. Open a terminal or command prompt and navigate to the directory where you saved the file. 3. Type `python rock_paper_scissors.py` to run the game. 4. A GUI window will appear with three buttons: Rock, Paper, and Scissors. 5. Click on the buttons to play the game. This GUI implementation provides a simple and intuitive way to play the Rock, Paper, Scissors game. The game logic is implemented in the `rock`, `paper`, and `scissors` methods, which are called when the corresponding buttons are clicked. The game displays the result in the label below the buttons. Note that this is a basic implementation, and you can customize it to suit your needs. For example, you could add additional features like scoring, high scores, or a more complex game logic.<|eot_id|><|start_header_id|>user<|end_header_id|> Could you tell me more about tkinter?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Tkinter is a Python library that allows you to create simple graphical user interfaces (GUIs) for your Python programs. It's a great tool for beginners and experienced Python developers alike. Here are some key features and concepts related to Tkinter: **Key Features:** 1. **Widgets:** Tkinter provides a variety of widgets, such as buttons, labels, text boxes, and more, that you can use to create your GUI. 2. **Layout Management:** Tkinter allows you to create a layout for your widgets using the `pack`, `grid`, or `place` methods. 3. **Event Handling:** Tkinter provides a way to handle events, such as button clicks, key presses, and mouse movements, which can trigger actions in your GUI. 4. **Graphics:** Tkinter can display images, charts, and other graphical elements in your GUI. **Basic Concepts:** 1. **Widgets:** A widget is a graphical element that you can use to create your GUI. Examples of widgets include buttons, labels, and text boxes. 2. **Layout:** A layout is a way of arranging widgets in your GUI. You can use the `pack`, `grid`, or `place` methods to create a layout. 3. **Event Handling:** Event handling is the process of responding to events, such as button clicks or key presses, in your GUI. 4. **Widgets:** You can create a GUI by creating a `Frame` widget, which is a container that holds other widgets. **Example Code:** Here's an example of a simple GUI created using Tkinter: ```python import tkinter as tk class MyGUI: def __init__(self): self.window = tk.Tk() self.window.title("My GUI") self.label = tk.Label(self.window, text="Hello, World!") self.label.pack() self.button = tk.Button(self.window, text="Click me!", command=self.button_click) self.button.pack() def button_click(self): print("Button clicked!") def run(self): self.window.mainloop() if __name__ == "__main__": gui = MyGUI() gui.run() ``` In this example, we create a `MyGUI` class that creates a window with a label and a button. When the button is clicked, the `button_click` method is called, which prints a message to the console. **Tips and Tricks:** 1. **Use the `pack` or `grid` layout manager:** These are the most common layout managers in Tkinter. 2. **Use the `mainloop` method:** This is the main method that starts the event loop of your GUI. 3. **Use the `bind` method:** This is used to bind a function to a specific event, such as button click or key press. 4. **Use the `config` method:** This is used to change the properties of a widget, such as its text or color. I hope this helps! Let me know if you have any specific questions or if you'd like more information on how to use Tkinter.<|eot_id|><|start_header_id|>user<|end_header_id|> Is it possible to deploy on a website?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Yes, it is possible to deploy a Tkinter application on a website. However, it requires some additional steps and considerations. Here are some options: 1. **Run the application as a web server:** You can use a web server like Apache or Nginx to run your Tkinter application on a website. This will allow users to access your application through a web browser. 2. **Use a web framework:** You can use a web framework like Flask or Django to create a web application that uses Tkinter as a backend. This will allow you to create a web application that can be accessed through a web browser. 3. **Use a Python web server:** You can use a Python web server like `http.server` or `webserver` to run your Tkinter application on a website. This will allow you to access your application through a web browser. Here's an example of how you can deploy a Tkinter application on a website using `http.server`: ```python import tkinter as tk class MyGUI: def __init__(self): self.window = tk.Tk() self.window.title("My GUI") self.label = tk.Label(self.window, text="Hello, World!") self.label.pack() def run(self): self.window.mainloop() if __name__ == "__main__": gui = MyGUI() gui.run() ``` To run this code on a website, you'll need to install `http.server` or `webserver` and add it to your `PATH` environment variable. Here's an example of how to do this: ```bash # Install http.server pip install http.server # Add http.server to your PATH environment variable export PATH=/usr/local/bin:/usr/local/bin ``` Once you've installed `http.server`, you can run the following code to deploy your Tkinter application on a website: ```python import http.server class MyGUI: def __init__(self): self.window = http.server.HTTPServer(("", 8000), MyGUI) def run(self): self.window.run() if __name__ == "__main__": gui = MyGUI() gui.run() ``` This will start a web server on port 8000 and allow users to access your Tkinter application through a web browser. Note that deploying a Tkinter application on a website requires some technical knowledge and setup. If you're not comfortable with this, you may want to consider using a more modern web development framework or a Python web framework like Flask or Django.<|eot_id|> ~~~ </details> The performance for each run is nearly similar. ``` I 00:00:00.528911 executorch:prompt_processor.cpp:267] Prompt Processor: total 26 prompt tokens (AR-1 * 26 iters) I 00:00:00.934570 executorch:runner.cpp:459] RSS after prompt prefill: 1326.957031 MiB (0 if unsupported) I 00:00:08.314366 executorch:token_generator.cpp:345] Reached to the end of generation I 00:00:08.314416 executorch:runner.cpp:474] RSS after finishing text generation: 1326.957031 MiB (0 if unsupported) I 00:00:08.314457 executorch:stats.h:143] Prompt Tokens: 26 Generated Tokens: 446 I 00:00:08.314473 executorch:stats.h:149] Model Load Time: 0.526000 (seconds) I 00:00:08.314489 executorch:stats.h:159] Total inference time: 7.786000 (seconds) Rate: 57.282302 (tokens/second) I 00:00:08.314506 executorch:stats.h:167] Prompt evaluation: 0.407000 (seconds) Rate: 63.882064 (tokens/second) I 00:00:08.314527 executorch:stats.h:178] Generated 446 tokens: 7.379000 (seconds) Rate: 60.441794 (tokens/second) I 00:00:08.314544 executorch:stats.h:186] Time to first generated token: 0.407000 (seconds) I 00:00:08.314547 executorch:stats.h:193] Sampling time over 472 tokens: 0.797000 (seconds) ``` cc: @haowhsu-quic
1 parent 6399767 commit 3b3c9d4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1971
-429
lines changed

backends/qualcomm/quantizer/annotators.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,16 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None:
780780

781781
@register_annotator([torch.ops.aten.slice.Tensor])
782782
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
783-
annotate_single_in_share_out(node, quantization_config)
783+
annotate_in_out_obs_sharing_op(node, quantization_config)
784+
if not _is_annotated([node]):
785+
annotate_single_in_share_out(node, quantization_config)
786+
787+
788+
@register_annotator([torch.ops.aten.narrow.default])
789+
def annotate_narrow(node: Node, quantization_config: QuantizationConfig) -> None:
790+
annotate_in_out_obs_sharing_op(node, quantization_config)
791+
if not _is_annotated([node]):
792+
annotate_single_in_share_out(node, quantization_config)
784793

785794

786795
@register_annotator([torch.ops.aten.slice_scatter.default])

backends/qualcomm/quantizer/qconfig.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def get_8a8w_qnn_ptq_config(
151151

152152

153153
def get_8a4w_qnn_ptq_config(
154-
act_symmetric: bool = True,
154+
act_symmetric: bool = False,
155155
act_observer=MovingAverageMinMaxObserver,
156156
eps: float = None,
157157
) -> QuantizationConfig:
@@ -210,15 +210,19 @@ def get_8a4w_qnn_ptq_config(
210210

211211
# 4 bits quantization only supports specific ops.
212212
def get_16a4w_qnn_ptq_config(
213-
act_observer=MovingAverageMinMaxObserver, eps: float = None
213+
act_symmetric: bool = False,
214+
act_observer=MovingAverageMinMaxObserver,
215+
eps: float = None,
214216
) -> QuantizationConfig:
215217
# the smallest defaults to DEFAULT_EPS_16BIT
216218
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
217219
act_quantization_spec = QuantizationSpec(
218220
dtype=torch.int32,
219221
quant_min=torch.iinfo(torch.uint16).min,
220222
quant_max=torch.iinfo(torch.uint16).max,
221-
qscheme=torch.per_tensor_affine,
223+
qscheme=(
224+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
225+
),
222226
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
223227
)
224228

@@ -250,15 +254,19 @@ def get_16a4w_qnn_ptq_config(
250254

251255

252256
def get_16a8w_qnn_ptq_config(
253-
act_observer=MovingAverageMinMaxObserver, eps: float = None
257+
act_symmetric: bool = False,
258+
act_observer=MovingAverageMinMaxObserver,
259+
eps: float = None,
254260
) -> QuantizationConfig:
255261
# the smallest defaults to DEFAULT_EPS_16BIT
256262
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
257263
act_quantization_spec = QuantizationSpec(
258264
dtype=torch.int32,
259265
quant_min=torch.iinfo(torch.uint16).min,
260266
quant_max=torch.iinfo(torch.uint16).max,
261-
qscheme=torch.per_tensor_affine,
267+
qscheme=(
268+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
269+
),
262270
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
263271
)
264272

@@ -288,15 +296,19 @@ def get_16a8w_qnn_ptq_config(
288296

289297

290298
def get_16a16w_qnn_ptq_config(
291-
act_observer=MovingAverageMinMaxObserver, eps: float = None
299+
act_symmetric: bool = False,
300+
act_observer=MovingAverageMinMaxObserver,
301+
eps: float = None,
292302
) -> QuantizationConfig:
293303
# the smallest defaults to DEFAULT_EPS_16BIT
294304
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
295305
act_quantization_spec = QuantizationSpec(
296306
dtype=torch.int32,
297307
quant_min=torch.iinfo(torch.uint16).min,
298308
quant_max=torch.iinfo(torch.uint16).max,
299-
qscheme=torch.per_tensor_affine,
309+
qscheme=(
310+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
311+
),
300312
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
301313
)
302314

@@ -330,22 +342,28 @@ def get_16a16w_qnn_ptq_config(
330342

331343
# TODO merge qat and ptq to a function, and use a bool flag to control it
332344
def get_16a8w_qnn_qat_config(
333-
act_observer=MovingAverageMinMaxObserver, eps: float = None
345+
act_symmetric: bool = False,
346+
act_observer=MovingAverageMinMaxObserver,
347+
eps: float = None,
334348
) -> QuantizationConfig:
335349
# the smallest defaults to DEFAULT_EPS_16BIT
336350
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
337351
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
338352
dtype=torch.int32,
339353
quant_min=torch.iinfo(torch.uint16).min,
340354
quant_max=torch.iinfo(torch.uint16).max,
341-
qscheme=torch.per_tensor_affine,
355+
qscheme=(
356+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
357+
),
342358
observer=act_observer.with_args(**extra_args),
343359
)
344360
act_quantization_spec = QuantizationSpec(
345361
dtype=torch.int32,
346362
quant_min=torch.iinfo(torch.uint16).min,
347363
quant_max=torch.iinfo(torch.uint16).max,
348-
qscheme=torch.per_tensor_affine,
364+
qscheme=(
365+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
366+
),
349367
observer_or_fake_quant_ctr=act_fake_quant_ctr,
350368
)
351369

@@ -648,22 +666,28 @@ def get_8a8w_qnn_qat_config(
648666

649667

650668
def get_16a4w_qnn_qat_config(
651-
act_observer=MovingAverageMinMaxObserver, eps: float = None
669+
act_symmetric: bool = False,
670+
act_observer=MovingAverageMinMaxObserver,
671+
eps: float = None,
652672
) -> QuantizationConfig:
653673
# the smallest defaults to DEFAULT_EPS_16BIT
654674
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
655675
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
656676
dtype=torch.int32,
657677
quant_min=torch.iinfo(torch.uint16).min,
658678
quant_max=torch.iinfo(torch.uint16).max,
659-
qscheme=torch.per_tensor_affine,
679+
qscheme=(
680+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
681+
),
660682
observer=act_observer.with_args(**extra_args),
661683
)
662684
act_quantization_spec = QuantizationSpec(
663685
dtype=torch.int32,
664686
quant_min=torch.iinfo(torch.uint16).min,
665687
quant_max=torch.iinfo(torch.uint16).max,
666-
qscheme=torch.per_tensor_affine,
688+
qscheme=(
689+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
690+
),
667691
observer_or_fake_quant_ctr=act_fake_quant_ctr,
668692
)
669693

backends/qualcomm/quantizer/quant_recipe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
is_qat: bool,
7474
granularity: QuantGranularity,
7575
act_observer: UniformQuantizationObserverBase,
76+
act_symmetric: bool,
7677
extra_kwargs: Dict,
7778
note: str,
7879
priority: int,
@@ -81,6 +82,7 @@ def __init__(
8182
self.is_qat = is_qat
8283
self.granularity = granularity
8384
self.act_observer = act_observer
85+
self.act_symmetric = act_symmetric
8486
self.extra_kwargs = extra_kwargs
8587
self.note = note
8688
self.priority = priority
@@ -91,6 +93,7 @@ def __init__(
9193
is_conv_per_channel=True,
9294
is_linear_per_channel=True,
9395
act_observer=self.act_observer,
96+
act_symmetric=self.act_symmetric,
9497
)
9598

9699
@abstractmethod
@@ -143,6 +146,7 @@ def __init__(
143146
is_qat,
144147
granularity,
145148
act_observer,
149+
act_symmetric,
146150
extra_kwargs,
147151
note,
148152
priority,
@@ -153,6 +157,7 @@ def __init__(
153157
is_qat,
154158
granularity,
155159
act_observer,
160+
act_symmetric,
156161
extra_kwargs,
157162
note,
158163
priority,
@@ -179,6 +184,7 @@ def __init__(
179184
is_qat,
180185
granularity,
181186
act_observer,
187+
act_symmetric,
182188
extra_kwargs,
183189
note,
184190
priority,
@@ -189,6 +195,7 @@ def __init__(
189195
is_qat,
190196
granularity,
191197
act_observer,
198+
act_symmetric,
192199
extra_kwargs,
193200
note,
194201
priority,
@@ -228,6 +235,7 @@ def __init__(
228235
is_qat,
229236
act_observer: UniformQuantizationObserverBase,
230237
granularity: QuantGranularity,
238+
act_symmetric: bool = False,
231239
note: str = "",
232240
extra_kwargs: Optional[dict] = None,
233241
verbose: bool = False,
@@ -257,6 +265,7 @@ def __init__(
257265
is_qat,
258266
granularity,
259267
act_observer,
268+
act_symmetric,
260269
extra_kwargs or {},
261270
note,
262271
priority=1,
@@ -311,6 +320,7 @@ def add_node_target(
311320
is_qat,
312321
act_observer: UniformQuantizationObserverBase,
313322
granularity: QuantGranularity,
323+
act_symmetric: bool = False,
314324
note: str = "",
315325
priority: int = 1,
316326
extra_kwargs: Optional[dict] = None,
@@ -321,6 +331,7 @@ def add_node_target(
321331
is_qat,
322332
granularity,
323333
act_observer,
334+
act_symmetric,
324335
extra_kwargs or {},
325336
note,
326337
priority,
@@ -336,6 +347,7 @@ def add_regex(
336347
is_qat,
337348
act_observer: UniformQuantizationObserverBase,
338349
granularity: QuantGranularity,
350+
act_symmetric: bool = False,
339351
note: str = "",
340352
priority: int = 1,
341353
extra_kwargs: Optional[dict] = None,
@@ -359,6 +371,7 @@ def add_regex(
359371
is_qat,
360372
granularity,
361373
act_observer,
374+
act_symmetric,
362375
extra_kwargs or {},
363376
note,
364377
priority,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class ModuleQConfig:
160160
is_conv_per_channel: bool = False
161161
is_linear_per_channel: bool = False
162162
act_observer: Optional[UniformQuantizationObserverBase] = None
163+
act_symmetric: bool = False
163164
eps: Optional[float] = None
164165

165166
def __post_init__(self):
@@ -173,9 +174,13 @@ def __post_init__(self):
173174
per_block_quant_config_func,
174175
) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)]
175176
self.quant_config = (
176-
quant_config_func(act_observer=self.act_observer, eps=self.eps)
177+
quant_config_func(
178+
act_symmetric=self.act_symmetric,
179+
act_observer=self.act_observer,
180+
eps=self.eps,
181+
)
177182
if self.act_observer
178-
else quant_config_func(eps=self.eps)
183+
else quant_config_func(act_symmetric=self.act_symmetric, eps=self.eps)
179184
)
180185

181186
# Assume per_channel_quant/per_block_quant only happen on axis_0 or axis_1, increase the range if there's a need
@@ -186,12 +191,15 @@ def __post_init__(self):
186191
self.per_channel_quant_config_list.append(
187192
(
188193
per_channel_quant_config_func(
194+
act_symmetric=self.act_symmetric,
189195
act_observer=self.act_observer,
190196
ch_axis=i,
191197
eps=self.eps,
192198
)
193199
if self.act_observer
194-
else per_channel_quant_config_func(ch_axis=i, eps=self.eps)
200+
else per_channel_quant_config_func(
201+
act_symmetric=self.act_symmetric, ch_axis=i, eps=self.eps
202+
)
195203
)
196204
)
197205

@@ -229,10 +237,14 @@ def __post_init__(self):
229237
self.per_block_quant_config_list.append(
230238
(
231239
per_block_quant_config_func(
232-
act_observer=self.act_observer, ch_axis=i
240+
act_symmetric=self.act_symmetric,
241+
act_observer=self.act_observer,
242+
ch_axis=i,
233243
)
234244
if self.act_observer
235-
else per_block_quant_config_func(ch_axis=i)
245+
else per_block_quant_config_func(
246+
act_symmetric=self.act_symmetric, ch_axis=i
247+
)
236248
)
237249
)
238250

@@ -412,6 +424,7 @@ def set_default_quant_config(
412424
is_conv_per_channel=False,
413425
is_linear_per_channel=False,
414426
act_observer=None,
427+
act_symmetric=False,
415428
eps=None,
416429
) -> None:
417430
"""
@@ -432,6 +445,7 @@ def set_default_quant_config(
432445
is_conv_per_channel=is_conv_per_channel,
433446
is_linear_per_channel=is_linear_per_channel,
434447
act_observer=act_observer,
448+
act_symmetric=act_symmetric,
435449
eps=eps,
436450
)
437451

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,14 @@ def forward(self, x):
16321632
return attn_output
16331633

16341634

1635+
class Narrow(torch.nn.Module):
1636+
def __init__(self):
1637+
super().__init__()
1638+
1639+
def forward(self, x):
1640+
return (x.narrow(1, 4, 32),)
1641+
1642+
16351643
class Neg(torch.nn.Module):
16361644
def __init__(self):
16371645
super().__init__()

backends/qualcomm/tests/test_passes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def test_mha_to_sha(self):
7070
# Initailize model config
7171
args = ModelArgs()
7272
args.max_seq_len = 128
73+
args.max_context_len = 128
7374
args.ar_len = 32
7475
args.use_kv_cache = True
7576
args.dim = 32

0 commit comments

Comments
 (0)