Skip to content

Commit 2ba4273

Browse files
committed
Improve Streamlit dashboard layout and constants reuse
1 parent 3e186cf commit 2ba4273

File tree

7 files changed

+130
-116
lines changed

7 files changed

+130
-116
lines changed

dashboard/app.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
import sys
88
from datetime import datetime
9+
from infinimetrics.common.constants import AcceleratorType
910

1011
# Add project root to path
1112
project_root = Path(__file__).parent
@@ -53,18 +54,23 @@ def main():
5354
st.markdown("---")
5455
st.markdown("## 🧠 筛选条件")
5556

57+
# Base accelerator types from constants.py
58+
ACCELERATOR_OPTIONS = ["cpu"] + [a.value for a in AcceleratorType]
59+
60+
# UI display names (only labels live here)
5661
ACCELERATOR_LABELS = {
57-
"nvidia": "NVIDIA",
5862
"cpu": "CPU",
59-
"mlu": "寒武纪 MLU",
60-
"npu": "昇腾 NPU",
61-
"musa": "摩尔线程 MUSA",
63+
AcceleratorType.NVIDIA.value: "NVIDIA",
64+
AcceleratorType.AMD.value: "AMD",
65+
AcceleratorType.ASCEND.value: "昇腾 NPU",
66+
AcceleratorType.CAMBRICON.value: "寒武纪 MLU",
67+
AcceleratorType.GENERIC.value: "Generic",
6268
}
6369

6470
selected_accs = st.multiselect(
6571
"加速卡类型",
66-
options=list(ACCELERATOR_LABELS.keys()),
67-
default=list(ACCELERATOR_LABELS.keys()),
72+
options=ACCELERATOR_OPTIONS,
73+
default=ACCELERATOR_OPTIONS,
6874
format_func=lambda x: ACCELERATOR_LABELS.get(x, x),
6975
)
7076
st.session_state.selected_accelerators = selected_accs
@@ -263,11 +269,11 @@ def _latest(lst):
263269

264270
col1, col2, col3 = st.columns(3)
265271
if col1.button("🔗 通信测试分析", use_container_width=True):
266-
st.switch_page("pages/1_comm.py")
272+
st.switch_page("pages/communication.py")
267273
if col2.button("⚡ 算子测试分析", use_container_width=True):
268-
st.switch_page("pages/2_ops.py")
274+
st.switch_page("pages/operator.py")
269275
if col3.button("🤖 推理测试分析", use_container_width=True):
270-
st.switch_page("pages/3_infer.py")
276+
st.switch_page("pages/inference.py")
271277

272278
except Exception as e:
273279
st.error(f"Dashboard 加载失败: {e}")

dashboard/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python3
2+
"""Common utilities for dashboard pages."""
3+
4+
import streamlit as st
5+
import sys
6+
from pathlib import Path
7+
8+
9+
def init_page(page_title: str, page_icon: str):
10+
"""
11+
通用页面初始化:
12+
- 设置 Streamlit 页面配置
13+
- 初始化 DataLoader
14+
- 设置项目路径
15+
"""
16+
17+
# Add project root to Python path
18+
project_root = Path(__file__).parent.parent
19+
if str(project_root) not in sys.path:
20+
sys.path.append(str(project_root))
21+
22+
# Page configuration
23+
st.set_page_config(page_title=page_title, page_icon=page_icon, layout="wide")
24+
25+
# Initialize DataLoader
26+
if "data_loader" not in st.session_state:
27+
from utils.data_loader import InfiniMetricsDataLoader
28+
29+
st.session_state.data_loader = InfiniMetricsDataLoader()
Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,23 @@
11
#!/usr/bin/env python3
22
"""Communication tests analysis page."""
33

4-
import sys
54
import streamlit as st
65
import pandas as pd
7-
from pathlib import Path
8-
from utils.metrics import extract_core_metrics
9-
10-
# Add parent directory to path
11-
project_root = Path(__file__).parent.parent
12-
sys.path.append(str(project_root))
136

7+
from common import init_page
148
from components.header import render_header
15-
from utils.data_loader import InfiniMetricsDataLoader, get_friendly_size
9+
from utils.data_loader import get_friendly_size
10+
from utils.metrics import extract_core_metrics
1611
from utils.visualizations import (
17-
plot_bandwidth_vs_size,
18-
plot_latency_vs_size,
12+
plot_metric_vs_size,
1913
plot_comparison_matrix,
20-
create_summary_table,
2114
create_gauge_chart,
15+
create_summary_table,
16+
plot_timeseries_auto,
17+
create_summary_table_infer,
2218
)
2319

24-
# Page configuration
25-
st.set_page_config(page_title="通信测试分析 | InfiniMetrics", page_icon="🔗", layout="wide")
26-
27-
# Initialize session state
28-
if "data_loader" not in st.session_state:
29-
st.session_state.data_loader = InfiniMetricsDataLoader()
20+
init_page("推理测试分析 | InfiniMetrics", "🔗")
3021

3122

3223
def main():
@@ -90,7 +81,7 @@ def main():
9081
# Run selector
9182
st.markdown("### 选择测试运行")
9283

