Skip to content

Commit 799eaa3

Browse files
Merge pull request #84 from peachest/fix/grpc-register-after-serve
Fix device plugin grpc properly stop and restart after kubelet restart
2 parents a27ce4a + 73996f2 commit 799eaa3

3 files changed

Lines changed: 242 additions & 4 deletions

File tree

internal/server/register.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
/*
2+
* Copyright 2024 The HAMi Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
117
package server
218

319
import (
@@ -19,6 +35,8 @@ import (
1935
)
2036

2137
func (ps *PluginServer) watchAndRegister() {
38+
ps.wg.Add(1)
39+
defer ps.wg.Done()
2240
timer := time.After(1 * time.Second)
2341
for {
2442
select {
@@ -110,6 +128,9 @@ func (ps *PluginServer) getDeviceNetworkID(idx int, deviceType string) (int, err
110128
}
111129

112130
func (ps *PluginServer) registerKubelet() error {
131+
if ps.registerKubeletFunc != nil {
132+
return ps.registerKubeletFunc()
133+
}
113134
conn, err := ps.dial(v1beta1.KubeletSocket, 5*time.Second)
114135
if err != nil {
115136
return err
@@ -135,6 +156,9 @@ func (ps *PluginServer) registerKubelet() error {
135156
}
136157

137158
func (ps *PluginServer) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
159+
if ps.dialFunc != nil {
160+
return ps.dialFunc(unixSocketPath, timeout)
161+
}
138162
ctx, cancel := context.WithTimeout(context.Background(), timeout)
139163
defer cancel()
140164
c, _ := grpc.NewClient(unixSocketPath,

internal/server/server.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"net"
2424
"os"
2525
"path"
26+
"sync"
2627
"time"
2728

2829
"google.golang.org/grpc"
@@ -64,6 +65,12 @@ type PluginServer struct {
6465
stopCh chan interface{}
6566
healthCh chan int32
6667
checkIdleVNPUInterval int
68+
wg sync.WaitGroup
69+
70+
// test hooks — injected by tests to avoid real socket/kubelet dependencies
71+
dialFunc func(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error)
72+
registerKubeletFunc func() error
73+
prepareHostResourcesFunc func() error
6774
}
6875

6976
type RuntimeInfo struct {
@@ -82,7 +89,6 @@ func NewPluginServer(mgr manager.Manager, nodeName string, checkIdleVNPUInterval
8289
handshakeAnno: fmt.Sprintf("hami.io/node-handshake-%s", commonWord),
8390
allocAnno: fmt.Sprintf("huawei.com/%s", commonWord),
8491
toAllocDeviceAnno: fmt.Sprintf("hami.io/%s-devices-to-allocate", commonWord),
85-
grpcServer: grpc.NewServer(),
8692
mgr: mgr,
8793
socket: path.Join(v1beta1.DevicePluginPath, fmt.Sprintf("%s.sock", commonWord)),
8894
stopCh: make(chan interface{}),
@@ -94,14 +100,25 @@ func NewPluginServer(mgr manager.Manager, nodeName string, checkIdleVNPUInterval
94100
return server, nil
95101
}
96102

103+
// prepareHostResources wraps the package-level prepareHostResources() to
104+
// allow test injection via the prepareHostResourcesFunc hook.
105+
func (ps *PluginServer) prepareHostResources() error {
106+
if ps.prepareHostResourcesFunc != nil {
107+
return ps.prepareHostResourcesFunc()
108+
}
109+
return prepareHostResources()
110+
}
111+
97112
func (ps *PluginServer) Start() error {
98113
// Automatically prepare host environment when the plugin starts
99-
if err := prepareHostResources(); err != nil {
114+
if err := ps.prepareHostResources(); err != nil {
100115
klog.Errorf("Failed to prepare host resources: %v. vNPU core functionality will be impaired.", err)
101116
return err
102117
}
103118

104119
ps.stopCh = make(chan interface{})
120+
ps.grpcServer = grpc.NewServer()
121+
105122
err := ps.mgr.UpdateDevice()
106123
if err != nil {
107124
return err
@@ -120,6 +137,8 @@ func (ps *PluginServer) Start() error {
120137
}
121138

122139
func (ps *PluginServer) startPeriodicCheckIdleVNPUs() {
140+
ps.wg.Add(1)
141+
defer ps.wg.Done()
123142
ticker := time.NewTicker(time.Duration(ps.checkIdleVNPUInterval) * time.Second)
124143
defer ticker.Stop()
125144
for {
@@ -137,8 +156,19 @@ func (ps *PluginServer) startPeriodicCheckIdleVNPUs() {
137156
}
138157

139158
func (ps *PluginServer) Stop() error {
140-
close(ps.stopCh)
141-
ps.grpcServer.Stop()
159+
if ps.stopCh != nil {
160+
select {
161+
case <-ps.stopCh:
162+
// already closed; no-op
163+
default:
164+
close(ps.stopCh)
165+
}
166+
}
167+
if ps.grpcServer != nil {
168+
ps.grpcServer.Stop()
169+
}
170+
ps.wg.Wait()
171+
_ = os.Remove(ps.socket)
142172
return nil
143173
}
144174

@@ -158,7 +188,9 @@ func (ps *PluginServer) serve() error {
158188
}
159189
v1beta1.RegisterDevicePluginServer(ps.grpcServer, ps)
160190
resourceName := ps.mgr.ResourceName()
191+
ps.wg.Add(1)
161192
go func() {
193+
defer ps.wg.Done()
162194
lastCrashTime := time.Now()
163195
restartCount := 0
164196
for {

internal/server/server_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"os"
24+
"path"
2325
"strings"
2426
"testing"
2527

28+
"google.golang.org/grpc/grpclog"
2629
v1 "k8s.io/api/core/v1"
2730
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2831
"k8s.io/client-go/kubernetes/fake"
@@ -762,3 +765,182 @@ func TestCleanupIdleVNPUs(t *testing.T) {
762765
})
763766
}
764767
}
768+
769+
// ============================================================================
770+
// gRPC restart tests
771+
// ============================================================================
772+
773+
// panicOnFatalLogger is a gRPC logger that converts Fatalf calls to panics.
774+
// This allows tests to verify that gRPC does NOT call Fatalf (which would
775+
// otherwise call os.Exit(1) and abort the test process).
776+
//
777+
// Usage:
778+
//
779+
// defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
780+
// grpclog.SetLoggerV2(newPanicOnFatalLogger())
781+
type panicOnFatalLogger struct {
782+
inner grpclog.LoggerV2
783+
}
784+
785+
func newPanicOnFatalLogger() *panicOnFatalLogger {
786+
return &panicOnFatalLogger{
787+
inner: grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr),
788+
}
789+
}
790+
791+
var _ grpclog.LoggerV2 = (*panicOnFatalLogger)(nil)
792+
793+
func (l *panicOnFatalLogger) Info(args ...interface{}) { l.inner.Info(args...) }
794+
func (l *panicOnFatalLogger) Infoln(args ...interface{}) { l.inner.Infoln(args...) }
795+
func (l *panicOnFatalLogger) Infof(format string, args ...interface{}) {
796+
l.inner.Infof(format, args...)
797+
}
798+
func (l *panicOnFatalLogger) Warning(args ...interface{}) { l.inner.Warning(args...) }
799+
func (l *panicOnFatalLogger) Warningln(args ...interface{}) { l.inner.Warningln(args...) }
800+
func (l *panicOnFatalLogger) Warningf(format string, args ...interface{}) {
801+
l.inner.Warningf(format, args...)
802+
}
803+
func (l *panicOnFatalLogger) Error(args ...interface{}) { l.inner.Error(args...) }
804+
func (l *panicOnFatalLogger) Errorln(args ...interface{}) { l.inner.Errorln(args...) }
805+
func (l *panicOnFatalLogger) Errorf(format string, args ...interface{}) {
806+
l.inner.Errorf(format, args...)
807+
}
808+
func (l *panicOnFatalLogger) V(level int) bool { return l.inner.V(level) }
809+
810+
func (l *panicOnFatalLogger) Fatalf(format string, args ...interface{}) {
811+
panic(fmt.Sprintf("grpc FATAL: "+format, args...))
812+
}
813+
814+
func (l *panicOnFatalLogger) Fatalln(args ...interface{}) {
815+
panic(fmt.Sprintf("grpc FATAL: %v", fmt.Sprintln(args...)))
816+
}
817+
818+
func (l *panicOnFatalLogger) Fatal(args ...interface{}) {
819+
panic(fmt.Sprintf("grpc FATAL: %v", fmt.Sprint(args...)))
820+
}
821+
822+
// setupRestartablePluginServer creates a PluginServer with all test hooks
823+
// injected so that Start()/Stop() work without real socket files or a kubelet.
824+
func setupRestartablePluginServer(t *testing.T) *PluginServer {
825+
t.Helper()
826+
827+
ps := &PluginServer{
828+
commonWord: "test-ascend",
829+
registerAnno: "hami.io/node-register-test-ascend",
830+
handshakeAnno: "hami.io/node-handshake-test-ascend",
831+
allocAnno: "huawei.com/test-ascend",
832+
toAllocDeviceAnno: "hami.io/test-ascend-devices-to-allocate",
833+
mgr: &FakeManager{ResourceNameFunc: func() string { return "test-ascend" }},
834+
socket: path.Join(t.TempDir(), "test-ascend.sock"),
835+
stopCh: make(chan interface{}),
836+
healthCh: make(chan int32),
837+
checkIdleVNPUInterval: 3600,
838+
dialFunc: nil,
839+
registerKubeletFunc: func() error {
840+
return nil
841+
},
842+
prepareHostResourcesFunc: func() error {
843+
return nil
844+
},
845+
}
846+
return ps
847+
}
848+
849+
// TestGrpcServer_RestartDoesNotPanic verifies that a single Stop+Start cycle
850+
// does not trigger the gRPC "RegisterService after Serve" fatal error.
851+
func TestGrpcServer_RestartDoesNotPanic(t *testing.T) {
852+
defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
853+
grpclog.SetLoggerV2(newPanicOnFatalLogger())
854+
855+
ps := setupRestartablePluginServer(t)
856+
857+
// First Start
858+
if err := ps.Start(); err != nil {
859+
t.Fatalf("first Start() failed: %v", err)
860+
}
861+
862+
// Stop
863+
if err := ps.Stop(); err != nil {
864+
t.Fatalf("Stop() failed: %v", err)
865+
}
866+
867+
// Second Start — this must not trigger grpc Fatalf
868+
if err := ps.Start(); err != nil {
869+
t.Fatalf("second Start() after restart failed: %v", err)
870+
}
871+
872+
// Cleanup
873+
if err := ps.Stop(); err != nil {
874+
t.Fatalf("final Stop() failed: %v", err)
875+
}
876+
}
877+
878+
// TestGrpcServer_MultipleRestarts verifies that the server can survive
879+
// multiple Stop+Start cycles without panic.
880+
func TestGrpcServer_MultipleRestarts(t *testing.T) {
881+
defer grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr))
882+
grpclog.SetLoggerV2(newPanicOnFatalLogger())
883+
884+
ps := setupRestartablePluginServer(t)
885+
886+
for i := 0; i < 5; i++ {
887+
if err := ps.Start(); err != nil {
888+
t.Fatalf("Start() iteration %d failed: %v", i, err)
889+
}
890+
if err := ps.Stop(); err != nil {
891+
t.Fatalf("Stop() iteration %d failed: %v", i, err)
892+
}
893+
}
894+
}
895+
896+
// TestGrpcServer_StopWithoutStart verifies that Stop() is safe when
897+
// Start() was never called (no goroutines to wait for).
898+
func TestGrpcServer_StopWithoutStart(t *testing.T) {
899+
ps := setupRestartablePluginServer(t)
900+
if err := ps.Stop(); err != nil {
901+
t.Fatalf("Stop() without Start() should be safe: %v", err)
902+
}
903+
}
904+
905+
// TestGrpcServer_DoubleStop verifies that calling Stop() twice is safe.
906+
func TestGrpcServer_DoubleStop(t *testing.T) {
907+
ps := setupRestartablePluginServer(t)
908+
909+
if err := ps.Start(); err != nil {
910+
t.Fatalf("Start() failed: %v", err)
911+
}
912+
913+
if err := ps.Stop(); err != nil {
914+
t.Fatalf("first Stop() failed: %v", err)
915+
}
916+
917+
if err := ps.Stop(); err != nil {
918+
t.Fatalf("second Stop() should be safe: %v", err)
919+
}
920+
}
921+
922+
// TestGrpcServer_StopWaitForAllGoroutines verifies that Stop() returns
923+
// only after all goroutines have exited. We verify this indirectly by
924+
// checking that Start() after Stop() does not race (goroutine leak would
925+
// manifest as stale channel reads).
926+
func TestGrpcServer_StopWaitForAllGoroutines(t *testing.T) {
927+
ps := setupRestartablePluginServer(t)
928+
929+
if err := ps.Start(); err != nil {
930+
t.Fatalf("Start() failed: %v", err)
931+
}
932+
933+
if err := ps.Stop(); err != nil {
934+
t.Fatalf("Stop() failed: %v", err)
935+
}
936+
937+
// bgWG and serveWG should be zero after Stop() returns.
938+
// A new cycle confirms there is no deadlock or hang.
939+
if err := ps.Start(); err != nil {
940+
t.Fatalf("Start() after Stop() failed: %v", err)
941+
}
942+
943+
if err := ps.Stop(); err != nil {
944+
t.Fatalf("Stop() failed: %v", err)
945+
}
946+
}

0 commit comments

Comments
 (0)