Skip to content

Commit c509fd5

Browse files
Merge pull request #31 from InfiniTensor/feat/dashboard-streamlit
feat(dashboard): add training analysis page and improve communication…
2 parents 90ec223 + 0cab933 commit c509fd5

File tree

9 files changed

+587
-37
lines changed

9 files changed

+587
-37
lines changed

dashboard/app.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def main():
6565
st.markdown("---")
6666

6767
results_dir = st.text_input(
68-
"测试结果目录", value="../output", help="包含 JSON/CSV 测试结果的目录"
68+
"测试结果目录", value="./output", help="包含 JSON/CSV 测试结果的目录"
6969
)
7070

7171
if not use_mongodb and results_dir != str(
@@ -122,20 +122,23 @@ def render_dashboard(run_id_filter: str):
122122
<div style="
123123
margin-top: 0.5em;
124124
margin-bottom: 1.5em;
125-
max-width: 1100px;
126-
font-size: 1.05em;
125+
max-width: 80%;
126+
font-size: 1.3em;
127127
line-height: 1.6;
128128
">
129129
<strong>InfiniMetrics Dashboard</strong> 用于统一展示
130130
<strong>通信(NCCL / 集合通信)</strong>、
131+
<strong>训练(Training / 分布式训练)</strong>、
131132
<strong>推理(直接推理 / 服务性能)</strong>、
132133
<strong>算子(核心算子性能)</strong>、
133134
<strong>硬件(内存带宽 / 缓存性能)</strong>
134135
等 AI 加速卡性能测试结果。
135136
<br/>
136137
测试框架输出 <code>JSON</code>(环境 / 配置 / 标量指标) +
137138
<code>CSV</code>(曲线 / 时序数据),
138-
Dashboard 自动加载并支持多次运行的对比分析与可视化。
139+
Dashboard 自动加载并支持多次运行的
140+
<strong>性能对比</strong>、<strong>趋势分析</strong> 与
141+
<strong>可视化展示</strong>。
139142
</div>
140143
""",
141144
unsafe_allow_html=True,
@@ -177,6 +180,7 @@ def _parse_time(t):
177180
# ========== Categorize runs ==========
178181
comm_runs = [r for r in runs if r.get("testcase", "").startswith("comm")]
179182
infer_runs = [r for r in runs if r.get("testcase", "").startswith("infer")]
183+
train_runs = [r for r in runs if r.get("testcase", "").startswith("train")]
180184

181185
ops_runs, hw_runs = [], []
182186
for r in runs:
@@ -188,13 +192,14 @@ def _parse_time(t):
188192
hw_runs.append(r)
189193

190194
# ========== KPI ==========
191-
c1, c2, c3, c4, c5, c6 = st.columns(6)
195+
c1, c2, c3, c4, c5, c6, c7 = st.columns(7)
192196
c1.metric("总测试数", total)
193197
c2.metric("成功率", f"{(success/total*100):.1f}%")
194198
c3.metric("通信测试", len(comm_runs))
195199
c4.metric("推理测试", len(infer_runs))
196-
c5.metric("算子测试", len(ops_runs))
197-
c6.metric("硬件检测", len(hw_runs))
200+
c5.metric("训练测试", len(train_runs))
201+
c6.metric("算子测试", len(ops_runs))
202+
c7.metric("硬件检测", len(hw_runs))
198203

199204
st.caption(f"失败测试数:{fail}")
200205
st.caption(f"当前筛选:加速卡={','.join(selected_accs) or '全部'}")
@@ -208,8 +213,9 @@ def _latest(lst):
208213
latest_comm = _latest(comm_runs)
209214
latest_infer = _latest(infer_runs)
210215
latest_ops = _latest(ops_runs)
216+
latest_train = _latest(train_runs)
211217

212-
colA, colB, colC = st.columns(3)
218+
colA, colB, colC, colD = st.columns(4)
213219

214220
with colA:
215221
st.markdown("#### 🔗 通信(最新)")
@@ -238,6 +244,17 @@ def _latest(lst):
238244
st.write(f"- time: {latest_ops.get('time','')}")
239245
st.write(f"- status: {'✅' if latest_ops.get('success') else '❌'}")
240246

247+
with colD:
248+
st.markdown("#### 🏋️ 训练(最新)")
249+
if not latest_train:
250+
st.info("暂无训练结果")
251+
else:
252+
framework = latest_train.get("config", {}).get("framework", "unknown")
253+
model = latest_train.get("config", {}).get("model", "unknown")
254+
st.write(f"- 框架/模型: `{framework}/{model}`")
255+
st.write(f"- time: {latest_train.get('time','')}")
256+
st.write(f"- status: {'✅' if latest_train.get('success') else '❌'}")
257+
241258
st.divider()
242259

243260
# ========== Recent runs table ==========
@@ -294,13 +311,15 @@ def _latest(lst):
294311
st.markdown("---")
295312
st.markdown("### 🚀 快速导航")
296313

297-
col1, col2, col3 = st.columns(3)
314+
col1, col2, col3, col4 = st.columns(4)
298315
if col1.button("🔗 通信测试分析", use_container_width=True):
299316
st.switch_page("pages/communication.py")
300317
if col2.button("⚡ 算子测试分析", use_container_width=True):
301318
st.switch_page("pages/operator.py")
302-
if col3.button("🤖 推理测试分析", use_container_width=True):
319+
if col3.button("🚀 推理测试分析", use_container_width=True):
303320
st.switch_page("pages/inference.py")
321+
if col4.button("🏋️ 训练测试分析", use_container_width=True):
322+
st.switch_page("pages/training.py")
304323

305324
except Exception as e:
306325
st.error(f"Dashboard 加载失败: {e}")

dashboard/pages/communication.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main():
6060
# Status filter
6161
show_success = st.checkbox("仅显示成功测试", value=True)
6262

63-
# Apply filters
63+
# Apply filter
6464
filtered_runs = [
6565
r
6666
for r in comm_runs
@@ -123,6 +123,7 @@ def main():
123123
identifier = run_info.get("path") or run_info.get("run_id")
124124
result = st.session_state.data_loader.load_test_result(identifier)
125125
run_info["data"] = result
126+
126127
selected_runs.append(run_info)
127128

128129
# Tabs for different views
@@ -183,36 +184,30 @@ def main():
183184
st.plotly_chart(fig, use_container_width=True)
184185

185186
if len(selected_runs) == 1:
186-
st.markdown("#### 📌 核心指标(最新)")
187+
st.markdown("#### 关键指标")
187188
run = selected_runs[0]
188189
core = extract_core_metrics(run)
189190

190-
c1, c2, c3 = st.columns(3)
191-
192-
c1.metric(
191+
# First Line: numerical indicators
192+
cols = st.columns(3)
193+
cols[0].metric(
193194
"峰值带宽",
194-
(
195-
f"{core['bandwidth_gbps']:.2f} GB/s"
196-
if core["bandwidth_gbps"]
197-
else "-"
198-
),
195+
f"{core['bandwidth_gbps']:.2f} GB/s"
196+
if core["bandwidth_gbps"]
197+
else "-",
199198
)
200-
c2.metric(
199+
cols[1].metric(
201200
"平均延迟",
202201
f"{core['latency_us']:.2f} μs" if core["latency_us"] else "-",
203202
)
204-
c3.metric(
203+
cols[2].metric(
205204
"测试耗时",
206205
f"{core['duration_ms']:.2f} ms" if core["duration_ms"] else "-",
207206
)
208-
# Gauge charts for key metrics
209-
if len(selected_runs) == 1:
210-
st.markdown("#### 关键指标")
211-
run = selected_runs[0]
212207

213-
col1, col2, col3 = st.columns(3)
208+
cols = st.columns(3)
214209

215-
with col1:
210+
with cols[0]:
216211
# Find max bandwidth
217212
max_bw = 0
218213
for metric in run.get("data", {}).get("metrics", []):
@@ -233,7 +228,7 @@ def main():
233228
st.plotly_chart(fig, use_container_width=True)
234229
break
235230

236-
with col2:
231+
with cols[1]:
237232
# Find average latency
238233
avg_lat = 0
239234
for metric in run.get("data", {}).get("metrics", []):
@@ -254,7 +249,7 @@ def main():
254249
st.plotly_chart(fig, use_container_width=True)
255250
break
256251

257-
with col3:
252+
with cols[2]:
258253
# Extract duration
259254
duration = 0
260255
for metric in run.get("data", {}).get("metrics", []):

dashboard/pages/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
create_summary_table_infer,
1212
)
1313

14-
init_page("推理测试分析 | InfiniMetrics", "🤖")
14+
init_page("推理测试分析 | InfiniMetrics", "🚀")
1515

1616

1717
def main():
@@ -180,7 +180,7 @@ def _plot_metric(metric_name_contains: str, container):
180180

181181
_plot_metric("infer.compute_latency", c1)
182182
_plot_metric("infer.ttft", c2)
183-
_plot_metric("infer.direct_throughput", c3)
183+
_plot_metric("infer.direct_throughput_tps", c3)
184184

185185
# ---------- Tables ----------
186186
with tab2:

dashboard/pages/training.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env python3
2+
"""Training tests analysis page."""
3+
4+
import streamlit as st
5+
6+
from common import init_page
7+
from components.header import render_header
8+
from utils.training_utils import (
9+
load_training_runs,
10+
filter_runs,
11+
create_run_options,
12+
load_selected_runs,
13+
create_training_summary,
14+
)
15+
from utils.training_plots import (
16+
render_performance_curves,
17+
render_throughput_comparison,
18+
render_data_tables,
19+
render_config_details,
20+
)
21+
22+
init_page("训练测试分析 | InfiniMetrics", "🏋️")
23+
24+
25+
def main():
26+
render_header()
27+
st.markdown("## 🏋️ 训练性能测试分析")
28+
29+
dl = st.session_state.data_loader
30+
runs = load_training_runs(dl)
31+
32+
if not runs:
33+
st.info("未找到训练测试结果\n请将训练测试结果放在 output/train/ 或 output/training/ 目录下")
34+
return
35+
36+
# Sidebar Filters
37+
with st.sidebar:
38+
st.markdown("### 🔍 筛选条件")
39+
40+
frameworks = sorted(
41+
{r.get("config", {}).get("framework", "unknown") for r in runs}
42+
)
43+
models = sorted({r.get("config", {}).get("model", "unknown") for r in runs})
44+
device_counts = sorted({r.get("device_used", 1) for r in runs})
45+
46+
selected_fw = st.multiselect("框架", frameworks, default=frameworks)
47+
selected_models = st.multiselect("模型", models, default=models)
48+
selected_dev = st.multiselect("设备数", device_counts, default=device_counts)
49+
only_success = st.checkbox("仅显示成功测试", value=True)
50+
51+
st.markdown("---")
52+
st.markdown("### 📈 图表选项")
53+
y_log = st.checkbox("Y轴对数刻度", value=False)
54+
smoothing = st.slider("平滑窗口", 1, 50, 5, help="对曲线进行移动平均平滑")
55+
56+
# Apply filters
57+
filtered = filter_runs(
58+
runs, selected_fw, selected_models, selected_dev, only_success
59+
)
60+
st.caption(f"找到 {len(filtered)} 个训练测试")
61+
62+
if not filtered:
63+
st.warning("没有符合条件的测试结果")
64+
return
65+
66+
# Run Selection
67+
options = create_run_options(filtered)
68+
selected = st.multiselect(
69+
"选择要分析的测试运行(可多选对比)",
70+
list(options.keys()),
71+
default=list(options.keys())[: min(3, len(options))],
72+
)
73+
74+
if not selected:
75+
return
76+
77+
# Load selected runs
78+
selected_runs = load_selected_runs(dl, filtered, options, selected)
79+
80+
# Tabs
81+
tab1, tab2, tab3, tab4 = st.tabs(["📈 性能曲线", "📊 吞吐量对比", "📋 数据表格", "🔍 详细配置"])
82+
83+
with tab1:
84+
render_performance_curves(selected_runs, smoothing, y_log)
85+
with tab2:
86+
render_throughput_comparison(selected_runs)
87+
with tab3:
88+
render_data_tables(selected_runs)
89+
with tab4:
90+
render_config_details(selected_runs, create_training_summary)
91+
92+
93+
if __name__ == "__main__":
94+
main()

dashboard/utils/data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class InfiniMetricsDataLoader:
2828

2929
def __init__(
3030
self,
31-
results_dir: str = "../output",
31+
results_dir: str = "./output",
3232
use_mongodb: bool = False,
3333
mongo_config=None,
3434
fallback_to_files: bool = True,

dashboard/utils/data_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def source_type(self) -> str:
5151
class FileDataSource(DataSource):
5252
"""File-based data source (reads from JSON/CSV files)."""
5353

54-
def __init__(self, results_dir: str = "../output"):
54+
def __init__(self, results_dir: str = "./output"):
5555
self.results_dir = Path(results_dir)
5656

5757
@property

0 commit comments

Comments
 (0)