Skip to content

Commit 631c8da

Browse files
committed
Refactor LoadHostKey function: streamline host key loading and improve error handling
1 parent d182f82 commit 631c8da

2 files changed

Lines changed: 22 additions & 38 deletions

File tree

internal/proxy/conn.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@ type Conn struct {
1313
sshCfg *ssh.ServerConfig
1414
}
1515

16-
func (c *Conn) Close() {
17-
if c.client != nil {
18-
c.client.Close()
19-
}
20-
if c.target != nil {
21-
c.target.Close()
22-
}
23-
}
24-
2516
func (c *Conn) Serve() {
2617
buf := make([]byte, 4096)
2718
n, err := c.client.Read(buf)
@@ -54,3 +45,12 @@ func (c *Conn) Proxy() {
5445
}()
5546
io.Copy(c.client, c.target)
5647
}
48+
49+
func (c *Conn) Close() {
50+
if c.client != nil {
51+
c.client.Close()
52+
}
53+
if c.target != nil {
54+
c.target.Close()
55+
}
56+
}

internal/ssh/keys.go

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,32 @@
11
package ssh
22

33
import (
4-
"crypto/rand"
5-
"crypto/rsa"
6-
"crypto/x509"
7-
"encoding/pem"
84
"fmt"
95
"os"
6+
"path/filepath"
107

118
"golang.org/x/crypto/ssh"
129
)
1310

14-
const HostKeyPath = "/etc/ssh-ify/host_key"
15-
1611
func LoadHostKey() (ssh.Signer, error) {
17-
privateBytes, err := os.ReadFile(HostKeyPath)
12+
home, err := os.UserHomeDir()
1813
if err != nil {
19-
if err := os.MkdirAll("/etc/ssh-ify", 0700); err != nil {
20-
return nil, fmt.Errorf("failed to create config directory: %v", err)
21-
}
22-
if err := GenerateHostKey(HostKeyPath); err != nil {
23-
return nil, fmt.Errorf("failed to generate host key: %v", err)
24-
}
25-
privateBytes, err = os.ReadFile(HostKeyPath)
26-
if err != nil {
27-
return nil, fmt.Errorf("failed to read generated host key: %v", err)
28-
}
14+
return nil, fmt.Errorf("failed to get home directory: %w", err)
2915
}
3016

31-
return ssh.ParsePrivateKey(privateBytes)
32-
}
33-
34-
func GenerateHostKey(keyPath string) error {
35-
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
17+
keyPath := filepath.Join(home, ".ssh", "id_rsa")
18+
privateBytes, err := os.ReadFile(keyPath)
3619
if err != nil {
37-
return err
20+
if os.IsNotExist(err) {
21+
return nil, fmt.Errorf("host key not found, please generate one: ssh-keygen -t rsa -b 4096 -f %s", keyPath)
22+
}
23+
return nil, fmt.Errorf("failed to read host key: %w", err)
3824
}
3925

40-
privDER := x509.MarshalPKCS1PrivateKey(privateKey)
41-
privBlock := &pem.Block{
42-
Type: "RSA PRIVATE KEY",
43-
Bytes: privDER,
26+
signer, err := ssh.ParsePrivateKey(privateBytes)
27+
if err != nil {
28+
return nil, fmt.Errorf("failed to parse host key: %w", err)
4429
}
45-
privateBytes := pem.EncodeToMemory(privBlock)
4630

47-
return os.WriteFile(keyPath, privateBytes, 0600)
31+
return signer, nil
4832
}

0 commit comments

Comments
 (0)