From 5dd4c4129180c4335414ae5f17dc821be40ce72f Mon Sep 17 00:00:00 2001 From: Aarnav Tale Date: Thu, 6 Mar 2025 17:32:33 -0500 Subject: [PATCH] feat: use a new ws implementation thats encapsulated --- app/routes.ts | 3 + app/routes/api/agent.ts | 39 +++++++ app/routes/machines/components/machine.tsx | 20 ++-- app/routes/machines/machine.tsx | 20 ++-- app/routes/machines/overview.tsx | 12 +- app/utils/useAgent.ts | 25 ++++ server/context/app.ts | 9 +- server/context/parser.ts | 11 ++ server/entry.ts | 26 +++-- server/utils/ws.ts | 60 ---------- server/ws/cache.ts | 126 +++++++++++++++++++++ server/ws/data.ts | 61 ++++++++++ server/ws/socket.ts | 59 ++++++++++ 13 files changed, 370 insertions(+), 101 deletions(-) create mode 100644 app/routes/api/agent.ts create mode 100644 app/utils/useAgent.ts delete mode 100644 server/utils/ws.ts create mode 100644 server/ws/cache.ts create mode 100644 server/ws/data.ts create mode 100644 server/ws/socket.ts diff --git a/app/routes.ts b/app/routes.ts index 51a6363..7fcf112 100644 --- a/app/routes.ts +++ b/app/routes.ts @@ -11,6 +11,9 @@ export default [ route('/oidc/callback', 'routes/auth/oidc-callback.ts'), route('/oidc/start', 'routes/auth/oidc-start.ts'), + // API + route('/api/agent', 'routes/api/agent.ts'), + // All the main logged-in dashboard routes // Double nested to separate error propagations layout('layouts/shell.tsx', [ diff --git a/app/routes/api/agent.ts b/app/routes/api/agent.ts new file mode 100644 index 0000000..34810a3 --- /dev/null +++ b/app/routes/api/agent.ts @@ -0,0 +1,39 @@ +import { LoaderFunctionArgs } from 'react-router'; +import type { AppContext } from '~server/context/app'; + +export async function loader({ + request, + context, +}: LoaderFunctionArgs) { + if (!context?.agentData) { + return new Response(JSON.stringify({ error: 'Agent data unavailable' }), { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + }); + } + + const qp = new URLSearchParams(request.url.split('?')[1]); + const nodeIds = qp.get('node_ids')?.split(','); + if (!nodeIds) { + return new Response(JSON.stringify({ error: 'No node IDs provided' }), { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + }); + } + + const entries = context.agentData.toJSON(); + const missing = nodeIds.filter((nodeID) => !entries[nodeID]); + if (missing.length > 0) { + await context.hp_agentRequest(missing); + } + + return new Response(JSON.stringify(context.agentData), { + headers: { + 'Content-Type': 'application/json', + }, + }); +} diff --git a/app/routes/machines/components/machine.tsx b/app/routes/machines/components/machine.tsx index db955f7..f804cac 100644 --- a/app/routes/machines/components/machine.tsx +++ b/app/routes/machines/components/machine.tsx @@ -146,24 +146,18 @@ export default function MachineRow({ - {/** - {stats !== undefined ? ( - <> -

- {hinfo.getTSVersion(stats)} -

+ {stats !== undefined ? ( + <> +

{hinfo.getTSVersion(stats)}

{hinfo.getOSInfo(stats)}

- - ) : ( -

- Unknown -

- )} + + ) : ( +

Unknown

+ )} - **/} -
-

- Status -

-
- {tags.map((tag) => ( - - ))} + {tags.length > 0 ? ( +
+

+ Status +

+
+ {tags.map((tag) => ( + + ))} +
-
+ ) : undefined}

Subnets & Routing

