feat: use a new ws implementation thats encapsulated
This commit is contained in:
parent
45537620a6
commit
5dd4c41291
@ -11,6 +11,9 @@ export default [
|
||||
route('/oidc/callback', 'routes/auth/oidc-callback.ts'),
|
||||
route('/oidc/start', 'routes/auth/oidc-start.ts'),
|
||||
|
||||
// API
|
||||
route('/api/agent', 'routes/api/agent.ts'),
|
||||
|
||||
// All the main logged-in dashboard routes
|
||||
// Double nested to separate error propagations
|
||||
layout('layouts/shell.tsx', [
|
||||
|
||||
39
app/routes/api/agent.ts
Normal file
39
app/routes/api/agent.ts
Normal file
@ -0,0 +1,39 @@
|
||||
import { LoaderFunctionArgs } from 'react-router';
|
||||
import type { AppContext } from '~server/context/app';
|
||||
|
||||
export async function loader({
|
||||
request,
|
||||
context,
|
||||
}: LoaderFunctionArgs<AppContext>) {
|
||||
if (!context?.agentData) {
|
||||
return new Response(JSON.stringify({ error: 'Agent data unavailable' }), {
|
||||
status: 400,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const qp = new URLSearchParams(request.url.split('?')[1]);
|
||||
const nodeIds = qp.get('node_ids')?.split(',');
|
||||
if (!nodeIds) {
|
||||
return new Response(JSON.stringify({ error: 'No node IDs provided' }), {
|
||||
status: 400,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const entries = context.agentData.toJSON();
|
||||
const missing = nodeIds.filter((nodeID) => !entries[nodeID]);
|
||||
if (missing.length > 0) {
|
||||
await context.hp_agentRequest(missing);
|
||||
}
|
||||
|
||||
return new Response(JSON.stringify(context.agentData), {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
}
|
||||
@ -146,24 +146,18 @@ export default function MachineRow({
|
||||
</Menu>
|
||||
</div>
|
||||
</td>
|
||||
{/**
|
||||
<td className="py-2">
|
||||
{stats !== undefined ? (
|
||||
<>
|
||||
<p className="leading-snug">
|
||||
{hinfo.getTSVersion(stats)}
|
||||
</p>
|
||||
<p className="leading-snug">{hinfo.getTSVersion(stats)}</p>
|
||||
<p className="text-sm opacity-50 max-w-48 truncate">
|
||||
{hinfo.getOSInfo(stats)}
|
||||
</p>
|
||||
</>
|
||||
) : (
|
||||
<p className="text-sm opacity-50">
|
||||
Unknown
|
||||
</p>
|
||||
<p className="text-sm opacity-50">Unknown</p>
|
||||
)}
|
||||
</td>
|
||||
**/}
|
||||
<td className="py-2">
|
||||
<span
|
||||
className={cn(
|
||||
|
||||
@ -148,6 +148,7 @@ export default function Page() {
|
||||
{machine.user.name}
|
||||
</div>
|
||||
</div>
|
||||
{tags.length > 0 ? (
|
||||
<div className="p-2 pl-4">
|
||||
<p className="text-sm text-headplane-600 dark:text-headplane-300">
|
||||
Status
|
||||
@ -158,6 +159,7 @@ export default function Page() {
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : undefined}
|
||||
</div>
|
||||
<h2 className="text-xl font-medium mb-4 mt-8">Subnets & Routing</h2>
|
||||
<Routes
|
||||
|
||||
@ -9,11 +9,11 @@ import type { Machine, Route, User } from '~/types';
|
||||
import cn from '~/utils/cn';
|
||||
import { pull } from '~/utils/headscale';
|
||||
import { getSession } from '~/utils/sessions.server';
|
||||
import { initAgentSocket, queryAgent } from '~/utils/ws-agent';
|
||||
|
||||
import Tooltip from '~/components/Tooltip';
|
||||
import { hs_getConfig } from '~/utils/config/loader';
|
||||
import { noContext } from '~/utils/log';
|
||||
import useAgent from '~/utils/useAgent';
|
||||
import { AppContext } from '~server/context/app';
|
||||
import { menuAction } from './action';
|
||||
import MachineRow from './components/machine';
|
||||
@ -34,12 +34,8 @@ export async function loader({
|
||||
throw noContext();
|
||||
}
|
||||
|
||||
initAgentSocket(context);
|
||||
|
||||
const stats = await queryAgent(machines.nodes.map((node) => node.nodeKey));
|
||||
const ctx = context.context;
|
||||
const { mode, config } = hs_getConfig();
|
||||
|
||||
let magic: string | undefined;
|
||||
|
||||
if (mode !== 'no') {
|
||||
@ -53,7 +49,6 @@ export async function loader({
|
||||
routes: routes.routes,
|
||||
users: users.users,
|
||||
magic,
|
||||
stats,
|
||||
server: ctx.headscale.url,
|
||||
publicServer: ctx.headscale.public_url,
|
||||
};
|
||||
@ -65,6 +60,7 @@ export async function action({ request }: ActionFunctionArgs) {
|
||||
|
||||
export default function Page() {
|
||||
const data = useLoaderData<typeof loader>();
|
||||
const { data: stats } = useAgent(data.nodes.map((node) => node.nodeKey));
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -108,7 +104,7 @@ export default function Page() {
|
||||
) : undefined}
|
||||
</div>
|
||||
</th>
|
||||
{/**<th className="uppercase text-xs font-bold pb-2">Version</th>**/}
|
||||
<th className="uppercase text-xs font-bold pb-2">Version</th>
|
||||
<th className="uppercase text-xs font-bold pb-2">Last Seen</th>
|
||||
</tr>
|
||||
</thead>
|
||||
@ -127,7 +123,7 @@ export default function Page() {
|
||||
)}
|
||||
users={data.users}
|
||||
magic={data.magic}
|
||||
stats={data.stats?.[machine.nodeKey]}
|
||||
stats={stats?.[machine.nodeKey]}
|
||||
/>
|
||||
))}
|
||||
</tbody>
|
||||
|
||||
25
app/utils/useAgent.ts
Normal file
25
app/utils/useAgent.ts
Normal file
@ -0,0 +1,25 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useFetcher } from 'react-router';
|
||||
import { HostInfo } from '~/types';
|
||||
|
||||
export default function useAgent(nodeIds: string[], interval = 3000) {
|
||||
const fetcher = useFetcher<Record<string, HostInfo>>();
|
||||
|
||||
useEffect(() => {
|
||||
const qp = new URLSearchParams({ node_ids: nodeIds.join(',') });
|
||||
fetcher.load(`/api/agent?${qp.toString()}`);
|
||||
|
||||
const intervalID = setInterval(() => {
|
||||
fetcher.load(`/api/agent?${qp.toString()}`);
|
||||
}, interval);
|
||||
|
||||
return () => {
|
||||
clearInterval(intervalID);
|
||||
};
|
||||
}, [fetcher, interval, nodeIds]);
|
||||
|
||||
return {
|
||||
data: fetcher.data,
|
||||
isLoading: fetcher.state === 'loading',
|
||||
};
|
||||
}
|
||||
@ -1,12 +1,19 @@
|
||||
import type { HostInfo } from '~/types';
|
||||
import { TimedCache } from '~server/ws/cache';
|
||||
import { hp_agentRequest, hp_getAgentCache } from '~server/ws/data';
|
||||
import { hp_getConfig } from './loader';
|
||||
import { HeadplaneConfig } from './parser';
|
||||
import type { HeadplaneConfig } from './parser';
|
||||
|
||||
export interface AppContext {
|
||||
context: HeadplaneConfig;
|
||||
agentData?: TimedCache<HostInfo>;
|
||||
hp_agentRequest: typeof hp_agentRequest;
|
||||
}
|
||||
|
||||
export default function appContext() {
|
||||
return {
|
||||
context: hp_getConfig(),
|
||||
agentData: hp_getAgentCache(),
|
||||
hp_agentRequest,
|
||||
};
|
||||
}
|
||||
|
||||
@ -8,6 +8,17 @@ const serverConfig = type({
|
||||
port: type('string | number.integer').pipe((v) => Number(v)),
|
||||
cookie_secret: '32 <= string <= 32',
|
||||
cookie_secure: stringToBool,
|
||||
agent: type({
|
||||
authkey: 'string',
|
||||
ttl: 'number.integer = 180000', // Default to 3 minutes
|
||||
cache_path: 'string = "/var/lib/headplane/agent_cache.json"',
|
||||
})
|
||||
.onDeepUndeclaredKey('reject')
|
||||
.default(() => ({
|
||||
authkey: '',
|
||||
ttl: 180000,
|
||||
cache_path: '/var/lib/headplane/agent_cache.json',
|
||||
})),
|
||||
});
|
||||
|
||||
const oidcConfig = type({
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
// import { initWebsocket } from '~server/ws';
|
||||
import { constants, access } from 'node:fs/promises';
|
||||
import { createServer } from 'node:http';
|
||||
import { hp_getConfig, hp_loadConfig } from '~server/context/loader';
|
||||
import { listener } from '~server/listener';
|
||||
import log from '~server/utils/log';
|
||||
import { hp_loadAgentCache } from '~server/ws/data';
|
||||
import { initWebsocket } from '~server/ws/socket';
|
||||
|
||||
log.info('SRVX', 'Running Node.js %s', process.versions.node);
|
||||
|
||||
@ -19,16 +20,21 @@ try {
|
||||
await hp_loadConfig();
|
||||
const server = createServer(listener);
|
||||
|
||||
// const ws = initWebsocket();
|
||||
// if (ws) {
|
||||
// server.on('upgrade', (req, socket, head) => {
|
||||
// ws.handleUpgrade(req, socket, head, (ws) => {
|
||||
// ws.emit('connection', ws, req);
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
const context = hp_getConfig();
|
||||
const ws = initWebsocket(context.server.agent.authkey);
|
||||
if (ws) {
|
||||
await hp_loadAgentCache(
|
||||
context.server.agent.ttl,
|
||||
context.server.agent.cache_path,
|
||||
);
|
||||
|
||||
server.on('upgrade', (req, socket, head) => {
|
||||
ws.handleUpgrade(req, socket, head, (ws) => {
|
||||
ws.emit('connection', ws, req);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
server.listen(context.server.port, context.server.host, () => {
|
||||
log.info(
|
||||
'SRVX',
|
||||
|
||||
@ -1,60 +0,0 @@
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import log from '~server/utils/log';
|
||||
|
||||
const server = new WebSocketServer({ noServer: true });
|
||||
export function initWebsocket() {
|
||||
// TODO: Finish this and make public
|
||||
return;
|
||||
|
||||
const key = process.env.LOCAL_AGENT_AUTHKEY;
|
||||
if (!key) {
|
||||
return;
|
||||
}
|
||||
|
||||
log.info('CACH', 'Initializing agent WebSocket');
|
||||
server.on('connection', (ws, req) => {
|
||||
// biome-ignore lint: this file is not USED
|
||||
const auth = req.headers['authorization'];
|
||||
if (auth !== `Bearer ${key}`) {
|
||||
log.warn('CACH', 'Invalid agent WebSocket connection');
|
||||
ws.close(1008, 'ERR_INVALID_AUTH');
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeID = req.headers['x-headplane-ts-node-id'];
|
||||
if (!nodeID) {
|
||||
log.warn('CACH', 'Invalid agent WebSocket connection');
|
||||
ws.close(1008, 'ERR_INVALID_NODE_ID');
|
||||
return;
|
||||
}
|
||||
|
||||
const pinger = setInterval(() => {
|
||||
if (ws.readyState !== WebSocket.OPEN) {
|
||||
clearInterval(pinger);
|
||||
return;
|
||||
}
|
||||
|
||||
ws.ping();
|
||||
}, 30000);
|
||||
|
||||
ws.on('close', () => {
|
||||
clearInterval(pinger);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
clearInterval(pinger);
|
||||
log.error('CACH', 'Closing agent WebSocket connection');
|
||||
log.error('CACH', 'Agent WebSocket error: %s', error);
|
||||
ws.close(1011, 'ERR_INTERNAL_ERROR');
|
||||
});
|
||||
});
|
||||
|
||||
return server;
|
||||
}
|
||||
|
||||
export function appContext() {
|
||||
return {
|
||||
ws: server,
|
||||
wsAuthKey: process.env.LOCAL_AGENT_AUTHKEY,
|
||||
};
|
||||
}
|
||||
126
server/ws/cache.ts
Normal file
126
server/ws/cache.ts
Normal file
@ -0,0 +1,126 @@
|
||||
import { createHash } from 'node:crypto';
|
||||
import { readFile, writeFile } from 'node:fs/promises';
|
||||
import { type } from 'arktype';
|
||||
import log from '~server/utils/log';
|
||||
import mutex from '~server/utils/mutex';
|
||||
|
||||
const diskSchema = type({
|
||||
key: 'string',
|
||||
value: 'unknown',
|
||||
expires: 'number?',
|
||||
}).array();
|
||||
|
||||
// A persistent HashMap with a TTL for each key
|
||||
export class TimedCache<V> {
|
||||
private _cache = new Map<string, V>();
|
||||
private _timings = new Map<string, number>();
|
||||
|
||||
// Default TTL is 1 minute
|
||||
private defaultTTL: number;
|
||||
private filePath: string;
|
||||
private writeLock = mutex();
|
||||
|
||||
// Last flush ID is essentially a hash of the flush contents
|
||||
// Prevents unnecessary flushing if nothing has changed
|
||||
private lastFlushId = '';
|
||||
|
||||
constructor(defaultTTL: number, filePath: string) {
|
||||
this.defaultTTL = defaultTTL;
|
||||
this.filePath = filePath;
|
||||
|
||||
// Load the cache from disk and then queue flushes every 10 seconds
|
||||
this.load().then(() => {
|
||||
setInterval(() => this.flush(), 10000);
|
||||
});
|
||||
}
|
||||
|
||||
set(key: string, value: V, ttl: number = this.defaultTTL) {
|
||||
this._cache.set(key, value);
|
||||
this._timings.set(key, Date.now() + ttl);
|
||||
}
|
||||
|
||||
get(key: string) {
|
||||
const value = this._cache.get(key);
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
const expires = this._timings.get(key);
|
||||
if (!expires || expires < Date.now()) {
|
||||
this._cache.delete(key);
|
||||
this._timings.delete(key);
|
||||
return;
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
// Map into a Record without any TTLs
|
||||
toJSON() {
|
||||
const result: Record<string, V> = {};
|
||||
for (const [key, value] of this._cache.entries()) {
|
||||
result[key] = value;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// WARNING: This function expects that this.filePath is NOT ENOENT
|
||||
private async load() {
|
||||
const data = await readFile(this.filePath, 'utf-8');
|
||||
const cache = () => {
|
||||
try {
|
||||
return JSON.parse(data);
|
||||
} catch (e) {
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const diskData = cache();
|
||||
if (diskData === undefined) {
|
||||
log.error('CACH', 'Failed to load cache at %s', this.filePath);
|
||||
return;
|
||||
}
|
||||
|
||||
const cacheData = diskSchema(diskData);
|
||||
if (cacheData instanceof type.errors) {
|
||||
log.error('CACH', 'Failed to load cache at %s', this.filePath);
|
||||
log.debug('CACHE', 'Error details: %s', cacheData.toString());
|
||||
|
||||
// Skip loading the cache (it should be overwritten soon)
|
||||
return;
|
||||
}
|
||||
|
||||
for (const { key, value, expires } of diskData) {
|
||||
this._cache.set(key, value);
|
||||
this._timings.set(key, expires);
|
||||
}
|
||||
|
||||
log.info('CACH', 'Loaded cache from %s', this.filePath);
|
||||
}
|
||||
|
||||
private async flush() {
|
||||
this.writeLock.acquire();
|
||||
const data = Array.from(this._cache.entries()).map(([key, value]) => {
|
||||
return { key, value, expires: this._timings.get(key) };
|
||||
});
|
||||
|
||||
if (data.length === 0) {
|
||||
this.writeLock.release();
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate the hash of the data
|
||||
const dumpData = JSON.stringify(data);
|
||||
const sha = createHash('sha256').update(dumpData).digest('hex');
|
||||
if (sha === this.lastFlushId) {
|
||||
this.writeLock.release();
|
||||
return;
|
||||
}
|
||||
|
||||
await writeFile(this.filePath, dumpData, 'utf-8');
|
||||
this.lastFlushId = sha;
|
||||
this.writeLock.release();
|
||||
log.debug('CACH', 'Flushed cache to %s', this.filePath);
|
||||
}
|
||||
}
|
||||
61
server/ws/data.ts
Normal file
61
server/ws/data.ts
Normal file
@ -0,0 +1,61 @@
|
||||
import { open } from 'node:fs/promises';
|
||||
import type { HostInfo } from '~/types';
|
||||
import log from '~server/utils/log';
|
||||
import { TimedCache } from './cache';
|
||||
import { hp_getAgents } from './socket';
|
||||
|
||||
let cache: TimedCache<HostInfo> | undefined;
|
||||
export async function hp_loadAgentCache(defaultTTL: number, filepath: string) {
|
||||
log.debug('CACH', `Loading agent cache from ${filepath}`);
|
||||
|
||||
try {
|
||||
const handle = await open(filepath, 'w');
|
||||
log.info('CACH', `Using agent cache file at ${filepath}`);
|
||||
await handle.close();
|
||||
} catch (e) {
|
||||
log.info('CACH', `Agent cache file not found at ${filepath}`);
|
||||
return;
|
||||
}
|
||||
|
||||
cache = new TimedCache(defaultTTL, filepath);
|
||||
}
|
||||
|
||||
export function hp_getAgentCache() {
|
||||
return cache;
|
||||
}
|
||||
|
||||
export async function hp_agentRequest(nodeList: string[]) {
|
||||
// Request to all connected agents (we can have multiple)
|
||||
// Luckily we can parse all the data at once through message parsing
|
||||
// and then overlapping cache entries will be overwritten by time
|
||||
const agents = [...hp_getAgents()];
|
||||
console.log(agents);
|
||||
|
||||
// Deduplicate the list of nodes
|
||||
const NodeIDs = [...new Set(nodeList)];
|
||||
NodeIDs.map((node) => {
|
||||
log.debug('CACH', 'Requesting agent data for', node);
|
||||
});
|
||||
|
||||
// Await so that data loads on first request without racing
|
||||
// Since we do agent.once() we NEED to wait for it to finish
|
||||
await Promise.allSettled(
|
||||
agents.map(async (agent) => {
|
||||
agent.send(JSON.stringify({ NodeIDs }));
|
||||
await new Promise<void>((resolve) => {
|
||||
// Just as a safety measure, we set a maximum timeout of 3 seconds
|
||||
setTimeout(() => resolve(), 3000);
|
||||
|
||||
agent.once('message', (data) => {
|
||||
const parsed = JSON.parse(data.toString());
|
||||
for (const [node, info] of Object.entries<HostInfo>(parsed)) {
|
||||
cache?.set(node, info);
|
||||
log.debug('CACH', 'Cached %s', node);
|
||||
}
|
||||
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
}),
|
||||
);
|
||||
}
|
||||
59
server/ws/socket.ts
Normal file
59
server/ws/socket.ts
Normal file
@ -0,0 +1,59 @@
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import log from '~server/utils/log';
|
||||
|
||||
const server = new WebSocketServer({ noServer: true });
|
||||
export function initWebsocket(authKey: string) {
|
||||
if (authKey.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
log.info('SRVX', 'Starting a WebSocket server for agent connections');
|
||||
server.on('connection', (ws, req) => {
|
||||
const tailnetID = req.headers['x-headplane-tailnet-id'];
|
||||
if (!tailnetID) {
|
||||
log.warn(
|
||||
'SRVX',
|
||||
'Rejecting an agent WebSocket connection without a tailnet ID',
|
||||
);
|
||||
ws.close(1008, 'ERR_INVALID_TAILNET_ID');
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.headers.authorization !== `Bearer ${authKey}`) {
|
||||
log.warn('SRVX', 'Rejecting an unauthorized WebSocket connection');
|
||||
if (req.socket.remoteAddress) {
|
||||
log.warn('SRVX', 'Agent source IP: %s', req.socket.remoteAddress);
|
||||
}
|
||||
|
||||
ws.close(1008, 'ERR_UNAUTHORIZED');
|
||||
return;
|
||||
}
|
||||
|
||||
const pinger = setInterval(() => {
|
||||
if (ws.readyState !== WebSocket.OPEN) {
|
||||
clearInterval(pinger);
|
||||
return;
|
||||
}
|
||||
|
||||
ws.ping();
|
||||
}, 30000);
|
||||
|
||||
ws.on('close', () => {
|
||||
clearInterval(pinger);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
clearInterval(pinger);
|
||||
log.error('SRVX', 'Agent WebSocket error: %s', error);
|
||||
log.debug('SRVX', 'Error details: %o', error);
|
||||
log.error('SRVX', 'Closing agent WebSocket connection');
|
||||
ws.close(1011, 'ERR_INTERNAL_ERROR');
|
||||
});
|
||||
});
|
||||
|
||||
return server;
|
||||
}
|
||||
|
||||
export function hp_getAgents() {
|
||||
return server.clients;
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user