Skip to content

Commit 0e47e93

Browse files
divya-kumari32lchu6daviswerani300
authored
Mamba fix (#123)
* make mamba * add quick debug * add quick debug * revert debug verbosity * Learning rate scheduler changed (Constant) * Cosine 0.01 decay Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Add AutoHandler Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Add Auto cfg option for AutoHAndler Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Len gets called before open Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * path/filepath typo fix Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Partitioning fix from mup-search Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Warmup interval change Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Schedule change Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Constant schedule Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * LR schedule change (cool down and constant lr) Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Update dataset_utils.py Added a check for length of doc Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * LR schedule change (Warmup + constant) Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Update main_training.py cleanup for main_training.py Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Mirror doc len check into AHandler, fix mypy in autoHandler Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Linting Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Further linting Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * More mypy type fix Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Rename main_training.py to main_training_mamba.py Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Added main_training_llama.py file Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Rename fms_to_hf.py to fms_to_hf_mamba.py Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Added fms_to_hf_llama.py file Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Delete fms_fsdp/utils/config_utils.py Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Added mamba variant 9.8b Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Incremental mypy fix Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Fix imports (mypy) Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * Rename adapters to work correctly Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> * linting Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> --------- Signed-off-by: Davis Wertheimer <davis.wertheimer@ibm.com> Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> Co-authored-by: Linsong Chu <lchu@us.ibm.com> Co-authored-by: Davis Wertheimer <dww78@cornell.edu> Co-authored-by: Antoni Viros i Martin <aviros@ibm.com>
1 parent 408c751 commit 0e47e93

8 files changed

Lines changed: 339 additions & 37 deletions

File tree

fms_fsdp/utils/config_utils.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def update_config(config, **kwargs):
2424

2525
def get_model_config(model_variant):
2626
if model_variant == "llama2_70b":
27-
llama_config = LLaMAConfig(
27+
model_config = LLaMAConfig(
2828
emb_dim=8192,
2929
multiple_of=4096,
3030
nheads=64,
@@ -33,7 +33,7 @@ def get_model_config(model_variant):
3333
hidden_grow_factor=28672 / 8192,
3434
)
3535
elif model_variant == "llama2_34b":
36-
llama_config = LLaMAConfig(
36+
model_config = LLaMAConfig(
3737
emb_dim=8192,
3838
nheads=64,
3939
kvheads=8,
@@ -43,27 +43,27 @@ def get_model_config(model_variant):
4343
rope_theta=1000000.0,
4444
)
4545
elif model_variant == "llama2_13b":
46-
llama_config = LLaMAConfig(
46+
model_config = LLaMAConfig(
4747
emb_dim=5120,
4848
nheads=40,
4949
nlayers=40,
5050
hidden_grow_factor=13824 / 5120,
5151
)
5252
elif model_variant == "llama2_7b":
53-
llama_config = LLaMAConfig(
53+
model_config = LLaMAConfig(
5454
hidden_grow_factor=11008 / 4096,
5555
kvheads=32,
5656
)
5757
elif model_variant == "llama2_1.4b":
58-
llama_config = LLaMAConfig(
58+
model_config = LLaMAConfig(
5959
emb_dim=2048,
6060
nheads=16,
6161
nlayers=24,
6262
hidden_grow_factor=3,
6363
kvheads=4,
6464
)
6565
elif model_variant == "llama3_8b":
66-
llama_config = LLaMAConfig(
66+
model_config = LLaMAConfig(
6767
src_vocab_size=128256,
6868
emb_dim=4096,
6969
nheads=32,
@@ -74,7 +74,7 @@ def get_model_config(model_variant):
7474
rope_theta=500000.0,
7575
)
7676
elif model_variant == "llama3_8b_4k":
77-
llama_config = LLaMAConfig(
77+
model_config = LLaMAConfig(
7878
src_vocab_size=128256,
7979
emb_dim=4096,
8080
nheads=32,
@@ -85,7 +85,7 @@ def get_model_config(model_variant):
8585
rope_theta=500000.0,
8686
)
8787
elif model_variant == "llama3_1.8b":
88-
llama_config = LLaMAConfig(
88+
model_config = LLaMAConfig(
8989
src_vocab_size=128256,
9090
emb_dim=2048,
9191
nheads=16,
@@ -96,7 +96,7 @@ def get_model_config(model_variant):
9696
rope_theta=500000.0,
9797
)
9898
elif model_variant == "llama3_1.8b_4k":
99-
llama_config = LLaMAConfig(
99+
model_config = LLaMAConfig(
100100
src_vocab_size=128256,
101101
emb_dim=2048,
102102
nheads=16,
@@ -107,7 +107,7 @@ def get_model_config(model_variant):
107107
rope_theta=500000.0,
108108
)
109109
elif model_variant == "llama3_3.2b":
110-
llama_config = LLaMAConfig(
110+
model_config = LLaMAConfig(
111111
src_vocab_size=128256,
112112
emb_dim=3072,
113113
nheads=24,
@@ -118,7 +118,7 @@ def get_model_config(model_variant):
118118
rope_theta=500000.0,
119119
)
120120
elif model_variant == "llama3_3.2b_4k":
121-
llama_config = LLaMAConfig(
121+
model_config = LLaMAConfig(
122122
src_vocab_size=128256,
123123
emb_dim=3072,
124124
nheads=24,
@@ -129,7 +129,7 @@ def get_model_config(model_variant):
129129
rope_theta=500000.0,
130130
)
131131
elif model_variant == "llama3_70b":
132-
llama_config = LLaMAConfig(
132+
model_config = LLaMAConfig(
133133
src_vocab_size=128256,
134134
emb_dim=8192,
135135
nheads=64,
@@ -140,7 +140,7 @@ def get_model_config(model_variant):
140140
rope_theta=500000.0,
141141
)
142142
elif model_variant == "llama3_70b_4k":
143-
llama_config = LLaMAConfig(
143+
model_config = LLaMAConfig(
144144
src_vocab_size=128256,
145145
emb_dim=8192,
146146
nheads=64,
@@ -151,15 +151,39 @@ def get_model_config(model_variant):
151151
rope_theta=500000.0,
152152
)
153153
elif model_variant == "llama3_194m_4k":
154-
llama_config = LLaMAConfig(
154+
model_config = LLaMAConfig(
155155
src_vocab_size=128256,
156156
emb_dim=1024,
157157
nheads=8,
158158
nlayers=10,
159159
max_expected_seq_len=4096,
160160
rope_theta=500000.0,
161161
)
162+
elif model_variant == "mamba_9.8b":
163+
model_config = {
164+
"d_model": 4096,
165+
"d_intermediate": 14336,
166+
"n_layer": 32,
167+
"vocab_size": 128256,
168+
"ssm_cfg": {"layer": "Mamba2"},
169+
"attn_layer_idx": [9, 18, 27],
170+
"attn_cfg": {
171+
"causal": True,
172+
"d_conv": 0,
173+
"head_dim": 128,
174+
"num_heads": 32,
175+
"num_heads_kv": 8,
176+
"out_proj_bias": False,
177+
"qkv_proj_bias": False,
178+
"rotary_emb_dim": 64,
179+
},
180+
"rms_norm": True,
181+
"residual_in_fp32": True,
182+
"fused_add_norm": True,
183+
"pad_vocab_size_multiple": 16,
184+
"tie_embeddings": False,
185+
}
162186
else:
163187
raise ValueError(f"model variant {model_variant} not supported.")
164188

165-
return llama_config
189+
return model_config

fms_fsdp/utils/dataloader_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from fms_fsdp.utils.dataset_utils import (
44
ArrowHandler,
5+
AutoHandler,
56
BufferDataset,
67
CheckpointDataset,
78
ParquetHandler,
@@ -16,6 +17,7 @@
1617
_handler_map = {
1718
"arrow": ArrowHandler,
1819
"hf_parquet": ParquetHandler,
20+
"auto": AutoHandler,
1921
}
2022

2123

@@ -84,10 +86,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
8486
assert (
8587
cfg.file_type in _handler_map
8688
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
87-
if cfg.file_type == "hf_parquet":
88-
filehandler = ParquetHandler(cfg.tokenizer_path, cfg.col_name)
89+
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
90+
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
8991
else:
90-
filehandler = _handler_map[cfg.file_type](cfg.col_name)
92+
filehandler = _handler_map[cfg.file_type]
9193
# Base reader layer
9294
data = StreamingDocDataset(
9395
cfg.data_path,

fms_fsdp/utils/dataset_utils.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,11 @@ def length(self, path: str):
357357

358358
def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
359359
doc = reader.get_batch(index)[self.col_name]
360-
if len(doc) > 0:
361-
if doc[0].as_py() in drop_tokens:
362-
doc = doc.slice(1, len(doc) - 1)
363-
if doc[-1].as_py() in drop_tokens:
364-
doc = doc.slice(0, len(doc) - 1)
360+
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
361+
doc = doc.slice(1, len(doc) - 1)
362+
# Recheck len for edge case where doc=[eos]
363+
if len(doc) > 0 and doc[-1].as_py() in drop_tokens:
364+
doc = doc.slice(0, len(doc) - 1)
365365
return doc
366366

367367
def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
@@ -384,24 +384,79 @@ def is_legal(self, filepath: str):
384384
return "parquet" in os.path.splitext(filepath)[1]
385385

386386
def open(self, path: str):
387-
return pq.read_pandas(path, columns=[self.col_name])[self.col_name]
387+
return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[
388+
self.col_name
389+
]
388390

389391
def length(self, path: str):
390-
return pq.read_pandas(path, columns=[]).num_rows
392+
return pq.read_metadata(path).num_rows
391393

392394
def get(self, reader, index: int, drop_tokens: Set):
393395
doc = self.tokenizer(str(reader[index]))["input_ids"]
394-
if len(doc) > 0:
395-
if doc[0] in drop_tokens:
396-
doc = doc[1:]
397-
if doc[-1] in drop_tokens:
398-
doc = doc[:-1]
396+
if len(doc) > 0 and doc[0] in drop_tokens:
397+
doc = doc[1:]
398+
# Recheck len for edge case where doc=[eos]
399+
if len(doc) > 0 and doc[-1] in drop_tokens:
400+
doc = doc[:-1]
399401
return doc
400402

401403
def slice(self, doc: List, index: int, n_pull: int) -> List:
402404
return doc[index : index + n_pull]
403405

404406

407+
class AutoHandler(_ShardFileHandler):
408+
def __init__(self, tokenizer_path: str, col_name: str = "text"):
409+
self.PHandler = ParquetHandler(tokenizer_path, col_name)
410+
self.AHandler = ArrowHandler()
411+
self.current = _ShardFileHandler()
412+
413+
def is_legal(self, filepath: str):
414+
return (
415+
"parquet" in os.path.splitext(filepath)[1]
416+
or "arrow" in os.path.splitext(filepath)[1]
417+
)
418+
419+
def open(self, path: str):
420+
"""
421+
Open the file, to be indexed via self.get() method.
422+
Avoid reading entire multi-Gb files when possible!
423+
"""
424+
if "arrow" in os.path.splitext(path)[1]:
425+
self.current = self.AHandler
426+
else:
427+
self.current = self.PHandler
428+
return self.current.open(path)
429+
430+
def length(self, path: str):
431+
"""
432+
Calculate the number of documents in the given file.
433+
Avoid reading entire multi-Gb files when possible!
434+
"""
435+
if "arrow" in os.path.splitext(path)[1]:
436+
return self.AHandler.length(path)
437+
else:
438+
return self.PHandler.length(path)
439+
440+
def get(self, reader, index: int, drop_tokens: Set):
441+
"""
442+
Given the output of self.open() and an index, return the document at that index.
443+
Then, remove the first and/or last items if they appear in drop_tokens.
444+
Try to avoid reading entire documents at a time in case of long documents,
445+
but this is less important than avoiding reading entire files as above.
446+
Output must support len().
447+
"""
448+
return self.current.get(reader, index, drop_tokens)
449+
450+
def slice(self, doc, index: int, n_pull: int) -> List:
451+
"""
452+
Given a long document, retrieve n_pull consecutive items starting from index.
453+
Again, try to be memory-efficient when doing so, but efficiency in self.get()
454+
and self.open() is far more important.
455+
Must return a python list.
456+
"""
457+
return self.current.slice(doc, index, n_pull)
458+
459+
405460
#### ------------------------- PIPELINE LAYERS ------------------------- ####
406461

407462

fms_to_hf.py renamed to fms_to_hf_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import fire
22
import torch
3-
from fms.models.hf import to_hf_api
3+
from fms.models.hf.utils import to_hf_api
44
from fms.models.llama import LLaMA
55
from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
66
from transformers import LlamaConfig, LlamaForCausalLM

fms_to_hf_mamba.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import fire
2+
from mamba_ssm.models.config_mamba import MambaConfig
3+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
4+
from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
5+
6+
from fms_fsdp.utils.config_utils import get_model_config
7+
8+
9+
def main(model_variant, load_path, save_path, tokenizer_name_or_path):
10+
print("Initializing model...")
11+
config_data = get_model_config(model_variant)
12+
mamba_config = MambaConfig(**config_data)
13+
model = MambaLMHeadModel(mamba_config)
14+
15+
print(f"Reading state dict from {load_path}")
16+
state_dict = {"model_state": model.state_dict()}
17+
load_state_dict(
18+
state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True
19+
)
20+
21+
print("Loading state dict into the model...")
22+
model.load_state_dict(state_dict["model_state"])
23+
24+
print("Saving model to HF-compatible format...")
25+
model.save_pretrained(save_path)
26+
27+
print("Copying tokenizer...")
28+
from transformers import AutoTokenizer
29+
30+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
31+
tokenizer.save_pretrained(save_path)
32+
33+
print(f"Model saving at {save_path}")
34+
35+
36+
if __name__ == "__main__":
37+
fire.Fire(main)
File renamed without changes.

0 commit comments

Comments
 (0)