Skip to content

Commit a2b6819

Browse files
committed
chore: Merge remote-tracking branch 'origin/main' into hf_checkpoint_conversion_for_fsdp2
2 parents 3be2921 + b856127 commit a2b6819

3 files changed

Lines changed: 13 additions & 4 deletions

File tree

src/modalities/config/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ class EvaluationResultToDiscSubscriberConfig(BaseModel):
494494

495495
class WandBEvaluationResultSubscriberConfig(BaseModel):
496496
global_rank: int
497+
entity: Optional[str] = None
497498
project: str
498499
experiment_id: str
499500
mode: WandbMode
@@ -547,7 +548,7 @@ def load_app_config_dict(
547548
"""
548549

549550
def cuda_env_resolver_fun(var_name: str) -> int | str | None:
550-
int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK"]
551+
int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE"]
551552
return int(os.environ[var_name]) if var_name in int_env_variable_names else os.getenv(var_name)
552553

553554
def modalities_env_resolver_fun(var_name: str, kwargs: dict[str, Any]) -> str | Path:

src/modalities/logging_broker/subscriber_impl/results_subscriber.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ def __init__(
6565
project: str,
6666
experiment_id: str,
6767
mode: WandbMode,
68-
logging_directory: Path,
68+
logging_directory: Path | None,
6969
config_file_path: Path,
70+
entity: str | None = None,
7071
) -> None:
7172
super().__init__()
7273

7374
with open(config_file_path, "r", encoding="utf-8") as file:
7475
config = yaml.safe_load(file)
7576
self.run = wandb.init(
77+
entity=entity,
7678
project=project,
7779
name=experiment_id,
7880
mode=mode.value.lower(),
@@ -81,7 +83,7 @@ def __init__(
8183
settings=wandb.Settings(init_timeout=120),
8284
)
8385

84-
self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config")
86+
self.run.log_artifact(config_file_path, name=f"config_{self.run.id}", type="config")
8587

8688
def consume_dict(self, message_dict: dict[str, Any]):
8789
for k, v in message_dict.items():

src/modalities/logging_broker/subscriber_impl/subscriber_factory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_wandb_result_subscriber(
6868
mode: WandbMode,
6969
config_file_path: Path,
7070
directory: Optional[Path] = None,
71+
entity: Optional[str] = None,
7172
) -> WandBEvaluationResultSubscriber:
7273
if global_rank == 0 and (mode != WandbMode.DISABLED):
7374
if directory is not None:
@@ -88,7 +89,12 @@ def get_wandb_result_subscriber(
8889
absolute_dir = None
8990

9091
result_subscriber = WandBEvaluationResultSubscriber(
91-
project, experiment_id, mode, absolute_dir, config_file_path
92+
project=project,
93+
experiment_id=experiment_id,
94+
mode=mode,
95+
logging_directory=absolute_dir,
96+
config_file_path=config_file_path,
97+
entity=entity,
9298
)
9399
else:
94100
result_subscriber = ResultsSubscriberFactory.get_dummy_result_subscriber()

0 commit comments

Comments
 (0)