Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 50 additions & 32 deletions fms_mo/utils/dq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,72 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utils for DQ
"""Utility functions for Direct Quantization" (DQ)."""

"""

def config_quantize_smooth_layers(qcfg: dict):
"""Update qcfg with model-dependent config parameters:
- qlayer_name_pattern: identifier of transformer layers containing linear layers
to quantize (if any, tracing is bypassed)
- scale_layers: identifier of linear layers to apply smoothquant on
- qskip_layer_name: full name of linear layers that will not be quantized
- act_scale_path: path to save/load smoothquant activation scales

def config_quantize_smooth_layers(qcfg):
"""
To set the config for each model, for example
layers to quantize
layers to skip
layers to apply smooth-scale
block_size
smooth_alpha
Selected model is determined by comparing all architecture identifiers against
`model` and `model_type` fields in qcfg.

NOTE: layer quantization skip is determined by bool `qskip_large_mag_layers`
NOTE: different versions of granite models are based on different architectures
(chronologically: bigcode -> llama -> granite)
"""

llama_architecture = [
"llama",
"Nemotron",
"granite-3b-code",
"granite-8b-code",
]
granite_BigCode_architecture = [
bigcode_architecture = [
"granite-3b-base",
"granite-13b-base",
"granite-20b-code",
"granite-20b-code",
]
if (
any(model in qcfg["model"] for model in llama_architecture)
or any(model in qcfg["model_type"] for model in llama_architecture)
and qcfg["qskip_large_mag_layers"]
granite_architecture = [
"granite-3.0-8b-base",
"granite-3.0-8b-instruct",
"granite-3.1-8b-base",
"granite-3.1-8b-instruct",
"granite-3.2-8b-instruct",
"granite-3.3-8b-base",
"granite-3.3-8b-instruct",
]

if any(model in qcfg["model"] for model in llama_architecture) or any(
model in qcfg["model_type"] for model in llama_architecture
):
qcfg["qlayer_name_pattern"] = ["model.layers."]
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
large_mag_layers = {
"2-7b": [1, 30],
"2-70b": [2, 8, 79],
"3-8B": [1, 31],
"3-70B": [3, 78, 79],
"405B-Instruct": [5, 124, 125],
}
for llama_family, layers in large_mag_layers.items():
if llama_family in qcfg["model"]:
qcfg["qskip_layer_name"] += [
f"model.layers.{i}.mlp.down_proj" for i in layers
]
break

if qcfg["qskip_large_mag_layers"]:
large_mag_layers = {
"2-7b": [1, 30],
"2-70b": [2, 8, 79],
"3-8B": [1, 31],
"3-70B": [3, 78, 79],
"405B-Instruct": [5, 124, 125],
}
for llama_family, layers in large_mag_layers.items():
if llama_family in qcfg["model"]:
qcfg["qskip_layer_name"] += [
f"model.layers.{i}.mlp.down_proj" for i in layers
]
break
elif any(model in qcfg["model"] for model in granite_architecture) or any(
model in qcfg["model_type"] for model in granite_architecture
):
qcfg["qlayer_name_pattern"] = ["model.layers."]
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
# NOTE: supported granite-v3 models do not need layer skip for large magnitude
elif "mixtral" in qcfg["model"]:
qcfg["qlayer_name_pattern"] = (
["model.layers"] if qcfg["nbits_bmm1"] == 32 else []
Expand All @@ -81,10 +99,10 @@ def config_quantize_smooth_layers(qcfg):
]
]
qcfg["act_scale_path"] = "./act_scales/Mixtral-8x7B-v0.1.pt"
elif any(model in qcfg["model"] for model in granite_BigCode_architecture):
elif any(model in qcfg["model"] for model in bigcode_architecture):
qcfg["qlayer_name_pattern"] = ["transformer.h"]
qcfg["scale_layers"] = ["c_attn", "c_fc"]
qcfg["qskip_layer_name"] = []
# NOTE: supported bigcode models do not need layer skip for large magnitude
if "granite-3b-base-v2" in qcfg["model"]:
qcfg["act_scale_path"] = "./act_scales/granite_3b_base_v2_500_nw.pt"
if "granite-13b-base-v2" in qcfg["model"]:
Expand Down
Loading