Skip to content

Commit 44f4dc0

Browse files
authored
[Modular] guard ModularPipeline.blocks attribute (#13014)
* up * style
1 parent fd705bd commit 44f4dc0

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,11 +1598,11 @@ def __init__(
15981598
else:
15991599
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
16001600

1601-
self.blocks = blocks
1601+
self._blocks = blocks
16021602
self._components_manager = components_manager
16031603
self._collection = collection
1604-
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
1605-
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
1604+
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
1605+
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
16061606

16071607
# update component_specs and config_specs based on modular_model_index.json
16081608
if modular_config_dict is not None:
@@ -1649,7 +1649,9 @@ def __init__(
16491649
for name, config_spec in self._config_specs.items():
16501650
default_configs[name] = config_spec.default
16511651
self.register_to_config(**default_configs)
1652-
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
1652+
self.register_to_config(
1653+
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
1654+
)
16531655

16541656
@property
16551657
def default_call_parameters(self) -> Dict[str, Any]:
@@ -1658,7 +1660,7 @@ def default_call_parameters(self) -> Dict[str, Any]:
16581660
- Dictionary mapping input names to their default values
16591661
"""
16601662
params = {}
1661-
for input_param in self.blocks.inputs:
1663+
for input_param in self._blocks.inputs:
16621664
params[input_param.name] = input_param.default
16631665
return params
16641666

@@ -1829,7 +1831,15 @@ def doc(self):
18291831
Returns:
18301832
- The docstring of the pipeline blocks
18311833
"""
1832-
return self.blocks.doc
1834+
return self._blocks.doc
1835+
1836+
@property
1837+
def blocks(self) -> ModularPipelineBlocks:
1838+
"""
1839+
Returns:
1840+
- A copy of the pipeline blocks
1841+
"""
1842+
return deepcopy(self._blocks)
18331843

18341844
def register_components(self, **kwargs):
18351845
"""
@@ -2565,7 +2575,7 @@ def _dict_to_component_spec(
25652575
)
25662576

25672577
def set_progress_bar_config(self, **kwargs):
2568-
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
2578+
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
25692579
if hasattr(sub_block, "set_progress_bar_config"):
25702580
sub_block.set_progress_bar_config(**kwargs)
25712581

@@ -2619,7 +2629,7 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
26192629

26202630
# Add inputs to state, using defaults if not provided in the kwargs or the state
26212631
# if same input already in the state, will override it if provided in the kwargs
2622-
for expected_input_param in self.blocks.inputs:
2632+
for expected_input_param in self._blocks.inputs:
26232633
name = expected_input_param.name
26242634
default = expected_input_param.default
26252635
kwargs_type = expected_input_param.kwargs_type
@@ -2638,9 +2648,9 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
26382648
# Run the pipeline
26392649
with torch.no_grad():
26402650
try:
2641-
_, state = self.blocks(self, state)
2651+
_, state = self._blocks(self, state)
26422652
except Exception:
2643-
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
2653+
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
26442654
logger.error(error_msg)
26452655
raise
26462656

0 commit comments

Comments
 (0)