node.nodeKey)); const ctx = context.context; const { mode, config } = hs_getConfig(); - let magic: string | undefined; if (mode !== 'no') { @@ -53,7 +49,6 @@ export async function loader({ routes: routes.routes, users: users.users, magic, - stats, server: ctx.headscale.url, publicServer: ctx.headscale.public_url, }; @@ -65,6 +60,7 @@ export async function action({ request }: ActionFunctionArgs) { export default function Page() { const data = useLoaderData(); + const { data: stats } = useAgent(data.nodes.map((node) => node.nodeKey)); return ( <> @@ -108,7 +104,7 @@ export default function Page() { ) : undefined} - {/**Version**/} + Version Last Seen @@ -127,7 +123,7 @@ export default function Page() { )} users={data.users} magic={data.magic} - stats={data.stats?.[machine.nodeKey]} + stats={stats?.[machine.nodeKey]} /> ))} diff --git a/app/utils/useAgent.ts b/app/utils/useAgent.ts new file mode 100644 index 0000000..2be16d7 --- /dev/null +++ b/app/utils/useAgent.ts @@ -0,0 +1,25 @@ +import { useEffect } from 'react'; +import { useFetcher } from 'react-router'; +import { HostInfo } from '~/types'; + +export default function useAgent(nodeIds: string[], interval = 3000) { + const fetcher = useFetcher>(); + + useEffect(() => { + const qp = new URLSearchParams({ node_ids: nodeIds.join(',') }); + fetcher.load(`/api/agent?${qp.toString()}`); + + const intervalID = setInterval(() => { + fetcher.load(`/api/agent?${qp.toString()}`); + }, interval); + + return () => { + clearInterval(intervalID); + }; + }, [fetcher, interval, nodeIds]); + + return { + data: fetcher.data, + isLoading: fetcher.state === 'loading', + }; +} diff --git a/server/context/app.ts b/server/context/app.ts index c06e79b..23449ec 100644 --- a/server/context/app.ts +++ b/server/context/app.ts @@ -1,12 +1,19 @@ +import type { HostInfo } from '~/types'; +import { TimedCache } from '~server/ws/cache'; +import { hp_agentRequest, hp_getAgentCache } from '~server/ws/data'; import { hp_getConfig } from './loader'; -import { HeadplaneConfig } from './parser'; +import type { HeadplaneConfig } from './parser'; export interface AppContext { context: HeadplaneConfig; + agentData?: TimedCache; + hp_agentRequest: typeof hp_agentRequest; } export default function appContext() { return { context: hp_getConfig(), + agentData: hp_getAgentCache(), + hp_agentRequest, }; } diff --git a/server/context/parser.ts b/server/context/parser.ts index da9d246..303c6e7 100644 --- a/server/context/parser.ts +++ b/server/context/parser.ts @@ -8,6 +8,17 @@ const serverConfig = type({ port: type('string | number.integer').pipe((v) => Number(v)), cookie_secret: '32 <= string <= 32', cookie_secure: stringToBool, + agent: type({ + authkey: 'string', + ttl: 'number.integer = 180000', // Default to 3 minutes + cache_path: 'string = "/var/lib/headplane/agent_cache.json"', + }) + .onDeepUndeclaredKey('reject') + .default(() => ({ + authkey: '', + ttl: 180000, + cache_path: '/var/lib/headplane/agent_cache.json', + })), }); const oidcConfig = type({ diff --git a/server/entry.ts b/server/entry.ts index baac614..228f426 100644 --- a/server/entry.ts +++ b/server/entry.ts @@ -1,9 +1,10 @@ -// import { initWebsocket } from '~server/ws'; import { constants, access } from 'node:fs/promises'; import { createServer } from 'node:http'; import { hp_getConfig, hp_loadConfig } from '~server/context/loader'; import { listener } from '~server/listener'; import log from '~server/utils/log'; +import { hp_loadAgentCache } from '~server/ws/data'; +import { initWebsocket } from '~server/ws/socket'; log.info('SRVX', 'Running Node.js %s', process.versions.node); @@ -19,16 +20,21 @@ try { await hp_loadConfig(); const server = createServer(listener); -// const ws = initWebsocket(); -// if (ws) { -// server.on('upgrade', (req, socket, head) => { -// ws.handleUpgrade(req, socket, head, (ws) => { -// ws.emit('connection', ws, req); -// }); -// }); -// } - const context = hp_getConfig(); +const ws = initWebsocket(context.server.agent.authkey); +if (ws) { + await hp_loadAgentCache( + context.server.agent.ttl, + context.server.agent.cache_path, + ); + + server.on('upgrade', (req, socket, head) => { + ws.handleUpgrade(req, socket, head, (ws) => { + ws.emit('connection', ws, req); + }); + }); +} + server.listen(context.server.port, context.server.host, () => { log.info( 'SRVX', diff --git a/server/utils/ws.ts b/server/utils/ws.ts deleted file mode 100644 index b01f71c..0000000 --- a/server/utils/ws.ts +++ /dev/null @@ -1,60 +0,0 @@ -import WebSocket, { WebSocketServer } from 'ws'; -import log from '~server/utils/log'; - -const server = new WebSocketServer({ noServer: true }); -export function initWebsocket() { - // TODO: Finish this and make public - return; - - const key = process.env.LOCAL_AGENT_AUTHKEY; - if (!key) { - return; - } - - log.info('CACH', 'Initializing agent WebSocket'); - server.on('connection', (ws, req) => { - // biome-ignore lint: this file is not USED - const auth = req.headers['authorization']; - if (auth !== `Bearer ${key}`) { - log.warn('CACH', 'Invalid agent WebSocket connection'); - ws.close(1008, 'ERR_INVALID_AUTH'); - return; - } - - const nodeID = req.headers['x-headplane-ts-node-id']; - if (!nodeID) { - log.warn('CACH', 'Invalid agent WebSocket connection'); - ws.close(1008, 'ERR_INVALID_NODE_ID'); - return; - } - - const pinger = setInterval(() => { - if (ws.readyState !== WebSocket.OPEN) { - clearInterval(pinger); - return; - } - - ws.ping(); - }, 30000); - - ws.on('close', () => { - clearInterval(pinger); - }); - - ws.on('error', (error) => { - clearInterval(pinger); - log.error('CACH', 'Closing agent WebSocket connection'); - log.error('CACH', 'Agent WebSocket error: %s', error); - ws.close(1011, 'ERR_INTERNAL_ERROR'); - }); - }); - - return server; -} - -export function appContext() { - return { - ws: server, - wsAuthKey: process.env.LOCAL_AGENT_AUTHKEY, - }; -} diff --git a/server/ws/cache.ts b/server/ws/cache.ts new file mode 100644 index 0000000..148018a --- /dev/null +++ b/server/ws/cache.ts @@ -0,0 +1,126 @@ +import { createHash } from 'node:crypto'; +import { readFile, writeFile } from 'node:fs/promises'; +import { type } from 'arktype'; +import log from '~server/utils/log'; +import mutex from '~server/utils/mutex'; + +const diskSchema = type({ + key: 'string', + value: 'unknown', + expires: 'number?', +}).array(); + +// A persistent HashMap with a TTL for each key +export class TimedCache { + private _cache = new Map(); + private _timings = new Map(); + + // Default TTL is 1 minute + private defaultTTL: number; + private filePath: string; + private writeLock = mutex(); + + // Last flush ID is essentially a hash of the flush contents + // Prevents unnecessary flushing if nothing has changed + private lastFlushId = ''; + + constructor(defaultTTL: number, filePath: string) { + this.defaultTTL = defaultTTL; + this.filePath = filePath; + + // Load the cache from disk and then queue flushes every 10 seconds + this.load().then(() => { + setInterval(() => this.flush(), 10000); + }); + } + + set(key: string, value: V, ttl: number = this.defaultTTL) { + this._cache.set(key, value); + this._timings.set(key, Date.now() + ttl); + } + + get(key: string) { + const value = this._cache.get(key); + if (!value) { + return; + } + + const expires = this._timings.get(key); + if (!expires || expires < Date.now()) { + this._cache.delete(key); + this._timings.delete(key); + return; + } + + return value; + } + + // Map into a Record without any TTLs + toJSON() { + const result: Record = {}; + for (const [key, value] of this._cache.entries()) { + result[key] = value; + } + + return result; + } + + // WARNING: This function expects that this.filePath is NOT ENOENT + private async load() { + const data = await readFile(this.filePath, 'utf-8'); + const cache = () => { + try { + return JSON.parse(data); + } catch (e) { + return undefined; + } + }; + + const diskData = cache(); + if (diskData === undefined) { + log.error('CACH', 'Failed to load cache at %s', this.filePath); + return; + } + + const cacheData = diskSchema(diskData); + if (cacheData instanceof type.errors) { + log.error('CACH', 'Failed to load cache at %s', this.filePath); + log.debug('CACHE', 'Error details: %s', cacheData.toString()); + + // Skip loading the cache (it should be overwritten soon) + return; + } + + for (const { key, value, expires } of diskData) { + this._cache.set(key, value); + this._timings.set(key, expires); + } + + log.info('CACH', 'Loaded cache from %s', this.filePath); + } + + private async flush() { + this.writeLock.acquire(); + const data = Array.from(this._cache.entries()).map(([key, value]) => { + return { key, value, expires: this._timings.get(key) }; + }); + + if (data.length === 0) { + this.writeLock.release(); + return; + } + + // Calculate the hash of the data + const dumpData = JSON.stringify(data); + const sha = createHash('sha256').update(dumpData).digest('hex'); + if (sha === this.lastFlushId) { + this.writeLock.release(); + return; + } + + await writeFile(this.filePath, dumpData, 'utf-8'); + this.lastFlushId = sha; + this.writeLock.release(); + log.debug('CACH', 'Flushed cache to %s', this.filePath); + } +} diff --git a/server/ws/data.ts b/server/ws/data.ts new file mode 100644 index 0000000..e2245d3 --- /dev/null +++ b/server/ws/data.ts @@ -0,0 +1,61 @@ +import { open } from 'node:fs/promises'; +import type { HostInfo } from '~/types'; +import log from '~server/utils/log'; +import { TimedCache } from './cache'; +import { hp_getAgents } from './socket'; + +let cache: TimedCache | undefined; +export async function hp_loadAgentCache(defaultTTL: number, filepath: string) { + log.debug('CACH', `Loading agent cache from ${filepath}`); + + try { + const handle = await open(filepath, 'w'); + log.info('CACH', `Using agent cache file at ${filepath}`); + await handle.close(); + } catch (e) { + log.info('CACH', `Agent cache file not found at ${filepath}`); + return; + } + + cache = new TimedCache(defaultTTL, filepath); +} + +export function hp_getAgentCache() { + return cache; +} + +export async function hp_agentRequest(nodeList: string[]) { + // Request to all connected agents (we can have multiple) + // Luckily we can parse all the data at once through message parsing + // and then overlapping cache entries will be overwritten by time + const agents = [...hp_getAgents()]; + console.log(agents); + + // Deduplicate the list of nodes + const NodeIDs = [...new Set(nodeList)]; + NodeIDs.map((node) => { + log.debug('CACH', 'Requesting agent data for', node); + }); + + // Await so that data loads on first request without racing + // Since we do agent.once() we NEED to wait for it to finish + await Promise.allSettled( + agents.map(async (agent) => { + agent.send(JSON.stringify({ NodeIDs })); + await new Promise((resolve) => { + // Just as a safety measure, we set a maximum timeout of 3 seconds + setTimeout(() => resolve(), 3000); + + agent.once('message', (data) => { + const parsed = JSON.parse(data.toString()); + for (const [node, info] of Object.entries(parsed)) { + cache?.set(node, info); + log.debug('CACH', 'Cached %s', node); + } + + resolve(); + }); + }); + }), + ); +} diff --git a/server/ws/socket.ts b/server/ws/socket.ts new file mode 100644 index 0000000..6fc10d0 --- /dev/null +++ b/server/ws/socket.ts @@ -0,0 +1,59 @@ +import WebSocket, { WebSocketServer } from 'ws'; +import log from '~server/utils/log'; + +const server = new WebSocketServer({ noServer: true }); +export function initWebsocket(authKey: string) { + if (authKey.length === 0) { + return; + } + + log.info('SRVX', 'Starting a WebSocket server for agent connections'); + server.on('connection', (ws, req) => { + const tailnetID = req.headers['x-headplane-tailnet-id']; + if (!tailnetID) { + log.warn( + 'SRVX', + 'Rejecting an agent WebSocket connection without a tailnet ID', + ); + ws.close(1008, 'ERR_INVALID_TAILNET_ID'); + return; + } + + if (req.headers.authorization !== `Bearer ${authKey}`) { + log.warn('SRVX', 'Rejecting an unauthorized WebSocket connection'); + if (req.socket.remoteAddress) { + log.warn('SRVX', 'Agent source IP: %s', req.socket.remoteAddress); + } + + ws.close(1008, 'ERR_UNAUTHORIZED'); + return; + } + + const pinger = setInterval(() => { + if (ws.readyState !== WebSocket.OPEN) { + clearInterval(pinger); + return; + } + + ws.ping(); + }, 30000); + + ws.on('close', () => { + clearInterval(pinger); + }); + + ws.on('error', (error) => { + clearInterval(pinger); + log.error('SRVX', 'Agent WebSocket error: %s', error); + log.debug('SRVX', 'Error details: %o', error); + log.error('SRVX', 'Closing agent WebSocket connection'); + ws.close(1011, 'ERR_INTERNAL_ERROR'); + }); + }); + + return server; +} + +export function hp_getAgents() { + return server.clients; +}