diff --git a/examples/lightning/training.py b/examples/lightning/training.py index ab1164878..7b18d7021 100644 --- a/examples/lightning/training.py +++ b/examples/lightning/training.py @@ -2,7 +2,6 @@ import math import os -from dataclasses import _MISSING_TYPE from dataclasses import dataclass import datasets @@ -22,6 +21,8 @@ from liger_kernel.transformers import AutoLigerKernelForCausalLM from liger_kernel.utils import infer_device +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + _RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"} QUESTION = "" CHOICES = "" @@ -46,10 +47,8 @@ class Args: def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0): def lr_lambda(current_step): if current_step < warmup_steps: - # Linear warmup return float(current_step) / float(max(1, warmup_steps)) else: - # Cosine annealing progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress))) @@ -61,7 +60,7 @@ def parse_args() -> Args: for k, v in Args.__dataclass_fields__.items(): parser.add_argument(f"--{k}", type=v.type, default=v.default) parsed = parser.parse_args() - return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)}) + return Args(**vars(parsed)) class LanguageModel(pl.LightningModule): @@ -72,7 +71,7 @@ def __init__(self, args: Args, tokenizer): self.model = None def configure_model(self): - # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization + # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization if self.model is not None: return self.model = AutoLigerKernelForCausalLM.from_pretrained( @@ -89,7 +88,7 @@ def training_step(self, batch): outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], - labels=batch["labels"], + labels=batch.get("labels"), ) loss = outputs.loss self.log_dict( @@ -107,11 +106,11 @@ def validation_step(self, batch): outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], - labels=batch["labels"], + labels=batch.get("labels"), ) loss = outputs.loss self.log_dict( - {"val_loss": outputs.loss}, + {"val_loss": loss}, on_step=True, on_epoch=True, prog_bar=True, @@ -182,10 +181,10 @@ def setup(self, stage) -> None: dataset = datasets.load_dataset(self.args.data, "auxiliary_train") flattened_data = [ { - "answer": x["train"]["answer"], - "choices": x["train"]["choices"], - "question": x["train"]["question"], - "subject": x["train"]["subject"], + "answer": x["answer"], + "choices": x["choices"], + "question": x["question"], + "subject": x["subject"], } for x in dataset["train"] ] @@ -237,11 +236,11 @@ def train(): if args.strategy == "fsdp": strategy = FSDPStrategy( - auto_wrap_policy=layers, + auto_wrap_policy=transformer_auto_wrap_policy, sharding_strategy="FULL_SHARD", backward_prefetch=BackwardPrefetch.BACKWARD_PRE, sync_module_states=True, - activation_checkpointing_policy=layers, + activation_checkpointing_policy=transformer_auto_wrap_policy, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16), forward_prefetch=True, ) @@ -251,16 +250,18 @@ def train(): precision = "bf16-mixed" elif args.strategy == "ddp": strategy = "ddp" - precision = "bf16-true" + precision = "bf16-mixed" else: strategy = "auto" - precision = "bf16-true" + precision = "bf16-mixed" device = infer_device() + devices = args.num_gpu or (torch.cuda.device_count() if torch.cuda.is_available() else 1) + trainer = pl.Trainer( accelerator=device, strategy=strategy, - devices=(getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu), + devices=devices, default_root_dir=args.output_dir, log_every_n_steps=1, max_epochs=1,