Skip to content

Commit c47a270

Browse files
committed
Add auto-parallelism for RDS refresh dump and restore
Resolve optimal -j values automatically: use EC2 DescribeInstanceTypes to determine vCPU count of the RDS clone (for pg_dump parallelism) and runtime.NumCPU for the local machine (for pg_restore parallelism). Pass both values through the existing ConfigProjection when updating DBLab config during refresh. https://claude.ai/code/session_01AhnBVCBWjk24T7BBQtmkbq
1 parent 7631aab commit c47a270

File tree

7 files changed

+417
-21
lines changed

7 files changed

+417
-21
lines changed

engine/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/aws/aws-sdk-go v1.55.8
1010
github.com/aws/aws-sdk-go-v2 v1.41.5
1111
github.com/aws/aws-sdk-go-v2/config v1.32.14
12+
github.com/aws/aws-sdk-go-v2/service/ec2 v1.297.0
1213
github.com/aws/aws-sdk-go-v2/service/rds v1.117.1
1314
github.com/containerd/errdefs v1.0.0
1415
github.com/docker/cli v28.5.2+incompatible

engine/go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgq
2828
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
2929
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
3030
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
31+
github.com/aws/aws-sdk-go-v2/service/ec2 v1.297.0 h1:A+7NViqbMUCoTQFWjbSXdbzE4K5Ziu2zWJtZzAusm+A=
32+
github.com/aws/aws-sdk-go-v2/service/ec2 v1.297.0/go.mod h1:R+2BNtUfTfhPY0RH18oL02q116bakeBWjanrbnVBqkM=
3133
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
3234
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
3335
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=

engine/internal/rdsrefresh/dblab.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ type SourceConfigUpdate struct {
180180
Password string
181181
// RDSIAMDBInstance is the RDS DB instance identifier for IAM auth. When empty, this field is omitted from the config update.
182182
RDSIAMDBInstance string
183+
// DumpParallelJobs sets the -j flag for pg_dump. When zero, the existing value is preserved.
184+
DumpParallelJobs int
185+
// RestoreParallelJobs sets the -j flag for pg_restore. When zero, the existing value is preserved.
186+
RestoreParallelJobs int
183187
}
184188

185189
// UpdateSourceConfig updates the source database connection in DBLab config.
@@ -198,6 +202,16 @@ func (c *DBLabClient) UpdateSourceConfig(ctx context.Context, update SourceConfi
198202
proj.RDSIAMDBInstance = &update.RDSIAMDBInstance
199203
}
200204

205+
if update.DumpParallelJobs > 0 {
206+
dumpJobs := int64(update.DumpParallelJobs)
207+
proj.DumpParallelJobs = &dumpJobs
208+
}
209+
210+
if update.RestoreParallelJobs > 0 {
211+
restoreJobs := int64(update.RestoreParallelJobs)
212+
proj.RestoreParallelJobs = &restoreJobs
213+
}
214+
201215
nested := map[string]interface{}{}
202216

203217
// defensive error check: StoreJSON only fails if target is not an addressable struct,

engine/internal/rdsrefresh/dblab_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,70 @@ func TestDBLabClientUpdateSourceConfig(t *testing.T) {
191191
assert.Nil(t, receivedConfig.RDSIAMDBInstance)
192192
})
193193

