Skip to content

Commit 926aeb1

Browse files
initial example script exploration (only pirate script tested)
1 parent b86ff64 commit 926aeb1

File tree

4 files changed

+298
-32
lines changed

4 files changed

+298
-32
lines changed

examples/compile_inference.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

examples/compile_inference.py.tmp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import logging
2+
3+
import torch
4+
import torch._dynamo
5+
6+
# Enable verbose logging using PyTorch's new logging system
7+
from torch._logging import set_logs
8+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9+
10+
set_logs(dynamo=logging.DEBUG, graph_breaks=True, recompiles=True) # Enable specific artifacts
11+
12+
torch.set_float32_matmul_precision("high")
13+
14+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
15+
16+
model_id = "google/gemma-2-2b-it"
17+
tokenizer = AutoTokenizer.from_pretrained(model_id)
18+
model = AutoModelForCausalLM.from_pretrained(
19+
model_id,
20+
quantization_config=quantization_config,
21+
device_map="auto",
22+
torch_dtype=torch.bfloat16,
23+
)
24+
25+
input_text = "Write me a poem about Machine Learning."
26+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
27+
28+
# Improved explanation with artifact capture
29+
from torch._dynamo import explain
30+
31+
explanation, graphs, guards, log_location = explain(
32+
model,
33+
**input_ids,
34+
traceback=True, # Show origin of graph breaks
35+
)
36+
print(f"Graph breaks: {len(explanation)}\nDetails:")
37+
for break_info in explanation:
38+
print(f"- {break_info}")
39+
40+
41+
# Enhanced AOTAutograd tracing
42+
def trace_handler(gm: torch.fx.GraphModule, example_inputs):
43+
print(f"AOT traced graph with {len(gm.graph.nodes)} nodes")
44+
gm.graph.print_tabular() # Show full graph structure
45+
return gm.forward
46+
47+
48+
model = torch.compile(
49+
model,
50+
backend="aot_eager",
51+
options={
52+
"fw_compiler": trace_handler,
53+
"track_graph_metrics": True, # Additional compilation metrics
54+
},
55+
)
56+
57+
# Generate with compilation diagnostics
58+
outputs = model.generate(**input_ids, max_new_tokens=32)
59+
print(tokenizer.decode(outputs[0]))

