Skip to content

Commit 38d9522

Browse files
Address coderabbit comments
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 62070ae commit 38d9522

File tree

7 files changed

+23
-12
lines changed

7 files changed

+23
-12
lines changed

examples/puzzletron/evaluation/hf_deployable_anymodel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ def get_triton_input(self):
331331
Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True),
332332
Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True),
333333
Tensor(name="random_seed", shape=(-1,), dtype=np.int_, optional=True),
334-
Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True),
335334
Tensor(name="output_logits", shape=(-1,), dtype=np.bool_, optional=True),
336335
Tensor(name="output_scores", shape=(-1,), dtype=np.bool_, optional=True),
337336
)

modelopt/torch/puzzletron/mip/mip_with_multi_layer_replacements.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def run_mip(
5555
)
5656
print("\n\n\n")
5757

58+
if not replacements:
59+
return [], 0.0, {}
60+
5861
mip_model = Model()
5962

6063
objective_vars = []

modelopt/torch/puzzletron/replacement_library/build_replacement_library.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def build_replacement_library(
8888
add_attention_no_ops,
8989
trust_remote_code=trust_remote_code,
9090
)
91-
block_library_df = _build_block_library_from_subblocks(subblocks_df)
91+
block_library_df = _build_block_library_from_subblocks(subblocks_df, master_puzzle_dir)
9292

9393
layer_replacements = _build_layer_replacements(
9494
block_library_df, master_puzzle_dir, teacher_checkpoint_dir, trust_remote_code
@@ -143,7 +143,9 @@ def infer_teacher_dir(
143143
return teacher_checkpoint_dir
144144

145145

146-
def _build_block_library_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame:
146+
def _build_block_library_from_subblocks(
147+
subblocks_df: pd.DataFrame, output_dir: Path
148+
) -> pd.DataFrame:
147149
joint_blocks_df = subblocks_df.dropna(subset=["block_config"]).copy()
148150
constructed_blocks_df = _construct_blocks_from_subblocks(subblocks_df)
149151

@@ -164,8 +166,12 @@ def _build_block_library_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFr
164166
dups_with_same_block_idx = dups[dups["block_idx"] == dup_block_idx]
165167
for _, row in dups_with_same_block_idx.head(10).iterrows():
166168
mprint(row.to_dict())
167-
json_dump(block_library_df.to_dict(orient="records"), "ERROR_block_library.json")
168-
json_dump(subblocks_df.to_dict(orient="records"), "ERROR_subblock_library.json")
169+
json_dump(
170+
block_library_df.to_dict(orient="records"), output_dir / "ERROR_block_library.json"
171+
)
172+
json_dump(
173+
subblocks_df.to_dict(orient="records"), output_dir / "ERROR_subblock_library.json"
174+
)
169175
raise ValueError(
170176
f"Found {len(dups)} duplicate blocks in the block library. See ERROR_block_library.json and ERROR_subblock_library.json for more details."
171177
)

modelopt/torch/puzzletron/scoring.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ def main(cfg: DictConfig) -> None:
8282
cfg = hydra.utils.instantiate(cfg)
8383
mprint(cfg)
8484
dist.setup(timeout=cfg.nccl_timeout_minutes)
85-
launch_scoring(cfg)
86-
dist.cleanup()
85+
try:
86+
launch_scoring(cfg)
87+
finally:
88+
dist.cleanup()
8789

8890

8991
if __name__ == "__main__":

modelopt/torch/puzzletron/sewing_kit/passage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def __enter__(self):
297297
def __exit__(self, exc_type, exc_val, exc_tb):
298298
assert self.active_context_manager is not None
299299
self.active_context_manager.__exit__(exc_type, exc_val, exc_tb)
300+
self.active_context_manager = None
300301

301302
def freeze(self):
302303
self.eval()

modelopt/torch/puzzletron/sewing_kit/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,21 +375,21 @@ def has_fake_tensor(v: Any) -> bool:
375375

376376
def _get_device_for_distributed(
377377
group: Optional[torch.distributed.ProcessGroup] = None,
378-
) -> str:
378+
) -> torch.device:
379379
"""
380380
Determine the appropriate device for distributed communication based on the backend.
381381
NCCL backend requires CUDA tensors, while Gloo supports both CPU and CUDA.
382382
"""
383383
if not torch.distributed.is_initialized():
384-
return "cpu"
384+
return torch.device("cpu")
385385

386386
backend = torch.distributed.get_backend(group)
387387
if backend == "nccl":
388388
# NCCL requires CUDA tensors
389-
return torch.cuda.current_device()
389+
return torch.device("cuda", torch.cuda.current_device())
390390
else:
391391
# Gloo and other backends support CPU tensors
392-
return "cpu"
392+
return torch.device("cpu")
393393

394394

395395
def distributed_isend_obj(

modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def optimized_safe_save(kwargs):
342342
# Check for any failures
343343
failed_saves = sum(1 for r in results if not r)
344344
if failed_saves > 0:
345-
mprint(f" Warning: {failed_saves} files failed to save")
345+
raise RuntimeError(f" {failed_saves} shard file(s) failed to save")
346346
else:
347347
mprint(" Using single-threaded saving...")
348348
for kwargs in safe_save_kwargs:

0 commit comments

Comments
 (0)