Skip to content

Commit 971f76a

Browse files
authored
Merge pull request #1154 from ZhouBencheng/issue/1153
Issue/1153 add fused FFN operator and hardware-task mutual awareness analyzer
2 parents 4acc528 + 38d3dc3 commit 971f76a

42 files changed

Lines changed: 4735 additions & 8 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/infinicore/analyzer.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
// Convenience header — includes all analyzer components.
4+
5+
#include "analyzer/op_type.hpp"
6+
#include "analyzer/op_trace.hpp"
7+
#include "analyzer/optimization_intent.hpp"
8+
#include "analyzer/phase_detector.hpp"
9+
#include "analyzer/resource_sensor.hpp"
10+
#include "analyzer/intent_generator.hpp"
11+
#include "analyzer/mutual_awareness_analyzer.hpp"
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#pragma once
2+
3+
#include "optimization_intent.hpp"
4+
#include "op_trace.hpp"
5+
6+
#include <algorithm>
7+
#include <vector>
8+
9+
namespace infinicore::analyzer {
10+
11+
/// IntentGenerator — the core "mutual awareness" logic.
12+
///
13+
/// This is where task demand and resource supply are jointly
14+
/// analyzed to produce an OptimizationIntent. It implements
15+
/// the key insight: the same task phase has different optimization
16+
/// needs under different resource conditions, and the same resource
17+
/// state has different supply value under different task phases.
18+
class IntentGenerator {
19+
public:
20+
IntentGenerator() = default;
21+
22+
/// Generate the global semantic intent from phase detection
23+
/// result and op trace window.
24+
GlobalSemanticIntent generateGlobal(
25+
PhaseType phase,
26+
const std::vector<OpTraceEntry> &window,
27+
const std::vector<DeviceLocalIntent> &device_intents) const {
28+
29+
GlobalSemanticIntent intent;
30+
intent.current_phase = phase;
31+
intent.timestamp_ns = OpTraceEntry::now();
32+
33+
if (!window.empty()) {
34+
intent.op_window_start = 0;
35+
intent.op_window_end = static_cast<uint32_t>(window.size());
36+
}
37+
38+
// --- Compute intensity estimation ---
39+
intent.compute_intensity = estimateComputeIntensity(phase, window);
40+
41+
// --- Determine primary bottleneck (mutual awareness) ---
42+
intent.primary_bottleneck = determineGlobalBottleneck(phase, device_intents);
43+
44+
// --- Set optimization goal based on phase + bottleneck ---
45+
intent.goal = determineGoal(phase, intent.primary_bottleneck);
46+
47+
// --- Generate strategy hints ---
48+
intent.strategy = generateStrategy(phase, intent.primary_bottleneck, device_intents);
49+
50+
// --- Confidence ---
51+
intent.confidence = computeConfidence(phase, window);
52+
53+
return intent;
54+
}
55+
56+
/// Build the complete two-layer OptimizationIntent.
57+
OptimizationIntent generate(
58+
PhaseType phase,
59+
const std::vector<OpTraceEntry> &window,
60+
const std::vector<DeviceLocalIntent> &device_intents) const {
61+
62+
OptimizationIntent result;
63+
result.global = generateGlobal(phase, window, device_intents);
64+
result.per_device = device_intents;
65+
return result;
66+
}
67+
68+
private:
69+
/// Estimate compute intensity (higher = more compute-heavy).
70+
/// Uses a simple heuristic based on op type composition.
71+
float estimateComputeIntensity(
72+
PhaseType phase,
73+
const std::vector<OpTraceEntry> &window) const {
74+
75+
if (window.empty()) return 0.0f;
76+
77+
size_t heavy_compute_ops = 0;
78+
for (auto &e : window) {
79+
if (isGemmMlpOp(e.op_type) || isAttentionOp(e.op_type)) {
80+
heavy_compute_ops++;
81+
}
82+
}
83+
return static_cast<float>(heavy_compute_ops) / static_cast<float>(window.size());
84+
}
85+
86+
/// Determine global bottleneck by jointly considering phase and
87+
/// per-device resource state (the core mutual awareness logic).
88+
BottleneckType determineGlobalBottleneck(
89+
PhaseType phase,
90+
const std::vector<DeviceLocalIntent> &device_intents) const {
91+
92+
bool any_memory_bound = false;
93+
bool any_compute_bound = false;
94+
bool any_bandwidth_bound = false;
95+
bool any_communication_bound = false;
96+
for (auto &d : device_intents) {
97+
any_memory_bound = any_memory_bound || d.local_bottleneck == BottleneckType::MEMORY_BOUND;
98+
any_compute_bound = any_compute_bound || d.local_bottleneck == BottleneckType::COMPUTE_BOUND;
99+
any_bandwidth_bound = any_bandwidth_bound || d.local_bottleneck == BottleneckType::BANDWIDTH_BOUND;
100+
any_communication_bound = any_communication_bound || d.local_bottleneck == BottleneckType::COMMUNICATION_BOUND;
101+
}
102+
103+
// --- Mutual awareness logic ---
104+
// The same resource state has different "supply value" depending on phase:
105+
106+
if (any_memory_bound) {
107+
return BottleneckType::MEMORY_BOUND;
108+
}
109+
110+
if (phase == PhaseType::COMMUNICATION || any_communication_bound) {
111+
return BottleneckType::COMMUNICATION_BOUND;
112+
}
113+
114+
switch (phase) {
115+
case PhaseType::ATTENTION_DENSE:
116+
case PhaseType::PREFILL:
117+
// Attention/prefill is dominated by memory movement and KV access,
118+
// so phase semantics should win unless memory/communication already
119+
// forced an earlier return above.
120+
if (any_bandwidth_bound) {
121+
return BottleneckType::BANDWIDTH_BOUND;
122+
}
123+
return BottleneckType::BANDWIDTH_BOUND;
124+
125+
case PhaseType::GEMM_MLP_DENSE:
126+
if (any_compute_bound) {
127+
return BottleneckType::COMPUTE_BOUND;
128+
}
129+
if (any_bandwidth_bound) {
130+
return BottleneckType::BANDWIDTH_BOUND;
131+
}
132+
return BottleneckType::COMPUTE_BOUND;
133+
134+
case PhaseType::DECODE:
135+
if (any_bandwidth_bound) {
136+
return BottleneckType::BANDWIDTH_BOUND;
137+
}
138+
if (any_compute_bound) {
139+
return BottleneckType::COMPUTE_BOUND;
140+
}
141+
return BottleneckType::BANDWIDTH_BOUND;
142+
143+
case PhaseType::KV_CACHE:
144+
if (any_bandwidth_bound) {
145+
return BottleneckType::BANDWIDTH_BOUND;
146+
}
147+
return BottleneckType::MEMORY_BOUND;
148+
149+
default:
150+
if (any_bandwidth_bound) {
151+
return BottleneckType::BANDWIDTH_BOUND;
152+
}
153+
if (any_compute_bound) {
154+
return BottleneckType::COMPUTE_BOUND;
155+
}
156+
return BottleneckType::BALANCED;
157+
}
158+
}
159+
160+
/// Determine optimization goal based on phase and bottleneck.
161+
OptimizationGoal determineGoal(
162+
PhaseType phase,
163+
BottleneckType bottleneck) const {
164+
165+
// Under memory pressure, prioritize memory safety
166+
if (bottleneck == BottleneckType::MEMORY_BOUND) {
167+
return OptimizationGoal::MEMORY_SAFE;
168+
}
169+
170+
if (bottleneck == BottleneckType::COMMUNICATION_BOUND) {
171+
return OptimizationGoal::STABILITY_FIRST;
172+
}
173+
174+
switch (phase) {
175+
case PhaseType::DECODE:
176+
// Decode latency is user-facing → latency first
177+
return OptimizationGoal::LATENCY_FIRST;
178+
179+
case PhaseType::PREFILL:
180+
// Prefill processes a full prompt → throughput first
181+
return OptimizationGoal::THROUGHPUT_FIRST;
182+
183+
case PhaseType::ATTENTION_DENSE:
184+
return OptimizationGoal::LATENCY_FIRST;
185+
186+
case PhaseType::GEMM_MLP_DENSE:
187+
return OptimizationGoal::THROUGHPUT_FIRST;
188+
189+
default:
190+
return OptimizationGoal::LATENCY_FIRST;
191+
}
192+
}
193+
194+
/// Generate strategy hints from phase + bottleneck + resources.
195+
StrategyHint generateStrategy(
196+
PhaseType phase,
197+
BottleneckType bottleneck,
198+
const std::vector<DeviceLocalIntent> &device_intents) const {
199+
200+
StrategyHint hint;
201+
202+
// Fusion is beneficial for bandwidth-bound phases (reduce memory traffic)
203+
hint.prefer_fused_ops = (bottleneck == BottleneckType::BANDWIDTH_BOUND)
204+
|| phase == PhaseType::DECODE;
205+
206+
// In-place when memory is tight
207+
hint.prefer_in_place = (bottleneck == BottleneckType::MEMORY_BOUND);
208+
209+
// Recomputation (activation checkpointing) when memory is critical
210+
bool extreme_memory = false;
211+
for (auto &d : device_intents) {
212+
if (d.memory_usage_ratio >= 0.95f) {
213+
extreme_memory = true;
214+
break;
215+
}
216+
}
217+
hint.prefer_recomputation = extreme_memory;
218+
219+
// Async comm overlap for multi-device and communication phases
220+
hint.prefer_async_comm = (device_intents.size() > 1)
221+
&& (phase == PhaseType::GEMM_MLP_DENSE
222+
|| phase == PhaseType::COMMUNICATION);
223+
224+
return hint;
225+
}
226+
227+
/// Compute confidence based on how clear the phase signal is.
228+
float computeConfidence(
229+
PhaseType phase,
230+
const std::vector<OpTraceEntry> &window) const {
231+
232+
if (window.empty() || phase == PhaseType::UNKNOWN) {
233+
return 0.0f;
234+
}
235+
236+
// Count how many ops in the window match the detected phase
237+
size_t matching = 0;
238+
for (auto &e : window) {
239+
bool match = false;
240+
switch (phase) {
241+
case PhaseType::ATTENTION_DENSE:
242+
case PhaseType::PREFILL:
243+
match = isAttentionOp(e.op_type);
244+
break;
245+
case PhaseType::GEMM_MLP_DENSE:
246+
match = isGemmMlpOp(e.op_type) || isActivationOp(e.op_type);
247+
break;
248+
case PhaseType::KV_CACHE:
249+
match = isKvCacheOp(e.op_type);
250+
break;
251+
case PhaseType::DECODE:
252+
match = isAttentionOp(e.op_type) || isGemmMlpOp(e.op_type);
253+
break;
254+
default:
255+
break;
256+
}
257+
if (match) matching++;
258+
}
259+
260+
return static_cast<float>(matching) / static_cast<float>(window.size());
261+
}
262+
};
263+
264+
} // namespace infinicore::analyzer
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#pragma once
2+
3+
#include "intent_generator.hpp"
4+
#include "op_trace.hpp"
5+
#include "optimization_intent.hpp"
6+
#include "phase_detector.hpp"
7+
#include "resource_sensor.hpp"
8+
9+
#include <mutex>
10+
#include <vector>
11+
12+
namespace infinicore::analyzer {
13+
14+
/// MutualAwarenessAnalyzer — the top-level facade for the
15+
/// hardware-task mutual awareness requirements analysis module.
16+
///
17+
/// This is the primary entry point exposed to external frameworks
18+
/// (e.g., InfiniLM) via C++ function calls. It orchestrates:
19+
/// 1. Op trace collection (via OpTraceRing)
20+
/// 2. Phase detection (via PhaseDetector)
21+
/// 3. Resource sensing (via ResourceSensor)
22+
/// 4. Intent generation (via IntentGenerator)
23+
///
24+
/// Usage:
25+
/// auto& analyzer = MutualAwarenessAnalyzer::instance();
26+
/// // ... ops execute and get traced automatically ...
27+
/// auto intent = analyzer.analyze(); // Produces OptimizationIntent
28+
///
29+
/// Thread safety: analyze() is safe to call from any thread.
30+
/// The analyzer reads a snapshot of the op trace ring.
31+
class MutualAwarenessAnalyzer {
32+
public:
33+
/// Get the singleton instance.
34+
static MutualAwarenessAnalyzer &instance();
35+
36+
// Non-copyable, non-movable
37+
MutualAwarenessAnalyzer(const MutualAwarenessAnalyzer &) = delete;
38+
MutualAwarenessAnalyzer &operator=(const MutualAwarenessAnalyzer &) = delete;
39+
40+
/// Main analysis entry point.
41+
/// Analyzes the current op trace window + resource state
42+
/// and returns a complete OptimizationIntent.
43+
///
44+
/// This is the function InfiniLM should call.
45+
/// Latency: expected < 1ms for MVP rule-based analysis.
46+
OptimizationIntent analyze();
47+
48+
/// Analyze with explicitly provided memory stats per device.
49+
/// Use this when the caller can provide resource info directly.
50+
OptimizationIntent analyze(const std::vector<std::pair<int, MemoryStats>> &device_stats);
51+
52+
/// Analyze with explicitly provided device resource snapshots.
53+
/// This is the richer input path used by demand-analysis-oriented callers.
54+
OptimizationIntent analyze(const std::vector<DeviceResourceSnapshot> &device_snapshots);
55+
56+
/// Get the current phase without generating full intent.
57+
/// Lightweight query for simple use cases.
58+
PhaseType getCurrentPhase() const;
59+
60+
/// Get the current optimization goal derived from the
61+
/// latest analyzer result.
62+
OptimizationGoal getCurrentOptimizationGoal() const;
63+
64+
/// Get the most recent OptimizationIntent (cached from last analyze()).
65+
const OptimizationIntent &lastIntent() const;
66+
67+
/// Access the underlying components for configuration.
68+
PhaseDetector &phaseDetector() { return phase_detector_; }
69+
ResourceSensor &resourceSensor() { return resource_sensor_; }
70+
OpTraceRing &opTrace() { return getGlobalOpTrace(); }
71+
72+
/// Enable / disable the analyzer.
73+
/// When disabled, analyze() returns a default intent and
74+
/// op trace recording is skipped.
75+
void setEnabled(bool enabled) { enabled_ = enabled; }
76+
bool isEnabled() const { return enabled_; }
77+
78+
/// Graph recording support: when graph recording stops,
79+
/// analyze the recorded op sequence once and cache the result.
80+
/// Subsequent calls return the cached intent without re-analysis.
81+
void onGraphRecordingStop();
82+
void clearGraphCache();
83+
84+
private:
85+
MutualAwarenessAnalyzer();
86+
87+
PhaseDetector phase_detector_;
88+
ResourceSensor resource_sensor_;
89+
IntentGenerator intent_generator_;
90+
91+
OptimizationIntent last_intent_;
92+
mutable std::mutex mutex_;
93+
94+
bool enabled_ = true;
95+
96+
// Graph recording cache
97+
bool graph_intent_cached_ = false;
98+
OptimizationIntent graph_cached_intent_;
99+
};
100+
101+
// ============================================================
102+
// C-style API for external framework integration (e.g., InfiniLM)
103+
// ============================================================
104+
105+
/// Analyze current state and return an OptimizationIntent.
106+
/// This is the simplest API for external frameworks to call.
107+
OptimizationIntent analyzeCurrentState();
108+
109+
/// Get the current detected phase.
110+
PhaseType getCurrentPhase();
111+
112+
/// Get the current optimization goal.
113+
OptimizationGoal getCurrentOptimizationGoal();
114+
115+
/// Enable / disable the mutual awareness analyzer.
116+
void setAnalyzerEnabled(bool enabled);
117+
118+
} // namespace infinicore::analyzer

0 commit comments

Comments
 (0)