@@ -847,7 +847,10 @@ def _find_nvfp4_layers(state_dict: dict[str, torch.Tensor]) -> set[str]:
847847 s_key = f"{ layer } .weight_scale"
848848 if s_key not in state_dict or w_key not in state_dict :
849849 continue
850- if state_dict [w_key ].dtype == torch .uint8 and state_dict [s_key ].dtype == torch .float8_e4m3fn :
850+ if (
851+ state_dict [w_key ].dtype == torch .uint8
852+ and state_dict [s_key ].dtype == torch .float8_e4m3fn
853+ ):
851854 layers .add (layer )
852855 return layers
853856
@@ -882,7 +885,9 @@ def _roundup(a: int, b: int) -> int:
882885 rows , cols_w = weight .shape
883886 pad_r = _roundup (rows , 16 ) - rows
884887 pad_c_w = (_roundup (cols_w , 16 ) - cols_w ) if padding_strategy == "row_col" else 0
885- pad_c_s = (_roundup (scale .shape [1 ], 16 ) - scale .shape [1 ]) if padding_strategy == "row_col" else 0
888+ pad_c_s = (
889+ (_roundup (scale .shape [1 ], 16 ) - scale .shape [1 ]) if padding_strategy == "row_col" else 0
890+ )
886891
887892 if pad_r > 0 or pad_c_w > 0 :
888893 state_dict [w_key ] = torch .nn .functional .pad (weight , (0 , pad_c_w , 0 , pad_r ))
@@ -922,7 +927,9 @@ def _to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
922927 padded = input_matrix
923928 if (rows , cols ) != (padded_rows , padded_cols ):
924929 padded = torch .zeros (
925- (padded_rows , padded_cols ), device = input_matrix .device , dtype = input_matrix .dtype ,
930+ (padded_rows , padded_cols ),
931+ device = input_matrix .device ,
932+ dtype = input_matrix .dtype ,
926933 )
927934 padded [:rows , :cols ] = input_matrix
928935 blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
@@ -935,7 +942,6 @@ def _to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
935942 s_key = f"{ layer } .weight_scale"
936943 state_dict [s_key ] = _to_blocked (state_dict [s_key ].to (torch .float8_e4m3fn ))
937944
938-
939945 return state_dict
940946
941947
0 commit comments