Skip to content

Commit eeb2950

Browse files
author
Alex J Lennon
committed
Training: --yes flag, MFA --copy-only, skip successful steps
- Add --yes/-y to train.py to auto-confirm download/prepare (no EOF in non-interactive runs) - Add --copy-only to run_mfa_alignment_prepared.sh to copy alignment from cache only (fixes Step 5 with 100k+ files via find -print0) - Dataset manager: run MFA with --copy-only when alignment cache exists; require JSONs for 'prepared'; skip corpus creation when WAV+LAB already exist - QUICKSTART: AMD GPU (ROCm) section, note about uv run reverting to CUDA wheel - SSH multiplexing rule for ai-tools LXC Made-with: Cursor
1 parent f43a35a commit eeb2950

7 files changed

Lines changed: 274 additions & 39 deletions

File tree

QUICKSTART.md

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,82 @@ To train for **UK English** (British phoneme set and viseme mapping):
8383

8484
The UK recipe uses `training/configs/viseme_map_en_uk_mfa.json`, which maps the UK MFA phone set (IPA-style symbols) to the same 15 visemes. When prompted to download/prepare data, answer **`y`**; alignment will run with the UK model.
8585

86+
## 4c. Full training (production ONNX)
87+
88+
The quick recipes (4 and 4b) use **dev-clean** only and produce a small ONNX suitable for testing. For a **production-quality** model you need to train on the **full** LibriSpeech training sets and then export to ONNX.
89+
90+
**Data:** Full training uses LibriSpeech **train-clean-100** (~6GB), **train-clean-360** (~23GB), and **train-other-500** (~30GB). The first time you run, the script will prompt to download and prepare these; preparation (WAV + MFA alignment) takes a long time per split. **GPU optional:** the recipes default to `device = "cpu"` so training runs without a GPU; for much faster training set `[hardware] device = "cuda"` in the recipe (or `mps` on Apple Silicon).
91+
92+
**US English (full):**
93+
94+
```bash
95+
uv run python training/train.py --config training/recipes/tcn_config.toml
96+
```
97+
98+
When prompted to download and prepare missing datasets, answer **`y`**. Each split (train-clean-100, train-clean-360, train-other-500) will be downloaded, converted, and aligned with MFA (US) in turn. Training runs for up to 100 epochs with early stopping.
99+
100+
**UK English (full):**
101+
102+
1. Install UK MFA models (see 4b).
103+
2. Set MFA env vars and run the full UK recipe:
104+
105+
```bash
106+
export MFA_ACOUSTIC_MODEL=english_mfa MFA_DICTIONARY_MODEL=english_uk_mfa
107+
uv run python training/train.py --config training/recipes/tcn_full_uk.toml
108+
```
109+
110+
Answer **`y`** when asked to download and prepare datasets. Alignment will use the UK dictionary.
111+
112+
**Export to ONNX:** After training, export the best checkpoint so the realtime harness and C# app can use it:
113+
114+
```bash
115+
uv run python training/tools/export_onnx.py --list
116+
uv run python training/tools/export_onnx.py --run <run_name> --checkpoint best
117+
```
118+
119+
`--list` shows available runs under `training/runs/`. Use the run name (e.g. `tcn_full_uk_2026-02-21_12-00-00`) with `--run`. The export writes to `export/<run_name>/` (model.onnx and config.json). The realtime script and C# app pick the newest `export/*/model.onnx` by default.
120+
121+
**Smaller full run:** To try full training with less data, edit the recipe and set e.g. `splits = ["train-clean-100"]` (100h only). Use `training/recipes/tcn_config.toml` (US) or `training/recipes/tcn_full_uk.toml` (UK).
122+
86123
## 5. Optional: use GPU
87124

88-
Edit `training/recipes/tcn_quick_laptop.toml` and set:
125+
Edit the recipe (e.g. `training/recipes/tcn_quick_laptop.toml` or `tcn_config.toml`) and set:
89126

