Skip to content

Commit d6e412b

Browse files
committed
Refactor agent to use per-instance credentials with registration, revocation handling, and platform-specific file locking
1 parent a800f69 commit d6e412b

File tree

6 files changed

+531
-38
lines changed

6 files changed

+531
-38
lines changed

agent/client.go

Lines changed: 263 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,251 @@ package agent
22

33
import (
44
"bytes"
5+
"crypto/sha256"
6+
"encoding/hex"
57
"encoding/json"
8+
"errors"
69
"fmt"
710
"io"
811
"net/http"
12+
"os"
13+
"path/filepath"
14+
"runtime"
15+
"strings"
916
"time"
17+
18+
"github.com/actionforge/actrun-cli/build"
1019
)
1120

21+
// ErrInstanceRevoked is returned when the gateway reports this process's
22+
// instance secret is no longer valid. The worker unwinds, re-registers
23+
// with the agent token, and starts a fresh loop.
24+
var ErrInstanceRevoked = errors.New("agent instance revoked: re-register required")
25+
26+
// ErrAgentAlreadyRunning is returned by AcquireInstanceLock when another
27+
// process on this machine already holds the lock for the same (server, token).
28+
var ErrAgentAlreadyRunning = errors.New("another agent is already running for this token")
29+
30+
// ErrInvalidAgentToken is returned by Register when the gateway rejects
31+
// the enrollment token (HTTP 401). This is a fatal, non-retryable error.
32+
var ErrInvalidAgentToken = errors.New("register: invalid agent token")
33+
34+
// InstanceCred is the per-process credential persisted between restarts.
35+
// We never write the agent token to disk — only the server-minted
36+
// secret, which is narrower (scoped to one instance, revocable).
37+
type InstanceCred struct {
38+
InstanceID string `json:"instance_id"`
39+
InstanceSecret string `json:"instance_secret"`
40+
PoolID string `json:"pool_id"`
41+
PoolName string `json:"pool_name"`
42+
ServerURL string `json:"server_url"`
43+
PoolFingerprint string `json:"pool_fingerprint"`
44+
}
45+
46+
// registerRequest mirrors CIRunnerRegisterRequest on the gateway side.
47+
type registerRequest struct {
48+
Name string `json:"name,omitempty"`
49+
Hostname string `json:"hostname,omitempty"`
50+
OS string `json:"os,omitempty"`
51+
Arch string `json:"arch,omitempty"`
52+
Version string `json:"version,omitempty"`
53+
}
54+
55+
// registerResponse mirrors CIRunnerRegisterResponse.
56+
type registerResponse struct {
57+
InstanceID string `json:"instance_id"`
58+
InstanceSecret string `json:"instance_secret"`
59+
PoolID string `json:"pool_id"`
60+
PoolName string `json:"pool_name"`
61+
}
62+
63+
// Client talks to the gateway's runner protocol endpoints using an
64+
// instance secret as the bearer. Construct via NewClientFromCred after
65+
// obtaining a credential from Register or loadInstanceCred.
1266
type Client struct {
1367
serverURL string
14-
token string
68+
cred *InstanceCred
1569
httpClient *http.Client
1670
}
1771

