diff --git a/.pylintrc b/.pylintrc index 3ce134a02..e7d47d160 100644 --- a/.pylintrc +++ b/.pylintrc @@ -475,7 +475,7 @@ notes-rgx= [REFACTORING] # Maximum number of nested blocks for function / method body -max-nested-blocks=5 +max-nested-blocks=6 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then diff --git a/README.md b/README.md index 7ef790a89..90b963933 100644 --- a/README.md +++ b/README.md @@ -855,6 +855,9 @@ Notes: - When a boolean is passed, the expert parallel degree defaults to 1 and further the behaviour would be as follows: - if True, it is Scatter MoE Kernels with experts sharded based on the top level sharding protocol (e.g. FSDP). - if False, Scatter MoE Kernels with complete replication of experts across ranks. + - FSDP must be used when lora tuning with `--fast_moe` + - lora tuning with ScatterMoE is supported, but because of inference restrictions on vLLM/vanilla PEFT, the expert layers and router linear layer should not be trained as `target_modules` for models being tuned with ScatterMoE. Users have control over which `target_modules` they wish to train: + - At this time, only attention layers are trainable when using LoRA with scatterMoE. Until support for the router linear layer is added in, target modules must be specified explicitly (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj"]`) instead of passing `target_modules: ["all-linear"]`. - `world_size` must be divisible by the `ep_degree` - `number of experts` in the MoE module must be divisible by the `ep_degree` - Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script. diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 6cbc7d252..43cf8dda0 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -146,6 +146,17 @@ def main(): save_model_dir, save_model_dir, num_added_tokens ) + # In case of ScatterMoE LoRa + hf_converted_checkpoint = os.path.join( + save_model_dir, "hf_converted_checkpoint" + ) + if os.path.exists( + os.path.join(hf_converted_checkpoint, "adapter_model.safetensors") + ): + post_process_vLLM_adapters_new_tokens( + hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens + ) + if ( os.path.exists(os.path.join(output_dir, "added_tokens_info.json")) and job_config.get("save_strategy") != "no" @@ -159,11 +170,30 @@ def main(): for _, dirs, _ in os.walk(output_dir, topdown=False): for name in dirs: if "checkpoint-" in name.lower(): - post_process_vLLM_adapters_new_tokens( - os.path.join(output_dir, name), - os.path.join(output_dir, name), - num_added_tokens, + base_checkpoint_dir = os.path.join(output_dir, name) + hf_converted_checkpoint = os.path.join( + base_checkpoint_dir, "hf_converted_checkpoint" + ) + + # Use hf_converted_checkpoint if exists, otherwise use base_checkpoint_dir + checkpoint_dir = ( + hf_converted_checkpoint + if os.path.exists( + os.path.join( + hf_converted_checkpoint, "adapter_model.safetensors" + ) + ) + else base_checkpoint_dir ) + + if os.path.exists( + os.path.join(checkpoint_dir, "adapter_model.safetensors") + ): + post_process_vLLM_adapters_new_tokens( + checkpoint_dir, + checkpoint_dir, + num_added_tokens, + ) else: logging.warning( "Failed to post-process: file added_tokens_info.json not in path %s", diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 80a445304..757d9fa00 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -532,8 +532,8 @@ def test_framework_initialized_properly_moe(): ) # spy inside the train to ensure that the ilab plugin is called - assert spy["model_loader_calls"] == 1 - assert spy["augmentation_calls"] == 0 + assert spy["model_loader_calls"] == 0 + assert spy["augmentation_calls"] == 1 assert spy["get_ready_for_train_calls"] == 1 @@ -776,37 +776,34 @@ def test_error_raised_fast_moe_with_non_moe_model(): """ Ensure error is thrown when `--fast_moe` is passed and model is not MoE """ - with pytest.raises( - AttributeError, - match="'LlamaConfig' object has no attribute 'num_local_experts'", - ): - with tempfile.TemporaryDirectory() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: - model_args = copy.deepcopy(MODEL_ARGS) - model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" - model_args.torch_dtype = torch.bfloat16 - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.output_dir = tempdir - train_args.save_strategy = "no" - train_args.bf16 = True - data_args = copy.deepcopy(DATA_ARGS) - data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT - data_args.response_template = "\n\n### Label:" - data_args.dataset_text_field = "output" + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" + model_args.torch_dtype = torch.bfloat16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.bf16 = True + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.response_template = "\n\n### Label:" + data_args.dataset_text_field = "output" - # initialize a config - moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) + # initialize a config + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) - # 1. mock a plugin class - # 2. register the mocked plugins - # 3. call sft_trainer.train - with build_framework_and_maybe_instantiate( - [ - (["training.moe.scattermoe"], ScatterMoEAccelerationPlugin), - ], - instantiate=False, - ): - with instantiate_model_patcher(): + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["training.moe.scattermoe"], ScatterMoEAccelerationPlugin), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + with pytest.raises((ValueError, AttributeError)): sft_trainer.train( model_args, data_args, diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d5d6c4a28..1dfc553b4 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1453,6 +1453,61 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree): ) +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +@pytest.mark.parametrize( + "target_modules", + [ + "all-linear", + ["q_proj"], + ["q_proj", "k_proj"], + ["q_proj", "k_proj", "v_proj"], + ["q_proj", "k_proj", "v_proj", "o_proj"], + ], +) +@pytest.mark.parametrize("ep_degree", [True, False]) +@pytest.mark.parametrize("dataset_path", [TWITTER_COMPLAINTS_DATA_JSONL]) +def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree): + """Check if we can finetune a moe model and check if hf checkpoint is created""" + with tempfile.TemporaryDirectory() as tempdir: + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = dataset_path + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base" + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + lora_args = copy.deepcopy(PEFT_LORA_ARGS) + lora_args.r = 16 + lora_args.target_modules = target_modules + fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=ep_degree)) + + if target_modules == "all-linear": + with pytest.raises(ValueError): + sft_trainer.train( + model_args, + data_args, + train_args, + lora_args, + fast_moe_config=fast_moe_config, + ) + else: + sft_trainer.train( + model_args, + data_args, + train_args, + lora_args, + fast_moe_config=fast_moe_config, + ) + _test_run_inference( + checkpoint_path=os.path.join( + _get_checkpoint_path(tempdir), "hf_converted_checkpoint" + ), + base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base", + ) + + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="moe"), reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", @@ -1491,9 +1546,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): _validate_training(tempdir) -def _test_run_inference(checkpoint_path): +def _test_run_inference(checkpoint_path, base_model_name_or_path=None): # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + loaded_model = TunedCausalLM.load(checkpoint_path, base_model_name_or_path) # Run inference on the text output_inference = loaded_model.run( diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 1ace18dfa..37602daf1 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -16,6 +16,7 @@ from dataclasses import dataclass, field from typing import Union import argparse +import json import os # Third Party @@ -121,10 +122,29 @@ def checkpoint(checkpoint_dir, save_dir): args, os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME), ) - # Save model config files - self.trainer.model.config.save_pretrained( - hf_converted_output_dir - ) + + # Unwrap FSDP module + model = self.trainer.model + if hasattr(model, "module"): + model = model.module + + if hasattr(model, "peft_config"): + lora_config = model.peft_config["default"] + config_dict = lora_config.to_dict() + config_dict["target_modules"] = sorted( + list(config_dict["target_modules"]) + ) + with open( + os.path.join( + hf_converted_output_dir, "adapter_config.json" + ), + "w", + encoding="utf-8", + ) as f: + json.dump(config_dict, f, indent=2) + + else: + model.config.save_pretrained(hf_converted_output_dir) except Exception as e: raise ValueError( diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index d67ea02b8..b51a723e8 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -168,6 +168,33 @@ def train( "Trainer should not perform packing when using `--padding_free`" ) + if fast_moe_config is not None: + # Checking for unsupported modules with Scatter MoE for LoRA + # Only raise an error for `all-linear` + restricted_modules = ["all-linear"] + if ( + peft_config is not None + and hasattr(peft_config, "target_modules") + and any( + module in (peft_config.target_modules or []) + for module in restricted_modules + ) + ): + raise ValueError( + "`--fast_moe` with LoRA does not currently support `all-linear`, as " + "target modules at this time. Please explicitly specify target " + "modules when using `--fast_moe` with LoRA." + ) + # If other common non-linear modules, raise warning + if peft_config is not None and hasattr(peft_config, "target_modules"): + logger.warning( + "You are running lora with the ScatterMoE plugin, please note that " + "passing target modules that are part of the moe module can cause unexpected " + "behaviors and unsuccessful tuning while LoRA tuning with ScatterMoE. " + "For safe tuning, only pass linear modules such as those in the attn layer " + "(i.e. ['q_proj', 'v_proj', 'o_proj', 'k_proj'])" + ) + task_type = "CAUSAL_LM" additional_metrics = {} @@ -360,6 +387,15 @@ def train( model, (peft_config,) = framework.augmentation( model, train_args, modifiable_args=(peft_config,) ) + # HACK - For LoRa ScatterMoE, disable grad for ScatterMoE. + # In the future, requires_grad should be enabled for LoRA tuning + # with ScatterMoE and this code should be removed. + if peft_config is not None: + for module in model.modules(): + # Use string comparison to check if ScatterMoE module + if module.__class__.__name__ == "ScatterMoE": + for param in module.parameters(): + param.requires_grad = False # HACK - The SFT Trainer has internal validation which inspects the name of the class # being used for the HF training args; if it's a TrainingArguments class, which is