diff --git a/configs/configs.go b/configs/configs.go index e1178d8d..de6561bb 100644 --- a/configs/configs.go +++ b/configs/configs.go @@ -6,6 +6,7 @@ import ( "io" "os" "path/filepath" + "sync" "time" "github.com/xinliangnote/go-gin-api/pkg/env" @@ -16,6 +17,7 @@ import ( ) var config = new(Config) +var configLock sync.RWMutex type Config struct { MySQL struct { @@ -128,6 +130,8 @@ func init() { viper.WatchConfig() viper.OnConfigChange(func(e fsnotify.Event) { + configLock.Lock() + defer configLock.Unlock() if err := viper.Unmarshal(config); err != nil { panic(err) } @@ -135,5 +139,7 @@ func init() { } func Get() Config { + configLock.RLock() + defer configLock.RUnlock() return *config } diff --git a/configs/configs_test.go b/configs/configs_test.go new file mode 100644 index 00000000..42dcc4a3 --- /dev/null +++ b/configs/configs_test.go @@ -0,0 +1,32 @@ +package configs + +import ( + "testing" + + "github.com/spf13/viper" +) + +func TestRace(t *testing.T) { + done := make(chan bool) + + go func() { + for i := 0; i < 100; i++ { + Get() + } + done <- true + }() + + go func() { + for i := 0; i < 100; i++ { + configLock.Lock() + if err := viper.Unmarshal(config); err != nil { + // ignore error + } + configLock.Unlock() + } + done <- true + }() + + <-done + <-done +} diff --git a/pkg/env/env.go b/pkg/env/env.go index 615bb8cd..a79860bb 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -3,6 +3,7 @@ package env import ( "flag" "fmt" + "os" "strings" ) @@ -53,6 +54,7 @@ func (e *environment) IsPro() bool { func (e *environment) t() {} func init() { + flag.CommandLine.Init(os.Args[0], flag.ContinueOnError) env := flag.String("env", "", "请输入运行环境:\n dev:开发环境\n fat:测试环境\n uat:预上线环境\n pro:正式环境\n") flag.Parse()