-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathagentic_training.py
More file actions
87 lines (76 loc) · 3.54 KB
/
Copy pathagentic_training.py
File metadata and controls
87 lines (76 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import json
import google.generativeai as genai
class ObserveAgent:
def __init__(self, api_key=None):
"""
初始化 Gemini Agent
:param api_key: 你的 Google Gemini API Key
"""
if api_key:
genai.configure(api_key=api_key)
# 初始化模型,使用 System Instruction 约束其输出为标准 JSON 格式
self.model = genai.GenerativeModel('gemini-3.1-pro-preview',
system_instruction='''You are an expert AI training assistant monitoring a deep learning model.
Your job is to analyze metrics and output ONLY valid JSON.
Format: {"lr": <float>,"max_grad_norm": <float>, "action": "continue"|"stop", "save_checkpoint": <boolean>, "message": "<string>"}
Rules:
- If loss is NaN or explodes suddenly, set action to "stop".
- If val_loss drops significantly breaking a plateau, set save_checkpoint to true.
- Adjust LR slightly up or down based on the gradient and loss trends.
- Provide a short and professional training report in `message` in chinese.
''')
# 记录训练历史,帮助模型理解趋势
self.history =[]
def observe(self, step, train_loss, val_loss, current_lr, grad_norm,training_context):
"""
向 Gemini 发送当前的训练状态,并获取下一步控制指令
"""
# 构建当前步的监控文本
prompt = f"Step: {step}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6e}, Grad Norm: {grad_norm:.4f},training_info:{training_context}"
self.history.append(prompt)
# 只保留最近 10 步作为上下文,防止 Token 溢出
if len(self.history) > 10:
self.history = self.history[-10:]
# 拼接历史记录,让模型分析趋势
full_prompt = "Recent history:\n" + "\n".join(self.history) + "\n\nAnalyze the training state and output JSON directives."
try:
# 调用 Gemini API 进行分析
response = self.model.generate_content(full_prompt)
text = response.text.strip()
# 清理 Markdown 代码块标记(如果有)
if text.startswith('```json'):
text = text[7:-3]
elif text.startswith('```'):
text = text[3:-3]
# 解析并返回 JSON 字典
return json.loads(text)
except Exception as e:
# 容错处理:如果请求失败或解析失败,返回默认的安全参数继续训练
print(f"⚠️ Agent parsing error: {e}")
return {
"lr": current_lr,
"action": "continue",
"save_checkpoint": False,
"message": f"Error generating response: {e}"
}
# --- 测试代码 (直接运行此脚本时触发) ---
if __name__ == "__main__":
# 需要在环境变量中设置 GEMINI_API_KEY 才能运行此测试
test_key = ""
if not test_key:
print("💡 请设置环境变量 GEMINI_API_KEY 以进行本地测试。")
else:
print("🤖 正在启动 Agent 测试...")
agent = ObserveAgent(api_key=test_key)
# 模拟一次训练步骤
directive = agent.observe(
step=100,
train_loss=2.456,
val_loss=2.501,
current_lr=1e-4,
grad_norm=1.2,
training_context="LLM"
)
print("\n📊 Agent 返回的分析结果:")
print(json.dumps(directive, indent=2, ensure_ascii=False))