11package internal
22
33import (
4+ "context"
45 "crypto/aes"
56 "crypto/cipher"
67 "crypto/rand"
78 "crypto/sha256"
8- "encoding/base64"
99 "errors"
1010 "fmt"
1111 "io"
1212
13+ "kubecloud/internal/logger"
14+ "kubecloud/models"
15+
16+ "sync"
17+
1318 "golang.org/x/crypto/argon2"
1419)
1520
@@ -56,41 +61,36 @@ func (cm *CryptoManager) deriveKey(passphrase string, userIdentifier string) ([]
5661 return key , nil
5762}
5863
59- func (cm * CryptoManager ) encrypt (plainText string , key []byte ) (string , error ) {
64+ func (cm * CryptoManager ) encrypt (plainText string , key []byte ) ([] byte , error ) {
6065 if len (key ) != 32 {
61- return "" , ErrInvalidKeyLength
66+ return nil , ErrInvalidKeyLength
6267 }
6368
6469 block , err := aes .NewCipher (key )
6570 if err != nil {
66- return "" , fmt .Errorf ("failed to create cipher: %w" , err )
71+ return nil , fmt .Errorf ("failed to create cipher: %w" , err )
6772 }
6873
6974 aesGCM , err := cipher .NewGCM (block )
7075 if err != nil {
71- return "" , fmt .Errorf ("failed to create GCM: %w" , err )
76+ return nil , fmt .Errorf ("failed to create GCM: %w" , err )
7277 }
7378
7479 nonce := make ([]byte , aesGCM .NonceSize ())
7580 if _ , err = io .ReadFull (rand .Reader , nonce ); err != nil {
76- return "" , fmt .Errorf ("failed to generate nonce: %w" , err )
81+ return nil , fmt .Errorf ("failed to generate nonce: %w" , err )
7782 }
7883
7984 ciphertext := aesGCM .Seal (nonce , nonce , []byte (plainText ), nil )
8085
81- return base64 . StdEncoding . EncodeToString ( ciphertext ) , nil
86+ return ciphertext , nil
8287}
8388
84- func (cm * CryptoManager ) decrypt (encryptedText string , key []byte ) (string , error ) {
89+ func (cm * CryptoManager ) decrypt (encryptedBytes [] byte , key []byte ) (string , error ) {
8590 if len (key ) != 32 {
8691 return "" , ErrInvalidKeyLength
8792 }
8893
89- ciphertext , err := base64 .StdEncoding .DecodeString (encryptedText )
90- if err != nil {
91- return "" , fmt .Errorf ("failed to decode base64: %w" , err )
92- }
93-
9494 block , err := aes .NewCipher (key )
9595 if err != nil {
9696 return "" , fmt .Errorf ("failed to create cipher: %w" , err )
@@ -102,11 +102,11 @@ func (cm *CryptoManager) decrypt(encryptedText string, key []byte) (string, erro
102102 }
103103
104104 nonceSize := aesGCM .NonceSize ()
105- if len (ciphertext ) < nonceSize {
105+ if len (encryptedBytes ) < nonceSize {
106106 return "" , ErrInvalidData
107107 }
108108
109- nonce , ciphertext := ciphertext [:nonceSize ], ciphertext [nonceSize :]
109+ nonce , ciphertext := encryptedBytes [:nonceSize ], encryptedBytes [nonceSize :]
110110
111111 plaintext , err := aesGCM .Open (nil , nonce , ciphertext , nil )
112112 if err != nil {
@@ -124,18 +124,73 @@ func (cm *CryptoManager) getMnemonicKey(userAddress string) ([]byte, error) {
124124 return cm .deriveKey (cm .config .MnemonicEncryptionPassphrase , userAddress )
125125}
126126
127- func (cm * CryptoManager ) EncryptMnemonic (plainText string , userAddress string ) (string , error ) {
127+ func (cm * CryptoManager ) EncryptMnemonic (plainText string , userAddress string ) ([] byte , error ) {
128128 key , err := cm .getMnemonicKey (userAddress )
129129 if err != nil {
130- return "" , err
130+ return nil , err
131131 }
132132 return cm .encrypt (plainText , key )
133133}
134134
135- func (cm * CryptoManager ) DecryptMnemonic (encryptedText string , userAddress string ) (string , error ) {
135+ func (cm * CryptoManager ) DecryptMnemonic (encryptedBytes [] byte , userAddress string ) (string , error ) {
136136 key , err := cm .getMnemonicKey (userAddress )
137137 if err != nil {
138138 return "" , err
139139 }
140- return cm .decrypt (encryptedText , key )
140+ return cm .decrypt (encryptedBytes , key )
141+ }
142+
143+ func (cm * CryptoManager ) EnsureMnemonicsEncrypted (ctx context.Context , db models.DB ) error {
144+ users , err := db .ListAllUsers ()
145+ if err != nil {
146+ return fmt .Errorf ("ensure encryption: list users failed: %w" , err )
147+ }
148+
149+ const maxWorkers = 16
150+ sem := make (chan struct {}, maxWorkers )
151+ var wg sync.WaitGroup
152+
153+ for i := range users {
154+ u := users [i ]
155+ wg .Add (1 )
156+ sem <- struct {}{}
157+ go func (u models.User ) {
158+ defer wg .Done ()
159+ defer func () { <- sem }()
160+
161+ select {
162+ case <- ctx .Done ():
163+ return
164+ default :
165+ }
166+
167+ if len (u .Mnemonic ) == 0 {
168+ return
169+ }
170+
171+ // Derive account address if missing and mnemonic appears to be plaintext
172+ if len (u .AccountAddress ) == 0 {
173+ addr , err := AccountFromMnemonic (string (u .Mnemonic ))
174+ if err != nil {
175+ logger .GetLogger ().Error ().Err (err ).Int ("user_id" , u .ID ).Msg ("failed to derive account address from mnemonic" )
176+ return
177+ }
178+ u .AccountAddress = addr
179+ }
180+
181+ if _ , err := cm .DecryptMnemonic (u .Mnemonic , u .AccountAddress ); err == nil {
182+ return
183+ }
184+
185+ encryptedMnemonic , err := cm .EncryptMnemonic (string (u .Mnemonic ), u .AccountAddress )
186+ if err != nil {
187+ return
188+ }
189+ u .Mnemonic = encryptedMnemonic
190+ _ = db .UpdateUserByID (& u )
191+ }(u )
192+ }
193+
194+ wg .Wait ()
195+ return nil
141196}
0 commit comments