-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_results.py
More file actions
41 lines (34 loc) · 1.49 KB
/
plot_results.py
File metadata and controls
41 lines (34 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pandas as pd
import matplotlib.pyplot as plt
# 读取实验结果
df = pd.read_csv("benchmark_results_gpu.csv")
# 获取所有数据集和模型
datasets = df['Dataset'].unique()
models = df['Model'].unique()
# 设置颜色和标记
colors = {'LSTM_Baseline': '#1f77b4', 'DLinear': '#ff7f0e', 'FITS': '#2ca02c'}
markers = {'LSTM_Baseline': 'o', 'DLinear': 's', 'FITS': '^'}
# 优化排版:2行3列,第6个位置留给图例
fig, axes = plt.subplots(2, 3, figsize=(14, 8), sharey=False)
axes = axes.flatten()
# 遍历5个数据集绘图
for i, ds in enumerate(datasets):
ax = axes[i]
sub = df[df['Dataset'] == ds]
for model in models:
model_data = sub[sub['Model'] == model].sort_values('Pred_Len')
ax.plot(model_data['Pred_Len'], model_data['MSE'],
marker=markers[model], color=colors[model],
label=model.replace('_Baseline', ''), linewidth=1.8, markersize=6)
ax.set_title(ds, fontsize=12, fontweight='bold')
ax.set_xlabel('Prediction Horizon', fontsize=10)
ax.set_ylabel('MSE', fontsize=10)
ax.grid(True, linestyle='--', alpha=0.5)
# 隐藏第6个子图的坐标轴,并放置统一图例
axes[5].axis('off')
handles, labels = axes[0].get_legend_handles_labels()
axes[5].legend(handles, labels, loc='center', fontsize=12, frameon=False)
# 调整布局并保存
plt.tight_layout()
plt.savefig("mse_curves_multimodel.pdf", dpi=300, bbox_inches='tight')
print("✅ 多模型对比曲线已保存至 ./mse_curves_multimodel.pdf")