Skip to content

Commit 8c990a1

Browse files
authored
fix: Added per rank log file for ODM (#168)
* Added per rank log file for ODM Signed-off-by: romit <romit@ibm.com> * Pinned transformers version Signed-off-by: romit <romit@ibm.com> * Pinned transformers in framework package * Fixed CI/CD for other packages Signed-off-by: romit <romit@ibm.com> --------- Signed-off-by: romit <romit@ibm.com>
1 parent e806a00 commit 8c990a1

8 files changed

Lines changed: 22 additions & 22 deletions

File tree

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def save_fsdp_optimizer(
113113
)
114114
sd_options = _prepare_sd_options(fsdp_plugin)
115115
# get the state dicts for model and optimize
116-
(model_state_dict, optimizer_state_dict) = get_state_dict(
116+
model_state_dict, optimizer_state_dict = get_state_dict(
117117
model, optimizer, options=sd_options
118118
)
119119

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _maybe_scatter(
389389

390390
# expect these products to be produced by an earlier
391391
# all-to-all gather call
392-
(send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs) = (
392+
send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs = (
393393
gather_products
394394
)
395395

@@ -421,7 +421,7 @@ def forward(self, hidden_states: torch.Tensor):
421421
# compute the routing logits, weights, and expert assigments
422422
# - router_logits: will be passed out of forward, used for computing
423423
# routing loss.
424-
(router_logits, routing_weights, selected_experts) = (
424+
router_logits, routing_weights, selected_experts = (
425425
self._compute_routing_weights(hidden_states)
426426
)
427427

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _maybe_reshape_scattermoe_expert_weights(
188188
num_experts: int,
189189
intermediate_size: int,
190190
):
191-
(_is_w1, _is_w2, _is_w3) = [
191+
_is_w1, _is_w2, _is_w3 = [
192192
f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE
193193
]
194194

plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,8 @@ def get_gptq_peft_model(
163163
model.model, model_id, adapter_name
164164
)
165165
except Exception as exc:
166-
raise NotImplementedError(
167-
f"{model.__class__.__name__} not support \
168-
{peft_config.peft_type.value} peft type yet."
169-
) from exc
166+
raise NotImplementedError(f"{model.__class__.__name__} not support \
167+
{peft_config.peft_type.value} peft type yet.") from exc
170168

171169
return peft_model
172170

plugins/accelerated-peft/tests/test_gptqmodel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,5 @@ def test_quantizing_pretrained_model_outputs_match(
297297
target = torch.nn.functional.softmax(original_logits, dim=-1)
298298
target = target.view(BS * SEQLEN, -1)
299299
error = loss_fn(input, target)
300-
assert error.lt(
301-
LOSS_TOLERANCE
302-
), "Model logits don't match between both libraries \
300+
assert error.lt(LOSS_TOLERANCE), "Model logits don't match between both libraries \
303301
after quantization"

plugins/framework/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"peft>=0.15.0",
2828
"accelerate @ git+https://github.com/huggingface/accelerate.git@5998f8625b8dfde9253c241233ff13bc2c18635d",
2929
"pandas",
30+
"transformers>=4.55.0,<=4.55.4",
3031
]
3132

3233
[tool.hatch.build.targets.wheel]

plugins/online-data-mixing/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"datasets==4.*",
3131
"torchdata==0.11.0",
3232
"sentence-transformers==5.*",
33+
"transformers>=4.55.0,<=4.55.4",
3334
]
3435

3536
[project.optional-dependencies]

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
self.id2cat = dict(enumerate(self.category_list))
143143
self.cat2id = {c: i for i, c in enumerate(self.category_list)}
144144
self.total_categories = len(self.category_list)
145+
self.rank = os.environ.get("RANK", "0")
145146

146147
# If not starting weights given, then all arms (categories)
147148
# are equally important. Weights based on the size of the datasets
@@ -174,7 +175,7 @@ def __init__(
174175
self.output_dir = output_dir
175176
if not os.path.exists(self.output_dir):
176177
os.makedirs(self.output_dir)
177-
self.log_file_path = os.path.join(self.output_dir, "odm.jsonl")
178+
self.log_file_path = os.path.join(self.output_dir, f"odm_rank_{self.rank}.jsonl")
178179
logger.info(
179180
"Logs for online data mixing to be stored at {log_file_path}".format(
180181
log_file_path=self.log_file_path
@@ -191,6 +192,7 @@ def __init__(
191192
"rewards": [0] * self.total_categories,
192193
"count": 0,
193194
"action": "", # one of sample or update
195+
"rank": self.rank,
194196
}
195197

196198
# Local RNG so every process can deterministically sample identical streams.
@@ -274,6 +276,7 @@ def __next__(self):
274276
"action": "sample",
275277
}
276278
)
279+
277280
return sample
278281

279282
def load_state_dict(self, state_dict):
@@ -548,13 +551,12 @@ def update_sampling_weights(self, model, accelerator, state):
548551
count = accelerator.reduce(count, reduction="sum")
549552

550553
self._update_weights(count, rewards)
551-
if accelerator and accelerator.is_main_process:
552-
self.log_to_file(
553-
{
554-
"current_sampling_weights": self.sampling_weights.tolist(),
555-
"current_sampling_ratio": self.sampling_ratio,
556-
"rewards": rewards.tolist(),
557-
"count": count.tolist(),
558-
"action": "update",
559-
}
560-
)
554+
self.log_to_file(
555+
{
556+
"current_sampling_weights": self.sampling_weights.tolist(),
557+
"current_sampling_ratio": self.sampling_ratio,
558+
"rewards": rewards.tolist(),
559+
"count": count.tolist(),
560+
"action": "update",
561+
}
562+
)

0 commit comments

Comments
 (0)