Skip to content

Commit fbfa43c

Browse files
[bugfix] [4/n] Validate SamplingParam.update keys against dataclass fields (#1234)
hasattr-based validation let callers overwrite methods (e.g. update itself) or other non-field attributes. Switch update() and _from_preset() to check dataclasses.fields instead, and use logger.error with a clearer "has no field" message on unknown keys.
1 parent 7da18b3 commit fbfa43c

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

fastvideo/api/sampling_param.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import copy
3-
from dataclasses import dataclass, field
3+
from dataclasses import dataclass, field, fields
44
from typing import Any
55

66
from fastvideo.logger import init_logger
@@ -117,11 +117,12 @@ def check_sampling_param(self):
117117
raise ValueError("prompt_path must be a txt file")
118118

119119
def update(self, source_dict: dict[str, Any]) -> None:
120+
valid_fields = {f.name for f in fields(self)}
120121
for key, value in source_dict.items():
121-
if hasattr(self, key):
122+
if key in valid_fields:
122123
setattr(self, key, value)
123124
else:
124-
logger.warning("%s has no attribute %s", type(self).__name__, key)
125+
logger.error("%s has no field %s", type(self).__name__, key)
125126

126127
self.__post_init__()
127128

@@ -162,8 +163,9 @@ def _from_preset(
162163

163164
preset = get_preset(preset_name, model_family)
164165
sp = cls()
166+
valid_fields = {f.name for f in fields(cls)}
165167
for key, value in preset.defaults.items():
166-
if hasattr(sp, key):
168+
if key in valid_fields:
167169
setattr(sp, key, copy.deepcopy(value))
168170
sp.__post_init__()
169171
return sp

0 commit comments

Comments
 (0)