93-
# Run ID 模糊搜索(真正生效)
84+
# Run ID Fuzzy search (really works)
9485
run_id_kw = st.text_input(
9586
"🔎 Run ID 模糊搜索(支持前缀 / 子串)",
9687
placeholder="例如:20240109 / abcd1234",
@@ -149,8 +140,11 @@ def main():
149140
and metric.get("data") is not None
150141
):
151142
df = metric["data"]
152-
fig = plot_bandwidth_vs_size(
153-
df, f"带宽分析 - {run['operation']}", y_log_scale
143+
fig = plot_metric_vs_size(
144+
df=df,
145+
metric_type="bandwidth",
146+
title=f"带宽分析 - {run['operation']}",
147+
y_log_scale=y_log_scale,
154148
)
155149
st.plotly_chart(fig, use_container_width=True)
156150
break
@@ -172,8 +166,11 @@ def main():
172166
and metric.get("data") is not None
173167
):
174168
df = metric["data"]
175-
fig = plot_latency_vs_size(
176-
df, f"延迟分析 - {run['operation']}", y_log_scale
169+
fig = plot_metric_vs_size(
170+
df=df,
171+
metric_type="latency",
172+
title=f"延迟分析 - {run['operation']}",
173+
y_log_scale=y_log_scale,
177174
)
178175
st.plotly_chart(fig, use_container_width=True)
179176
break
@@ -201,9 +198,9 @@ def main():
201198
f"{core['latency_us']:.2f} μs" if core["latency_us"] else "-",
202199
)
203200
c3.metric(
204-
"TTFT", f"{core['ttft_ms']:.2f} ms" if core["ttft_ms"] else "-"
201+
"测试耗时",
202+
f"{core['duration_ms']:.2f} ms" if core["duration_ms"] else "-",
205203
)
206-
207204
# Gauge charts for key metrics
208205
if len(selected_runs) == 1:
209206
st.markdown("#### 关键指标")
Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,16 @@
33

44
import streamlit as st
55
import pandas as pd
6-
from pathlib import Path
7-
import sys
8-
9-
project_root = Path(__file__).parent.parent
10-
sys.path.append(str(project_root))
116

7+
from common import init_page
128
from components.header import render_header
139
from utils.data_loader import InfiniMetricsDataLoader, get_friendly_size
1410
from utils.visualizations import (
1511
plot_timeseries_auto,
1612
create_summary_table_infer,
1713
)
1814

19-
st.set_page_config(page_title="推理测试分析 | InfiniMetrics", page_icon="🤖", layout="wide")
20-
21-
if "data_loader" not in st.session_state:
22-
st.session_state.data_loader = InfiniMetricsDataLoader()
15+
init_page("推理测试分析 | InfiniMetrics", "🤖")
2316

2417

2518
def main():
Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,15 @@
33

44
import streamlit as st
55
import pandas as pd
6-
from pathlib import Path
7-
import sys
8-
9-
project_root = Path(__file__).parent.parent
10-
sys.path.append(str(project_root))
116

7+
from common import init_page
128
from components.header import render_header
13-
from utils.data_loader import InfiniMetricsDataLoader
149
from utils.visualizations import (
1510
create_summary_table_ops,
1611
plot_timeseries_auto,
1712
)
1813

19-
st.set_page_config(page_title="算子测试分析 | InfiniMetrics", page_icon="⚡", layout="wide")
20-
21-
if "data_loader" not in st.session_state:
22-
st.session_state.data_loader = InfiniMetricsDataLoader()
14+
init_page("算子测试分析 | InfiniMetrics", "⚡")
2315

2416

2517
def main():

dashboard/utils/metrics.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,36 @@ def extract_core_metrics(run: dict) -> dict:
1010
"latency_ms": None,
1111
"ttft_ms": None,
1212
"throughput": None,
13+
"duration_ms": None,
1314
}
1415

1516
for m in metrics:
1617
name = m.get("name", "")
1718
val = m.get("value")
19+
df = m.get("data")
1820

19-
if "bandwidth" in name and val is not None:
21+
# Process bandwidth
22+
if name == "comm.bandwidth" and df is not None:
23+
if hasattr(df, "columns") and "bandwidth_gbs" in df.columns:
24+
out["bandwidth_gbps"] = df["bandwidth_gbs"].max()
25+
26+
# Process communication delay
27+
elif name == "comm.latency" and df is not None:
28+
if hasattr(df, "columns") and "latency_us" in df.columns:
29+
out["latency_us"] = df["latency_us"].mean()
30+
31+
# Process duration
32+
elif name == "comm.duration" and val is not None:
33+
out["duration_ms"] = val
34+
35+
elif "bandwidth" in name and val is not None:
2036
out["bandwidth_gbps"] = val
2137
elif "latency_us" in name:
2238
out["latency_us"] = val
2339
elif "latency_ms" in name:
2440
out["latency_ms"] = val
41+
elif "duration" in name:
42+
out["duration_ms"] = val
2543
elif "ttft" in name:
2644
out["ttft_ms"] = val
2745
elif "throughput" in name:

0 commit comments

Comments
 (0)