forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathblock_config.py
More file actions
304 lines (247 loc) · 9.76 KB
/
block_config.py
File metadata and controls
304 lines (247 loc) · 9.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# mypy: ignore-errors
import dataclasses
import inspect
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional, Type, Union, get_args, get_origin
__all__ = [
"BaseDataclass",
"SubblockConfig",
"MoEConfig",
"MambaConfig",
"Llama4AttentionConfig",
"AttentionConfig",
"FFNConfig",
"SUBBLOCK_CLS_DICT",
"BlockConfig",
"maybe_cast_block_configs",
]
@dataclass(frozen=True, kw_only=True)
class BaseDataclass:
"""
A dataclass base class with several utilities:
1. Comparison via string representation.
2. Initialization of dataclasses fields from dicts.
3. Setting attributes even though it's frozen (but only inside __post_init__!)
"""
def __eq__(self, other: "BaseDataclass") -> bool:
return str(self) == str(other)
def __hash__(self) -> int:
return hash(str(self))
def __lt__(self, other: "BaseDataclass") -> bool:
return str(self) < str(other)
def _force_setattr(self, name: str, value: Any) -> None:
"""
Set an attribute even in frozen dataclasses.
Use only inside __post_init__!
"""
assert _is_called_from_post_init(), (
"_force_setattr should only be called from __post_init__, "
"if you need to change an attribute use dataclasses.replace "
"or create a new instance :)"
)
object.__setattr__(self, name, value)
def __post_init__(self):
"""
Init dataclass fields from dicts
"""
for field in dataclasses.fields(self):
field_dict = getattr(self, field.name)
if isinstance(field_dict, dict) and _is_dataclass_type(field.type):
dataclass_cls = _get_dataclass_type(field.type)
sub_fields = [field.name for field in dataclasses.fields(dataclass_cls)]
unsupported_fields = [
field_name for field_name in field_dict.keys() if field_name not in sub_fields
]
if len(unsupported_fields) > 0:
warnings.warn(
f"Removed unsupported fields {unsupported_fields} from {dataclass_cls}"
)
field_dict = {k: v for k, v in field_dict.items() if k not in unsupported_fields}
self._force_setattr(field.name, dataclass_cls(**field_dict))
def _is_called_from_post_init() -> bool:
frame = inspect.currentframe()
while frame:
if frame.f_code.co_name == "__post_init__":
return True
frame = frame.f_back
return False
def _is_dataclass_type(tp: Type) -> bool:
"""
Like dataclasses.is_dataclass but also works for Optional[] and Union[] of a dataclass type
"""
try:
_get_dataclass_type(tp)
return True
except:
return False
def _get_dataclass_type(tp: Type) -> dataclass:
"""
If the given type is a dataclass, the function returns it.
If it is a Union[] or Optional[], the function extracts the first dataclass type.
If no dataclass type is found, the function raises a ValueError.
"""
origin = get_origin(tp)
if origin is Union:
for type_in_union in get_args(tp):
if dataclasses.is_dataclass(type_in_union):
return type_in_union
if dataclasses.is_dataclass(tp):
return tp
raise ValueError("Not a dataclass")
@dataclass(frozen=True, kw_only=True)
class SubblockConfig(BaseDataclass):
"""Base configuration for a subblock (e.g. attention or FFN) within a transformer block."""
no_op: bool = False
replace_with_linear: bool = False
sparsify: Optional[list[str]] = None
weights_precision: Optional[str] = "bf16"
def __post_init__(self):
super().__post_init__()
assert not (self.no_op and self.replace_with_linear)
if self.no_op:
self._force_setattr("sparsify", None)
@abstractmethod
def to_blockconfig(self) -> "BlockConfig":
""" "
Convert to a block including this subblock only.
"""
...
@dataclass(frozen=True, kw_only=True)
class MoEConfig(BaseDataclass):
"""
Configuration class for Mixture of Experts parameters.
"""
num_local_experts: int = 8
num_experts_per_tok: int = 1
expert_intermediate_dim: int = 8192
shared_expert_intermediate_dim: int = 8192
# router_aux_loss_coef: float = 0.01
# router_z_loss_coef: float = 0.0 # Optional z-loss coefficient
def __post_init__(self):
# Validate the configuration
if self.num_local_experts <= 0:
raise ValueError(f"num_local_experts must be positive, got {self.num_local_experts}")
if self.num_experts_per_tok <= 0:
raise ValueError(
f"num_experts_per_tok must be positive, got {self.num_experts_per_tok}"
)
if self.num_experts_per_tok > self.num_local_experts:
raise ValueError(
f"num_experts_per_tok ({self.num_experts_per_tok}) cannot be greater than num_local_experts ({self.num_local_experts})"
)
# if self.router_aux_loss_coef < 0:
# raise ValueError(f"router_aux_loss_coef must be non-negative, got {self.router_aux_loss_coef}")
@dataclass(frozen=True, kw_only=True)
class MambaConfig(BaseDataclass):
"""Configuration for a Mamba (state-space model) subblock."""
state_dim: int
num_heads: int
head_dim: int
num_groups: int
@dataclass(frozen=True, kw_only=True)
class Llama4AttentionConfig(BaseDataclass):
"""Configuration for Llama-4-specific attention parameters."""
attention_chunk_size: Optional[int] = None
use_rope: Optional[bool] = None
use_qk_norm: Optional[bool] = None
attn_scale: Optional[float] = None
floor_scale: Optional[float] = None
attn_temperature_tuning: Optional[bool] = None
attention_dropout: Optional[float] = None
@dataclass(frozen=True, kw_only=True)
class AttentionConfig(SubblockConfig):
"""Configuration for an attention subblock within a transformer block."""
num_key_value_heads: Optional[int] = None
llama4: Optional[Llama4AttentionConfig] = None
mamba: Optional[MambaConfig] = None
def __post_init__(self):
super().__post_init__()
if self.no_op:
assert not self.is_mamba
assert not self.is_llama4
if self.no_op or self.is_mamba:
for irrelevant_att in [
"num_key_value_heads",
]:
self._force_setattr(irrelevant_att, None)
else:
assert self.num_key_value_heads is not None
def to_blockconfig(self) -> "BlockConfig":
return BlockConfig(attention=self, ffn=FFNConfig(no_op=True))
@property
def is_llama4(self) -> bool:
return self.llama4 is not None
@property
def is_mamba(self) -> bool:
return self.mamba is not None
@dataclass(frozen=True, kw_only=True)
class FFNConfig(SubblockConfig):
"""Configuration for a feed-forward network subblock within a transformer block."""
moe: Optional[MoEConfig] = None
intermediate_size: Optional[int] = None
def __post_init__(self):
super().__post_init__()
if self.no_op:
self._force_setattr("moe", None)
self._force_setattr("intermediate_size", None)
elif self.is_moe:
self._force_setattr("intermediate_size", None)
else:
assert self.intermediate_size is not None, (
"Intermediate size must be provided for an FFN block"
)
def to_blockconfig(self) -> "BlockConfig":
return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self)
@property
def is_moe(self) -> bool:
return self.moe is not None
SUBBLOCK_CLS_DICT = {
"attention": AttentionConfig,
"ffn": FFNConfig,
}
@dataclass(frozen=True, kw_only=True)
class BlockConfig(BaseDataclass):
"""Configuration for a single transformer block, including its attention and FFN subblocks."""
attention: Optional[AttentionConfig] = None
ffn: Optional[FFNConfig] = None
parallel_blocks: Optional[list["BlockConfig"]] = None
def __post_init__(self):
super().__post_init__()
if (self.parallel_blocks is not None) and isinstance(self.parallel_blocks[0], dict):
initialized_block_configs = [
BlockConfig(**block_config) for block_config in self.parallel_blocks
]
self._force_setattr("parallel_blocks", initialized_block_configs)
def to_dict(self) -> dict:
"""Convert BlockConfig to a dictionary."""
return dataclasses.asdict(self)
def maybe_cast_block_configs(
block_configs: List[BlockConfig | dict] | None,
) -> List[BlockConfig] | None:
"""Cast a list of dicts to BlockConfig objects if needed.
Args:
block_configs: List of BlockConfig or dict objects, or None.
Returns:
List of BlockConfig objects, or None if input is None/empty.
"""
if not block_configs:
return block_configs
if isinstance(block_configs[0], dict):
return [BlockConfig(**conf) for conf in block_configs]
return block_configs