diff --git a/.gitignore b/.gitignore index 1f6cfa4..2c13530 100644 --- a/.gitignore +++ b/.gitignore @@ -208,7 +208,7 @@ __marimo__/ # local test outputs summary_output/ -test_output/ +output/ traces/ test_*.json format*.json diff --git a/dashboard/app.py b/dashboard/app.py index 26205c9..8e7efc8 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -3,9 +3,11 @@ import streamlit as st import pandas as pd +import plotly.graph_objects as go from pathlib import Path import sys from datetime import datetime +from typing import Optional from infinimetrics.common.constants import AcceleratorType # Add project root to path @@ -33,17 +35,31 @@ st.session_state.use_mongodb = False +def parse_timestamp(ts) -> Optional[datetime]: + """Parse timestamp, support multiple formats""" + try: + ts_str = str(ts) + if "_" in ts_str and len(ts_str) == 15: + return datetime.strptime(ts_str, "%Y%m%d_%H%M%S") + return datetime.fromisoformat(ts_str.replace("Z", "+00:00")) + except Exception: + return None + + +def format_time(ts) -> str: + """Format timestamp as display string""" + dt = parse_timestamp(ts) + if dt: + return dt.strftime("%Y-%m-%d %H:%M:%S") + return str(ts)[:19] if ts else "未知" + + def main(): render_header() - # ========================= - # Sidebar - # ========================= - with st.sidebar: st.markdown("## ⚙️ 设置") - # Data source selection use_mongodb = st.toggle( "使用 MongoDB", value=st.session_state.use_mongodb, @@ -59,9 +75,7 @@ def main(): else: st.session_state.data_loader = InfiniMetricsDataLoader() - # Show current data source show_data_source_info(style="sidebar") - st.markdown("---") results_dir = st.text_input( @@ -80,10 +94,7 @@ def main(): st.markdown("---") st.markdown("## 🧠 筛选条件") - # Base accelerator types from constants.py ACCELERATOR_OPTIONS = ["cpu"] + [a.value for a in AcceleratorType] - - # UI display names (only labels live here) ACCELERATOR_LABELS = { "cpu": "CPU", AcceleratorType.NVIDIA.value: "NVIDIA", @@ -102,7 +113,6 @@ def main(): st.session_state.selected_accelerators = selected_accs run_id_filter = st.text_input("Run ID 模糊搜索") - # test_type / testcase filtering will be applied dynamically after runs are loaded render_dashboard(run_id_filter) @@ -146,8 +156,8 @@ def render_dashboard(run_id_filter: str): try: runs = st.session_state.data_loader.list_test_runs() + ci_summaries = st.session_state.data_loader.load_summaries() - # ========== Accelerator filtering ========== selected_accs = st.session_state.get("selected_accelerators", []) if selected_accs: runs = [ @@ -156,185 +166,349 @@ def render_dashboard(run_id_filter: str): if set(r.get("accelerator_types", [])) & set(selected_accs) ] - # ========== run_id filtering ========== if run_id_filter: runs = [r for r in runs if run_id_filter in r.get("run_id", "")] - if not runs: - st.warning("No test results match the current filters.") - return - - # ========== Sort by time (latest first) ========== - def _parse_time(t): - try: - return datetime.fromisoformat(t) - except Exception: - return datetime.min - - runs = sorted(runs, key=lambda r: _parse_time(r.get("time", "")), reverse=True) - - total = len(runs) - success = sum(1 for r in runs if r.get("success")) - fail = total - success - - # ========== Categorize runs ========== - comm_runs = [r for r in runs if r.get("testcase", "").startswith("comm")] - infer_runs = [r for r in runs if r.get("testcase", "").startswith("infer")] - train_runs = [r for r in runs if r.get("testcase", "").startswith("train")] - - ops_runs, hw_runs = [], [] - for r in runs: - p = str(r.get("path", "")).replace("\\", "/").lower() - tc = (r.get("testcase", "") or "").lower() - if "/operators/" in p or tc.startswith(("operator", "operators", "ops")): - ops_runs.append(r) - if "/hardware/" in p or tc.startswith("hardware"): - hw_runs.append(r) - - # ========== KPI ========== - c1, c2, c3, c4, c5, c6, c7 = st.columns(7) - c1.metric("总测试数", total) - c2.metric("成功率", f"{(success/total*100):.1f}%") - c3.metric("通信测试", len(comm_runs)) - c4.metric("推理测试", len(infer_runs)) - c5.metric("训练测试", len(train_runs)) - c6.metric("算子测试", len(ops_runs)) - c7.metric("硬件检测", len(hw_runs)) + render_test_run_stats(runs, ci_summaries, selected_accs) + render_ci_stats(ci_summaries) + render_latest_results(runs) + render_dispatcher_summary(ci_summaries) + render_ci_detailed_table(ci_summaries) + render_failure_details(ci_summaries) + + except Exception as e: + st.error(f"Dashboard 加载失败: {e}") + +def render_test_run_stats(runs, ci_summaries, selected_accs): + """Render test run statistics""" + total = len(runs) + success = sum(1 for r in runs if r.get("success")) + fail = total - success + + # Category statistics + comm_runs = [r for r in runs if r.get("testcase", "").startswith("comm")] + infer_runs = [r for r in runs if r.get("testcase", "").startswith("infer")] + train_runs = [r for r in runs if r.get("testcase", "").startswith("train")] + + ops_runs, hw_runs = [], [] + for r in runs: + p = str(r.get("path", "")).replace("\\", "/").lower() + tc = (r.get("testcase", "") or "").lower() + if "/operators/" in p or tc.startswith(("operator", "operators", "ops")): + ops_runs.append(r) + if "/hardware/" in p or tc.startswith("hardware"): + hw_runs.append(r) + + st.markdown("### 📝 测试运行统计") + st.caption("*基于当前筛选条件的测试运行记录*") + + col1, col2, col3, col4, col5, col6, col7, col8 = st.columns(8) + col1.metric("总运行数", total) + col2.metric("成功率", f"{(success/total*100):.1f}%" if total > 0 else "-") + col3.metric("通信", len(comm_runs)) + col4.metric("推理", len(infer_runs)) + col5.metric("训练", len(train_runs)) + col6.metric("算子", len(ops_runs)) + col7.metric("硬件", len(hw_runs)) + col8.metric("CI记录", len(ci_summaries)) + + with st.expander("📈 详细统计"): st.caption(f"失败测试数:{fail}") st.caption(f"当前筛选:加速卡={','.join(selected_accs) or '全部'}") + st.divider() - st.divider() - # ========== Latest results ========== - def _latest(lst): - return lst[0] if lst else None +def render_ci_stats(ci_summaries): + """Render CI statistics""" + if not ci_summaries: + return - latest_comm = _latest(comm_runs) - latest_infer = _latest(infer_runs) - latest_ops = _latest(ops_runs) - latest_train = _latest(train_runs) - latest_hw = _latest(hw_runs) + total_test_cases = sum(s.get("total_tests", 0) for s in ci_summaries) + total_success = sum(s.get("successful_tests", 0) for s in ci_summaries) + total_failed = sum(s.get("failed_tests", 0) for s in ci_summaries) + total_runs = len(ci_summaries) + avg_success_rate = ( + (total_success / total_test_cases * 100) if total_test_cases > 0 else 0 + ) - colA, colB, colC, colD, colE = st.columns(5) + recent_summaries = ci_summaries[:10] + recent_success = sum(s.get("successful_tests", 0) for s in recent_summaries) + recent_total = sum(s.get("total_tests", 0) for s in recent_summaries) + recent_rate = (recent_success / recent_total * 100) if recent_total > 0 else 0 + + st.markdown("### 📈 CI运行统计") + st.caption("*基于Dispatcher汇总的历史CI运行记录*") + + col1, col2, col3, col4, col5, col6 = st.columns(6) + col1.metric("CI运行次数", total_runs) + col2.metric("测试用例总数", f"{total_test_cases:,}") + col3.metric("通过用例", f"{total_success:,}") + col4.metric("失败用例", f"{total_failed:,}") + col5.metric("平均成功率", f"{avg_success_rate:.1f}%") + col6.metric( + "最近10次", f"{recent_rate:.1f}%", delta=f"{recent_rate - avg_success_rate:.1f}%" + ) - with colA: - st.markdown("#### 🔗 通信(最新)") - if not latest_comm: - st.info("暂无通信结果") - else: - st.write(f"- testcase: `{latest_comm.get('testcase','')}`") - st.write(f"- time: {latest_comm.get('time','')}") - st.write(f"- status: {'✅' if latest_comm.get('success') else '❌'}") - - with colB: - st.markdown("#### 🚀 推理(最新)") - if not latest_infer: - st.info("暂无推理结果") - else: - st.write(f"- testcase: `{latest_infer.get('testcase','')}`") - st.write(f"- time: {latest_infer.get('time','')}") - st.write(f"- status: {'✅' if latest_infer.get('success') else '❌'}") - - with colC: - st.markdown("#### ⚡ 算子(最新)") - if not latest_ops: - st.info("暂无算子结果") - else: - st.write(f"- testcase: `{latest_ops.get('testcase','')}`") - st.write(f"- time: {latest_ops.get('time','')}") - st.write(f"- status: {'✅' if latest_ops.get('success') else '❌'}") - - with colD: - st.markdown("#### 🏋️ 训练(最新)") - if not latest_train: - st.info("暂无训练结果") - else: - framework = latest_train.get("config", {}).get("framework", "unknown") - model = latest_train.get("config", {}).get("model", "unknown") - st.write(f"- 框架/模型: `{framework}/{model}`") - st.write(f"- time: {latest_train.get('time','')}") - st.write(f"- status: {'✅' if latest_train.get('success') else '❌'}") - - with colE: - st.markdown("#### 🔧 硬件(最新)") - if not latest_hw: - st.info("暂无硬件结果") + render_daily_trend(ci_summaries) + st.divider() + + +def render_daily_trend(ci_summaries): + """Render daily trend chart""" + daily_stats = {} + for s in ci_summaries: + dt = parse_timestamp(s.get("timestamp", "")) + if not dt: + continue + date_key = dt.strftime("%Y-%m-%d") + daily_stats[date_key] = { + "total": daily_stats.get(date_key, {}).get("total", 0) + + s.get("total_tests", 0), + "success": daily_stats.get(date_key, {}).get("success", 0) + + s.get("successful_tests", 0), + } + + if not daily_stats: + return + + dates = sorted(daily_stats.keys()) + totals = [daily_stats[d]["total"] for d in dates] + successes = [daily_stats[d]["success"] for d in dates] + failures = [totals[i] - successes[i] for i in range(len(dates))] + + fig = go.Figure() + fig.add_trace(go.Bar(x=dates, y=totals, name="总测试数", marker_color="lightblue")) + fig.add_trace(go.Bar(x=dates, y=successes, name="成功", marker_color="lightgreen")) + fig.add_trace(go.Bar(x=dates, y=failures, name="失败", marker_color="lightcoral")) + + fig.update_layout( + title="每日测试用例分布趋势", + barmode="group", + xaxis_title="日期", + yaxis_title="测试用例数", + template="plotly_white", + height=400, + xaxis_tickangle=-45, + ) + st.plotly_chart(fig, use_container_width=True) + + +def render_latest_results(runs): + """Render latest test results""" + st.markdown("### 🚀 最新测试结果") + + def get_latest_by_type(runs, test_type): + filtered = [r for r in runs if r.get("testcase", "").startswith(test_type)] + return filtered[0] if filtered else None + + categories = [ + { + "name": "通信", + "icon": "🔗", + "page": "pages/communication.py", + "run": get_latest_by_type(runs, "comm"), + }, + { + "name": "推理", + "icon": "🚀", + "page": "pages/inference.py", + "run": get_latest_by_type(runs, "infer"), + }, + { + "name": "算子", + "icon": "⚡", + "page": "pages/operator.py", + "run": get_latest_by_type(runs, "operator"), + }, + { + "name": "训练", + "icon": "🏋️", + "page": "pages/training.py", + "run": get_latest_by_type(runs, "train"), + }, + { + "name": "硬件", + "icon": "🔧", + "page": "pages/hardware.py", + "run": get_latest_by_type(runs, "hardware"), + }, + ] + + cols = st.columns(5) + for idx, cat in enumerate(categories): + with cols[idx]: + st.markdown(f"#### {cat['icon']} {cat['name']}(最新)") + if cat["run"]: + run = cat["run"] + st.write(f"- testcase: `{run.get('testcase', '')}`") + st.write(f"- time: {run.get('time', '')}") + st.write(f"- status: {'✅' if run.get('success') else '❌'}") + if st.button( + f"查看详情 →", key=f"btn_{cat['name']}", use_container_width=True + ): + st.switch_page(cat["page"]) else: - st.write(f"- testcase: `{latest_hw.get('testcase','')}`") - st.write(f"- time: {latest_hw.get('time','')}") - st.write(f"- status: {'✅' if latest_hw.get('success') else '❌'}") + st.info("暂无结果") + st.divider() + + +def render_dispatcher_summary(ci_summaries): + """Render dispatcher summary records""" + if not ci_summaries: + return + + st.markdown("### 🧾 Dispatcher 汇总记录") + st.caption("*每次CI运行的原始汇总文件记录*") + + rows = [] + for s in ci_summaries[:15]: + rows.append( + { + "时间": format_time(s.get("timestamp", "")), + "总测试数": s.get("total_tests", 0), + "成功": s.get("successful_tests", 0), + "失败": s.get("failed_tests", 0), + "成功率": f"{(s.get('successful_tests', 0)/s.get('total_tests', 1)*100):.1f}%" + if s.get("total_tests", 0) > 0 + else "-", + "文件": s.get("file", ""), + } + ) - st.divider() + st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) + st.divider() - # ========== Recent runs table ========== - st.markdown("### 🕒 最近测试运行") - df = pd.DataFrame( - [ - { - "类型": (r.get("testcase", "").split(".")[0] or "UNKNOWN").upper(), - "加速卡": ", ".join(r.get("accelerator_types", [])), - "时间": r.get("time", ""), - "状态": "✅" if r.get("success") else "❌", - "run_id": r.get("run_id", "")[:32], - } - for r in runs[:15] - ] + +def render_ci_detailed_table(ci_summaries): + """Render CI detailed records table""" + if not ci_summaries: + st.info("未找到 CI 汇总记录") + return + + st.markdown("### 📋 CI 详细记录") + st.caption("*包含Git信息、测试统计和时长的详细记录*") + + rows = [] + for s in ci_summaries[:30]: + time_str = format_time(s.get("timestamp", "")) + total = s.get("total_tests", 0) + success = s.get("successful_tests", 0) + failed = s.get("failed_tests", 0) + duration = s.get("duration", s.get("total_duration_seconds", 0)) + + # Run ID + results = s.get("results", []) + run_id_display = results[0].get("run_id", "-") if results else "-" + + # Git info + git_info = s.get("git", {}) + short_commit = git_info.get("short_commit", "") + commit_display = ( + short_commit[:8] if short_commit and short_commit != "unknown" else "本地运行" + ) + + commit_msg = git_info.get("commit_message", "") + msg_display = ( + commit_msg[:50] + "..." + if commit_msg and len(commit_msg) > 50 + else ( + commit_msg + if commit_msg and not commit_msg.startswith("unknown") + else "-" + ) + ) + + branch = git_info.get("branch", "") + branch_display = ( + branch if branch and not branch.startswith("unknown") else "local" + ) + + author = git_info.get("commit_author", "") + author_display = author if author and not author.startswith("unknown") else "-" + + rows.append( + { + "时间": time_str, + "Run ID": run_id_display, + "Commit": commit_display, + "提交信息": msg_display, + "分支": branch_display, + "作者": author_display, + "总数": total, + "✅": success, + "❌": failed, + "成功率": f"{(success/total*100):.1f}%" if total > 0 else "-", + "状态": "✅ 成功" if failed == 0 else "❌ 失败", + "时长": f"{duration:.1f}s" if duration > 0 else "-", + } ) - st.dataframe(df, use_container_width=True, hide_index=True) - # ========== Dispatcher summary ========== - summaries = st.session_state.data_loader.load_summaries() + df = pd.DataFrame(rows) + st.dataframe(df, use_container_width=True, hide_index=True) + st.divider() - if not summaries: - st.info("未找到 Dispatcher 汇总记录") - return - st.markdown("### 🧾 Dispatcher 汇总记录") +def render_failure_details(ci_summaries): + """Render failure details""" + failed_records = [] + for s in ci_summaries: + failed = s.get("failed_tests", 0) + if failed == 0: + continue - rows = [] - for s in summaries: - rows.append( + time_str = format_time(s.get("timestamp", "")) + git_info = s.get("git", {}) + short_commit = git_info.get("short_commit", "") + commit_display = ( + short_commit[:8] if short_commit and short_commit != "unknown" else "本地运行" + ) + + # Get failure details + failed_details = s.get("failed_tests_details", []) + if not failed_details and "results" in s: + failed_details = [ { - "时间": s.get("timestamp"), - "总测试数": s.get("total_tests"), - "成功": s.get("successful_tests"), - "失败": s.get("failed_tests"), - "成功率": ( - f"{s['successful_tests'] / s['total_tests'] * 100:.1f}%" - if s.get("total_tests") - else "-" - ), - "文件": s.get("file"), + "testcase": r.get("testcase", "unknown"), + "run_id": r.get("run_id", "unknown"), + "result_code": r.get("result_code", -1), + "result_file": r.get("result_file", ""), + "error_msg": r.get("error_msg", ""), } - ) + for r in s.get("results", []) + if r.get("result_code", 0) != 0 + ] - df = pd.DataFrame(rows).sort_values("时间", ascending=False) + if failed_details: + failed_records.append( + { + "时间": time_str, + "Commit": commit_display, + "失败数": len(failed_details), + "失败详情": failed_details, + } + ) - st.dataframe( - df, - use_container_width=True, - hide_index=True, - ) + if not failed_records: + return - # ========== Quick navigation ========== - st.markdown("---") - st.markdown("### 🚀 快速导航") - - col1, col2, col3, col4, col5 = st.columns(5) - if col1.button("🔗 通信测试分析", use_container_width=True): - st.switch_page("pages/communication.py") - if col2.button("⚡ 算子测试分析", use_container_width=True): - st.switch_page("pages/operator.py") - if col3.button("🚀 推理测试分析", use_container_width=True): - st.switch_page("pages/inference.py") - if col4.button("🏋️ 训练测试分析", use_container_width=True): - st.switch_page("pages/training.py") - if col5.button("🔧 硬件测试分析", use_container_width=True): - st.switch_page("pages/hardware.py") + st.markdown("### 🔍 失败详情") + st.caption("点击展开查看失败测试的详细信息") - except Exception as e: - st.error(f"Dashboard 加载失败: {e}") + for record in failed_records[:15]: + with st.expander( + f"📅 {record['时间']} - 失败: {record['失败数']}个测试 (Commit: {record['Commit']})" + ): + for i, fail in enumerate(record["失败详情"][:20]): + st.markdown(f"**{i+1}. {fail.get('testcase', 'unknown')}**") + st.markdown(f"- Run ID: `{fail.get('run_id', 'unknown')}`") + st.markdown(f"- Result Code: {fail.get('result_code', -1)}") + if fail.get("result_file"): + st.markdown(f"- Result File: `{fail.get('result_file')}`") + st.divider() + + if len(record["失败详情"]) > 20: + st.info(f"还有 {len(record['失败详情']) - 20} 个失败测试未显示") if __name__ == "__main__": diff --git a/dashboard/common.py b/dashboard/common.py index 0a59f07..451c932 100644 --- a/dashboard/common.py +++ b/dashboard/common.py @@ -35,21 +35,29 @@ def init_page(page_title: str, page_icon: str): ) -def show_data_source_info(style: str = "caption"): +def show_data_source_info(style: str = "caption", show_detailed: bool = False): """ Display current data source info (MongoDB or file system). Args: style: Display style - "caption" for pages, "sidebar" for main app sidebar + show_detailed: Whether to show detailed statistics """ dl = st.session_state.data_loader + if dl.source_type == "mongodb": if style == "sidebar": - st.success("🟢 数据源: MongoDB") + st.success("🟢 **数据源: MongoDB**") + if show_detailed: + st.caption("实时CI数据 | 支持完整历史查询") else: - st.caption("数据源: MongoDB") + st.caption("🟢 数据源: MongoDB (实时CI数据)") else: if style == "sidebar": - st.info(f"📁 数据源: 文件系统 ({dl.results_dir})") + st.info(f"📁 **数据源: 文件系统**") + st.caption(f"结果目录: `{dl.results_dir}`") + if show_detailed: + summary_dir = dl.results_dir.parent / "summary_output" + st.caption(f"摘要目录: `{summary_dir}`") else: - st.caption(f"数据源: 文件系统 ({dl.results_dir})") + st.caption(f"📁 数据源: 文件系统 ({dl.results_dir})") diff --git a/dashboard/utils/data_loader.py b/dashboard/utils/data_loader.py index f843617..6c24148 100644 --- a/dashboard/utils/data_loader.py +++ b/dashboard/utils/data_loader.py @@ -157,6 +157,76 @@ def load_summaries(self) -> List[Dict[str, Any]]: return [] return self._source.load_summaries() + def load_ci_history(self, limit: int = 100) -> List[Dict[str, Any]]: + """ + Load CI history with detailed execution information. + + Args: + limit: Maximum number of CI runs to load + + Returns: + List of CI run summaries with enhanced information + """ + if self._source is None: + return [] + + # Check if the source has load_ci_history method + if hasattr(self._source, "load_ci_history"): + return self._source.load_ci_history(limit) + + # Fallback: use load_summaries and enhance them + summaries = self._source.load_summaries() + return self._enhance_summaries(summaries[:limit]) + + def _enhance_summaries( + self, summaries: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Enhance summaries with CI information and failure details. + This is a fallback method when the data source doesn't provide enhanced data. + """ + enhanced = [] + for summary in summaries: + # Normalize fields + enhanced_summary = summary.copy() + + # Set default CI fields with placeholder markers + ci_fields = [ + "commit", + "branch", + "commit_message", + "author", + "pipeline_id", + "ci_url", + ] + for field in ci_fields: + if field not in enhanced_summary: + enhanced_summary[field] = f"unknown_{field}" + enhanced_summary[f"{field}_placeholder"] = True + + # Calculate status + total = enhanced_summary.get("total_tests", 0) + failed = enhanced_summary.get("failed_tests", 0) + if total == 0: + enhanced_summary["status"] = "无测试" + elif failed == 0: + enhanced_summary["status"] = "成功" + elif enhanced_summary.get("successful_tests", 0) > 0: + enhanced_summary["status"] = "部分成功" + else: + enhanced_summary["status"] = "失败" + + # Add failed tests details if not present + if "failed_tests_details" not in enhanced_summary: + enhanced_summary["failed_tests_details"] = [] + + # Add data source marker + enhanced_summary["_data_source"] = self.source_type + + enhanced.append(enhanced_summary) + + return enhanced + # Re-export from sibling modules from .data_sources import DataSource, FileDataSource diff --git a/dashboard/utils/data_sources.py b/dashboard/utils/data_sources.py index 7ce2269..c07dae8 100644 --- a/dashboard/utils/data_sources.py +++ b/dashboard/utils/data_sources.py @@ -10,7 +10,13 @@ import pandas as pd -from .data_utils import extract_accelerator_types, extract_run_info, load_summary_file +from .data_utils import ( + extract_accelerator_types, + extract_run_info, + load_summary_file, + normalize_ci_summary, + extract_failed_tests_details, +) # Add project root to path for db module access (works regardless of cwd) _project_root = Path(__file__).parent.parent.parent @@ -116,12 +122,84 @@ def load_test_result(self, json_path: Path) -> Dict[str, Any]: return data - def _is_test_result_file(self, data: Dict[str, Any]) -> bool: - """Check if JSON file is a test result (not a summary).""" - required = ["run_id", "testcase", "config"] - return all(key in data for key in required) and "metrics" in data - def load_summaries(self) -> List[Dict[str, Any]]: """Load dispatcher summary files from summary_output directory.""" summary_dir = self.results_dir.parent / "summary_output" return load_summary_file(str(summary_dir)) + + def load_ci_history(self, limit: int = 100) -> List[Dict[str, Any]]: + """ + Load the CI history from the summary file + """ + summaries = self.load_summaries() + enhanced_summaries = [] + + for summary in summaries[:limit]: + # Normalize CI metadata + summary = normalize_ci_summary(summary) + + # Extract detailed information of failed test cases + summary["failed_tests_details"] = extract_failed_tests_details(summary) + + # Add data source marker + summary["_data_source"] = "file" + summary["_summary_file"] = summary.get("file", "unknown") + + # Compute overall status + total = summary.get("total_tests", 0) + failed = summary.get("failed_tests", 0) + if total == 0: + summary["status"] = "无测试" + elif failed == 0: + summary["status"] = "成功" + elif summary.get("successful_tests", 0) > 0: + summary["status"] = "部分成功" + else: + summary["status"] = "失败" + + enhanced_summaries.append(summary) + + return enhanced_summaries + + def _is_test_result_file(self, data: Dict[str, Any]) -> bool: + """Check if JSON file is a test result.""" + required = ["run_id", "testcase", "config"] + return all(key in data for key in required) and "metrics" in data + + def _get_csv_base_dir(self, json_data: Dict[str, Any], json_path: Path) -> Path: + """Get the correct base directory for CSV files.""" + config = json_data.get("config", {}) + output_dir = config.get("output_dir") + + if output_dir: + output_path = Path(output_dir) + if output_path.is_absolute(): + return output_path + return json_path.parent / output_dir + + return json_path.parent + + def _resolve_csv_path(self, csv_url: str, base_dir: Path) -> Optional[Path]: + """Resolve CSV path from raw_data_url.""" + try: + if not csv_url: + return None + + if csv_url.startswith("./"): + csv_url = csv_url[2:] + + # Try a variety of possible paths + candidates = [ + base_dir / csv_url, + base_dir / Path(csv_url).name, + base_dir.parent / csv_url, + base_dir.parent / Path(csv_url).name, + ] + + for p in candidates: + if p.exists(): + return p + + return None + except Exception: + return None diff --git a/dashboard/utils/data_utils.py b/dashboard/utils/data_utils.py index 33e6e50..c749451 100644 --- a/dashboard/utils/data_utils.py +++ b/dashboard/utils/data_utils.py @@ -30,6 +30,102 @@ def load_summary_file(summary_path: str = "../summary_output") -> List[Dict[str, return summaries +def normalize_ci_summary(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize CI summary information, prioritizing Git metadata when available. + """ + # Extract Git and CI environment information + git_info = data.get("git", {}) + ci_info = data.get("ci_environment", {}) + + # Prefer values from git_info if this is a valid Git repository + if git_info.get("is_git_repo"): + data["commit"] = git_info.get("commit", "unknown") + data["short_commit"] = git_info.get("short_commit", "unknown") + data["branch"] = git_info.get("branch", "unknown") + data["commit_message"] = git_info.get("commit_message", "unknown") + data["commit_author"] = git_info.get("commit_author", "unknown") + data["commit_date"] = git_info.get("commit_date", "unknown") + data["_has_real_git_info"] = True + else: + data["commit"] = git_info.get("commit", "not_in_git_repo") + data["short_commit"] = git_info.get("short_commit", "not_in_git_repo") + data["branch"] = git_info.get("branch", "not_in_git_repo") + data["commit_message"] = git_info.get("commit_message", "Not in Git repository") + data["_has_real_git_info"] = False + + # Attach CI environment metadata + if ci_info: + data["ci_provider"] = ci_info.get("ci_provider", "unknown") + data["ci_pipeline_id"] = ci_info.get( + "ci_pipeline_id", ci_info.get("ci_run_id", "") + ) + + # Compute duration if not already present + if "duration" not in data and "total_duration_seconds" in data: + data["duration"] = data["total_duration_seconds"] + + # Derive overall CI status + total = data.get("total_tests", 0) + failed = data.get("failed_tests", 0) + if total == 0: + data["status"] = "无测试" + elif failed == 0: + data["status"] = "✅ 成功" + elif data.get("successful_tests", 0) > 0: + data["status"] = "⚠️ 部分成功" + else: + data["status"] = "❌ 失败" + + return data + + +def extract_failed_tests_details(data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Extract failed test details from CI summary data. + + Supports two formats: + 1. Directly using the `failed_tests_details` field + 2. Extracting failed cases from `test_results` or `tests` lists + """ + if "failed_tests_details" in data and data["failed_tests_details"]: + return data["failed_tests_details"] + + failed_details = [] + + # Attempt to extract from test_results list + if "test_results" in data: + for test in data["test_results"]: + if not test.get("success", True): + failed_details.append( + { + "test_name": test.get("name", test.get("testcase", "unknown")), + "error": test.get( + "error", test.get("error_msg", "Unknown error") + ), + "duration": test.get("duration", 0), + "logs": test.get("logs", test.get("output", "")), + } + ) + + # Attempt to extract from tests list + elif "tests" in data: + for test in data["tests"]: + if test.get("status") in ["failed", "error"]: + failed_details.append( + { + "test_name": test.get("name", test.get("testcase", "unknown")), + "error": test.get( + "error", test.get("message", "Unknown error") + ), + "duration": test.get("duration", 0), + "logs": test.get("logs", test.get("output", "")), + } + ) + + return failed_details + + def get_friendly_size(size_bytes: int) -> str: """Convert bytes to human-readable size.""" for unit in ["B", "KB", "MB", "GB", "TB"]: diff --git a/dashboard/utils/mongo_data_source.py b/dashboard/utils/mongo_data_source.py index 85daa28..526fb95 100644 --- a/dashboard/utils/mongo_data_source.py +++ b/dashboard/utils/mongo_data_source.py @@ -9,7 +9,12 @@ import pandas as pd from .data_sources import DataSource -from .data_utils import extract_accelerator_types, extract_run_info +from .data_utils import ( + extract_accelerator_types, + extract_run_info, + normalize_ci_summary, + extract_failed_tests_details, +) logger = logging.getLogger(__name__) @@ -33,7 +38,7 @@ def _connect(self): if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) - from db import MongoDBClient, TestRunRepository + from db import MongoDBClient, TestRunRepository, DispatcherSummaryRepository if self._config: self._client = MongoDBClient(self._config) @@ -47,6 +52,10 @@ def _connect(self): self._repository = TestRunRepository( self._client.get_collection(config.collection_name) ) + + self._summary_repo = DispatcherSummaryRepository( + self._client.get_collection(config.summary_collection_name) + ) self._connected = True logger.info("Connected to MongoDB data source") else: @@ -130,3 +139,106 @@ def load_summaries(self) -> List[Dict[str, Any]]: except Exception as e: logger.warning(f"Failed to load summaries from MongoDB: {e}") return [] + + def load_ci_history(self, limit: int = 100) -> List[Dict[str, Any]]: + """ + Load CI history with enhanced information from MongoDB. + """ + if not self._connect(): + logger.warning("MongoDB not connected, returning empty list") + return [] + + try: + # Retrieve CI summaries from the summary repository + summaries = self._summary_repo.list_summaries(limit=limit) + enhanced_summaries = [] + + for summary in summaries: + # Remove internal MongoDB fields + summary.pop("_id", None) + summary.pop("_metadata", None) + + # Normalize CI summary format + summary = normalize_ci_summary(summary) + + # Extract failed test details + if "failed_tests_details" not in summary: + # Try to load failure details from associated test results + failed_details = self._load_failed_tests_for_summary(summary) + summary["failed_tests_details"] = failed_details + + # Add data source marker + summary["_data_source"] = "mongodb" + + # Derive overall status + total = summary.get("total_tests", 0) + failed = summary.get("failed_tests", 0) + if total == 0: + summary["status"] = "无测试" + elif failed == 0: + summary["status"] = "成功" + elif summary.get("successful_tests", 0) > 0: + summary["status"] = "部分成功" + else: + summary["status"] = "失败" + + enhanced_summaries.append(summary) + + return enhanced_summaries + + except Exception as e: + logger.warning(f"Failed to load CI history from MongoDB: {e}") + return [] + + def _load_failed_tests_for_summary( + self, summary: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """ + Load failed test details for a given summary. + """ + failed_details = [] + + # If run_ids are available, try loading each failed test individually + if "run_ids" in summary: + for run_id in summary.get("run_ids", []): + try: + test_result = self.load_test_result(run_id) + if test_result and not test_result.get("success", True): + failed_details.append( + { + "test_name": test_result.get("testcase", "unknown"), + "run_id": run_id, + "error": test_result.get("error_msg", "Unknown error"), + "duration": self._extract_duration(test_result), + "logs": test_result.get("logs", ""), + "config": test_result.get("config", {}), + } + ) + except Exception as e: + logger.debug(f"Failed to load test {run_id}: {e}") + + # If test_results are already embedded in the summary + elif "test_results" in summary: + for test in summary["test_results"]: + if not test.get("success", True): + failed_details.append( + { + "test_name": test.get( + "name", test.get("testcase", "unknown") + ), + "error": test.get( + "error", test.get("message", "Unknown error") + ), + "duration": test.get("duration", 0), + "logs": test.get("logs", test.get("output", "")), + } + ) + + return failed_details + + def _extract_duration(self, test_result: Dict[str, Any]) -> float: + """Extract duration from test result metrics.""" + for metric in test_result.get("metrics", []): + if metric.get("name") == "duration": + return metric.get("value", 0) + return 0 diff --git a/infinimetrics/dispatcher.py b/infinimetrics/dispatcher.py index 5482d2f..75856cc 100644 --- a/infinimetrics/dispatcher.py +++ b/infinimetrics/dispatcher.py @@ -201,26 +201,77 @@ def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """Aggregate results from executors.""" total = len(results) successful = sum(1 for r in results if r["result_code"] == 0) + failed = total - successful - return { + # 计算总duration + total_duration = 0 + + aggregated = { "total_tests": total, "successful_tests": successful, - "failed_tests": total - successful, - "results": [ + "failed_tests": failed, + "results": [], + "timestamp": datetime.now().isoformat(), + } + + for r in results: + # 提取duration(如果存在) + duration = r.get("duration", 0) + total_duration += duration + + aggregated["results"].append( { "run_id": r["run_id"], "testcase": r["testcase"], "result_code": r["result_code"], "result_file": r["result_file"], "skipped": r.get("skipped", False), + "duration": duration, # 添加duration } - for r in results - ], - "timestamp": datetime.now().isoformat(), - } + ) + + aggregated["total_duration_seconds"] = total_duration + + # 添加失败详情 + failed_details = [] + for r in results: + if r["result_code"] != 0: + failed_details.append( + { + "testcase": r["testcase"], + "run_id": r["run_id"], + "result_code": r["result_code"], + "error_msg": r.get("error_msg", "Unknown error"), + "result_file": r.get("result_file"), + } + ) + + if failed_details: + aggregated["failed_tests_details"] = failed_details + + return aggregated def _save_summary(self, aggregated: Dict[str, Any]) -> None: - """Save aggregated results summary to disk.""" + """Save aggregated results summary to disk with Git and CI information.""" + # 获取Git信息 + from infinimetrics.utils.git_utils import get_git_info, get_ci_environment_info + + git_info = get_git_info() + ci_info = get_ci_environment_info() + + # 构建增强的汇总数据(保留原有的所有字段) + enhanced_summary = { + **aggregated, # 这里包含了 results, total_tests 等所有原有字段 + "git": git_info, + "ci_environment": ci_info, + } + + # 计算总运行时长 + total_duration = sum( + r.get("duration", 0) for r in aggregated.get("results", []) + ) + enhanced_summary["total_duration_seconds"] = total_duration + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"dispatcher_summary_{timestamp}.json" @@ -228,7 +279,17 @@ def _save_summary(self, aggregated: Dict[str, Any]) -> None: summary_dir = Path("./summary_output") summary_dir.mkdir(parents=True, exist_ok=True) + # Save summary to a separate directory + summary_dir = Path("./summary_output") + summary_dir.mkdir(parents=True, exist_ok=True) + with open(summary_dir / filename, "w", encoding="utf-8") as f: - json.dump(aggregated, f, indent=2, ensure_ascii=False) + json.dump(enhanced_summary, f, indent=2, ensure_ascii=False) logger.info(f"Summary saved to {summary_dir / filename}") + if git_info.get("is_git_repo"): + logger.info( + f"Git info: commit={git_info.get('short_commit')}, branch={git_info.get('branch')}" + ) + else: + logger.info("Git info: not in a Git repository") diff --git a/infinimetrics/executor.py b/infinimetrics/executor.py index 3d44b28..fe95b12 100644 --- a/infinimetrics/executor.py +++ b/infinimetrics/executor.py @@ -5,18 +5,18 @@ import json import logging -import subprocess +import time from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional + +import subprocess from infinimetrics.adapter import BaseAdapter -from infinimetrics.common.error_handler import ErrorHandler -from infinimetrics.common.hardware_info import collect_hardware_info -from infinimetrics.utils.path_utils import sanitize_filename from infinimetrics.common.constants import ErrorCode, TEST_CATEGORIES - +from infinimetrics.utils.hardware_detector import HardwareDetector +from infinimetrics.utils.path_utils import sanitize_filename logger = logging.getLogger(__name__) @@ -38,6 +38,8 @@ class TestResult: result_file: Optional[str] = None skipped: bool = False config: Optional[Dict[str, Any]] = None + duration: float = 0.0 + error_msg: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert to lightweight dictionary format for Dispatcher aggregation.""" @@ -48,6 +50,8 @@ def to_dict(self) -> Dict[str, Any]: "result_file": self.result_file, "skipped": self.skipped, "config": self.config, + "duration": self.duration, + "error_msg": self.error_msg, } @@ -75,9 +79,9 @@ def __init__(self, payload: Dict[str, Any], adapter: BaseAdapter): self.run_id = payload.get("run_id", "") self.test_input = None + # Setup output directory from config config = payload.get("config", {}) - output_dir = config.get("output_dir", "./output") - self.output_dir = Path(output_dir) + self.output_dir = Path(config.get("output_dir", "./output")) self.output_dir.mkdir(parents=True, exist_ok=True) logger.debug(f"Executor initialized: testcase={self.testcase}") @@ -145,101 +149,111 @@ def execute(self) -> TestResult: """ logger.info(f"Executor: Running {self.testcase}") - # Initialize TestResult directly (default: result_code=0) + start_time = time.time() config = self.payload.get("config", {}) test_result = TestResult( run_id=self.run_id, testcase=self.testcase, - result_code=0, # Default to success + result_code=0, result_file=None, config=config, + duration=0.0, + error_msg=None, ) - response: Dict[str, Any] = {} + response = {} try: - # Phase 1: Setup self.setup() - - # Phase 2: Process - logger.debug(f"Executor: Calling adapter.process()") response = self.adapter.process(self.test_input) - # Enrich environment ONLY if missing + # Enrich environment if missing if isinstance(response, dict) and "environment" not in response: - env = self._build_environment(response) - - # rebuild ordered dict (py3.7+ preserves insertion order) - ordered: Dict[str, Any] = {} - for k in [ - "run_id", - "time", - "testcase", - "success", - "environment", - "result_code", - "config", - "metrics", - ]: - if k == "environment": - ordered["environment"] = env - elif k in response: - ordered[k] = response[k] - - # append remaining keys in original order (skip those already set) - for k, v in response.items(): - if k not in ordered: - ordered[k] = v - - response = ordered - - # Phase 3: Teardown (cleanup, save result) + response = self._enrich_environment(response) + result_file = self.teardown(response) test_result.result_file = result_file + test_result.duration = time.time() - start_time logger.info( - f"Executor: {self.testcase} completed with code={test_result.result_code}" + f"Executor: {self.testcase} completed in {test_result.duration:.2f}s" ) - return test_result - except subprocess.TimeoutExpired as e: - # Timeout errors (possible hardware hang) + except Exception as e: + return self._handle_error(e, start_time, test_result) + + def _enrich_environment(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Enrich response with environment information.""" + env = self._build_environment(response) + + # Rebuild ordered dict + ordered = {} + for key in [ + "run_id", + "time", + "testcase", + "success", + "environment", + "result_code", + "config", + "metrics", + ]: + if key == "environment": + ordered["environment"] = env + elif key in response: + ordered[key] = response[key] + + for key, value in response.items(): + if key not in ordered: + ordered[key] = value + + return ordered + + def _handle_error( + self, error: Exception, start_time: float, test_result: TestResult + ) -> TestResult: + """Handle different types of errors.""" + duration = time.time() - start_time + test_result.duration = duration + error_msg = str(error) + + # Set error message + test_result.error_msg = error_msg + + # Determine error type and result code + if isinstance(error, subprocess.TimeoutExpired): test_result.result_code = ErrorCode.TIMEOUT - ErrorHandler.log_error(self.testcase, e, ErrorCode.TIMEOUT) - response = self._build_error_response(str(e), ErrorCode.TIMEOUT, config) - - except ValueError as e: - # Configuration or input validation errors + logger.error(f"Executor: Timeout after {duration:.2f}s: {error_msg[:300]}") + elif isinstance(error, ValueError): test_result.result_code = ErrorCode.CONFIG - ErrorHandler.log_error(self.testcase, e, ErrorCode.CONFIG) - response = self._build_error_response(str(e), ErrorCode.CONFIG, config) + logger.warning(f"Executor: Configuration error: {error_msg[:300]}") + elif isinstance(error, RuntimeError): + if any( + kw in error_msg.lower() for kw in ["memory", "oom", "out of memory"] + ): + test_result.result_code = ErrorCode.SYSTEM + logger.error(f"Executor: Memory error: {error_msg[:300]}") + else: + test_result.result_code = ErrorCode.GENERIC + logger.warning(f"Executor: Runtime error: {error_msg[:300]}") + else: + test_result.result_code = ErrorCode.GENERIC + logger.error(f"Executor: Unexpected error: {error}", exc_info=True) - except RuntimeError as e: - # RuntimeError: analyze error message for specific patterns - error_code = ErrorHandler.classify_runtime_error(str(e).lower()) - test_result.result_code = error_code - ErrorHandler.log_error(self.testcase, e, error_code) - response = self._build_error_response(str(e), error_code, config) + # Build error response + response = self._build_error_response(error_msg, test_result.result_code) + # Save result + try: + if not test_result.result_file: + test_result.result_file = self._save_result(response) except Exception as e: - # Unexpected exceptions - test_result.result_code = ErrorCode.GENERIC - logger.error( - f"Executor: {self.testcase} failed with unexpected exception: {e}", - exc_info=True, - ) - response = self._build_error_response(str(e), ErrorCode.GENERIC, config) - - finally: - # Always save result (even on failure) - self._finalize_result(test_result, response) + logger.error(f"Executor: Failed to save result: {e}") return test_result - def _build_error_response( - self, error_msg: str, result_code: int, config: Dict[str, Any] - ) -> Dict[str, Any]: + def _build_error_response(self, error_msg: str, result_code: int) -> Dict[str, Any]: """ Build a response dict containing error information for saving to disk. @@ -250,11 +264,28 @@ def _build_error_response( Returns: Dictionary with basic test info and error details """ - response = ErrorHandler.build_error_response( - self.run_id, self.testcase, error_msg, result_code, config - ) - response["resolved"] = self._extract_device_info(config) - return response + config = self.payload.get("config", {}) + + # Create a cleaned config without injected metadata + cleaned_config = { + k: v + for k, v in config.items() + if not k.startswith("_") # Skip _testcase, _run_id, _time + } + + # Extract device information + resolved = self._extract_device_info(config) + + return { + "run_id": self.run_id, + "testcase": self.testcase, + "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "result_code": result_code, + "error_msg": error_msg, + "success": 1, # 1 = failure + "config": cleaned_config, + "resolved": resolved, + } def _extract_device_info(self, config: Dict[str, Any]) -> Dict[str, Any]: """Extract device information from config.""" @@ -325,7 +356,15 @@ def _build_environment(self, response: Dict[str, Any]) -> Dict[str, Any]: else: topo = f"{nodes}x{(gpn or max(1, device_used // nodes))} ring mesh" - hw = collect_hardware_info(accel_type=accel_type, device_ids=device_ids) + hw = { + "cpu_model": "Unknown", + "memory_gb": 0, + "gpu_model": "Unknown", + "gpu_memory_gb": 0, + "driver_version": "Unknown", + "cuda_version": "Unknown", + "accelerator_type": accel_type or "generic", + } framework_info = resolved.get("framework", {}) if not framework_info: @@ -356,17 +395,6 @@ def _build_environment(self, response: Dict[str, Any]) -> Dict[str, Any]: ], } - def _finalize_result( - self, test_result: TestResult, response: Dict[str, Any] - ) -> None: - """Save result file if not already saved.""" - if not test_result.result_file: - try: - result_file = self._save_result(response) - test_result.result_file = result_file - except Exception as teardown_error: - logger.error(f"Executor: Failed to save result: {teardown_error}") - def _save_result(self, result: Dict[str, Any]) -> str: """ Save detailed result to disk as JSON. diff --git a/infinimetrics/utils/git_utils.py b/infinimetrics/utils/git_utils.py new file mode 100644 index 0000000..2a54059 --- /dev/null +++ b/infinimetrics/utils/git_utils.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +"""Git utility functions for CI/CD integration.""" + +import subprocess +import os +from pathlib import Path +from typing import Dict + + +def get_git_info(project_root: Path = None) -> Dict[str, str]: + """ + 获取当前Git仓库信息。 + + Args: + project_root: 项目根目录,默认为当前文件所在目录的父目录的父目录 + + Returns: + 包含Git信息的字典 + """ + if project_root is None: + # Looks up the project root from the current file + current = Path(__file__).parent + while current != current.parent: + if (current / ".git").exists(): + project_root = current + break + current = current.parent + + def run_cmd(cmd): + try: + result = ( + subprocess.check_output( + cmd, + shell=True, + cwd=project_root, + stderr=subprocess.DEVNULL, + timeout=5, + ) + .decode() + .strip() + ) + return result if result else "unknown" + except Exception: + return "unknown" + + # Check to see if it is in the Git repository + is_git = run_cmd("git rev-parse --git-dir") != "unknown" + + if not is_git: + return { + "is_git_repo": False, + "commit": "not_in_git_repo", + "short_commit": "not_in_git_repo", + "branch": "not_in_git_repo", + "commit_message": "Not in Git repository", + "commit_author": "unknown", + "commit_date": "unknown", + } + + return { + "is_git_repo": True, + "commit": run_cmd("git rev-parse HEAD"), + "short_commit": run_cmd("git rev-parse --short HEAD"), + "branch": run_cmd("git rev-parse --abbrev-ref HEAD"), + "commit_message": run_cmd("git log -1 --pretty=%s"), + "commit_author": run_cmd("git log -1 --pretty=%an"), + "commit_date": run_cmd("git log -1 --pretty=%ci"), + "commit_body": run_cmd("git log -1 --pretty=%b"), + } + + +def get_ci_environment_info() -> Dict[str, str]: + """ + Get CI environment information + """ + ci_info = {"ci_provider": "local"} + + # GitHub Actions + if os.environ.get("GITHUB_ACTIONS") == "true": + ci_info.update( + { + "ci_provider": "github_actions", + "ci_run_id": os.environ.get("GITHUB_RUN_ID", ""), + "ci_run_number": os.environ.get("GITHUB_RUN_NUMBER", ""), + "ci_repository": os.environ.get("GITHUB_REPOSITORY", ""), + "ci_ref": os.environ.get("GITHUB_REF", ""), + "ci_sha": os.environ.get("GITHUB_SHA", ""), + } + ) + + # GitLab CI + elif os.environ.get("GITLAB_CI") == "true": + ci_info.update( + { + "ci_provider": "gitlab_ci", + "ci_pipeline_id": os.environ.get("CI_PIPELINE_ID", ""), + "ci_job_id": os.environ.get("CI_JOB_ID", ""), + "ci_commit_sha": os.environ.get("CI_COMMIT_SHA", ""), + "ci_commit_branch": os.environ.get("CI_COMMIT_BRANCH", ""), + } + ) + + # Jenkins + elif os.environ.get("JENKINS_URL"): + ci_info.update( + { + "ci_provider": "jenkins", + "ci_build_id": os.environ.get("BUILD_ID", ""), + "ci_build_number": os.environ.get("BUILD_NUMBER", ""), + "ci_job_name": os.environ.get("JOB_NAME", ""), + } + ) + + return ci_info diff --git a/infinimetrics/utils/hardware_detector.py b/infinimetrics/utils/hardware_detector.py new file mode 100644 index 0000000..2af64bc --- /dev/null +++ b/infinimetrics/utils/hardware_detector.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +"""Hardware detection utilities for Executor.""" + +import re +import subprocess +from pathlib import Path +from typing import Any, Dict, List, Optional + + +def _which(cmd: str) -> Optional[str]: + """Check if command exists in PATH.""" + try: + from shutil import which + + return which(cmd) + except Exception: + return None + + +class HardwareDetector: + """Detect hardware information (CPU, GPU, memory).""" + + NVIDIA_SMI_QUERY = [ + "nvidia-smi", + "--query-gpu=name,memory.total,driver_version", + "--format=csv,noheader", + ] + AMD_SMI_CANDIDATES = ["amd-smi", "rocm-smi"] + + @classmethod + def detect(cls, accel_type_hint: str = "") -> Dict[str, Any]: + """ + Detect hardware information. + + Args: + accel_type_hint: Hint for accelerator type (nvidia/amd/ascend/cambricon) + + Returns: + Dictionary with hardware information + """ + hw = cls._init_hardware_dict() + + # CPU detection + cls._detect_cpu(hw) + # Memory detection + cls._detect_memory(hw) + + # GPU detection order + probes = cls._get_probe_order(accel_type_hint) + + for probe in probes: + if probe == "nvidia" and cls._probe_nvidia(hw): + hw["accelerator_type"] = "nvidia" + hw["cuda_version"] = cls._get_cuda_version() or hw["cuda_version"] + return hw + if probe == "amd" and cls._probe_amd(hw): + hw["accelerator_type"] = "amd" + return hw + if probe == "ascend" and cls._probe_ascend(hw): + hw["accelerator_type"] = "ascend" + return hw + if probe == "cambricon" and cls._probe_cambricon(hw): + hw["accelerator_type"] = "cambricon" + return hw + if probe == "generic": + hw["accelerator_type"] = "generic" + return hw + + return hw + + @classmethod + def _init_hardware_dict(cls) -> Dict[str, Any]: + return { + "cpu_model": "Unknown", + "memory_gb": 0, + "gpu_model": "Unknown", + "gpu_count": 0, + "gpu_memory_gb": 0, + "driver_version": "Unknown", + "cuda_version": "Unknown", + "accelerator_type": "generic", + } + + @classmethod + def _detect_cpu(cls, hw: Dict[str, Any]) -> None: + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + hw["cpu_model"] = line.split(":", 1)[1].strip() + break + except Exception: + pass + + @classmethod + def _detect_memory(cls, hw: Dict[str, Any]) -> None: + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if "MemTotal" in line: + mem_kb = int(line.split()[1]) + hw["memory_gb"] = mem_kb // (1024 * 1024) + break + except Exception: + pass + + @classmethod + def _get_probe_order(cls, hint: str) -> List[str]: + hint = hint.lower().strip() + probes = ( + [hint] + if hint in ("nvidia", "amd", "ascend", "cambricon", "generic") + else [] + ) + for p in ["nvidia", "amd", "ascend", "cambricon", "generic"]: + if p not in probes: + probes.append(p) + return probes + + @classmethod + def _probe_nvidia(cls, hw: Dict[str, Any]) -> bool: + try: + r = subprocess.run( + cls.NVIDIA_SMI_QUERY, capture_output=True, text=True, timeout=5 + ) + if r.returncode != 0 or not r.stdout.strip(): + return False + + lines = [x.strip() for x in r.stdout.strip().splitlines() if x.strip()] + hw["gpu_count"] = len(lines) + + p = [x.strip() for x in lines[0].split(",")] + if len(p) >= 3: + hw["gpu_model"] = p[0] + hw["driver_version"] = p[2] + mm = re.search(r"(\d+)\s*MiB", p[1]) + if mm: + hw["gpu_memory_gb"] = int(mm.group(1)) // 1024 + return True + except Exception: + return False + + @classmethod + def _probe_amd(cls, hw: Dict[str, Any]) -> bool: + try: + tool = None + for c in cls.AMD_SMI_CANDIDATES: + if _which(c): + tool = c + break + if not tool: + return False + + cmd = ["amd-smi", "list"] if tool == "amd-smi" else ["rocm-smi", "-i"] + r = subprocess.run(cmd, capture_output=True, text=True, timeout=5) + if r.returncode != 0 or not r.stdout.strip(): + return False + + lines = [ + x + for x in r.stdout.splitlines() + if re.search(r"\bGPU\b", x, re.IGNORECASE) + ] + hw["gpu_count"] = ( + max(hw["gpu_count"], len(lines)) if lines else hw["gpu_count"] + ) + hw["gpu_model"] = ( + hw["gpu_model"] if hw["gpu_model"] != "Unknown" else "AMD GPU" + ) + return True + except Exception: + return False + + @classmethod + def _probe_ascend(cls, hw: Dict[str, Any]) -> bool: + try: + if not _which("npu-smi"): + return False + r = subprocess.run( + ["npu-smi", "info"], capture_output=True, text=True, timeout=5 + ) + if r.returncode != 0 or not r.stdout.strip(): + return False + + cnt = len( + [ + x + for x in r.stdout.splitlines() + if re.search(r"\bNPU\b|\bDevice\b", x) + ] + ) + hw["gpu_count"] = max(hw["gpu_count"], cnt) if cnt else hw["gpu_count"] + hw["gpu_model"] = ( + hw["gpu_model"] if hw["gpu_model"] != "Unknown" else "Ascend NPU" + ) + return True + except Exception: + return False + + @classmethod + def _probe_cambricon(cls, hw: Dict[str, Any]) -> bool: + try: + if not _which("cnmon"): + return False + r = subprocess.run( + ["cnmon", "info"], capture_output=True, text=True, timeout=5 + ) + if r.returncode != 0 or not r.stdout.strip(): + return False + + cnt = len( + [ + x + for x in r.stdout.splitlines() + if re.search(r"\bMLU\b|\bDevice\b", x) + ] + ) + hw["gpu_count"] = max(hw["gpu_count"], cnt) if cnt else hw["gpu_count"] + hw["gpu_model"] = ( + hw["gpu_model"] if hw["gpu_model"] != "Unknown" else "Cambricon MLU" + ) + return True + except Exception: + return False + + @classmethod + def _get_cuda_version(cls) -> Optional[str]: + try: + r = subprocess.run( + ["nvcc", "--version"], capture_output=True, text=True, timeout=2 + ) + if r.returncode == 0: + for line in r.stdout.splitlines(): + if "release" in line: + m = re.search(r"release\s+(\d+\.\d+)", line) + if m: + return m.group(1) + except Exception: + pass + return None