examples/compile_pirate_qlora.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# for torch.compile trace run:
2+
#
3+
# TORCH_TRACE="./tracedir" TORCH_LOGS="graph_breaks" CUDA_VISIBLE_DEVICES=0 python examples/compile_pirate_qlora.py
4+
5+
# 🏴☠️⛵ Pirate Coder's Delight: Fine-Tune Mistral-7B to Speak Like a Buccaneer Hacker
6+
# Using bitsandbytes 4-bit + torch.compile - 100% Original Dataset
7+
8+
from datasets import Dataset
9+
from peft import LoraConfig
10+
import torch
11+
from transformers import AutoModelForCausalLM, AutoTokenizer
12+
from trl import SFTConfig, SFTTrainer
13+
14+
# 1. Load Model with Pirate-Optimized Quantization 🏴☠️
15+
model_id = "Qwen/Qwen1.5-1.8B"
16+
tokenizer = AutoTokenizer.from_pretrained(model_id)
17+
tokenizer.pad_token = tokenizer.eos_token
18+
19+
model = AutoModelForCausalLM.from_pretrained(
20+
model_id,
21+
quantization_config={
22+
"load_in_4bit": True,
23+
"bnb_4bit_quant_type": "nf4",
24+
"bnb_4bit_compute_dtype": torch.bfloat16,
25+
"bnb_4bit_use_double_quant": True,
26+
},
27+
device_map="auto",
28+
attn_implementation="flash_attention_2",
29+
torch_dtype=torch.bfloat16,
30+
)
31+
32+
33+
# 2. Original Pirate Programmer Dataset 🦜
34+
def pirate_formatting_func(example):
35+
return {"text": f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"}
36+
37+
38+
train_dataset = Dataset.from_list(
39+
[
40+
{
41+
"instruction": "Explain quantum computing using pirate slang",
42+
"response": "Arrr, matey! 'Tis like sailin' parallel seas...",
43+
},
44+
{
45+
"instruction": "Write Python code to find buried treasure",
46+
"response": "def find_booty():\n return (sum(coordinates) / len(coordinates))",
47+
},
48+
{
49+
"instruction": "Why do pirates hate distributed systems?",
50+
"response": "Too many captains sink the ship, ye scallywag!",
51+
},
52+
]
53+
).map(pirate_formatting_func)
54+
55+
56+
# 2. Prepare Pirate Dataset
57+
def tokenize_pirate_data(examples):
58+
return tokenizer(examples["text"], padding="max_length", max_length=256, truncation=True, return_tensors="pt")
59+
60+
61+
train_dataset = train_dataset.map(
62+
tokenize_pirate_data,
63+
batched=True,
64+
remove_columns=["instruction", "response"], # Keep only tokenized fields
65+
)
66+
67+
# 3. Configure QLoRA ⚙️
68+
peft_config = LoraConfig(
69+
r=32,
70+
lora_alpha=64,
71+
target_modules=["q_proj", "v_proj"],
72+
lora_dropout=0.05,
73+
bias="none",
74+
task_type="CAUSAL_LM",
75+
)
76+
77+
# 5. Training Configuration
78+
training_args = SFTConfig(
79+
per_device_train_batch_size=2,
80+
gradient_accumulation_steps=1,
81+
max_steps=5,
82+
learning_rate=2e-5,
83+
max_seq_length=256,
84+
remove_unused_columns=False,
85+
output_dir="./pirate_coder",
86+
optim="paged_adamw_8bit",
87+
dataset_text_field="text",
88+
packing=True,
89+
torch_compile={
90+
"mode": "reduce-overhead",
91+
"fullgraph": False,
92+
"dynamic": False,
93+
},
94+
report_to="none",
95+
logging_steps=1,
96+
)
97+
98+
# 6. Launch Training with Pirate Flair! 🚀
99+
trainer = SFTTrainer(
100+
model=model,
101+
train_dataset=train_dataset,
102+
args=training_args,
103+
peft_config=peft_config,
104+
formatting_func=pirate_formatting_func,
105+
)
106+
107+
print("⚡ Batten down the hatches - training with torch.compile!")
108+
trainer.train()

examples/compile_qlora.py.tmp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# 🦄 Fine-Tune Llama-3.8B to Write Pirate Jokes & Shakespearean Sonnets
2+
# Using bitsandbytes 4-bit + torch.compile 🏴☠️
3+
4+
# !pip install -qU "unsloth[colab] @ git+https://github.com/unslothai/unsloth.git" \
5+
# transformers==4.40.0 accelerate==0.30.0 bitsandbytes==0.43.0
6+
7+
from datasets import load_dataset
8+
import torch
9+
from unsloth import FastLanguageModel
10+
11+
# 1. Load Pre-Quantized Model with bitsandbytes 🎯
12+
model, tokenizer = FastLanguageModel.from_pretrained(
13+
model_name="unsloth/llama-3-8B-bnb-4bit",
14+
max_seq_length=2048,
15+
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
16+
load_in_4bit=True,
17+
quantization_config={
18+
"bnb_4bit_quant_type": "nf4",
19+
"bnb_4bit_compute_dtype": torch.bfloat16,
20+
"bnb_4bit_use_double_quant": True,
21+
},
22+
)
23+
24+
# 2. Prepare Creative Dataset 🎭
25+
pirate_dataset = load_dataset(
26+
"json",
27+
data_files={"train": "https://huggingface.co/datasets/jondurbin/pirate-jokes/resolve/main/pirate_jokes.json"},
28+
)
29+
30+
31+
def format_creative_prompt(sample):
32+
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
33+
You are a swashbuckling pirate poet. Respond ONLY in pirate speak or Shakespearean verse.<|eot_id|>
34+
<|start_header_id|>user<|end_header_id|>
35+
{sample['prompt']}<|eot_id|>
36+
<|start_header_id|>assistant<|end_header_id|>
37+
{sample['response']}<|eot_id|>"""
38+
39+
40+
dataset = pirate_dataset.map(format_creative_prompt)
41+
42+
# 3. Configure QLoRA with torch.compile Diagnostics 🔍
43+
model = FastLanguageModel.get_peft_model(
44+
model,
45+
r=32,
46+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
47+
lora_alpha=64,
48+
lora_dropout=0.1,
49+
bias="none",
50+
use_gradient_checkpointing=True,
51+
torch_compile={
52+
"mode": "reduce-overhead",
53+
"fullgraph": False, # Allow partial compilation initially
54+
"dynamic": True,
55+
},
56+
random_state=3407, # For reproducibility of creative outputs
57+
)
58+
59+
60+
# 4. Custom Training Loop with Graph Break Analysis 🕵️♂️
61+
def detect_graph_breaks():
62+
import torch._dynamo
63+
64+
original_verbose = torch._dynamo.config.verbose
65+
torch._dynamo.config.verbose = True
66+
67+
# Trigger compilation with sample input
68+
sample_input = tokenizer("Arrr! Tell me about yer treasure...", return_tensors="pt").to("cuda")
69+
compiled_model = torch.compile(model)
70+
_ = compiled_model(**sample_input)
71+
72+
torch._dynamo.config.verbose = original_verbose
73+
74+
75+
detect_graph_breaks() # Initial graph break detection
76+
77+
78+
# 5. Creative Generation During Training 🎩
79+
class PirateStreamer:
80+
def __init__(self, tokenizer):
81+
self.tokenizer = tokenizer
82+
self.prompt = "Yarrr! Why did the pirate's chicken cross the road?"
83+
84+
def __call__(self, input_ids, *args, **kwargs):
85+
if random.random() < 0.1: # 10% chance to generate during training
86+
with torch.no_grad():
87+
outputs = model.generate(
88+
input_ids=self.tokenizer(self.prompt, return_tensors="pt").to("cuda").input_ids,
89+
max_new_tokens=50,
90+
temperature=0.7,
91+
repetition_penalty=1.1,
92+
)
93+
print("\n🏴☠️ Crew's Update:", self.tokenizer.decode(outputs[0]))
94+
return input_ids
95+
96+
97+
# 6. Launch Training with Progressive Compilation 🚀
98+
trainer = SFTTrainer(
99+
model=model,
100+
train_dataset=dataset,
101+
dataset_text_field="text",
102+
max_seq_length=1024,
103+
packing=True,
104+
callbacks=[PirateStreamer(tokenizer)],
105+
args=TrainingArguments(
106+
per_device_train_batch_size=2,
107+
gradient_accumulation_steps=4,
108+
warmup_steps=10,
109+
max_steps=100,
110+
learning_rate=3e-5,
111+
fp16=not torch.cuda.is_bf16_supported(),
112+
bf16=torch.cuda.is_bf16_supported(),
113+
logging_steps=1,
114+
optim="paged_adamw_8bit",
115+
weight_decay=0.01,
116+
lr_scheduler_type="cosine",
117+
output_dir="pirate-poet",
118+
report_to="none",
119+
),
120+
)
121+
122+
# Progressive compilation strategy
123+
trainer.train_step = torch.compile(
124+
trainer.train_step,
125+
mode="reduce-overhead",
126+
fullgraph=False, # Start with partial graphs
127+
dynamic=True,
128+
)
129+
130+
print("🏁 Starting training - watch for graph breaks and pirate wisdom!")
131+
trainer.train()

0 commit comments

Comments
 (0)