Skip to content

Commit 7524370

Browse files
committed
Fix verbosity gating across wids modules
Introduce WIDS_VERBOSE flag driven by env var Gate ShardListDataset summary prints behind WIDS_VERBOSE Replace print-based warnings with warnings.warn Respect WIDS_VERBOSE in TarFileReader verbose default
1 parent 5c717c0 commit 7524370

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

src/wids/wids.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from urllib.parse import quote, urlparse
1515

1616
import numpy as np
17+
# Verbosity flag for wids; set WIDS_VERBOSE=1 to enable verbose output
18+
WIDS_VERBOSE = bool(int(os.environ.get("WIDS_VERBOSE", 0)))
1719
import torch.distributed as dist
1820

1921
from .wids_decode import default_decoder
@@ -100,7 +102,8 @@ def group_by_key(names):
100102
for i, fname in enumerate(names):
101103
# Ignore files that are not in a subdirectory.
102104
if "." not in fname:
103-
print(f"Warning: Ignoring file {fname} (no '.')")
105+
# Warn about files without extensions; can be silenced via warnings filter
106+
warnings.warn(f"Ignoring file {fname} (no '.')")
104107
continue
105108
key, ext = splitname(fname)
106109
if key != last_key:
@@ -435,7 +438,8 @@ def __init__(
435438
self.cache_dir = os.environ.get("WIDS_CACHE", "/tmp/_wids_cache")
436439
self.localname = DefaultLocalname(self.cache_dir)
437440

438-
if True or int(os.environ.get("WIDS_VERBOSE", 0)):
441+
# Only print dataset summary if verbosity enabled
442+
if WIDS_VERBOSE:
439443
nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
440444
nsamples = sum(shard["nsamples"] for shard in self.shards)
441445
print(
@@ -480,7 +484,8 @@ def check_cache_misses(self):
480484
if accesses > 100 and misses / accesses > 0.3:
481485
# output a warning only once
482486
self.check_cache_misses = lambda: None
483-
print("Warning: ShardListDataset has a cache miss rate of {:.1%}%".format(misses * 100.0 / accesses))
487+
# Warn about high cache miss rate; can be silenced via warnings filter
488+
warnings.warn(f"ShardListDataset has a cache miss rate of {misses/accesses:.1%}")
484489

485490
def get_shard(self, index):
486491
"""Get the shard and index within the shard corresponding to the given index."""

src/wids/wids_tar.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ def find_index_file(file):
1717

1818

1919
class TarFileReader:
20-
def __init__(self, file, index_file=find_index_file, verbose=True):
21-
self.verbose = verbose
20+
def __init__(self, file, index_file=find_index_file, verbose=None):
21+
# Determine verbosity: use parameter if provided, else env var WIDS_VERBOSE
22+
if verbose is None:
23+
self.verbose = bool(int(os.environ.get("WIDS_VERBOSE", 0)))
24+
else:
25+
self.verbose = bool(verbose)
2226
if callable(index_file):
2327
index_file = index_file(file)
2428
self.index_file = index_file

0 commit comments

Comments
 (0)