Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions main/src/app-events/block-quit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import {
recreateMainWindowForShutdown,
sendToMainWindowRenderer,
} from '../main-window'
import { getToolhivePort, stopToolhive, binPath } from '../toolhive-manager'
import { stopToolhive, binPath } from '../toolhive-manager'
import { stopAllServers } from '../graceful-exit'
import { createMainProcessFetch } from '../unix-socket-fetch'
import { safeTrayDestroy } from '../system-tray'
import { delay } from '../../../utils/delay'
import log from '../logger'
Expand Down Expand Up @@ -39,10 +40,7 @@ export async function blockQuit(source: string, event?: Electron.Event) {
}

try {
const port = getToolhivePort()
if (port) {
await stopAllServers(binPath, port)
}
await stopAllServers(binPath, { createFetch: createMainProcessFetch })
} catch (err) {
log.error('Teardown failed: ', err)
} finally {
Expand Down
8 changes: 3 additions & 5 deletions main/src/app-events/process-signals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import {
setTearingDownState,
setQuittingState,
} from '../app-state'
import { getToolhivePort, stopToolhive, binPath } from '../toolhive-manager'
import { stopToolhive, binPath } from '../toolhive-manager'
import { stopAllServers } from '../graceful-exit'
import { createMainProcessFetch } from '../unix-socket-fetch'
import { safeTrayDestroy } from '../system-tray'
import log from '../logger'

Expand All @@ -17,10 +18,7 @@ export function register() {
setQuittingState(true)
log.info(`[${sig}] delaying exit for teardown...`)
try {
const port = getToolhivePort()
if (port) {
await stopAllServers(binPath, port)
}
await stopAllServers(binPath, { createFetch: createMainProcessFetch })
} finally {
stopToolhive()
safeTrayDestroy()
Expand Down
9 changes: 6 additions & 3 deletions main/src/app-events/when-ready.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
isToolhiveRunning,
stopToolhive,
} from '../toolhive-manager'
import { registerApiFetchHandlers } from '../unix-socket-fetch'
import { getMainWindow, createMainWindow, hideMainWindow } from '../main-window'
import { extractDeepLinkFromArgs, handleDeepLink } from '../deep-links'
import { getCspString } from '../csp'
Expand Down Expand Up @@ -69,6 +70,9 @@ export function register() {
// Start ToolHive with tray reference
await startToolhive()

// Register IPC handlers for renderer -> main -> thv API bridge
registerApiFetchHandlers()

// Create main window
try {
const mainWindow = await createMainWindow()
Expand Down Expand Up @@ -131,10 +135,9 @@ export function register() {
if (process.env.NODE_ENV === 'development') {
return callback({ responseHeaders: details.responseHeaders })
}
// When using UNIX sockets, API requests go through IPC so no port is
// needed in connect-src. Pass the port only when available (TCP fallback).
const port = getToolhivePort()
if (port == null) {
throw new Error('[content-security-policy] ToolHive port is not set')
}
return callback({
responseHeaders: {
...details.responseHeaders,
Expand Down
17 changes: 3 additions & 14 deletions main/src/auto-update.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@ import { app, autoUpdater, dialog, ipcMain, type BrowserWindow } from 'electron'
import { updateElectronApp, UpdateSourceType } from 'update-electron-app'
import * as Sentry from '@sentry/electron/main'
import { stopAllServers } from './graceful-exit'
import {
stopToolhive,
getToolhivePort,
binPath,
isToolhiveRunning,
} from './toolhive-manager'
import { stopToolhive, binPath, isToolhiveRunning } from './toolhive-manager'
import { createMainProcessFetch } from './unix-socket-fetch'
import { safeTrayDestroy } from './system-tray'
import { getAppVersion, pollWindowReady } from './util'
import { delay } from '../../utils/delay'
Expand Down Expand Up @@ -35,14 +31,7 @@ let updateState: UpdateState = 'none'

async function safeServerShutdown(): Promise<boolean> {
try {
const port = getToolhivePort()
if (!port) {
log.info('[update] No ToolHive port available, skipping server shutdown')
return true
}

await stopAllServers(binPath, port)

await stopAllServers(binPath, { createFetch: createMainProcessFetch })
log.info('[update] All servers stopped successfully')
return true
} catch (error) {
Expand Down
15 changes: 10 additions & 5 deletions main/src/csp.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
const getCspMap = (port: number, sentryDsn?: string) => {
// In production with Sentry enabled, allow blob workers for replay
const getCspMap = (port: number | undefined, sentryDsn?: string) => {
const hasSentry = Boolean(sentryDsn)
const workerSrc = hasSentry ? "'self' blob:" : "'self'"

// When using UNIX sockets the renderer never makes direct HTTP requests
// to the thv server, so no localhost entry is needed in connect-src.
const connectParts = ["'self'"]
if (port != null) connectParts.push(`http://localhost:${port}`)
connectParts.push('https://api.hsforms.com')
if (hasSentry) connectParts.push('https://*.sentry.io')

return {
'default-src': "'self'",
'script-src': "'self'",
'style-src': "'self' 'unsafe-inline'",
'img-src': "'self' data: blob:",
'font-src': "'self' data:",
'connect-src': `'self' http://localhost:${port} https://api.hsforms.com${hasSentry ? ' https://*.sentry.io' : ''}`,
'connect-src': connectParts.join(' '),
'frame-src': "'none'",
'object-src': "'none'",
'base-uri': "'self'",
'form-action': "'self'",
'frame-ancestors': "'none'",
'manifest-src': "'self'",
'media-src': "'self' blob: data:",
// Allow blob: workers only when Sentry is configured
'worker-src': workerSrc,
'child-src': "'none'",
}
}

export const getCspString = (port: number, sentryDsn?: string) =>
export const getCspString = (port: number | undefined, sentryDsn?: string) =>
Object.entries(getCspMap(port, sentryDsn))
.map(([key, value]) => `${key} ${value}`)
.join('; ')
17 changes: 11 additions & 6 deletions main/src/graceful-exit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ export const shutdownStore = new Store({
},
})

/** Create API client for the given port */
function createApiClient(port: number) {
/**
* Create API client. When a custom fetch is provided (UNIX socket transport),
* the baseUrl is a dummy since the custom fetch handles routing.
*/
function createApiClient(opts: { port?: number; customFetch?: typeof fetch }) {
return createClient({
baseUrl: `http://localhost:${port}`,
baseUrl: opts.port ? `http://localhost:${opts.port}` : 'http://localhost',
headers: getHeaders(),
...(opts.customFetch ? { fetch: opts.customFetch } : {}),
})
}

Expand Down Expand Up @@ -114,10 +118,11 @@ async function pollUntilAllStopped(

/** Stop every running server in parallel and wait until *all* are down. */
export async function stopAllServers(
_binPath: string, // Kept for backward compatibility
port: number
_binPath: string,
opts: { port?: number; createFetch?: () => typeof fetch }
): Promise<void> {
const client = createApiClient(port)
const customFetch = opts.createFetch?.()
const client = createApiClient({ port: opts.port, customFetch })
const servers = await getRunningServers(client)
log.info(
`Found ${servers.length} running servers: `,
Expand Down
5 changes: 5 additions & 0 deletions main/src/ipc-handlers/toolhive.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ import { ipcMain } from 'electron'
import {
restartToolhive,
getToolhivePort,
getToolhiveSocketPath,
isToolhiveRunning,
getToolhiveMcpPort,
isUsingCustomPort,
} from '../toolhive-manager'
import { checkContainerEngine } from '../container-engine'
import { getLastShutdownServers, clearShutdownHistory } from '../graceful-exit'
import { registerApiFetchHandlers } from '../unix-socket-fetch'
import log from '../logger'

export function register() {
ipcMain.handle('get-toolhive-port', () => getToolhivePort())
ipcMain.handle('get-toolhive-mcp-port', () => getToolhiveMcpPort())
ipcMain.handle('get-toolhive-socket-path', () => getToolhiveSocketPath())
ipcMain.handle('is-toolhive-running', () => isToolhiveRunning())
ipcMain.handle('is-using-custom-port', () => isUsingCustomPort())

Expand Down Expand Up @@ -41,4 +44,6 @@ export function register() {
clearShutdownHistory()
return { success: true }
})

registerApiFetchHandlers()
}
26 changes: 15 additions & 11 deletions main/src/tests/auto-update.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ vi.mock('../toolhive-manager', () => ({
binPath: '/mock/bin/path',
}))

vi.mock('../unix-socket-fetch', () => ({
createMainProcessFetch: vi.fn(() => vi.fn()),
}))

vi.mock('../system-tray', () => ({
safeTrayDestroy: vi.fn(),
}))
Expand Down Expand Up @@ -156,7 +160,7 @@ vi.mock('../app-state', () => ({
}))

import { stopAllServers } from '../graceful-exit'
import { stopToolhive, getToolhivePort } from '../toolhive-manager'
import { stopToolhive } from '../toolhive-manager'
import { safeTrayDestroy } from '../system-tray'
import { pollWindowReady } from '../util'
import { delay } from '../../../utils/delay'
Expand Down Expand Up @@ -199,7 +203,6 @@ describe('auto-update', () => {
// Setup default mocks
vi.mocked(stopAllServers).mockResolvedValue(undefined)
vi.mocked(stopToolhive).mockReturnValue(undefined)
vi.mocked(getToolhivePort).mockReturnValue(3000)
vi.mocked(pollWindowReady).mockResolvedValue(undefined)
vi.mocked(delay).mockResolvedValue(undefined)
vi.mocked(dialog.showMessageBox).mockResolvedValue({
Expand Down Expand Up @@ -803,8 +806,7 @@ describe('auto-update', () => {
expect(vi.mocked(autoUpdater).quitAndInstall).toHaveBeenCalled()
})

it('integrates with toolhive manager port detection', async () => {
vi.mocked(getToolhivePort).mockReturnValue(undefined)
it('always attempts server shutdown via IPC fetch bridge', async () => {
vi.mocked(dialog.showMessageBox).mockResolvedValue({
response: 0,
checkboxChecked: false,
Expand All @@ -823,13 +825,14 @@ describe('auto-update', () => {

await updatePromise

// Should skip server shutdown when no port is available
expect(vi.mocked(getToolhivePort)).toHaveBeenCalled()
expect(vi.mocked(stopAllServers)).not.toHaveBeenCalled()
// Always attempts server shutdown (connection errors handled internally)
expect(vi.mocked(stopAllServers)).toHaveBeenCalled()
})

it('handles missing toolhive port gracefully', async () => {
vi.mocked(getToolhivePort).mockReturnValue(undefined)
it('handles server shutdown failure gracefully', async () => {
vi.mocked(stopAllServers).mockRejectedValueOnce(
new Error('No ToolHive connection available')
)
vi.mocked(dialog.showMessageBox).mockResolvedValue({
response: 0,
checkboxChecked: false,
Expand All @@ -848,8 +851,9 @@ describe('auto-update', () => {

await updatePromise

expect(vi.mocked(log).info).toHaveBeenCalledWith(
'[update] No ToolHive port available, skipping server shutdown'
expect(vi.mocked(log).error).toHaveBeenCalledWith(
expect.stringContaining('[update] Server shutdown failed'),
expect.anything()
)
})

Expand Down
20 changes: 11 additions & 9 deletions main/src/tests/graceful-exit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ describe('graceful-exit', () => {
createMockWorkloadsResponse([])
)

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

expect(mockLog.info).toHaveBeenCalledWith(
'No running servers – teardown complete'
Expand All @@ -140,7 +140,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledTimes(1)
expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledWith({
Expand All @@ -165,7 +165,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

expect(mockLog.info).toHaveBeenCalledWith(
'All servers have reached final state'
Expand All @@ -182,7 +182,9 @@ describe('graceful-exit', () => {
new Error('Stop failed')
)

await expect(stopAllServers('', 3000)).rejects.toThrow('Stop failed')
await expect(stopAllServers('', { port: 3000 })).rejects.toThrow(
'Stop failed'
)
})

it('handles timeout when servers do not stop', async () => {
Expand All @@ -201,7 +203,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await expect(stopAllServers('', 3000)).rejects.toThrow(
await expect(stopAllServers('', { port: 3000 })).rejects.toThrow(
'Some servers failed to stop within timeout'
)
})
Expand All @@ -213,7 +215,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

expect(mockWriteShutdownServers).toHaveBeenCalledWith(mockRunningServers)
})
Expand All @@ -234,7 +236,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

// Should only include the server with a name in the batch call
expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledTimes(1)
Expand Down Expand Up @@ -300,7 +302,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

expect(mockLog.info).toHaveBeenCalledWith(
'Still waiting for 1 servers to reach final state: server1(stopping)'
Expand All @@ -326,7 +328,7 @@ describe('graceful-exit', () => {

mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse())

await stopAllServers('', 3000)
await stopAllServers('', { port: 3000 })

// Should call delay between polling attempts (not on first attempt)
expect(mockDelay).toHaveBeenCalledWith(2000)
Expand Down
Loading
Loading