Skip to content

Commit 9357243

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 80453c4 + b9380e4 commit 9357243

2 files changed

Lines changed: 43 additions & 30 deletions

File tree

scripts/run_inference.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Third Party
1717
from peft import AutoPeftModelForCausalLM
1818
from tqdm import tqdm
19-
from transformers import AutoTokenizer
19+
from transformers import AutoModelForCausalLM, AutoTokenizer
2020
import torch
2121

2222

@@ -156,16 +156,22 @@ def load(
156156
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
157157
# Apply the configs to the adapter config of this model; if no overrides
158158
# are provided, then the context manager doesn't have any effect.
159-
with AdapterConfigPatcher(checkpoint_path, overrides):
160-
try:
161-
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
162-
except OSError as e:
163-
print("Failed to initialize checkpoint model!")
164-
raise e
159+
try:
160+
with AdapterConfigPatcher(checkpoint_path, overrides):
161+
try:
162+
model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
163+
except OSError as e:
164+
print("Failed to initialize checkpoint model!")
165+
raise e
166+
except FileNotFoundError:
167+
print("No adapter config found! Loading as a merged model...")
168+
# Unable to find the adapter config; fall back to loading as a merged model
169+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
170+
165171
device = "cuda" if torch.cuda.is_available() else None
166172
print(f"Inferred device: {device}")
167-
peft_model.to(device)
168-
return cls(peft_model, tokenizer, device)
173+
model.to(device)
174+
return cls(model, tokenizer, device)
169175

170176
def run(self, text: str, *, max_new_tokens: int) -> str:
171177
"""Runs inference on an instance of this model.
@@ -198,7 +204,7 @@ def main():
198204
description="Loads a tuned model and runs an inference call(s) through it"
199205
)
200206
parser.add_argument(
201-
"--model", help="Path to tuned model to be loaded", required=True
207+
"--model", help="Path to tuned model / merged model to be loaded", required=True
202208
)
203209
parser.add_argument(
204210
"--out_file",
@@ -207,7 +213,7 @@ def main():
207213
)
208214
parser.add_argument(
209215
"--base_model_name_or_path",
210-
help="Override for base model to be used [default: value in model adapter_config.json]",
216+
help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]",
211217
default=None,
212218
)
213219
parser.add_argument(

tuning/sft_trainer.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,41 @@ def __init__(self, logger):
4747
self.logger = logger
4848

4949
def on_log(self, args, state, control, logs=None, **kwargs):
50-
"""Checks if this log contains keys of interest, e.g., los, and if so, creates
50+
"""Checks if this log contains keys of interest, e.g., loss, and if so, creates
5151
train_loss.jsonl in the model output dir (if it doesn't already exist),
5252
appends the subdict of the log & dumps the file.
5353
"""
5454
# All processes get the logs from this node; only update from process 0.
5555
if not state.is_world_process_zero:
5656
return
5757

58+
# separate evaluation loss with train loss
5859
log_file_path = os.path.join(args.output_dir, "train_loss.jsonl")
60+
eval_log_file_path = os.path.join(args.output_dir, "eval_loss.jsonl")
5961
if logs is not None and "loss" in logs and "epoch" in logs:
60-
try:
61-
# Take the subdict of the last log line; if any log_keys aren't part of this log
62-
# object, asssume this line is something else, e.g., train completion, and skip.
63-
log_obj = {
64-
"name": "loss",
65-
"data": {
66-
"epoch": round(logs["epoch"], 2),
67-
"step": state.global_step,
68-
"value": logs["loss"],
69-
"timestamp": datetime.isoformat(datetime.now()),
70-
},
71-
}
72-
except KeyError:
73-
return
74-
75-
# append the current log to the jsonl file
76-
with open(log_file_path, "a") as log_file:
77-
log_file.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
62+
self._track_loss("loss", log_file_path, logs, state)
63+
elif logs is not None and "eval_loss" in logs and "epoch" in logs:
64+
self._track_loss("eval_loss", eval_log_file_path, logs, state)
65+
66+
def _track_loss(self, loss_key, log_file, logs, state):
67+
try:
68+
# Take the subdict of the last log line; if any log_keys aren't part of this log
69+
# object, assume this line is something else, e.g., train completion, and skip.
70+
log_obj = {
71+
"name": loss_key,
72+
"data": {
73+
"epoch": round(logs["epoch"], 2),
74+
"step": state.global_step,
75+
"value": logs[loss_key],
76+
"timestamp": datetime.isoformat(datetime.now()),
77+
},
78+
}
79+
except KeyError:
80+
return
81+
82+
# append the current log to the jsonl file
83+
with open(log_file, "a") as f:
84+
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
7885

7986

8087
def train(

0 commit comments

Comments
 (0)