18-
func NewClient(serverURL, token string) *Client {
72+
// NewClientFromCred wraps an existing credential for subsequent runner
73+
// API calls.
74+
func NewClientFromCred(serverURL string, cred *InstanceCred) *Client {
1975
return &Client{
2076
serverURL: serverURL,
21-
token: token,
77+
cred: cred,
2278
httpClient: &http.Client{
2379
Timeout: 30 * time.Second,
2480
},
2581
}
2682
}
2783

84+
// NewClientFromSecret builds a client from a bare instance secret, used
85+
// by child processes (the node reporter in cmd_root) that inherit
86+
// BUILD_AGENT_TOKEN / BUILD_SERVER_URL in their environment and don't
87+
// need the full credential struct.
88+
func NewClientFromSecret(serverURL, instanceSecret string) *Client {
89+
return NewClientFromCred(serverURL, &InstanceCred{InstanceSecret: instanceSecret})
90+
}
91+
92+
// Cred returns the active instance credential.
93+
func (c *Client) Cred() *InstanceCred {
94+
return c.cred
95+
}
96+
97+
// InstanceCredPath returns the stable disk location for a credential
98+
// keyed on (serverURL, agentToken). Keeping the hash short keeps
99+
// paths readable while avoiding collisions on machines that run multiple
100+
// pools at once.
101+
func InstanceCredPath(serverURL, agentToken string) (string, error) {
102+
dir, err := os.UserConfigDir()
103+
if err != nil {
104+
return "", err
105+
}
106+
sum := sha256.Sum256([]byte(strings.TrimSuffix(serverURL, "/") + "\x00" + agentToken))
107+
return filepath.Join(dir, "actrun", "instance-"+hex.EncodeToString(sum[:])+".json"), nil
108+
}
109+
110+
// poolFingerprint returns a stable hash of the enrollment token so a
111+
// stored credential can detect "this file was written for a different
112+
// token" without ever persisting the raw token itself.
113+
func poolFingerprint(agentToken string) string {
114+
sum := sha256.Sum256([]byte(agentToken))
115+
return hex.EncodeToString(sum[:])
116+
}
117+
118+
// LoadInstanceCred returns an existing credential if one is on disk and
119+
// matches the given (serverURL, agentToken). A mismatch on either
120+
// field invalidates the file so the caller can re-register against the
121+
// new server or token.
122+
func LoadInstanceCred(serverURL, agentToken string) (*InstanceCred, error) {
123+
path, err := InstanceCredPath(serverURL, agentToken)
124+
if err != nil {
125+
return nil, err
126+
}
127+
data, err := os.ReadFile(path)
128+
if err != nil {
129+
if os.IsNotExist(err) {
130+
return nil, nil
131+
}
132+
return nil, err
133+
}
134+
var cred InstanceCred
135+
if err := json.Unmarshal(data, &cred); err != nil {
136+
return nil, fmt.Errorf("parse instance cred: %w", err)
137+
}
138+
if cred.ServerURL != serverURL || cred.PoolFingerprint != poolFingerprint(agentToken) {
139+
return nil, nil
140+
}
141+
if cred.InstanceSecret == "" || cred.InstanceID == "" {
142+
return nil, nil
143+
}
144+
return &cred, nil
145+
}
146+
147+
// saveInstanceCred persists a credential to disk with owner-only
148+
// permissions so another user on the host can't steal the secret.
149+
func saveInstanceCred(serverURL, agentToken string, cred *InstanceCred) error {
150+
path, err := InstanceCredPath(serverURL, agentToken)
151+
if err != nil {
152+
return err
153+
}
154+
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
155+
return err
156+
}
157+
data, err := json.Marshal(cred)
158+
if err != nil {
159+
return err
160+
}
161+
return os.WriteFile(path, data, 0o600)
162+
}
163+
164+
// DeleteInstanceCred removes the on-disk credential file for a given
165+
// (serverURL, agentToken). Called after a successful deregister, or
166+
// after a 401 tells us the credential is no longer valid.
167+
func DeleteInstanceCred(serverURL, agentToken string) error {
168+
path, err := InstanceCredPath(serverURL, agentToken)
169+
if err != nil {
170+
return err
171+
}
172+
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
173+
return err
174+
}
175+
return nil
176+
}
177+
178+
// instanceLockPath returns the lockfile path next to the credential file.
179+
func instanceLockPath(serverURL, agentToken string) (string, error) {
180+
credPath, err := InstanceCredPath(serverURL, agentToken)
181+
if err != nil {
182+
return "", err
183+
}
184+
return credPath + ".lock", nil
185+
}
186+
187+
// AcquireInstanceLock and ReleaseInstanceLock are in
188+
// lock_unix.go / lock_windows.go (platform-specific implementations).
189+
190+
// Register exchanges an agent token for a per-process agent credential.
191+
// The resulting credential is persisted to disk so a future restart can
192+
// skip this round trip entirely.
193+
func Register(serverURL, agentToken string) (*InstanceCred, error) {
194+
hostname, _ := os.Hostname()
195+
body := registerRequest{
196+
Name: hostname,
197+
Hostname: hostname,
198+
OS: runtime.GOOS,
199+
Arch: runtime.GOARCH,
200+
Version: build.Version,
201+
}
202+
data, err := json.Marshal(body)
203+
if err != nil {
204+
return nil, err
205+
}
206+
207+
req, err := http.NewRequest(http.MethodPost, serverURL+"/api/v2/ci/runner/register", bytes.NewReader(data))
208+
if err != nil {
209+
return nil, err
210+
}
211+
req.Header.Set("Authorization", "Bearer "+agentToken)
212+
req.Header.Set("Content-Type", "application/json")
213+
214+
httpClient := &http.Client{Timeout: 30 * time.Second}
215+
resp, err := httpClient.Do(req)
216+
if err != nil {
217+
return nil, fmt.Errorf("register request failed: %w", err)
218+
}
219+
defer resp.Body.Close()
220+
221+
switch resp.StatusCode {
222+
case http.StatusCreated:
223+
case http.StatusUnauthorized:
224+
return nil, ErrInvalidAgentToken
225+
default:
226+
body, _ := io.ReadAll(resp.Body)
227+
return nil, fmt.Errorf("register: %s %s", resp.Status, string(body))
228+
}
229+
230+
var out registerResponse
231+
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
232+
return nil, fmt.Errorf("decode register response: %w", err)
233+
}
234+
cred := &InstanceCred{
235+
InstanceID: out.InstanceID,
236+
InstanceSecret: out.InstanceSecret,
237+
PoolID: out.PoolID,
238+
PoolName: out.PoolName,
239+
ServerURL: serverURL,
240+
PoolFingerprint: poolFingerprint(agentToken),
241+
}
242+
if err := saveInstanceCred(serverURL, agentToken, cred); err != nil {
243+
// Persistence failure is non-fatal: the process can still run
244+
// with the in-memory credential, it just won't survive restart.
245+
_ = err
246+
}
247+
return cred, nil
248+
}
249+
28250
func (c *Client) doRequest(method, path string, body interface{}) (*http.Response, error) {
29251
var bodyReader io.Reader
30252
if body != nil {
@@ -39,21 +261,30 @@ func (c *Client) doRequest(method, path string, body interface{}) (*http.Respons
39261
if err != nil {
40262
return nil, err
41263
}
42-
req.Header.Set("Authorization", "Bearer "+c.token)
264+
req.Header.Set("Authorization", "Bearer "+c.cred.InstanceSecret)
43265
if body != nil {
44266
req.Header.Set("Content-Type", "application/json")
45267
}
46-
return c.httpClient.Do(req)
268+
resp, err := c.httpClient.Do(req)
269+
if err != nil {
270+
return nil, err
271+
}
272+
if resp.StatusCode == http.StatusUnauthorized {
273+
// Only treat 401 as revocation on runner-protocol endpoints.
274+
// A proxy or WAF returning 401 should not trigger re-registration.
275+
if strings.HasPrefix(path, "/api/v2/ci/runner/") {
276+
_, _ = io.Copy(io.Discard, resp.Body)
277+
resp.Body.Close()
278+
return nil, ErrInstanceRevoked
279+
}
280+
}
281+
return resp, nil
47282
}
48283

49284
func (c *Client) ServerURL() string {
50285
return c.serverURL
51286
}
52287

53-
func (c *Client) Token() string {
54-
return c.token
55-
}
56-
57288
func (c *Client) Claim() (*ClaimResponse, error) {
58289
resp, err := c.doRequest("POST", "/api/v2/ci/runner/claim", nil)
59290
if err != nil {
@@ -143,17 +374,34 @@ func (c *Client) SubmitActiveNodes(jobID string, nodes []ActiveNode) error {
143374
return drainAndCheck(resp)
144375
}
145376

146-
func (c *Client) Heartbeat(req HeartbeatRequest) error {
377+
// HeartbeatResponse carries pool metadata back from the gateway. Labels
378+
// are included so the agent always has a fresh view without caching.
379+
type HeartbeatResponse struct {
380+
Labels string `json:"labels"`
381+
}
382+
383+
func (c *Client) Heartbeat(req HeartbeatRequest) (*HeartbeatResponse, error) {
147384
resp, err := c.doRequest("POST", "/api/v2/ci/runner/heartbeat", req)
148385
if err != nil {
149-
return err
386+
return nil, err
150387
}
151-
return drainAndCheck(resp)
388+
defer resp.Body.Close()
389+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
390+
_, _ = io.Copy(io.Discard, resp.Body)
391+
return nil, fmt.Errorf("unexpected status: %s", resp.Status)
392+
}
393+
var out HeartbeatResponse
394+
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
395+
return &HeartbeatResponse{}, nil // non-fatal: old server without response body
396+
}
397+
return &out, nil
152398
}
153399

154-
155-
func (c *Client) Disconnect() error {
156-
resp, err := c.doRequest("POST", "/api/v2/ci/runner/disconnect", nil)
400+
// Deregister tells the gateway this process is shutting down so the
401+
// instance row is removed immediately. Graceful shutdown path only —
402+
// crash-exit will be cleaned up by the gateway's stale sweeper.
403+
func (c *Client) Deregister() error {
404+
resp, err := c.doRequest("POST", "/api/v2/ci/runner/deregister", nil)
157405
if err != nil {
158406
return err
159407
}

agent/lock_unix.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//go:build !windows
2+
3+
package agent
4+
5+
import (
6+
"os"
7+
"path/filepath"
8+
"syscall"
9+
)
10+
11+
// AcquireInstanceLock takes an exclusive flock on a file next to the
12+
// credential file. If another process on this machine already holds the
13+
// lock, ErrAgentAlreadyRunning is returned immediately (non-blocking).
14+
// The returned *os.File must be kept open for the lifetime of the process
15+
// to maintain the lock. On process crash the OS releases the flock
16+
// automatically, so no stale lockfiles.
17+
func AcquireInstanceLock(serverURL, agentToken string) (*os.File, error) {
18+
lockPath, err := instanceLockPath(serverURL, agentToken)
19+
if err != nil {
20+
return nil, err
21+
}
22+
if err := os.MkdirAll(filepath.Dir(lockPath), 0o700); err != nil {
23+
return nil, err
24+
}
25+
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600)
26+
if err != nil {
27+
return nil, err
28+
}
29+
// LOCK_EX | LOCK_NB: exclusive, non-blocking. Fails immediately if held.
30+
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
31+
f.Close()
32+
return nil, ErrAgentAlreadyRunning
33+
}
34+
return f, nil
35+
}
36+
37+
// ReleaseInstanceLock releases the flock and removes the lockfile.
38+
func ReleaseInstanceLock(f *os.File) {
39+
if f == nil {
40+
return
41+
}
42+
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
43+
name := f.Name()
44+
f.Close()
45+
_ = os.Remove(name)
46+
}

0 commit comments

Comments
 (0)