Skip to content

Commit aeb7bde

Browse files
authored
Fix data race and cache corruption in GetExternalMetric (#1200)
* Fix data race and cache corruption in GetExternalMetric Only applicable when `--external-metric-cache-ttl` is enabled. Previously, direct pointers to these objects were returned, which allows downstream handlers (e.g. transformResponseObject in k8s.io/apiserver) to mutate values in the cache, which can lead to a data race if two goroutines attempt to mutate it at the same time. This fix uses a deep-copy instead, which allows the caller to mutate their copy without impacting the cached copy. (Internal Google bug: http://b/511329740)
1 parent a2a3a79 commit aeb7bde

3 files changed

Lines changed: 141 additions & 2 deletions

File tree

custom-metrics-stackdriver-adapter/Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ push: docker
2222
test: $(PKG)
2323
CGO_ENABLED=0 go test -mod readonly ./...
2424

25+
test-race: $(PKG)
26+
CGO_ENABLED=1 go test -race -mod readonly ./...
27+
2528
clean:
2629
rm -rf build
2730

custom-metrics-stackdriver-adapter/pkg/adapter/provider/cache.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ func (c *externalMetricsCache) get(key cacheKey) (*external_metrics.ExternalMetr
5959
return nil, false
6060
}
6161

62-
return metrics, true
62+
return metrics.DeepCopy(), true
6363
}
6464

6565
func (c *externalMetricsCache) add(key cacheKey, value *external_metrics.ExternalMetricValueList) {
6666
if c.cache == nil {
6767
return
6868
}
69-
c.cache.Add(key, value, c.ttl)
69+
c.cache.Add(key, value.DeepCopy(), c.ttl)
7070
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
"time"
8+
9+
"github.com/GoogleCloudPlatform/k8s-stackdriver/custom-metrics-stackdriver-adapter/pkg/adapter/translator"
10+
"github.com/GoogleCloudPlatform/k8s-stackdriver/custom-metrics-stackdriver-adapter/pkg/config"
11+
sd "google.golang.org/api/monitoring/v3"
12+
"k8s.io/apimachinery/pkg/labels"
13+
"sigs.k8s.io/custom-metrics-apiserver/pkg/provider"
14+
)
15+
16+
func TestExternalMetricCache_PreventMutation(t *testing.T) {
17+
projectID := testProjectID // newMockExternalMetricRoundTripper hardcodes this value.
18+
shortMetricName := "testmetric"
19+
fullMetricName := "external.googleapis.com/" + shortMetricName
20+
mockRoundTripper := newMockExternalMetricRoundTripper(
21+
fullMetricName,
22+
labels.Everything(),
23+
&sd.ListTimeSeriesResponse{
24+
TimeSeries: []*sd.TimeSeries{
25+
{
26+
Metric: &sd.Metric{
27+
Type: fullMetricName,
28+
},
29+
Points: []*sd.Point{
30+
{
31+
Value: &sd.TypedValue{Int64Value: new(int64(100))},
32+
Interval: &sd.TimeInterval{
33+
EndTime: time.Now().Format(time.RFC3339Nano),
34+
},
35+
},
36+
},
37+
Resource: &sd.MonitoredResource{},
38+
},
39+
},
40+
},
41+
)
42+
mockSDService := translator.NewMockStackdriverService(t, mockRoundTripper)
43+
fakeTranslator := newFakeTranslator(t, mockSDService)
44+
p := &StackdriverProvider{
45+
stackdriverService: mockSDService,
46+
config: &config.GceConfig{Project: projectID},
47+
translator: fakeTranslator,
48+
}
49+
p.externalMetricsCache = newExternalMetricsCache(1, time.Minute)
50+
51+
// First call: Cache Miss, fetches from SD, stores in cache.
52+
namespace := "default"
53+
selector := labels.SelectorFromSet(labels.Set{"resource.labels.project_id": projectID})
54+
metricInfo := provider.ExternalMetricInfo{Metric: shortMetricName}
55+
resp1, err := p.GetExternalMetric(context.Background(), namespace, selector, metricInfo)
56+
if err != nil {
57+
t.Fatalf("First GetExternalMetric failed: %v", err)
58+
}
59+
60+
if len(resp1.Items) != 1 {
61+
t.Fatalf("Expected 1 item, got %d", len(resp1.Items))
62+
}
63+
64+
// Mutate the returned response (simulating apiserver mutation)
65+
resp1.Items[0].MetricName = "MUTATED"
66+
67+
// Second call: Cache Hit, should return a copy with the MetricName intact
68+
resp2, err := p.GetExternalMetric(context.Background(), namespace, selector, metricInfo)
69+
if err != nil {
70+
t.Fatalf("Second GetExternalMetric failed: %v", err)
71+
}
72+
73+
if resp2.Items[0].MetricName != shortMetricName {
74+
t.Errorf("GetExternalMetric().Items[0].MetricName = %q, want %q", resp2.Items[0].MetricName, shortMetricName)
75+
}
76+
}
77+
78+
func TestExternalMetricCache_ConcurrentAccess(t *testing.T) {
79+
projectID := testProjectID // newMockExternalMetricRoundTripper hardcodes this value.
80+
shortMetricName := "testmetric"
81+
fullMetricName := "external.googleapis.com/" + shortMetricName
82+
mockRoundTripper := newMockExternalMetricRoundTripper(
83+
fullMetricName,
84+
labels.Everything(),
85+
&sd.ListTimeSeriesResponse{
86+
TimeSeries: []*sd.TimeSeries{
87+
{
88+
Metric: &sd.Metric{
89+
Type: fullMetricName,
90+
},
91+
Points: []*sd.Point{
92+
{
93+
Value: &sd.TypedValue{Int64Value: new(int64(100))},
94+
Interval: &sd.TimeInterval{
95+
EndTime: time.Now().Format(time.RFC3339Nano),
96+
},
97+
},
98+
},
99+
Resource: &sd.MonitoredResource{},
100+
},
101+
},
102+
},
103+
)
104+
mockSDService := translator.NewMockStackdriverService(t, mockRoundTripper)
105+
fakeTranslator := newFakeTranslator(t, mockSDService)
106+
p := &StackdriverProvider{
107+
stackdriverService: mockSDService,
108+
config: &config.GceConfig{Project: projectID},
109+
translator: fakeTranslator,
110+
}
111+
p.externalMetricsCache = newExternalMetricsCache(1, time.Minute)
112+
113+
// Warm up cache
114+
namespace := "default"
115+
selector := labels.SelectorFromSet(labels.Set{"resource.labels.project_id": projectID})
116+
metricInfo := provider.ExternalMetricInfo{Metric: shortMetricName}
117+
_, err := p.GetExternalMetric(context.Background(), namespace, selector, metricInfo)
118+
if err != nil {
119+
t.Fatalf("Failed to warm up cache: %v", err)
120+
}
121+
122+
// Concurrently access the same value from multiple goroutines to check for data races
123+
var wg sync.WaitGroup
124+
for i := 0; i < 10; i++ {
125+
wg.Add(1)
126+
go func() {
127+
defer wg.Done()
128+
resp, err := p.GetExternalMetric(context.Background(), namespace, selector, metricInfo)
129+
if err == nil && len(resp.Items) > 0 {
130+
// Mutate concurrently. If cache returns same pointer, this will race under -race
131+
resp.Items[0].MetricName = "MUTATED_CONCURRENT"
132+
}
133+
}()
134+
}
135+
wg.Wait()
136+
}

0 commit comments

Comments
 (0)