194+
t.Run("successful with parallelism settings", func(t *testing.T) {
195+
var receivedConfig models.ConfigProjection
196+
197+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
198+
var nested map[string]interface{}
199+
err := json.NewDecoder(r.Body).Decode(&nested)
200+
require.NoError(t, err)
201+
202+
err = projection.LoadJSON(&receivedConfig, nested, projection.LoadOptions{
203+
Groups: []string{"default", "sensitive"},
204+
})
205+
require.NoError(t, err)
206+
207+
w.WriteHeader(http.StatusOK)
208+
}))
209+
defer server.Close()
210+
211+
client, err := NewDBLabClient(&DBLabConfig{APIEndpoint: server.URL, Token: "test-token"})
212+
require.NoError(t, err)
213+
214+
err = client.UpdateSourceConfig(context.Background(), SourceConfigUpdate{
215+
Host: "clone-host.rds.amazonaws.com", Port: 5432, DBName: "postgres",
216+
Username: "dbuser", Password: "dbpass",
217+
DumpParallelJobs: 4, RestoreParallelJobs: 8,
218+
})
219+
require.NoError(t, err)
220+
221+
require.NotNil(t, receivedConfig.DumpParallelJobs)
222+
assert.Equal(t, int64(4), *receivedConfig.DumpParallelJobs)
223+
require.NotNil(t, receivedConfig.RestoreParallelJobs)
224+
assert.Equal(t, int64(8), *receivedConfig.RestoreParallelJobs)
225+
})
226+
227+
t.Run("omits parallelism when zero", func(t *testing.T) {
228+
var receivedConfig models.ConfigProjection
229+
230+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
231+
var nested map[string]interface{}
232+
err := json.NewDecoder(r.Body).Decode(&nested)
233+
require.NoError(t, err)
234+
235+
err = projection.LoadJSON(&receivedConfig, nested, projection.LoadOptions{
236+
Groups: []string{"default", "sensitive"},
237+
})
238+
require.NoError(t, err)
239+
240+
w.WriteHeader(http.StatusOK)
241+
}))
242+
defer server.Close()
243+
244+
client, err := NewDBLabClient(&DBLabConfig{APIEndpoint: server.URL, Token: "test-token"})
245+
require.NoError(t, err)
246+
247+
err = client.UpdateSourceConfig(context.Background(), SourceConfigUpdate{
248+
Host: "host.rds.amazonaws.com", Port: 5432, DBName: "postgres",
249+
Username: "dbuser", Password: "dbpass",
250+
DumpParallelJobs: 0, RestoreParallelJobs: 0,
251+
})
252+
require.NoError(t, err)
253+
254+
assert.Nil(t, receivedConfig.DumpParallelJobs)
255+
assert.Nil(t, receivedConfig.RestoreParallelJobs)
256+
})
257+
194258
t.Run("error on non-2xx status", func(t *testing.T) {
195259
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
196260
w.WriteHeader(http.StatusBadRequest)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
2026 © PostgresAI
3+
*/
4+
5+
package rdsrefresh
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"runtime"
11+
"strings"
12+
13+
"github.com/aws/aws-sdk-go-v2/aws"
14+
awsconfig "github.com/aws/aws-sdk-go-v2/config"
15+
"github.com/aws/aws-sdk-go-v2/service/ec2"
16+
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
17+
18+
"gitlab.com/postgres-ai/database-lab/v3/pkg/log"
19+
)
20+
21+
const (
22+
// rdsInstanceClassPrefix is stripped to derive the EC2 instance type.
23+
rdsInstanceClassPrefix = "db."
24+
25+
// minParallelJobs is the minimum parallelism level.
26+
minParallelJobs = 1
27+
)
28+
29+
// EC2API defines the interface for EC2 client operations used for vCPU lookup.
30+
type EC2API interface {
31+
DescribeInstanceTypes(ctx context.Context, params *ec2.DescribeInstanceTypesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error)
32+
}
33+
34+
// ParallelismConfig holds the computed parallelism levels for dump and restore.
35+
type ParallelismConfig struct {
36+
DumpJobs int
37+
RestoreJobs int
38+
}
39+
40+
// ResolveParallelism determines the optimal parallelism levels for pg_dump and pg_restore.
41+
// dump parallelism is based on the vCPU count of the RDS clone instance class.
42+
// restore parallelism is based on the vCPU count of the local machine.
43+
func ResolveParallelism(ctx context.Context, cfg *Config) (*ParallelismConfig, error) {
44+
dumpJobs, err := resolveRDSInstanceVCPUs(ctx, cfg)
45+
if err != nil {
46+
return nil, fmt.Errorf("failed to resolve RDS instance vCPUs: %w", err)
47+
}
48+
49+
restoreJobs := resolveLocalVCPUs()
50+
51+
log.Msg("auto-parallelism: dump jobs =", dumpJobs, "(RDS clone vCPUs), restore jobs =", restoreJobs, "(local vCPUs)")
52+
53+
return &ParallelismConfig{
54+
DumpJobs: dumpJobs,
55+
RestoreJobs: restoreJobs,
56+
}, nil
57+
}
58+
59+
// resolveRDSInstanceVCPUs looks up the vCPU count for the configured RDS instance class
60+
// by querying the EC2 DescribeInstanceTypes API.
61+
func resolveRDSInstanceVCPUs(ctx context.Context, cfg *Config) (int, error) {
62+
ec2Client, err := newEC2Client(ctx, cfg)
63+
if err != nil {
64+
return 0, fmt.Errorf("failed to create EC2 client: %w", err)
65+
}
66+
67+
return lookupInstanceVCPUs(ctx, ec2Client, cfg.RDSClone.InstanceClass)
68+
}
69+
70+
func newEC2Client(ctx context.Context, cfg *Config) (EC2API, error) {
71+
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(cfg.AWS.Region))
72+
if err != nil {
73+
return nil, fmt.Errorf("failed to load AWS config: %w", err)
74+
}
75+
76+
var opts []func(*ec2.Options)
77+
if cfg.AWS.Endpoint != "" {
78+
opts = append(opts, func(o *ec2.Options) {
79+
o.BaseEndpoint = aws.String(cfg.AWS.Endpoint)
80+
})
81+
}
82+
83+
return ec2.NewFromConfig(awsCfg, opts...), nil
84+
}
85+
86+
// lookupInstanceVCPUs queries EC2 for the vCPU count of the given RDS instance class.
87+
func lookupInstanceVCPUs(ctx context.Context, client EC2API, rdsInstanceClass string) (int, error) {
88+
ec2InstanceType, err := rdsClassToEC2Type(rdsInstanceClass)
89+
if err != nil {
90+
return 0, err
91+
}
92+
93+
result, err := client.DescribeInstanceTypes(ctx, &ec2.DescribeInstanceTypesInput{
94+
InstanceTypes: []ec2types.InstanceType{ec2types.InstanceType(ec2InstanceType)},
95+
})
96+
if err != nil {
97+
return 0, fmt.Errorf("failed to describe EC2 instance type %q: %w", ec2InstanceType, err)
98+
}
99+
100+
if len(result.InstanceTypes) == 0 {
101+
return 0, fmt.Errorf("EC2 instance type %q not found", ec2InstanceType)
102+
}
103+
104+
info := result.InstanceTypes[0]
105+
if info.VCpuInfo == nil || info.VCpuInfo.DefaultVCpus == nil {
106+
return 0, fmt.Errorf("vCPU info not available for instance type %q", ec2InstanceType)
107+
}
108+
109+
vcpus := int(*info.VCpuInfo.DefaultVCpus)
110+
if vcpus < minParallelJobs {
111+
return minParallelJobs, nil
112+
}
113+
114+
return vcpus, nil
115+
}
116+
117+
// rdsClassToEC2Type converts an RDS instance class (e.g. "db.m5.xlarge") to an EC2 instance type ("m5.xlarge").
118+
func rdsClassToEC2Type(rdsClass string) (string, error) {
119+
if !strings.HasPrefix(rdsClass, rdsInstanceClassPrefix) {
120+
return "", fmt.Errorf("invalid RDS instance class %q: expected %q prefix", rdsClass, rdsInstanceClassPrefix)
121+
}
122+
123+
ec2Type := strings.TrimPrefix(rdsClass, rdsInstanceClassPrefix)
124+
if ec2Type == "" {
125+
return "", fmt.Errorf("invalid RDS instance class %q: empty after removing prefix", rdsClass)
126+
}
127+
128+
return ec2Type, nil
129+
}
130+
131+
// resolveLocalVCPUs returns the number of logical CPUs available on the local machine.
132+
func resolveLocalVCPUs() int {
133+
cpus := runtime.NumCPU()
134+
if cpus < minParallelJobs {
135+
return minParallelJobs
136+
}
137+
138+
return cpus
139+
}

0 commit comments

Comments
 (0)