|
1 | 1 | import { createConnection } from "net" |
| 2 | +import { createServer } from "http" |
2 | 3 | import { Log } from "../util/log" |
3 | 4 | import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider" |
4 | 5 |
|
@@ -52,105 +53,105 @@ interface PendingAuth { |
52 | 53 | } |
53 | 54 |
|
54 | 55 | export namespace McpOAuthCallback { |
55 | | - let server: ReturnType<typeof Bun.serve> | undefined |
| 56 | + let server: ReturnType<typeof createServer> | undefined |
56 | 57 | const pendingAuths = new Map<string, PendingAuth>() |
57 | 58 | // Reverse index: mcpName → oauthState, so cancelPending(mcpName) can |
58 | 59 | // find the right entry in pendingAuths (which is keyed by oauthState). |
59 | 60 | const mcpNameToState = new Map<string, string>() |
60 | 61 |
|
61 | 62 | const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes |
62 | 63 |
|
63 | | - export async function ensureRunning(): Promise<void> { |
64 | | - if (server) return |
| 64 | + function cleanupStateIndex(oauthState: string) { |
| 65 | + for (const [name, state] of mcpNameToState) { |
| 66 | + if (state === oauthState) { |
| 67 | + mcpNameToState.delete(name) |
| 68 | + break |
| 69 | + } |
| 70 | + } |
| 71 | + } |
65 | 72 |
|
66 | | - const running = await isPortInUse() |
67 | | - if (running) { |
68 | | - log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT }) |
| 73 | + function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) { |
| 74 | + const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`) |
| 75 | + |
| 76 | + if (url.pathname !== OAUTH_CALLBACK_PATH) { |
| 77 | + res.writeHead(404) |
| 78 | + res.end("Not found") |
69 | 79 | return |
70 | 80 | } |
71 | 81 |
|
72 | | - server = Bun.serve({ |
73 | | - port: OAUTH_CALLBACK_PORT, |
74 | | - fetch(req) { |
75 | | - const url = new URL(req.url) |
| 82 | + const code = url.searchParams.get("code") |
| 83 | + const state = url.searchParams.get("state") |
| 84 | + const error = url.searchParams.get("error") |
| 85 | + const errorDescription = url.searchParams.get("error_description") |
76 | 86 |
|
77 | | - if (url.pathname !== OAUTH_CALLBACK_PATH) { |
78 | | - return new Response("Not found", { status: 404 }) |
79 | | - } |
| 87 | + log.info("received oauth callback", { hasCode: !!code, state, error }) |
80 | 88 |
|
81 | | - const code = url.searchParams.get("code") |
82 | | - const state = url.searchParams.get("state") |
83 | | - const error = url.searchParams.get("error") |
84 | | - const errorDescription = url.searchParams.get("error_description") |
85 | | - |
86 | | - log.info("received oauth callback", { hasCode: !!code, state, error }) |
87 | | - |
88 | | - // Enforce state parameter presence |
89 | | - if (!state) { |
90 | | - const errorMsg = "Missing required state parameter - potential CSRF attack" |
91 | | - log.error("oauth callback missing state parameter", { url: url.toString() }) |
92 | | - return new Response(HTML_ERROR(errorMsg), { |
93 | | - status: 400, |
94 | | - headers: { "Content-Type": "text/html" }, |
95 | | - }) |
96 | | - } |
| 89 | + // Enforce state parameter presence |
| 90 | + if (!state) { |
| 91 | + const errorMsg = "Missing required state parameter - potential CSRF attack" |
| 92 | + log.error("oauth callback missing state parameter", { url: url.toString() }) |
| 93 | + res.writeHead(400, { "Content-Type": "text/html" }) |
| 94 | + res.end(HTML_ERROR(errorMsg)) |
| 95 | + return |
| 96 | + } |
97 | 97 |
|
98 | | - if (error) { |
99 | | - const errorMsg = errorDescription || error |
100 | | - if (pendingAuths.has(state)) { |
101 | | - const pending = pendingAuths.get(state)! |
102 | | - clearTimeout(pending.timeout) |
103 | | - pendingAuths.delete(state) |
104 | | - for (const [name, s] of mcpNameToState) { |
105 | | - if (s === state) { |
106 | | - mcpNameToState.delete(name) |
107 | | - break |
108 | | - } |
109 | | - } |
110 | | - pending.reject(new Error(errorMsg)) |
111 | | - } |
112 | | - return new Response(HTML_ERROR(errorMsg), { |
113 | | - headers: { "Content-Type": "text/html" }, |
114 | | - }) |
115 | | - } |
| 98 | + if (error) { |
| 99 | + const errorMsg = errorDescription || error |
| 100 | + if (pendingAuths.has(state)) { |
| 101 | + const pending = pendingAuths.get(state)! |
| 102 | + clearTimeout(pending.timeout) |
| 103 | + pendingAuths.delete(state) |
| 104 | + cleanupStateIndex(state) |
| 105 | + pending.reject(new Error(errorMsg)) |
| 106 | + } |
| 107 | + res.writeHead(200, { "Content-Type": "text/html" }) |
| 108 | + res.end(HTML_ERROR(errorMsg)) |
| 109 | + return |
| 110 | + } |
116 | 111 |
|
117 | | - if (!code) { |
118 | | - return new Response(HTML_ERROR("No authorization code provided"), { |
119 | | - status: 400, |
120 | | - headers: { "Content-Type": "text/html" }, |
121 | | - }) |
122 | | - } |
| 112 | + if (!code) { |
| 113 | + res.writeHead(400, { "Content-Type": "text/html" }) |
| 114 | + res.end(HTML_ERROR("No authorization code provided")) |
| 115 | + return |
| 116 | + } |
123 | 117 |
|
124 | | - // Validate state parameter |
125 | | - if (!pendingAuths.has(state)) { |
126 | | - const errorMsg = "Invalid or expired state parameter - potential CSRF attack" |
127 | | - log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) |
128 | | - return new Response(HTML_ERROR(errorMsg), { |
129 | | - status: 400, |
130 | | - headers: { "Content-Type": "text/html" }, |
131 | | - }) |
132 | | - } |
| 118 | + // Validate state parameter |
| 119 | + if (!pendingAuths.has(state)) { |
| 120 | + const errorMsg = "Invalid or expired state parameter - potential CSRF attack" |
| 121 | + log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) |
| 122 | + res.writeHead(400, { "Content-Type": "text/html" }) |
| 123 | + res.end(HTML_ERROR(errorMsg)) |
| 124 | + return |
| 125 | + } |
133 | 126 |
|
134 | | - const pending = pendingAuths.get(state)! |
| 127 | + const pending = pendingAuths.get(state)! |
135 | 128 |
|
136 | | - clearTimeout(pending.timeout) |
137 | | - pendingAuths.delete(state) |
138 | | - // Clean up reverse index |
139 | | - for (const [name, s] of mcpNameToState) { |
140 | | - if (s === state) { |
141 | | - mcpNameToState.delete(name) |
142 | | - break |
143 | | - } |
144 | | - } |
145 | | - pending.resolve(code) |
| 129 | + clearTimeout(pending.timeout) |
| 130 | + pendingAuths.delete(state) |
| 131 | + cleanupStateIndex(state) |
| 132 | + pending.resolve(code) |
146 | 133 |
|
147 | | - return new Response(HTML_SUCCESS, { |
148 | | - headers: { "Content-Type": "text/html" }, |
149 | | - }) |
150 | | - }, |
151 | | - }) |
| 134 | + res.writeHead(200, { "Content-Type": "text/html" }) |
| 135 | + res.end(HTML_SUCCESS) |
| 136 | + } |
| 137 | + |
| 138 | + export async function ensureRunning(): Promise<void> { |
| 139 | + if (server) return |
| 140 | + |
| 141 | + const running = await isPortInUse() |
| 142 | + if (running) { |
| 143 | + log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT }) |
| 144 | + return |
| 145 | + } |
152 | 146 |
|
153 | | - log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) |
| 147 | + server = createServer(handleRequest) |
| 148 | + await new Promise<void>((resolve, reject) => { |
| 149 | + server!.listen(OAUTH_CALLBACK_PORT, () => { |
| 150 | + log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) |
| 151 | + resolve() |
| 152 | + }) |
| 153 | + server!.on("error", reject) |
| 154 | + }) |
154 | 155 | } |
155 | 156 |
|
156 | 157 | export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> { |
@@ -196,7 +197,7 @@ export namespace McpOAuthCallback { |
196 | 197 |
|
197 | 198 | export async function stop(): Promise<void> { |
198 | 199 | if (server) { |
199 | | - server.stop() |
| 200 | + await new Promise<void>((resolve) => server!.close(() => resolve())) |
200 | 201 | server = undefined |
201 | 202 | log.info("oauth callback server stopped") |
202 | 203 | } |
|
0 commit comments