Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@ This repository contains training, inference, and evaluation code for the paper

## Installation

Installation requires Python 3.11+. To install the package and all dependencies with pip:
Installation requires Python 3.11+.

To install the package with `pip`:

```bash
git clone https://github.com/EleutherAI/aria
cd aria
pip install -e ".[all]"
cd aria && pip install -e ".[all]"
```

To install the package with `uv`:

```bash
git clone https://github.com/EleutherAI/aria
cd aria && uv sync --extra all
```

## Quickstart
Expand Down Expand Up @@ -79,7 +87,7 @@ Our embedding model was trained to capture composition-level and performance-lev

## Real-time demo

In `demo/` we provide an MLX (Apple Silicon) implementation of the real-time interactive piano-continuation demo showcased in our release blog post. In order to use the demo, you must download the demo-specific model checkpoint which enhances the model to additionally control the sustain pedal ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model-demo.safetensors?download=true)).
In `demo/` we provide an MLX (Apple Silicon) implementation of the real-time interactive piano-continuation demo showcased in our release blog post. In order to use the demo, you must download the demo-specific model checkpoint which enhances the model to additionally control the sustain pedal ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model-demo.safetensors?download=true)). If using `uv`, install the demo dependencies with `uv sync --extra demo`.

For our demonstration, we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. We disabled the built-in Disklavier playback mode, instead manually calibrating key-velocity latency to enhance responsiveness. You may recreate this in your own environment with our acoustic calibration settings, using the following script:

Expand All @@ -93,8 +101,8 @@ python ./demo/demo_mlx.py \
--hardware ./demo/hardware/c4dm-disklavier.json \
--midi_control_signal 67 \
--midi_reset_control_signal 66 \
--temp 0.85 \
--min_p 0.05
--temp 0.9 \
--min_p 0.035
```

A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. In this mode, you can initiate the model takeover by pressing the enter key.
Expand All @@ -107,8 +115,8 @@ python ./demo/demo_mlx.py \
--midi_path ${MIDI_PATH} \
--midi_through <midi-playback-port> \
--midi_out <midi-playback-port> \
--temp 0.85 \
--min_p 0.05
--temp 0.9 \
--min_p 0.035
```

❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, specifically GPU memory bandwidth.
Expand Down
4 changes: 2 additions & 2 deletions aria/inference/model_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def fill_condition_kv(self, emb: mx.array):
assert self.model_config.emb_size is not None

input_pos = mx.array([0], dtype=mx.int32)
mask = self.causal_mask[None, None, input_pos]
mask = self.causal_mask[None, None, input_pos, :1]
offset = 0

x = mx.expand_dims(emb, axis=1)

for layer in self.encode_layers:
x = layer(x, input_pos, offset, mask)
x = layer(x, input_pos, 0, offset, mask)

def __call__(
self,
Expand Down
36 changes: 18 additions & 18 deletions aria/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def get_tokenizer_name(
train_config = TrainingDataset.get_config_from_path(train_data_paths[0])
val_config = TrainingDataset.get_config_from_path(val_data_path)

assert (
train_config["tokenizer_name"] == val_config["tokenizer_name"]
), "Dataset tokenizers don't match"
assert train_config["tokenizer_name"] == val_config["tokenizer_name"], (
"Dataset tokenizers don't match"
)

return train_config["tokenizer_name"]

Expand Down Expand Up @@ -127,9 +127,9 @@ def setup_project_dir(project_dir: str | None):
elif project_dir:
# Run checks on project directory
if os.path.isdir(project_dir):
assert (
len(os.listdir(project_dir)) == 0
), "Provided project directory is not empty"
assert len(os.listdir(project_dir)) == 0, (
"Provided project directory is not empty"
)
project_dir_abs = os.path.abspath(project_dir)
elif os.path.isfile(project_dir):
raise FileExistsError(
Expand Down Expand Up @@ -226,9 +226,9 @@ def get_dataloaders(
if init_epoch:
train_dataset.init_epoch(idx=init_epoch)

assert (
len(val_dataset.epoch_files_by_dir[0]) == 1
), "val-data directory should only contain one epoch"
assert len(val_dataset.epoch_files_by_dir[0]) == 1, (
"val-data directory should only contain one epoch"
)

if apply_aug:
train_dataset.set_transform(tokenizer.export_data_aug())
Expand Down Expand Up @@ -427,9 +427,9 @@ def val_loop(dataloader, _epoch: int):
return avg_val_loss

if steps_per_checkpoint:
assert (
steps_per_checkpoint > 1
), "Invalid checkpoint mode value (too small)"
assert steps_per_checkpoint > 1, (
"Invalid checkpoint mode value (too small)"
)

TRAILING_LOSS_STEPS = 200
PAD_ID = train_dataloader.dataset.tokenizer.pad_id
Expand Down Expand Up @@ -516,9 +516,9 @@ def resume_train(
assert torch.cuda.is_available() is True, "CUDA not available"
assert os.path.isdir(checkpoint_dir), f"No dir at {checkpoint_dir}"
for train_data_path in train_data_paths:
assert os.path.isdir(
train_data_path
), f"No dir found at {train_data_path}"
assert os.path.isdir(train_data_path), (
f"No dir found at {train_data_path}"
)
assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}"

tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path)
Expand Down Expand Up @@ -649,9 +649,9 @@ def train(
assert batch_size > 0, "Invalid batch size"
assert torch.cuda.is_available() is True, "CUDA not available"
for train_data_path in train_data_paths:
assert os.path.isdir(
train_data_path
), f"No dir found at {train_data_path}"
assert os.path.isdir(train_data_path), (
f"No dir found at {train_data_path}"
)
assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}"

tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path)
Expand Down
51 changes: 0 additions & 51 deletions demo/config.json

This file was deleted.

Loading
Loading