90127
```toml
91128
[hardware]
92-
device = "cuda" # or "mps" on Apple Silicon
129+
device = "cuda" # NVIDIA GPU, or AMD GPU with ROCm (same API)
130+
# device = "mps" # Apple Silicon
93131
```
94132

95-
If CUDA/MPS isn’t available, the trainer falls back to CPU and logs a warning.
133+
If CUDA/ROCm/MPS isn’t available, the trainer falls back to CPU and logs a warning.
134+
135+
### 5b. AMD GPU (ROCm)
136+
137+
The default `uv sync` installs PyTorch built for **NVIDIA CUDA**. On a machine with an **AMD GPU** (e.g. Radeon RX 7700/7800, Navi 32), you need PyTorch built for **ROCm** so that `torch.cuda.is_available()` is True (ROCm uses the same `torch.cuda` API).
138+
139+
**1. Ensure the GPU is visible**
140+
141+
- Kernel driver: `/dev/kfd` and `/dev/dri/renderD*` should exist (amdgpu driver).
142+
- Your user must be in the `render` (and usually `video`) group so the process can open those devices:
143+
`groups` should list `render`; if not, add with `sudo usermod -aG render,video $USER` and log in again.
144+
145+
**2. Install PyTorch with ROCm**
146+
147+
From the project root, override the default torch/torchaudio with the ROCm wheels. Use the index that matches your ROCm version (see [PyTorch get-started](https://pytorch.org/get-started/locally/) and choose Linux → Pip → ROCm). Example for ROCm 6.3:
148+
149+
```bash
150+
uv pip install torch torchaudio --index-url https://download.pytorch.org/whl/rocm6.3
151+
```
152+
153+
If your distro uses a different ROCm version, use the matching index (e.g. `rocm5.6`, `rocm6.2`). Python 3.13 may not have ROCm wheels on all indices; if so, try the [AMD ROCm docs](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/) or PyTorch “Previous versions” for a compatible wheel.
154+
155+
**3. Use the GPU in training**
156+
157+
In the recipe set `device = "cuda"` (same as for NVIDIA). Then run training as usual; the trainer will use the AMD GPU via ROCm.
158+
159+
**Verify:** `uv run python -c "import torch; print(torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else '')"` should print `True` and the GPU name.
160+
161+
**Note:** If you use `uv` and install ROCm via `uv pip install ... --index-url ...rocm6.3`, then `uv run` will re-sync from the lock file and can revert to the default CUDA wheel. To keep using the GPU, run training with the venv Python directly, e.g. `.venv/bin/python training/train.py --config ...`, or a wrapper script that calls `.venv/bin/python`.
96162

97163
## 6. Where outputs go
98164

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ mfa model download g2p english_us_arpa
2727

2828
Dataset Download is now integrated in the training script.
2929

30-
```python training/train.py --config training/recipes/tcn_config.toml```
30+
**Quick (laptop) training:**
31+
`uv run python training/train.py --config training/recipes/tcn_quick_laptop.toml`
32+
33+
**Full training (production ONNX):** See [QUICKSTART.md](QUICKSTART.md) section 4c. Use `training/recipes/tcn_config.toml` (US) or `training/recipes/tcn_full_uk.toml` (UK), then export with `training/tools/export_onnx.py --run <run_name> --checkpoint best`.
3134

3235

3336
This project uses the [LibriSpeech ASR corpus](https://openslr.org/12/) (CC BY 4.0 license).

run_mfa_alignment_prepared.sh

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,21 @@ command_exists() {
4040
}
4141

4242
# Check arguments
43-
# Usage: DATASET [MODEL] or DATASET ACOUSTIC DICTIONARY (for UK: english_mfa english_uk_mfa)
43+
# Usage: DATASET [MODEL] or DATASET ACOUSTIC DICTIONARY or DATASET [MODEL] --copy-only
44+
COPY_ONLY=0
45+
while [[ $# -gt 0 ]]; do
46+
case "$1" in
47+
--copy-only) COPY_ONLY=1; shift ;;
48+
*) break ;;
49+
esac
50+
done
51+
4452
if [ $# -lt 1 ] || [ $# -gt 3 ]; then
45-
print_error "Usage: $0 DATASET_NAME [MFA_MODEL]"
46-
print_error " or: $0 DATASET_NAME ACOUSTIC_MODEL DICTIONARY_MODEL"
53+
print_error "Usage: $0 DATASET_NAME [MFA_MODEL] [--copy-only]"
54+
print_error " or: $0 DATASET_NAME ACOUSTIC_MODEL DICTIONARY_MODEL [--copy-only]"
4755
print_error "Example: $0 test-clean"
4856
print_error "Example (UK): $0 dev-clean english_mfa english_uk_mfa"
57+
print_error "Example (copy only, skip MFA when alignment cache exists): $0 train-clean-360 --copy-only"
4958
exit 1
5059
fi
5160

@@ -113,6 +122,33 @@ fi
113122

114123
print_status "Found ${WAV_COUNT} WAV files and ${LAB_COUNT} LAB files"
115124

125+
# --- Copy-only mode: only copy existing alignment from cache to prepared (skip MFA) ---
126+
if [ "${COPY_ONLY}" -eq 1 ]; then
127+
if [ ! -d "${TEMP_OUT_ALIGN}" ]; then
128+
print_error "Copy-only mode: alignment output not found at ${TEMP_OUT_ALIGN}"
129+
print_error "Run without --copy-only to perform full MFA alignment first."
130+
exit 1
131+
fi
132+
JSON_COUNT=$(find "${TEMP_OUT_ALIGN}" -name "*.json" -not -name "alignment_analysis*" | wc -l)
133+
if [ "${JSON_COUNT}" -eq 0 ]; then
134+
print_error "Copy-only mode: no JSON alignment files in ${TEMP_OUT_ALIGN}"
135+
exit 1
136+
fi
137+
print_status "Copy-only: copying ${JSON_COUNT} alignment files to prepared dataset..."
138+
ALIGNED_COUNT=0
139+
while IFS= read -r -d '' json_file; do
140+
base_name=$(basename "$json_file" .json)
141+
dest_file="${PREPARED_DIR}/${base_name}.json"
142+
if [ -f "${PREPARED_DIR}/${base_name}.wav" ]; then
143+
cp "$json_file" "$dest_file"
144+
ALIGNED_COUNT=$((ALIGNED_COUNT + 1))
145+
fi
146+
done < <(find "${TEMP_OUT_ALIGN}" -name "*.json" -not -name "alignment_analysis*" -print0)
147+
print_success "Copied ${ALIGNED_COUNT} alignment files to ${PREPARED_DIR}"
148+
print_success "Done (copy-only). No cleanup - cache left at ${TEMP_OUT_ALIGN}"
149+
exit 0
150+
fi
151+
116152
# Create necessary directories
117153
mkdir -p "${TEMP_CORPUS}"
118154
mkdir -p "${MFA_DIR}"
@@ -220,8 +256,8 @@ if [ $? -eq 0 ]; then
220256
print_status "Step 5: Copying alignment results to prepared dataset..."
221257

222258
ALIGNED_COUNT=0
223-
# Find all JSON files in speaker subdirectories (skip alignment_analysis.csv)
224-
for json_file in $(find "${TEMP_OUT_ALIGN}" -name "*.json" -not -name "alignment_analysis*"); do
259+
# Use find -exec to avoid command-line length limits with 100k+ files
260+
while IFS= read -r -d '' json_file; do
225261
base_name=$(basename "$json_file" .json)
226262
dest_file="${PREPARED_DIR}/${base_name}.json"
227263

@@ -231,7 +267,7 @@ if [ $? -eq 0 ]; then
231267
else
232268
print_warning "No corresponding WAV file for alignment: ${base_name}"
233269
fi
234-
done
270+
done < <(find "${TEMP_OUT_ALIGN}" -name "*.json" -not -name "alignment_analysis*" -print0)
235271

236272
print_success "Copied ${ALIGNED_COUNT} alignment files to prepared dataset"
237273

training/modules/data_pipeline.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ class LibriSpeechDataset(Dataset):
320320
"""
321321

322322
def __init__(self, config: TrainingConfiguration, split: str,
323-
is_training: bool = True, data_root: Optional[str] = None):
323+
is_training: bool = True, data_root: Optional[str] = None, interactive: bool = True):
324324
"""
325325
Initialize LibriSpeech dataset
326326
@@ -329,6 +329,7 @@ def __init__(self, config: TrainingConfiguration, split: str,
329329
split: Dataset split (e.g., "train-clean-100", "dev-clean")
330330
is_training: Whether this is for training (affects augmentation)
331331
data_root: Root directory for LibriSpeech data
332+
interactive: If False, auto-confirm dataset download/prepare (--yes)
332333
"""
333334
self.config = config
334335
self.split = split
@@ -347,7 +348,7 @@ def __init__(self, config: TrainingConfiguration, split: str,
347348
self.dataset_manager = DatasetManager()
348349

349350
# Ensure dataset is prepared
350-
if not self.dataset_manager.prepare_datasets([split], interactive=True):
351+
if not self.dataset_manager.prepare_datasets([split], interactive=interactive):
351352
raise RuntimeError(f"Failed to prepare dataset: {split}")
352353

353354
# Load prepared data file list
@@ -688,14 +689,16 @@ def collate_audio_samples(batch: List[AudioSample]) -> Dict[str, torch.Tensor]:
688689

689690
def create_data_loaders(config: TrainingConfiguration,
690691
data_root: Optional[str] = None,
691-
pin_memory: Optional[bool] = None) -> Tuple[DataLoader, DataLoader, DataLoader]:
692+
pin_memory: Optional[bool] = None,
693+
interactive: bool = True) -> Tuple[DataLoader, DataLoader, DataLoader]:
692694
"""
693695
Create training, validation, and test data loaders
694696
695697
Args:
696698
config: Training configuration
697699
data_root: Root directory for data (optional)
698700
pin_memory: Override pin_memory setting (optional)
701+
interactive: If False, auto-confirm dataset download/prepare (--yes)
699702
700703
Returns:
701704
Tuple of (train_loader, val_loader, test_loader)
@@ -710,7 +713,8 @@ def create_data_loaders(config: TrainingConfiguration,
710713
config=config,
711714
split=split,
712715
is_training=True,
713-
data_root=data_root
716+
data_root=data_root,
717+
interactive=interactive
714718
)
715719
train_datasets.append(dataset)
716720

@@ -722,14 +726,16 @@ def create_data_loaders(config: TrainingConfiguration,
722726
config=config,
723727
split=config.data.val_split,
724728
is_training=False,
725-
data_root=data_root
729+
data_root=data_root,
730+
interactive=interactive
726731
)
727732

728733
test_dataset = LibriSpeechDataset(
729734
config=config,
730735
split=config.data.test_split,
731736
is_training=False,
732-
data_root=data_root
737+
data_root=data_root,
738+
interactive=interactive
733739
)
734740

735741
# Create data loaders

training/modules/dataset_manager.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,21 @@ def _check_aligned_data(self, dataset: str) -> bool:
9999
return False
100100

101101
def _check_prepared_data(self, dataset: str) -> bool:
102-
"""Check if prepared data exists (WAV + LAB in flat structure).
103-
104-
JSON alignment files are optional and may be added later.
105-
"""
102+
"""Check if prepared data exists and is ready for training (WAV + LAB + alignment JSONs)."""
106103
prepared_dataset_dir = self.prepared_dir / dataset
107104
if not prepared_dataset_dir.exists():
108105
return False
109-
110-
# Require matching WAV and LAB pairs; JSONs are not required
106+
111107
wav_files = set(f.stem for f in prepared_dataset_dir.glob("*.wav"))
112108
lab_files = set(f.stem for f in prepared_dataset_dir.glob("*.lab"))
113-
return len(wav_files) > 0 and wav_files == lab_files
109+
if len(wav_files) == 0 or wav_files != lab_files:
110+
return False
111+
112+
# Require at least one alignment JSON so we don't treat "MFA copy failed" as ready
113+
json_stems = set(f.stem for f in prepared_dataset_dir.glob("*.json"))
114+
if not json_stems or not json_stems.intersection(wav_files):
115+
return False
116+
return True
114117

115118
def prepare_datasets(self, datasets: List[str], interactive: bool = True) -> bool:
116119
"""
@@ -163,8 +166,12 @@ def prepare_datasets(self, datasets: List[str], interactive: bool = True) -> boo
163166
logger.error("Cannot proceed without required datasets")
164167
return False
165168
else:
166-
logger.error(f"Missing datasets: {', '.join(missing_datasets)}")
167-
return False
169+
# Non-interactive (e.g. --yes): auto-confirm download and prepare
170+
logger.info(f"Auto-confirming download/prepare for: {', '.join(missing_datasets)}")
171+
if not self._download_datasets(missing_datasets):
172+
logger.error("Failed to download datasets")
173+
return False
174+
needs_preparation.extend(missing_datasets)
168175

169176
# Handle datasets that need preparation
170177
if needs_preparation:
@@ -224,9 +231,15 @@ def _download_datasets(self, datasets: List[str]) -> bool:
224231
def _prepare_single_dataset(self, dataset: str) -> bool:
225232
"""Prepare a single dataset through the full pipeline"""
226233
logger.info(f"Preparing dataset: {dataset}")
227-
228-
# Step 1: Create prepared dataset (WAV + LAB) directly
229-
if not self._check_prepared_data(dataset):
234+
prepared_dataset_dir = self.prepared_dir / dataset
235+
236+
# Step 1: Create prepared dataset (WAV + LAB) only if missing
237+
has_wav_lab = (
238+
prepared_dataset_dir.exists()
239+
and len(list(prepared_dataset_dir.glob("*.wav"))) > 0
240+
and len(list(prepared_dataset_dir.glob("*.lab"))) > 0
241+
)
242+
if not has_wav_lab and not self._check_prepared_data(dataset):
230243
logger.info("Creating prepared dataset...")
231244
if not self._create_corpus(dataset):
232245
return False
@@ -279,22 +292,30 @@ def _create_corpus(self, dataset: str) -> bool:
279292
return False
280293

281294
def _run_alignment(self, dataset: str) -> bool:
282-
"""Run MFA alignment using the prepared dataset MFA script"""
295+
"""Run MFA alignment using the prepared dataset MFA script.
296+
If alignment output already exists in cache, runs with --copy-only to skip MFA.
297+
"""
283298
try:
284-
# Path to the MFA alignment script
285299
mfa_script = self.project_root / "run_mfa_alignment_prepared.sh"
286-
287300
if not mfa_script.exists():
288301
logger.error(f"MFA alignment script not found: {mfa_script}")
289302
return False
290-
291-
logger.info(f"Running MFA alignment for {dataset}...")
292-
cmd = [str(mfa_script), dataset]
293-
303+
304+
cache_align_dir = self.cache_dir / f"out_align_{dataset}"
305+
copy_only = cache_align_dir.is_dir() and any(
306+
cache_align_dir.rglob("*.json")
307+
)
308+
if copy_only:
309+
logger.info(f"Alignment cache exists for {dataset}, running copy-only...")
310+
cmd = [str(mfa_script), dataset, "--copy-only"]
311+
else:
312+
logger.info(f"Running MFA alignment for {dataset}...")
313+
cmd = [str(mfa_script), dataset]
314+
294315
result = subprocess.run(cmd, check=True, capture_output=False)
295316
logger.info(f"MFA alignment completed for {dataset}")
296317
return True
297-
318+
298319
except subprocess.CalledProcessError as e:
299320
logger.error(f"MFA alignment failed for {dataset}: {e}")
300321
return False

0 commit comments

Comments
 (0)