diff --git a/README.md b/README.md index fce505e..973e9a2 100644 --- a/README.md +++ b/README.md @@ -487,6 +487,7 @@ To set up multi-node training: - Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example) 4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`). 5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model. +6. Edit `custom_models/__init__.py` to import your custom model (e.g., `from . import sba`). Set `__all__` accordingly so only valid packages are exposed. diff --git a/custom_models/__init__.py b/custom_models/__init__.py new file mode 100644 index 0000000..7e6824d --- /dev/null +++ b/custom_models/__init__.py @@ -0,0 +1,3 @@ +from . import sba + +__all__ = ["sba"] diff --git a/train.sh b/train.sh index 3135445..5135e24 100644 --- a/train.sh +++ b/train.sh @@ -66,7 +66,7 @@ steps=$(grep -oP '(?<=--training.steps )[^ ]+' <<< "$params") config=$(grep -oP '(?<=--model.config )[^ ]+' <<< "$params") tokenizer=$(grep -oP '(?<=--model.tokenizer_path )[^ ]+' <<< "$params") model=$( - python -c "import fla, sys; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type' + python -c "import fla, sys, custom_models; from transformers import AutoConfig; print(AutoConfig.from_pretrained(sys.argv[1]).to_json_string())" "$config" | jq -r '.model_type' ) mkdir -p $path