Skip to content

Commit cc44b86

Browse files
authored
Chore(pt): add --model-branch as alias (deepmodeling#4730)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added a new command-line option alias, `--model-branch`, for the `--head` argument in the "freeze" and "test" subcommands. Both options can now be used interchangeably. - **Tests** - Updated tests to use the new `--model-branch` argument in command-line examples, ensuring compatibility with the updated interface. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 17da9d2 commit cc44b86

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

deepmd/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ def main_parser() -> argparse.ArgumentParser:
330330
)
331331
parser_frz.add_argument(
332332
"--head",
333+
"--model-branch",
333334
default=None,
334335
type=str,
335-
help="(Supported backend: PyTorch) Task head to freeze if in multi-task mode.",
336+
help="(Supported backend: PyTorch) Task head (alias: model branch) to freeze if in multi-task mode.",
336337
)
337338

338339
# * test script ********************************************************************
@@ -409,9 +410,10 @@ def main_parser() -> argparse.ArgumentParser:
409410
)
410411
parser_tst.add_argument(
411412
"--head",
413+
"--model-branch",
412414
default=None,
413415
type=str,
414-
help="(Supported backend: PyTorch) Task head to test if in multi-task mode.",
416+
help="(Supported backend: PyTorch) Task head (alias: model branch) to test if in multi-task mode.",
415417
)
416418

417419
# * compress model *****************************************************************

source/tests/pd/test_dp_show.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def setUp(self):
148148
)
149149
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
150150
trainer.run()
151-
run_dp("dp --pd freeze --head model_1")
151+
run_dp("dp --pd freeze --model-branch model_1")
152152

153153
def test_checkpoint(self):
154154
INPUT = "model.ckpt.pd"

source/tests/pt/test_dp_show.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def setUp(self) -> None:
140140
)
141141
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
142142
trainer.run()
143-
run_dp("dp --pt freeze --head model_1")
143+
run_dp("dp --pt freeze --model-branch model_1")
144144

145145
def test_checkpoint(self) -> None:
146146
INPUT = "model.ckpt.pt"

0 commit comments

Comments
 (0)