diff --git a/pkg/roles/dhcp/dhcp_handler4_discover.go b/pkg/roles/dhcp/dhcp_handler4_discover.go index 16a94ea3e..a0e4a8485 100644 --- a/pkg/roles/dhcp/dhcp_handler4_discover.go +++ b/pkg/roles/dhcp/dhcp_handler4_discover.go @@ -18,10 +18,16 @@ func (r *Role) HandleDHCPDiscover4(req *Request4) *dhcpv4.DHCPv4 { if match == nil { return nil } - err := match.Put(req.Context, int64(r.cfg.LeaseNegotiateTimeout)) + match, created, err := r.CreateLeaseIfAbsent(req.Context, match, int64(r.cfg.LeaseNegotiateTimeout)) if err != nil { req.log.Warn("failed to update lease during discover creation", zap.Error(err)) } + if match == nil { + return nil + } + if !created { + r.ensureLeaseScope(req, match) + } } else { err := match.Put(req.Context, match.scope.TTL) if err != nil { diff --git a/pkg/roles/dhcp/dhcp_handler4_discover_internal_test.go b/pkg/roles/dhcp/dhcp_handler4_discover_internal_test.go new file mode 100644 index 000000000..0fe1c0a21 --- /dev/null +++ b/pkg/roles/dhcp/dhcp_handler4_discover_internal_test.go @@ -0,0 +1,208 @@ +package dhcp + +import ( + "context" + "testing" + "time" + + "beryju.io/gravity/pkg/roles/dhcp/types" + "beryju.io/gravity/pkg/storage/watcher" + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/api/v3/mvccpb" +) + +func TestDHCPDiscover_ReusesExistingLeaseWithoutDowngradingTTL(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyScopes, + "test", + ).String(), + mustJSON(Scope{ + SubnetCIDR: "10.100.0.0/24", + Default: true, + TTL: 86400, + IPAM: map[string]string{ + "type": "internal", + "range_start": "10.100.0.100", + "range_end": "10.100.0.250", + }, + }), + )) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{ + Port: 0, + LeaseNegotiateTimeout: 30, + })))) + defer role.Stop() + + scope, ok := role.scopes.GetPrefix("test") + assert.True(t, ok) + assert.NotNil(t, scope) + + lease := role.NewLease("b2:b7:86:2c:d3:fa") + lease.scope = scope + lease.ScopeKey = scope.Name + lease.Address = "10.100.0.100" + panicIfError(lease.Put(ctx, 3600)) + + assert.Eventually(t, func() bool { + match, ok := role.leases.GetPrefix(lease.Identifier) + return ok && match != nil && match.Address == lease.Address + }, time.Second, 10*time.Millisecond) + + role.leases = watcher.New( + func(kv *mvccpb.KeyValue) (*Lease, error) { + return role.leaseFromKV(kv) + }, + inst.KV(), + inst.KV().Key( + types.KeyRole, + types.KeyLeases, + ).Prefix(true), + ) + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeDiscover)) + + req4 := role.NewRequest4(req) + res := role.HandleDHCPDiscover4(req4) + assert.NotNil(t, res) + assert.Equal(t, lease.Address, res.YourIPAddr.String()) + + stored := role.FindLeaseInStore(req4) + assert.NotNil(t, stored) + assert.Greater(t, stored.Expiry, time.Now().Add(5*time.Minute).Unix()) +} + +func TestDHCPDiscover_ReturnsNilWhenNoScopeMatches(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{Port: 0})))) + defer role.Stop() + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeDiscover)) + + req4 := role.NewRequest4(req) + assert.Nil(t, role.HandleDHCPDiscover4(req4)) +} + +func TestDHCPDiscover_ReturnsNilWhenCreateLeaseFails(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyScopes, + "test", + ).String(), + mustJSON(Scope{ + SubnetCIDR: "10.100.0.0/24", + Default: true, + TTL: 86400, + IPAM: map[string]string{ + "type": "internal", + "range_start": "10.100.0.100", + "range_end": "10.100.0.250", + }, + }), + )) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{ + Port: 0, + LeaseNegotiateTimeout: 30, + })))) + defer role.Stop() + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeDiscover)) + + req4 := role.NewRequest4(req) + cancelledCtx, cancel := context.WithCancel(req4.Context) + cancel() + req4.Context = cancelledCtx + + assert.Nil(t, role.HandleDHCPDiscover4(req4)) +} + +func TestDHCPDiscover_ReturnsOfferWhenExistingLeaseRefreshFails(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyScopes, + "test", + ).String(), + mustJSON(Scope{ + SubnetCIDR: "10.100.0.0/24", + Default: true, + TTL: 86400, + IPAM: map[string]string{ + "type": "internal", + "range_start": "10.100.0.100", + "range_end": "10.100.0.250", + }, + }), + )) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{ + Port: 0, + LeaseNegotiateTimeout: 30, + })))) + defer role.Stop() + + scope, ok := role.scopes.GetPrefix("test") + assert.True(t, ok) + assert.NotNil(t, scope) + + lease := role.NewLease("b2:b7:86:2c:d3:fa") + lease.scope = scope + lease.ScopeKey = scope.Name + lease.Address = "10.100.0.100" + panicIfError(lease.Put(ctx, 3600)) + + assert.Eventually(t, func() bool { + match, ok := role.leases.GetPrefix(lease.Identifier) + return ok && match != nil && match.Address == lease.Address + }, time.Second, 10*time.Millisecond) + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeDiscover)) + + req4 := role.NewRequest4(req) + cancelledCtx, cancel := context.WithCancel(req4.Context) + cancel() + req4.Context = cancelledCtx + + res := role.HandleDHCPDiscover4(req4) + assert.NotNil(t, res) + assert.Equal(t, lease.Address, res.YourIPAddr.String()) +} diff --git a/pkg/roles/dhcp/dhcp_handler4_request.go b/pkg/roles/dhcp/dhcp_handler4_request.go index c5ceeab95..e27bb14f1 100644 --- a/pkg/roles/dhcp/dhcp_handler4_request.go +++ b/pkg/roles/dhcp/dhcp_handler4_request.go @@ -18,6 +18,15 @@ func (r *Role) HandleDHCPRequest4(req *Request4) *dhcpv4.DHCPv4 { if match == nil { return nil } + match, _, err := r.CreateLeaseIfAbsent(req.Context, match, match.scope.TTL) + if err != nil { + req.log.Warn("failed to create dhcp lease", zap.Error(err)) + return nil + } + if match == nil { + return nil + } + r.ensureLeaseScope(req, match) } err := match.Put(req.Context, match.scope.TTL) diff --git a/pkg/roles/dhcp/dhcp_handler4_request_internal_test.go b/pkg/roles/dhcp/dhcp_handler4_request_internal_test.go new file mode 100644 index 000000000..7c4e71f25 --- /dev/null +++ b/pkg/roles/dhcp/dhcp_handler4_request_internal_test.go @@ -0,0 +1,174 @@ +package dhcp + +import ( + "context" + "net" + "testing" + "time" + + "beryju.io/gravity/pkg/roles/dhcp/types" + "beryju.io/gravity/pkg/storage/watcher" + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/api/v3/mvccpb" +) + +func TestDHCPRequest_ReusesStoredLeaseAndReassignsScopeOnWatcherMiss(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyScopes, + "test", + ).String(), + mustJSON(Scope{ + SubnetCIDR: "10.100.0.0/24", + Default: true, + TTL: 86400, + IPAM: map[string]string{ + "type": "internal", + "range_start": "10.100.0.100", + "range_end": "10.100.0.250", + }, + }), + )) + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyScopes, + "test2", + ).String(), + mustJSON(Scope{ + SubnetCIDR: "10.200.0.0/24", + TTL: 86400, + IPAM: map[string]string{ + "type": "internal", + "range_start": "10.200.0.100", + "range_end": "10.200.0.250", + }, + }), + )) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{ + Port: 0, + LeaseNegotiateTimeout: 30, + })))) + defer role.Stop() + + scope, ok := role.scopes.GetPrefix("test") + assert.True(t, ok) + assert.NotNil(t, scope) + + lease := role.NewLease("b2:b7:86:2c:d3:fa") + lease.scope = scope + lease.ScopeKey = scope.Name + lease.Address = "10.100.0.100" + panicIfError(lease.Put(ctx, 3600)) + + assert.Eventually(t, func() bool { + match, ok := role.leases.GetPrefix(lease.Identifier) + return ok && match != nil && match.Address == lease.Address + }, time.Second, 10*time.Millisecond) + + role.leases = watcher.New( + func(kv *mvccpb.KeyValue) (*Lease, error) { + return role.leaseFromKV(kv) + }, + inst.KV(), + inst.KV().Key( + types.KeyRole, + types.KeyLeases, + ).Prefix(true), + ) + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + GatewayIPAddr: net.ParseIP("10.200.0.1"), + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeRequest)) + + req4 := role.NewRequest4(req) + res := role.HandleDHCPRequest4(req4) + assert.NotNil(t, res) + assert.Equal(t, "10.200.0.100", res.YourIPAddr.String()) + + storedKey := inst.KV().Key( + types.KeyRole, + types.KeyLeases, + lease.Identifier, + ) + storedResp, err := inst.KV().Get(ctx, storedKey.String()) + panicIfError(err) + assert.Len(t, storedResp.Kvs, 1) + + stored, err := role.leaseFromKV(storedResp.Kvs[0]) + panicIfError(err) + assert.Equal(t, "test2", stored.ScopeKey) + assert.Equal(t, "10.200.0.100", stored.Address) + assert.Greater(t, stored.Expiry, time.Now().Add(time.Hour).Unix()) +} + +func TestDHCPRequest_ReturnsNilWhenNoScopeMatches(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{Port: 0})))) + defer role.Stop() + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeRequest)) + + req4 := role.NewRequest4(req) + assert.Nil(t, role.HandleDHCPRequest4(req4)) +} + +func TestDHCPRequest_ReturnsNilWhenCreateLeaseFails(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyScopes, + "test", + ).String(), + mustJSON(Scope{ + SubnetCIDR: "10.100.0.0/24", + Default: true, + TTL: 86400, + IPAM: map[string]string{ + "type": "internal", + "range_start": "10.100.0.100", + "range_end": "10.100.0.250", + }, + }), + )) + + panicIfError(role.Start(ctx, []byte(mustJSON(RoleConfig{Port: 0})))) + defer role.Stop() + + req := &dhcpv4.DHCPv4{ + OpCode: dhcpv4.OpcodeBootRequest, + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + } + req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeRequest)) + + req4 := role.NewRequest4(req) + cancelledCtx, cancel := context.WithCancel(req4.Context) + cancel() + req4.Context = cancelledCtx + + assert.Nil(t, role.HandleDHCPRequest4(req4)) +} diff --git a/pkg/roles/dhcp/internal_test_helpers_test.go b/pkg/roles/dhcp/internal_test_helpers_test.go new file mode 100644 index 000000000..b66a9020d --- /dev/null +++ b/pkg/roles/dhcp/internal_test_helpers_test.go @@ -0,0 +1,95 @@ +package dhcp + +import ( + "context" + "encoding/json" + "testing" + + "beryju.io/gravity/pkg/extconfig" + "beryju.io/gravity/pkg/roles" + "beryju.io/gravity/pkg/storage" + "github.com/getsentry/sentry-go" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +type dhcpTestMigrator struct{} + +func (dhcpTestMigrator) AddMigration(roles.Migration) {} + +func (dhcpTestMigrator) Run(context.Context) (*storage.Client, error) { + return extconfig.Get().EtcdClient(), nil +} + +type dhcpTestInstance struct { + ctx context.Context + log *zap.Logger + kv *storage.Client +} + +func newDHCPTestInstance(ctx context.Context) *dhcpTestInstance { + return &dhcpTestInstance{ + ctx: ctx, + log: extconfig.Get().Logger().Named("role.dhcp.test"), + kv: extconfig.Get().EtcdClient(), + } +} + +func (i *dhcpTestInstance) KV() *storage.Client { + return i.kv +} + +func (i *dhcpTestInstance) Log() *zap.Logger { + return i.log +} + +func (i *dhcpTestInstance) DispatchEvent(string, *roles.Event) {} + +func (i *dhcpTestInstance) AddEventListener(string, roles.EventHandler) {} + +func (i *dhcpTestInstance) Context() context.Context { + return i.ctx +} + +func (i *dhcpTestInstance) ExecuteHook(roles.HookOptions, ...interface{}) interface{} { + return nil +} + +func (i *dhcpTestInstance) Migrator() roles.RoleMigrator { + return dhcpTestMigrator{} +} + +func setupDHCPInternalTest(t testing.TB) context.Context { + t.Helper() + + ctx, cancel := context.WithCancel(t.Context()) + tx := sentry.StartTransaction(ctx, "test") + + _, err := extconfig.Get().EtcdClient().Delete(tx.Context(), "/", clientv3.WithPrefix()) + if err != nil { + t.Fatalf("failed to reset etcd: %v", err) + } + + t.Cleanup(func() { + tx.Finish() + cancel() + }) + + return tx.Context() +} + +func mustJSON(v any) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + return string(raw) +} + +func panicIfError(args ...any) { + for _, arg := range args { + if err, ok := arg.(error); ok && err != nil { + panic(err) + } + } +} diff --git a/pkg/roles/dhcp/leases.go b/pkg/roles/dhcp/leases.go index b6fef7fb4..fb3f1dc96 100644 --- a/pkg/roles/dhcp/leases.go +++ b/pkg/roles/dhcp/leases.go @@ -47,6 +47,10 @@ func (r *Role) FindLease(req *Request4) *Lease { if !ok { return nil } + return r.ensureLeaseScope(req, lease) +} + +func (r *Role) ensureLeaseScope(req *Request4, lease *Lease) *Lease { // Check if the leases's scope matches the expected scope to handle this request expectedScope := r.findScopeForRequest(req) if expectedScope != nil && lease.scope != expectedScope { @@ -63,6 +67,28 @@ func (r *Role) FindLease(req *Request4) *Lease { return lease } +func (r *Role) FindLeaseInStore(req *Request4) *Lease { + leaseKey := r.i.KV().Key( + types.KeyRole, + types.KeyLeases, + r.DeviceIdentifier(req.DHCPv4), + ) + res, err := r.i.KV().Get(req.Context, leaseKey.String()) + if err != nil { + r.log.Warn("failed to fetch lease from store", zap.Error(err)) + return nil + } + if len(res.Kvs) < 1 { + return nil + } + lease, err := r.leaseFromKV(res.Kvs[0]) + if err != nil { + r.log.Warn("failed to parse lease from store", zap.Error(err)) + return nil + } + return r.ensureLeaseScope(req, lease) +} + func (r *Role) NewLease(identifier string) *Lease { return &Lease{ inst: r.i, @@ -142,7 +168,7 @@ func (l *Lease) Put(ctx context.Context, expiry int64, opts ...clientv3.OpOption opts = append(opts, clientv3.WithLease(exp.ID)) } - raw, err := json.Marshal(&l) + raw, err := json.Marshal(l) if err != nil { return err } @@ -162,6 +188,73 @@ func (l *Lease) Put(ctx context.Context, expiry int64, opts ...clientv3.OpOption return err } + l.afterPut(ctx, expiry, opts...) + return nil +} + +func (r *Role) CreateLeaseIfAbsent(ctx context.Context, lease *Lease, expiry int64) (*Lease, bool, error) { + opts := []clientv3.OpOption{} + var leaseGrant *clientv3.LeaseGrantResponse + var err error + if expiry > 0 && !lease.IsReservation() { + lease.Expiry = time.Now().Add(time.Duration(expiry) * time.Second).Unix() + + leaseGrant, err = lease.inst.KV().Grant(ctx, expiry) + if err != nil { + return nil, false, err + } + opts = append(opts, clientv3.WithLease(leaseGrant.ID)) + } + + raw, err := json.Marshal(lease) + if err != nil { + return nil, false, err + } + + leaseKey := lease.inst.KV().Key( + types.KeyRole, + types.KeyLeases, + lease.Identifier, + ) + res, err := lease.inst.KV().Txn(ctx). + If(clientv3.Compare(clientv3.CreateRevision(leaseKey.String()), "=", 0)). + Then(clientv3.OpPut(leaseKey.String(), string(raw), opts...)). + Else(clientv3.OpGet(leaseKey.String())). + Commit() + if err != nil { + return nil, false, err + } + if res.Succeeded { + lease.afterPut(ctx, expiry, opts...) + return lease, true, nil + } + if leaseGrant != nil { + _, err := lease.inst.KV().Revoke(ctx, leaseGrant.ID) + if err != nil { + lease.log.Warn("failed to revoke unused lease grant", zap.Error(err)) + } + } + if len(res.Responses) < 1 { + return nil, false, nil + } + rangeResp := res.Responses[0].GetResponseRange() + if rangeResp == nil || len(rangeResp.Kvs) < 1 { + return nil, false, nil + } + existing, err := r.leaseFromKV(rangeResp.Kvs[0]) + if err != nil { + return nil, false, err + } + return existing, false, nil +} + +func (l *Lease) afterPut(ctx context.Context, expiry int64, opts ...clientv3.OpOption) { + leaseKey := l.inst.KV().Key( + types.KeyRole, + types.KeyLeases, + l.Identifier, + ) + var zone string if l.scope != nil && l.scope.DNS != nil { zone = l.scope.DNS.Zone @@ -185,7 +278,6 @@ func (l *Lease) Put(ctx context.Context, expiry int64, opts ...clientv3.OpOption l.log.Debug("put lease", zap.Int64("expiry", expiry)) go l.scope.calculateUsage() - return nil } func (l *Lease) createReply(req *Request4) *dhcpv4.DHCPv4 { diff --git a/pkg/roles/dhcp/leases_internal_test.go b/pkg/roles/dhcp/leases_internal_test.go new file mode 100644 index 000000000..613f428c7 --- /dev/null +++ b/pkg/roles/dhcp/leases_internal_test.go @@ -0,0 +1,64 @@ +package dhcp + +import ( + "context" + "errors" + "testing" + + "beryju.io/gravity/pkg/roles/dhcp/types" + "beryju.io/gravity/pkg/storage" + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" +) + +func TestFindLeaseInStore_GetError(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + inst.kv = inst.KV().WithHooks(storage.StorageHook{ + GetPre: func(context.Context, string, ...clientv3.OpOption) error { + return errors.New("boom") + }, + }) + role := New(inst) + + req := role.NewRequest4(&dhcpv4.DHCPv4{ + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + }) + + assert.Nil(t, role.FindLeaseInStore(req)) +} + +func TestFindLeaseInStore_EmptyResult(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + req := role.NewRequest4(&dhcpv4.DHCPv4{ + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + }) + + assert.Nil(t, role.FindLeaseInStore(req)) +} + +func TestFindLeaseInStore_ParseError(t *testing.T) { + ctx := setupDHCPInternalTest(t) + inst := newDHCPTestInstance(ctx) + role := New(inst) + + panicIfError(inst.KV().Put( + ctx, + inst.KV().Key( + types.KeyRole, + types.KeyLeases, + "b2:b7:86:2c:d3:fa", + ).String(), + "{", + )) + + req := role.NewRequest4(&dhcpv4.DHCPv4{ + ClientHWAddr: []byte{0xb2, 0xb7, 0x86, 0x2c, 0xd3, 0xfa}, + }) + + assert.Nil(t, role.FindLeaseInStore(req)) +}