Skip to content

Commit d48e666

Browse files
JAORMXclaude
andcommitted
Add SSH host key pinning support
Add GenerateHostKeyPair for in-memory ECDSA P-256 host key generation, WithHostKey client option for ssh.FixedHostKey verification, and HostKey field on sshd.Config for injected host keys. The guest boot sequence loads a host key from disk, deletes the file, and passes the signer to the SSH server. All changes are backward-compatible with nil/zero-value fallbacks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bd53f6b commit d48e666

7 files changed

Lines changed: 262 additions & 13 deletions

File tree

guest/boot/boot.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,23 @@ func Run(logger *slog.Logger, opts ...Option) (shutdown func(), err error) {
8686
return nil, fmt.Errorf("parsing authorized keys: %w", err)
8787
}
8888

89+
// 7b. Load injected host key (if present). The key is deleted from
90+
// disk after loading into memory so it cannot be read by the sandbox
91+
// user. If the file does not exist, hostKeySigner remains nil and the
92+
// SSH server will generate an ephemeral key.
93+
var hostKeySigner ssh.Signer
94+
if hostKeyPEM, readErr := os.ReadFile(cfg.sshHostKeyPath); readErr == nil {
95+
signer, parseErr := ssh.ParsePrivateKey(hostKeyPEM)
96+
if parseErr != nil {
97+
logger.Warn("failed to parse injected host key, falling back to ephemeral",
98+
"path", cfg.sshHostKeyPath, "error", parseErr)
99+
} else {
100+
hostKeySigner = signer
101+
logger.Info("loaded injected SSH host key", "path", cfg.sshHostKeyPath)
102+
}
103+
_ = os.Remove(cfg.sshHostKeyPath)
104+
}
105+
89106
// 8. Drop unneeded capabilities from the bounding set.
90107
logger.Info("dropping unnecessary capabilities")
91108
if err := harden.DropBoundingCaps(
@@ -113,6 +130,7 @@ func Run(logger *slog.Logger, opts ...Option) (shutdown func(), err error) {
113130
DefaultShell: cfg.userShell,
114131
DefaultWorkDir: cfg.workspaceMountPoint,
115132
AgentForwarding: cfg.sshAgentForwarding,
133+
HostKey: hostKeySigner,
116134
Logger: logger,
117135
}
118136
srv, err := sshd.New(sshdCfg)

guest/boot/options.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type config struct {
2323
mountRetries int
2424
sshPort int
2525
sshKeysPath string
26+
sshHostKeyPath string
2627
envFilePath string
2728
userName string
2829
userHome string
@@ -42,6 +43,7 @@ func defaultConfig() *config {
4243
mountRetries: 5,
4344
sshPort: 22,
4445
sshKeysPath: "/home/sandbox/.ssh/authorized_keys",
46+
sshHostKeyPath: "/etc/ssh/ssh_host_ecdsa_key",
4547
envFilePath: "/etc/sandbox-env",
4648
userName: "sandbox",
4749
userHome: "/home/sandbox",
@@ -98,6 +100,14 @@ func WithLockdownRoot(enabled bool) Option {
98100
return optionFunc(func(c *config) { c.lockdownRoot = enabled })
99101
}
100102

103+
// WithSSHHostKeyPath sets the path to a PEM-encoded host private key
104+
// injected into the guest rootfs. If the file exists at boot, the key
105+
// is loaded into memory, the file is deleted, and the key is used as
106+
// the SSH server's host key (enabling client-side pinning).
107+
func WithSSHHostKeyPath(path string) Option {
108+
return optionFunc(func(c *config) { c.sshHostKeyPath = path })
109+
}
110+
101111
// WithSSHAgentForwarding controls whether the SSH server supports
102112
// agent forwarding. When enabled and the client requests it, the
103113
// server creates a Unix socket and sets SSH_AUTH_SOCK for the session.

guest/sshd/server.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ type Config struct {
6666
// per-session agent sockets.
6767
AgentForwarding bool
6868

69+
// HostKey is an optional pre-generated host key signer. When non-nil,
70+
// the server uses this key instead of generating an ephemeral one.
71+
// This enables host key pinning by the client.
72+
HostKey ssh.Signer
73+
6974
// Logger is the structured logger. If nil, slog.Default() is used.
7075
Logger *slog.Logger
7176
}
@@ -112,18 +117,23 @@ func New(cfg Config) (*Server, error) {
112117
},
113118
}
114119

115-
// Generate an ephemeral ECDSA P-256 host key.
116-
hostKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
117-
if err != nil {
118-
return nil, fmt.Errorf("generate host key: %w", err)
119-
}
120+
if cfg.HostKey != nil {
121+
// Use the injected host key (enables client-side pinning).
122+
sshCfg.AddHostKey(cfg.HostKey)
123+
} else {
124+
// Generate an ephemeral ECDSA P-256 host key.
125+
hostKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
126+
if err != nil {
127+
return nil, fmt.Errorf("generate host key: %w", err)
128+
}
120129

121-
signer, err := ssh.NewSignerFromKey(hostKey)
122-
if err != nil {
123-
return nil, fmt.Errorf("create host key signer: %w", err)
124-
}
130+
signer, err := ssh.NewSignerFromKey(hostKey)
131+
if err != nil {
132+
return nil, fmt.Errorf("create host key signer: %w", err)
133+
}
125134

126-
sshCfg.AddHostKey(signer)
135+
sshCfg.AddHostKey(signer)
136+
}
127137

128138
return &Server{
129139
cfg: cfg,

ssh/client.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ func (s *sshSessionAdapter) SetStderr(w io.Writer) { s.sess.Stderr = w }
4545
func (s *sshSessionAdapter) SetStdin(r io.Reader) { s.sess.Stdin = r }
4646
func (s *sshSessionAdapter) Close() error { return s.sess.Close() }
4747

48+
// ClientOption configures optional Client behavior.
49+
type ClientOption func(*Client)
50+
51+
// WithHostKey pins the expected SSH host key. When set, the client uses
52+
// ssh.FixedHostKey for host key verification instead of accepting any key.
53+
func WithHostKey(pubKey ssh.PublicKey) ClientOption {
54+
return func(c *Client) {
55+
c.expectedHostKey = pubKey
56+
}
57+
}
58+
4859
// Client provides a high-level SSH interface for communicating with a
4960
// microVM guest.
5061
type Client struct {
@@ -53,6 +64,11 @@ type Client struct {
5364
user string
5465
keyPath string
5566

67+
// expectedHostKey, when non-nil, enables host key pinning via
68+
// ssh.FixedHostKey. When nil, InsecureIgnoreHostKey is used for
69+
// backward compatibility.
70+
expectedHostKey ssh.PublicKey
71+
5672
// readFile reads a file from disk. Defaults to os.ReadFile.
5773
// Injected for testability.
5874
readFile func(string) ([]byte, error)
@@ -68,7 +84,8 @@ type Client struct {
6884

6985
// NewClient creates a new SSH Client configured to connect to the given
7086
// host and port using the specified user and private key file.
71-
func NewClient(host string, port uint16, user, keyPath string) *Client {
87+
// Options are applied after defaults to allow host key pinning, etc.
88+
func NewClient(host string, port uint16, user, keyPath string, opts ...ClientOption) *Client {
7289
c := &Client{
7390
host: host,
7491
port: port,
@@ -77,6 +94,9 @@ func NewClient(host string, port uint16, user, keyPath string) *Client {
7794
readFile: os.ReadFile,
7895
writeFile: os.WriteFile,
7996
}
97+
for _, o := range opts {
98+
o(c)
99+
}
80100
c.createSession = func(ctx context.Context) (remoteSession, func(), error) {
81101
sess, client, err := c.newSession(ctx)
82102
if err != nil {
@@ -278,13 +298,20 @@ func (c *Client) sshConfig() (*ssh.ClientConfig, error) {
278298
return nil, fmt.Errorf("parse SSH key %s: %w", c.keyPath, err)
279299
}
280300

301+
var hostKeyCallback ssh.HostKeyCallback
302+
if c.expectedHostKey != nil {
303+
hostKeyCallback = ssh.FixedHostKey(c.expectedHostKey)
304+
} else {
305+
//nolint:gosec // Backward compat when host key not available.
306+
hostKeyCallback = ssh.InsecureIgnoreHostKey()
307+
}
308+
281309
return &ssh.ClientConfig{
282310
User: c.user,
283311
Auth: []ssh.AuthMethod{
284312
ssh.PublicKeys(signer),
285313
},
286-
//nolint:gosec // We trust the VM we just created; host key checking is unnecessary.
287-
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
314+
HostKeyCallback: hostKeyCallback,
288315
Timeout: defaultSSHTimeout,
289316
}, nil
290317
}

ssh/client_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,109 @@ func TestSSHConfig_ValidKeyParsing(t *testing.T) {
491491
assert.Equal(t, "testuser", config.User)
492492
assert.NotEmpty(t, config.Auth)
493493
}
494+
495+
func TestWithHostKey_SetsExpectedHostKey(t *testing.T) {
496+
t.Parallel()
497+
498+
_, pubKey, err := GenerateHostKeyPair()
499+
require.NoError(t, err)
500+
501+
c := NewClient("127.0.0.1", 22, "user", "/tmp/key", WithHostKey(pubKey))
502+
assert.NotNil(t, c.expectedHostKey)
503+
assert.Equal(t, pubKey.Marshal(), c.expectedHostKey.Marshal())
504+
}
505+
506+
func TestWithHostKey_NilFallback(t *testing.T) {
507+
t.Parallel()
508+
509+
// No options → expectedHostKey should be nil.
510+
c := NewClient("127.0.0.1", 22, "user", "/tmp/key")
511+
assert.Nil(t, c.expectedHostKey)
512+
}
513+
514+
func TestSSHConfig_WithHostKey_AcceptsMatchingKey(t *testing.T) {
515+
t.Parallel()
516+
517+
keyDir := t.TempDir()
518+
privKeyPath, _, err := GenerateKeyPair(keyDir)
519+
require.NoError(t, err)
520+
521+
_, hostPubKey, err := GenerateHostKeyPair()
522+
require.NoError(t, err)
523+
524+
c := &Client{
525+
host: "testhost",
526+
port: 2222,
527+
user: "testuser",
528+
keyPath: privKeyPath,
529+
readFile: os.ReadFile,
530+
expectedHostKey: hostPubKey,
531+
}
532+
533+
config, err := c.sshConfig()
534+
require.NoError(t, err)
535+
require.NotNil(t, config.HostKeyCallback)
536+
537+
// Matching key should be accepted.
538+
err = config.HostKeyCallback("testhost:2222", nil, hostPubKey)
539+
assert.NoError(t, err, "matching host key should be accepted")
540+
}
541+
542+
func TestSSHConfig_WithHostKey_RejectsMismatchedKey(t *testing.T) {
543+
t.Parallel()
544+
545+
keyDir := t.TempDir()
546+
privKeyPath, _, err := GenerateKeyPair(keyDir)
547+
require.NoError(t, err)
548+
549+
_, hostPubKey, err := GenerateHostKeyPair()
550+
require.NoError(t, err)
551+
552+
// Generate a different key to simulate an impersonator.
553+
_, wrongPubKey, err := GenerateHostKeyPair()
554+
require.NoError(t, err)
555+
556+
c := &Client{
557+
host: "testhost",
558+
port: 2222,
559+
user: "testuser",
560+
keyPath: privKeyPath,
561+
readFile: os.ReadFile,
562+
expectedHostKey: hostPubKey,
563+
}
564+
565+
config, err := c.sshConfig()
566+
require.NoError(t, err)
567+
require.NotNil(t, config.HostKeyCallback)
568+
569+
// Mismatched key should be rejected.
570+
err = config.HostKeyCallback("testhost:2222", nil, wrongPubKey)
571+
assert.Error(t, err, "mismatched host key should be rejected")
572+
}
573+
574+
func TestSSHConfig_WithoutHostKey_AcceptsAnyKey(t *testing.T) {
575+
t.Parallel()
576+
577+
keyDir := t.TempDir()
578+
privKeyPath, _, err := GenerateKeyPair(keyDir)
579+
require.NoError(t, err)
580+
581+
c := &Client{
582+
host: "testhost",
583+
port: 2222,
584+
user: "testuser",
585+
keyPath: privKeyPath,
586+
readFile: os.ReadFile,
587+
}
588+
589+
config, err := c.sshConfig()
590+
require.NoError(t, err)
591+
require.NotNil(t, config.HostKeyCallback)
592+
593+
// Without host key pinning, any key should be accepted.
594+
_, anyPubKey, err := GenerateHostKeyPair()
595+
require.NoError(t, err)
596+
597+
err = config.HostKeyCallback("testhost:2222", nil, anyPubKey)
598+
assert.NoError(t, err, "insecure callback should accept any key")
599+
}

ssh/keygen.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,33 @@ func GenerateKeyPair(keyDir string) (privateKeyPath, publicKeyPath string, err e
7373
return privateKeyPath, publicKeyPath, nil
7474
}
7575

76+
// GenerateHostKeyPair generates an ECDSA P-256 host key pair in memory.
77+
// It returns the private key as PEM-encoded bytes and the corresponding
78+
// SSH public key. No files are written to disk.
79+
func GenerateHostKeyPair() (privateKeyPEM []byte, publicKey ssh.PublicKey, err error) {
80+
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
81+
if err != nil {
82+
return nil, nil, fmt.Errorf("generate host ECDSA key: %w", err)
83+
}
84+
85+
derBytes, err := x509.MarshalECPrivateKey(key)
86+
if err != nil {
87+
return nil, nil, fmt.Errorf("marshal host private key: %w", err)
88+
}
89+
90+
pemBlock := &pem.Block{
91+
Type: "EC PRIVATE KEY",
92+
Bytes: derBytes,
93+
}
94+
95+
pubKey, err := ssh.NewPublicKey(&key.PublicKey)
96+
if err != nil {
97+
return nil, nil, fmt.Errorf("convert to SSH public key: %w", err)
98+
}
99+
100+
return pem.EncodeToMemory(pemBlock), pubKey, nil
101+
}
102+
76103
// GetPublicKeyContent reads an SSH public key file and returns its content
77104
// as a string suitable for inclusion in authorized_keys.
78105
func GetPublicKeyContent(publicKeyPath string) (string, error) {

ssh/keygen_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package ssh
55

66
import (
7+
"encoding/pem"
78
"os"
89
"path/filepath"
910
"strings"
@@ -138,3 +139,53 @@ func TestGenerateKeyPair_CreatesKeyDirectory(t *testing.T) {
138139
assert.FileExists(t, privPath)
139140
assert.FileExists(t, pubPath)
140141
}
142+
143+
func TestGenerateHostKeyPair_Success(t *testing.T) {
144+
t.Parallel()
145+
146+
pemBytes, pubKey, err := GenerateHostKeyPair()
147+
require.NoError(t, err)
148+
149+
// PEM should be parseable.
150+
block, rest := pem.Decode(pemBytes)
151+
require.NotNil(t, block, "PEM block should not be nil")
152+
assert.Equal(t, "EC PRIVATE KEY", block.Type)
153+
assert.Empty(t, rest, "no trailing data after PEM block")
154+
155+
// PEM should be parseable as an SSH private key.
156+
signer, err := gossh.ParsePrivateKey(pemBytes)
157+
require.NoError(t, err)
158+
assert.NotNil(t, signer)
159+
160+
// Public key should be non-nil and ECDSA P-256.
161+
require.NotNil(t, pubKey)
162+
assert.Equal(t, "ecdsa-sha2-nistp256", pubKey.Type())
163+
}
164+
165+
func TestGenerateHostKeyPair_Unique(t *testing.T) {
166+
t.Parallel()
167+
168+
pem1, pub1, err := GenerateHostKeyPair()
169+
require.NoError(t, err)
170+
171+
pem2, pub2, err := GenerateHostKeyPair()
172+
require.NoError(t, err)
173+
174+
assert.NotEqual(t, pem1, pem2, "private keys should differ")
175+
assert.NotEqual(t, pub1.Marshal(), pub2.Marshal(), "public keys should differ")
176+
}
177+
178+
func TestGenerateHostKeyPair_RoundTrip(t *testing.T) {
179+
t.Parallel()
180+
181+
pemBytes, pubKey, err := GenerateHostKeyPair()
182+
require.NoError(t, err)
183+
184+
// Parse the PEM back and verify the public key matches.
185+
signer, err := gossh.ParsePrivateKey(pemBytes)
186+
require.NoError(t, err)
187+
188+
signerPub := signer.PublicKey()
189+
assert.Equal(t, pubKey.Type(), signerPub.Type())
190+
assert.Equal(t, pubKey.Marshal(), signerPub.Marshal())
191+
}

0 commit comments

Comments
 (0)