Skip to content

Commit c843b9c

Browse files
committed
training!
1 parent 85f2f14 commit c843b9c

3 files changed

Lines changed: 389 additions & 4 deletions

File tree

users/dorian_koch/speech_llm/chatterbox_inference.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
"turns": Sequence(
2828
{
2929
"speaker": Value("string"),
30-
"start_time": Value("float32"), # float32 is perfect for timestamps
30+
"start_time": Value("float32"),
31+
"end_time": Value("float32"),
3132
"text": Value("string"),
3233
}
3334
),
@@ -132,6 +133,7 @@ def gen_conversation(
132133
"speaker": turn["speaker"],
133134
"wav": wav,
134135
"start": start,
136+
"end": start + wav.shape[-1],
135137
"text": turn["text"],
136138
}
137139
)
@@ -151,7 +153,7 @@ def gen_conversation(
151153
# Determine total length.
152154
end_samples = 0
153155
for u in utterances:
154-
end_samples = max(end_samples, u["start"] + u["wav"].shape[-1])
156+
end_samples = max(end_samples, u["end"])
155157

156158
rendered = {}
157159
for s, exagg in speakers:
@@ -163,7 +165,7 @@ def gen_conversation(
163165
for u in utterances:
164166
s = u["speaker"]
165167
st = u["start"]
166-
en = st + u["wav"].shape[-1]
168+
en = u["end"]
167169
rendered[s][0, st:en] += u["wav"][0]
168170

169171
return rendered, utterances
@@ -200,6 +202,7 @@ def process_dialogue(
200202
"speaker": u["speaker"],
201203
"text": u["text"],
202204
"start_time": u["start"] / model.sr,
205+
"end_time": u["end"] / model.sr,
203206
}
204207
)
205208
with open(f"{output_dir}/metadata.json", "w", encoding="utf-8") as f:
@@ -385,6 +388,7 @@ def gen():
385388
"turns": {
386389
"speaker": [turn["speaker"] for turn in metadata],
387390
"start_time": [turn["start_time"] for turn in metadata],
391+
"end_time": [turn["end_time"] for turn in metadata],
388392
"text": [turn["text"] for turn in metadata],
389393
},
390394
}

users/dorian_koch/speech_llm/moshi.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
from i6_experiments.users.zeyer.external_models.huggingface import (
22
DownloadHuggingFaceRepoJob,
33
)
4-
from .common import HF_CACHE_DIR
4+
from .common import HF_CACHE_DIR
5+
from pathlib import Path
6+
from sisyphus import Job, Task, tk
7+
import os
8+
import subprocess
9+
import json
10+
from i6_experiments.users.dorian_koch.jobs.hf import HfMergeShards
11+
import sys
12+
import moshi_finetune # needed to get moshi_finetune path for PYTHONPATH below
513

614
# None of this is used anywhere i think
715

16+
817
def download_moshi():
918
# projects/moshi/moshi/moshi/models/loaders.py
1019
# untested code...
@@ -16,3 +25,173 @@ def download_moshi():
1625

