Skip to content

Commit 13b501c

Browse files
authored
Merge branch 'main' into yash-tfv5
2 parents 8322d9b + 78573e9 commit 13b501c

5 files changed

Lines changed: 112 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
- [Advanced Data Processing](./docs/advanced-data-preprocessing.md#data-config)
99
- [Guidelines on supported data formats](./docs/advanced-data-preprocessing.md#use-cases-supported-via-command-line-argument-training_data_path)
1010
- [Offline data processing](#offline-data-preprocessing)
11-
- [Online data mixing](./docs/online-data-mixing.md)
11+
- [Online data mixing](./docs/advanced-data-preprocessing.md#online-data-mixing-section)
1212
- [Additional Frameworks](#additional-frameworks)
1313
- [Inference](#inference)
1414
- [Validation](#validation)

build/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ ARG ENABLE_MLFLOW=false
2525
ARG ENABLE_FMS_ACCELERATION=true
2626
ARG ENABLE_SCANNER=false
2727
ARG ENABLE_CLEARML=false
28+
ARG ENABLE_RECOMMENDER=true
2829

2930
## Base Layer ##################################################################
3031
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
@@ -188,6 +189,9 @@ RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
188189
RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
189190
python -m pip install --user "$(head bdist_name)[clearml]"; \
190191
fi
192+
RUN if [[ "${ENABLE_RECOMMENDER}" == "true" ]]; then \
193+
python -m pip install --user "$(head bdist_name)[tuning-config-recommender]"; \
194+
fi
191195

192196
# Clean up the wheel module. It's only needed by flash-attn install
193197
RUN python -m pip uninstall wheel build -y && \

build/nvcr.Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ ARG ENABLE_MLFLOW=false
3434
ARG ENABLE_SCANNER=false
3535
ARG ENABLE_CLEARML=true
3636
ARG ENABLE_TRITON_KERNELS=true
37+
ARG ENABLE_RECOMMENDER=true
3738

3839
# Ensures to always build mamba_ssm from source
3940
ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm
@@ -76,6 +77,9 @@ RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
7677
RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
7778
pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \
7879
fi
80+
RUN if [[ "${ENABLE_RECOMMENDER}" == "true" ]]; then \
81+
python -m pip install --user "$(head bdist_name)[tuning-config-recommender]"; \
82+
fi
7983

8084
# cleanup
8185
RUN rm -rf /root/.cache /tmp/* /opt/pytorch

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ fms-accel-all = [
6060
"fms-acceleration-moe",
6161
"fms-acceleration-odm"
6262
]
63+
tuning-config-recommender=["tuning-config-recommender>=0.1.5"]
6364

6465
[tool.setuptools.packages.find]
6566
exclude = ["tests", "tests.*"]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import os
4+
from pathlib import Path
5+
6+
from huggingface_hub import snapshot_download
7+
from transformers import AutoConfig, AutoTokenizer
8+
9+
from fms_acceleration_moe.utils import recover_safetensors_from_dcp
10+
11+
12+
HF_CACHE = "/workspace/.hf"
13+
os.environ.setdefault("HF_HOME", HF_CACHE)
14+
15+
16+
def has_weights(p: Path) -> bool:
17+
return (
18+
(p / "model.safetensors").exists()
19+
or (p / "model.safetensors.index.json").exists()
20+
or any(p.glob("model-*.safetensors"))
21+
)
22+
23+
24+
def get_base_model(model_id_or_path: str, allow_download: bool) -> Path:
25+
p = Path(model_id_or_path)
26+
27+
if p.exists():
28+
if not has_weights(p):
29+
raise RuntimeError(f"No base weights found in {p}")
30+
return p.resolve()
31+
32+
if not allow_download:
33+
raise RuntimeError("Base model not found locally and downloads disabled")
34+
35+
local_dir = snapshot_download(
36+
repo_id=model_id_or_path,
37+
allow_patterns=[
38+
"config.json",
39+
"model*.safetensors",
40+
"tokenizer*",
41+
"special_tokens_map.json",
42+
"generation_config.json",
43+
],
44+
)
45+
46+
local_dir = Path(local_dir).resolve()
47+
if not has_weights(local_dir):
48+
raise RuntimeError(f"Downloaded base model but weights missing in {local_dir}")
49+
50+
return local_dir
51+
52+
53+
def main():
54+
ap = argparse.ArgumentParser()
55+
ap.add_argument("--dcp_checkpoint_dir", required=True, type=Path)
56+
ap.add_argument("--pretrained_model_name_or_path", required=True)
57+
ap.add_argument("--output_dir", required=True, type=Path)
58+
ap.add_argument("--allow_model_download", action="store_true")
59+
ap.add_argument(
60+
"--additional_special_tokens",
61+
nargs="*",
62+
default=[],
63+
)
64+
ap.add_argument("--chat_template", type=str, default=None)
65+
args = ap.parse_args()
66+
67+
args.output_dir.mkdir(parents=True, exist_ok=True)
68+
69+
# base model (local snapshot)
70+
base_model_dir = get_base_model(
71+
args.pretrained_model_name_or_path,
72+
args.allow_model_download,
73+
)
74+
75+
# dcp to hf compatible
76+
recover_safetensors_from_dcp(
77+
str(args.dcp_checkpoint_dir),
78+
str(base_model_dir),
79+
str(args.output_dir),
80+
)
81+
82+
# tokenizer chat_template plus additional tokens
83+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
84+
if args.chat_template is not None:
85+
tokenizer.chat_template = args.chat_template
86+
if args.additional_special_tokens:
87+
tokenizer.add_special_tokens(
88+
{"additional_special_tokens": args.additional_special_tokens}
89+
)
90+
tokenizer.save_pretrained(args.output_dir)
91+
92+
config = AutoConfig.from_pretrained(base_model_dir)
93+
config.vocab_size = len(tokenizer)
94+
config.save_pretrained(args.output_dir)
95+
96+
print(f"[OK] HF checkpoint written to {args.output_dir}")
97+
print(f"[OK] vocab_size = {len(tokenizer)}")
98+
99+
100+
if __name__ == "__main__":
101+
main()
102+

0 commit comments

Comments
 (0)