Skip to content

Commit 654c7db

Browse files
TimDettmersclaude
andcommitted
fix: Use dict return value from forward() in training script
forward() returns {"loss": tensor} or {"logits": tensor}, not a raw tensor. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 837a0dc commit 654c7db

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

scripts/train_qwen3_30b.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def train_standard(model, input_ids_list, labels_list, n_steps=100, lr=1e-4):
8787
labels = labels_list[idx].unsqueeze(0).cuda()
8888

8989
optimizer.zero_grad()
90-
loss = model(input_ids, labels)
90+
result = model(input_ids, labels)
91+
loss = result["loss"]
9192
loss.backward()
9293
optimizer.step()
9394

@@ -188,9 +189,12 @@ def main():
188189
input_ids = tokens["input_ids"].cuda()
189190

190191
with torch.no_grad():
191-
output = model_reload(input_ids, labels=None)
192-
# Just verify it runs without error
193-
print(f" LoRA reload OK, output shape: {output.shape if hasattr(output, 'shape') else type(output)}")
192+
result = model_reload(input_ids, labels=None)
193+
logits = result["logits"]
194+
# Generate a few tokens greedily
195+
next_tokens = logits.argmax(dim=-1)
196+
generated = tokenizer.decode(next_tokens[0], skip_special_tokens=True)
197+
print(f" LoRA reload OK, generated: {generated[:100]}")
194198

195199
# Save results
196200
results = {

0 commit comments

Comments
 (0)