1726
def moshi_inference_server(model):
1827
pass
28+
29+
30+
# runs moshi annotate.py
31+
class MoshiAnnotate(Job):
32+
def __init__(
33+
self,
34+
*,
35+
venv_python_path: tk.Path,
36+
in_hf: tk.Path,
37+
shard: int | None = None,
38+
num_shards: int | None = None,
39+
):
40+
self.venv_python_path = venv_python_path
41+
self.in_hf = in_hf
42+
self.shard = shard
43+
self.num_shards = num_shards
44+
self.out_annotations = self.output_path("annotations", directory=True)
45+
self.rqmt = {
46+
"gpu": 1,
47+
"cpu": 6,
48+
"mem": 8,
49+
"time": 1,
50+
}
51+
52+
def tasks(self):
53+
yield Task("run", rqmt=self.rqmt)
54+
55+
def run(self):
56+
this_file_path = Path(__file__).resolve()
57+
moshi_annotate_path = this_file_path.parent / "moshi_annotate_inference.py"
58+
59+
work_dir = os.path.join(os.getcwd(), "annotate_inference_workdir")
60+
os.makedirs(work_dir, exist_ok=True)
61+
62+
command = [
63+
self.venv_python_path.get(),
64+
str(moshi_annotate_path),
65+
"--in_hf",
66+
str(self.in_hf.get()),
67+
]
68+
# if self.out_dir is not None:
69+
# command += ["--out_dir", str(self.out_dir.get())]
70+
# else:
71+
command += ["--out_dir", str(self.out_annotations.get())]
72+
if self.shard is not None and self.num_shards is not None:
73+
command += [
74+
"--in_hf_shard",
75+
str(self.shard),
76+
"--in_hf_num_shards",
77+
str(self.num_shards),
78+
]
79+
env = os.environ.copy()
80+
env["PYTHONUNBUFFERED"] = "1"
81+
env["HF_HOME"] = HF_CACHE_DIR.get()
82+
top_level_file = sys.modules["moshi_finetune"].__file__
83+
package_base_dir = str(Path(top_level_file).parent.parent)
84+
env["PYTHONPATH"] = (
85+
f"{package_base_dir}{os.pathsep}{env['PYTHONPATH']}"
86+
if "PYTHONPATH" in env
87+
else package_base_dir
88+
)
89+
90+
print("Env:")
91+
for k, v in env.items():
92+
if k not in ["PYTHONPATH"]:
93+
continue
94+
print(f"{k}: {v}")
95+
96+
print(
97+
f"Running Moshi annotate with command: {' '.join(command)}",
98+
flush=True,
99+
)
100+
print(f"Using HF cache directory: {HF_CACHE_DIR}")
101+
subprocess.run(command, env=env, check=True)
102+
103+
104+
class MoshiFinetune(Job):
105+
def __init__(self, venv_python_path: tk.Path, train_data: tk.Path):
106+
self.train_data = train_data
107+
self.venv_python_path = venv_python_path
108+
self.out_config = self.output_path("config.yaml")
109+
self.rqmt = {
110+
"gpu": 1,
111+
"cpu": 6,
112+
"mem": 24,
113+
"time": 1,
114+
}
115+
116+
def tasks(self):
117+
yield Task("write_config", mini_task=True)
118+
yield Task("run", rqmt=self.rqmt)
119+
120+
def write_config(self):
121+
run_dir = os.path.join(os.getcwd(), "run_dir/")
122+
txt = f"""
123+
# data
124+
data:
125+
eval_data: '' # Optional Fill
126+
shuffle: true
127+
train_data: '{self.train_data.get()}' # Fill
128+
129+
# model
130+
moshi_paths:
131+
hf_repo_id: "kyutai/moshiko-pytorch-bf16"
132+
133+
full_finetuning: false # Activate lora.enable if partial finetuning
134+
lora:
135+
enable: true # Set to False if full_finetuning is True
136+
rank: 128
137+
scaling: 2.
138+
ft_embed: false # Optional, set to True if you want to finetune the embedding layer
139+
140+
first_codebook_weight_multiplier: 100.
141+
text_padding_weight: .5
142+
143+
# optim
144+
duration_sec: 100
145+
batch_size: 16
146+
max_steps: 2000
147+
gradient_checkpointing: true
148+
optim:
149+
lr: 2e-6
150+
weight_decay: 0.1
151+
pct_start: 0.05
152+
153+
# other
154+
seed: 0
155+
log_freq: 1
156+
eval_freq: 100
157+
do_eval: false
158+
do_ckpt: true
159+
ckpt_freq: 100
160+
161+
162+
save_adapters: true # Must be False if full_finetuning is True
163+
164+
run_dir: "{run_dir}" # Fill
165+
"""
166+
with open(self.out_config, "w") as f:
167+
f.write(txt)
168+
169+
def run(self):
170+
command = [
171+
self.venv_python_path.get(),
172+
"-m",
173+
"torch.distributed.run",
174+
"--nproc-per-node",
175+
str(self.rqmt["gpu"]),
176+
"-m",
177+
"moshi_finetune.train",
178+
self.out_config.get(),
179+
]
180+
181+
env = os.environ.copy()
182+
env["PYTHONUNBUFFERED"] = "1"
183+
env["HF_HOME"] = HF_CACHE_DIR.get()
184+
top_level_file = sys.modules["moshi_finetune"].__file__
185+
package_base_dir = f"{str(Path(top_level_file).parent.parent)}{os.pathsep}{str(Path(top_level_file).parent)}"
186+
187+
env["PYTHONPATH"] = (
188+
f"{package_base_dir}{os.pathsep}{env['PYTHONPATH']}"
189+
if "PYTHONPATH" in env
190+
else package_base_dir
191+
)
192+
print(
193+
f"Running Moshi training with command: {' '.join(command)}",
194+
flush=True,
195+
)
196+
print(f"Using HF cache directory: {HF_CACHE_DIR}")
197+
subprocess.run(command, env=env, check=True)

0 commit comments

Comments
 (0)