Skip to content

Commit f0b088f

Browse files
committed
Let GPU samples be cached and fix bogus frame 0 references
1 parent 0649863 commit f0b088f

5 files changed

Lines changed: 50 additions & 37 deletions

File tree

interpreter/gpu/cuda.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ var (
3434
)
3535

3636
// SymbolizedCudaTrace holds a symbolized trace awaiting GPU timing information.
37-
// The CPU frames are already symbolized; only the CUDA kernel frame (frame[0])
37+
// The CPU frames are already symbolized; only the CUDA kernel frame
3838
// needs the kernel name from the timing event.
3939
type SymbolizedCudaTrace struct {
4040
Trace *libpf.Trace
4141
Meta *samples.TraceEventMeta
42+
CUDAFrameIdx int // index of CUDAKernelFrame in Trace.Frames
4243
CorrelationID uint32
4344
CBID int32
4445
}
@@ -60,9 +61,9 @@ type CudaTraceOutput struct {
6061
// that launched the kernel."
6162
type gpuTraceFixer struct {
6263
mu sync.Mutex
63-
timesAwaitingTraces map[uint32][]CuptiTimingEvent // keyed by correlation ID
64-
tracesAwaitingTimes map[uint32]*SymbolizedCudaTrace // keyed by correlation ID
65-
maxCorrelationId uint32 // track highest ID for threshold-based clearing
64+
timesAwaitingTraces map[uint32][]CuptiTimingEvent // keyed by correlation ID
65+
tracesAwaitingTimes map[uint32]*SymbolizedCudaTrace // keyed by correlation ID
66+
maxCorrelationId uint32 // track highest ID for threshold-based clearing
6667
}
6768

6869
type data struct {
@@ -343,7 +344,7 @@ func (f *gpuTraceFixer) prepTrace(st *SymbolizedCudaTrace, ev *CuptiTimingEvent)
343344
out.Trace.CustomLabels["cuda_id"] = strconv.FormatUint(uint64(ev.Id), 10)
344345
}
345346

346-
// Extract kernel name from timing event, demangle, and update frame[0]
347+
// Extract kernel name from timing event, demangle, and update the CUDA frame.
347348
nameBytes := ev.KernelName[:]
348349
if idx := bytes.IndexByte(nameBytes, 0); idx >= 0 {
349350
nameBytes = nameBytes[:idx]
@@ -356,11 +357,10 @@ func (f *gpuTraceFixer) prepTrace(st *SymbolizedCudaTrace, ev *CuptiTimingEvent)
356357
funcName = libpf.Intern(demStr)
357358
}
358359

359-
currentFrame := out.Trace.Frames[0].Value()
360-
out.Trace.Frames[0] = unique.Make(libpf.Frame{
361-
Type: currentFrame.Type,
362-
AddressOrLineno: currentFrame.AddressOrLineno,
363-
FunctionName: funcName,
360+
fi := st.CUDAFrameIdx
361+
out.Trace.Frames[fi] = unique.Make(libpf.Frame{
362+
Type: out.Trace.Frames[fi].Value().Type,
363+
FunctionName: funcName,
364364
})
365365
}
366366

@@ -375,7 +375,7 @@ func AddTrace(st *SymbolizedCudaTrace) []CudaTraceOutput {
375375
pid := st.Meta.PID
376376
value, ok := gpuFixers.Load(pid)
377377
if !ok {
378-
log.Warnf("no GPU fixer found for PID %d", pid)
378+
log.Warnf("no GPU fixer found for PID %d in AddTrace", pid)
379379
return nil
380380
}
381381
fixer := value.(*gpuTraceFixer)
@@ -387,7 +387,7 @@ func addTimeSingle(ev *CuptiTimingEvent) (CudaTraceOutput, bool) {
387387
pid := libpf.PID(ev.Pid)
388388
value, ok := gpuFixers.Load(pid)
389389
if !ok {
390-
log.Warnf("no GPU fixer found for PID %d", pid)
390+
log.Warnf("no GPU fixer found for PID %d in AddTime", pid)
391391
return CudaTraceOutput{}, false
392392
}
393393
fixer := value.(*gpuTraceFixer)
@@ -409,6 +409,7 @@ func AddTimes(events []CuptiTimingEvent) []CudaTraceOutput {
409409
pid := libpf.PID(events[0].Pid)
410410
value, ok := gpuFixers.Load(pid)
411411
if !ok {
412+
log.Warnf("no GPU fixer found for PID %d in AddTimes", pid)
412413
return nil
413414
}
414415
fixer := value.(*gpuTraceFixer)
@@ -448,6 +449,7 @@ func MaybeClearAll() []metrics.Metric {
448449
totalTraces += stats.tracesLen
449450
totalTimesCleared += stats.timesCleared
450451
totalTracesCleared += stats.tracesCleared
452+
451453
return true
452454
})
453455

parcagpu/parcagpu.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func Start(ctx context.Context, tr *tracer.Tracer,
7373
// avoiding duplicate-metric warnings from the metrics system.
7474
metrics.AddSlice(gpu.MaybeClearAll())
7575
case <-ctx.Done():
76+
eventReader.Close()
7677
return
7778
default:
7879
if err := eventReader.ReadInto(&data); err != nil {
@@ -105,7 +106,19 @@ func Start(ctx context.Context, tr *tracer.Tracer,
105106
return false
106107
}
107108

108-
// Find the CUDA kernel frame in the symbolized trace
109+
// Extract correlation ID and CBID from the raw BPF trace (not the
110+
// symbolized trace, which may be a cached template with stale values).
111+
var correlationID uint32
112+
var cbid int32
113+
for i := range rawTrace.Frames {
114+
if rawTrace.Frames[i].Type == libpf.CUDAKernelFrame {
115+
correlationID = uint32(rawTrace.Frames[i].Lineno)
116+
cbid = int32(rawTrace.Frames[i].Lineno >> 32)
117+
break
118+
}
119+
}
120+
121+
// Find the CUDA kernel frame index in the symbolized trace.
109122
cudaFrameIdx := -1
110123
for i, uniqueFrame := range trace.Frames {
111124
if uniqueFrame.Value().Type == libpf.CUDAKernelFrame {
@@ -118,13 +131,10 @@ func Start(ctx context.Context, tr *tracer.Tracer,
118131
return false
119132
}
120133

121-
frame := trace.Frames[cudaFrameIdx].Value()
122-
correlationID := uint32(frame.AddressOrLineno)
123-
cbid := int32(frame.AddressOrLineno >> 32)
124-
125134
st := &gpu.SymbolizedCudaTrace{
126135
Trace: trace,
127136
Meta: meta,
137+
CUDAFrameIdx: cudaFrameIdx,
128138
CorrelationID: correlationID,
129139
CBID: cbid,
130140
}

tracehandler/tracehandler.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ func (m *traceHandler) HandleTrace(bpfTrace *host.Trace) {
146146
if trace, exists := m.traceCache.GetAndRefresh(bpfTrace.Hash,
147147
traceCacheLifetime); exists {
148148
m.traceCacheHit++
149-
// Fast path
149+
// Fast path: interceptor gets cached template + fresh rawTrace.
150+
if m.interceptor != nil && m.interceptor(&trace, meta, bpfTrace) {
151+
return
152+
}
150153
meta.APMServiceName = m.traceProcessor.MaybeNotifyAPMAgent(bpfTrace, trace.Hash, 1)
151154
if err := m.reporter.ReportTraceEvent(&trace, meta); err != nil {
152155
log.Errorf("Failed to report trace event: %v", err)
@@ -163,14 +166,13 @@ func (m *traceHandler) HandleTrace(bpfTrace *host.Trace) {
163166
}
164167
log.Debugf("Trace hash remap 0x%x -> 0x%x", bpfTrace.Hash, umTrace.Hash)
165168

166-
// If an interceptor is set and consumes the trace, skip caching and reporting.
167-
// CUDA traces are always intercepted here and never cached.
169+
m.traceCache.Add(bpfTrace.Hash, *umTrace)
170+
171+
// If an interceptor consumes the trace, skip reporting.
168172
if m.interceptor != nil && m.interceptor(umTrace, meta, bpfTrace) {
169173
return
170174
}
171175

172-
m.traceCache.Add(bpfTrace.Hash, *umTrace)
173-
174176
meta.APMServiceName = m.traceProcessor.MaybeNotifyAPMAgent(bpfTrace, umTrace.Hash, 1)
175177
if err := m.reporter.ReportTraceEvent(umTrace, meta); err != nil {
176178
log.Errorf("Failed to report trace event: %v", err)

tracehandler/tracehandler_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,15 @@ func TestTraceInterceptor(t *testing.T) {
110110
libpf.NewTraceHash(2, 2): 1,
111111
},
112112
},
113-
"consumed trace not cached on repeat": {
113+
"consumed on both cache miss and hit": {
114114
interceptReturn: map[host.TraceHash]bool{
115115
host.TraceHash(0xcc): true,
116116
},
117117
input: []arguments{
118118
{trace: &host.Trace{Hash: host.TraceHash(0xcc)}},
119119
{trace: &host.Trace{Hash: host.TraceHash(0xcc)}},
120120
},
121-
expectedEvents: nil, // both consumed, nothing reported
121+
expectedEvents: nil, // first miss + second hit, both consumed
122122
},
123123
}
124124

tracer/tracer.go

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -934,19 +934,6 @@ func (t *Tracer) loadBpfTrace(raw []byte, cpu int) *host.Trace {
934934
}
935935
}
936936

937-
// Trace fields included in the hash:
938-
// - PID, kernel stack ID, length & frame array
939-
// Intentionally excluded:
940-
// - ktime, COMM, APM trace, APM transaction ID, Origin and Off Time
941-
ptr.Comm = [16]byte{}
942-
ptr.Apm_trace_id = support.ApmTraceID{}
943-
ptr.Apm_transaction_id = support.ApmSpanID{}
944-
ptr.Ktime = 0
945-
ptr.Origin = 0
946-
ptr.Offtime = 0
947-
ptr.Custom_labels = support.CustomLabelsArray{}
948-
trace.Hash = host.TraceHash(xxh3.Hash128(raw).Lo)
949-
950937
if ptr.Kernel_stack_id >= 0 {
951938
var err error
952939
trace.KernelFrames, err = t.readKernelFrames(ptr.Kernel_stack_id)
@@ -955,6 +942,9 @@ func (t *Tracer) loadBpfTrace(raw []byte, cpu int) *host.Trace {
955942
}
956943
}
957944

945+
// Build host.Trace frames BEFORE zeroing per-sample fields, because
946+
// ZeroPerSampleFields clears the CUDA frame's Addr_or_line which carries
947+
// the correlation ID needed by the GPU interceptor.
958948
trace.Frames = make([]host.Frame, ptr.Stack_len)
959949
for i := 0; i < int(ptr.Stack_len); i++ {
960950
rawFrame := &ptr.Frames[i]
@@ -967,6 +957,15 @@ func (t *Tracer) loadBpfTrace(raw []byte, cpu int) *host.Trace {
967957
LJCallerPC: uint32(rawFrame.Caller_pc_lo) + (uint32(rawFrame.Caller_pc_hi) << 16),
968958
}
969959
}
960+
961+
// Trace fields included in the hash:
962+
// - PID, kernel stack ID, length & frame array
963+
// Intentionally excluded:
964+
// - ktime, COMM, APM trace, APM transaction ID, Origin and Off Time
965+
// - CUDA frame addr_or_line (correlation ID is per-sample, not trace identity)
966+
ptr.ZeroPerSampleFields()
967+
trace.Hash = host.TraceHash(xxh3.Hash128(raw).Lo)
968+
970969
return trace
971970
}
972971

0 commit comments

Comments
 (0)