Skip to content

Commit abd5d1a

Browse files
authored
Merge branch 'main' into main
2 parents 89f3c29 + 441b69e commit abd5d1a

5 files changed

Lines changed: 83 additions & 11 deletions

File tree

src/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
is_torch_version,
4848
logging,
4949
)
50+
from ..utils.distributed_utils import is_torch_dist_rank_zero
5051

5152

5253
logger = logging.get_logger(__name__)
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
429430
low_cpu_mem_usage=low_cpu_mem_usage,
430431
)
431432

433+
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
434+
if not is_torch_dist_rank_zero():
435+
tqdm_kwargs["disable"] = True
436+
432437
with ThreadPoolExecutor(max_workers=num_workers) as executor:
433-
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
438+
with logging.tqdm(**tqdm_kwargs) as pbar:
434439
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
435440
for future in as_completed(futures):
436441
result = future.result()

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,8 @@
5959
is_torch_version,
6060
logging,
6161
)
62-
from ..utils.hub_utils import (
63-
PushToHubMixin,
64-
load_or_create_model_card,
65-
populate_model_card,
66-
)
62+
from ..utils.distributed_utils import is_torch_dist_rank_zero
63+
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
6764
from ..utils.torch_utils import empty_device_cache
6865
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
6966
from .model_loading_utils import (
@@ -1672,7 +1669,10 @@ def _load_pretrained_model(
16721669
else:
16731670
shard_files = resolved_model_file
16741671
if len(resolved_model_file) > 1:
1675-
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
1672+
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
1673+
if not is_torch_dist_rank_zero():
1674+
shard_tqdm_kwargs["disable"] = True
1675+
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
16761676

16771677
for shard_file in shard_files:
16781678
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
logging,
6868
numpy_to_pil,
6969
)
70+
from ..utils.distributed_utils import is_torch_dist_rank_zero
7071
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
7172
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
7273

@@ -982,7 +983,11 @@ def load_module(name, value):
982983
# 7. Load each module in the pipeline
983984
current_device_map = None
984985
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
985-
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
986+
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
987+
if not is_torch_dist_rank_zero():
988+
logging_tqdm_kwargs["disable"] = True
989+
990+
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
986991
# 7.1 device_map shenanigans
987992
if final_device_map is not None:
988993
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
@@ -1908,10 +1913,14 @@ def progress_bar(self, iterable=None, total=None):
19081913
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
19091914
)
19101915

1916+
progress_bar_config = dict(self._progress_bar_config)
1917+
if "disable" not in progress_bar_config:
1918+
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
1919+
19111920
if iterable is not None:
1912-
return tqdm(iterable, **self._progress_bar_config)
1921+
return tqdm(iterable, **progress_bar_config)
19131922
elif total is not None:
1914-
return tqdm(total=total, **self._progress_bar_config)
1923+
return tqdm(total=total, **progress_bar_config)
19151924
else:
19161925
raise ValueError("Either `total` or `iterable` has to be defined.")
19171926

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2025 The HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
try:
17+
import torch
18+
except ImportError:
19+
torch = None
20+
21+
22+
def is_torch_dist_rank_zero() -> bool:
23+
if torch is None:
24+
return True
25+
26+
dist_module = getattr(torch, "distributed", None)
27+
if dist_module is None or not dist_module.is_available():
28+
return True
29+
30+
if not dist_module.is_initialized():
31+
return True
32+
33+
try:
34+
return dist_module.get_rank() == 0
35+
except (RuntimeError, ValueError):
36+
return True

src/diffusers/utils/logging.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
from tqdm import auto as tqdm_lib
3434

35+
from .distributed_utils import is_torch_dist_rank_zero
36+
3537

3638
_lock = threading.Lock()
3739
_default_handler: Optional[logging.Handler] = None
@@ -47,6 +49,23 @@
4749
_default_log_level = logging.WARNING
4850

4951
_tqdm_active = True
52+
_rank_zero_filter = None
53+
54+
55+
class _RankZeroFilter(logging.Filter):
56+
def filter(self, record):
57+
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
58+
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
59+
60+
61+
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
62+
global _rank_zero_filter
63+
64+
if _rank_zero_filter is None:
65+
_rank_zero_filter = _RankZeroFilter()
66+
67+
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
68+
logger.addFilter(_rank_zero_filter)
5069

5170

5271
def _get_default_logging_level() -> int:
@@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None:
90109
library_root_logger.addHandler(_default_handler)
91110
library_root_logger.setLevel(_get_default_logging_level())
92111
library_root_logger.propagate = False
112+
_ensure_rank_zero_filter(library_root_logger)
93113

94114

95115
def _reset_library_root_logger() -> None:
@@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
120140
name = _get_library_name()
121141

122142
_configure_library_root_logger()
123-
return logging.getLogger(name)
143+
logger = logging.getLogger(name)
144+
_ensure_rank_zero_filter(logger)
145+
return logger
124146

125147

126148
def get_verbosity() -> int:

0 commit comments

Comments
 (0)