Skip to content

Commit 02cecd5

Browse files
waleedlatif1Vikhyath Mondreti
andauthored
fix(sockets): implement longer token expiration for OTT, preventitive token refresh with retries (#566)
* fix(sockets): implement longer token expiration for OTT, preventitive token refresh with retries * cleanup tests * make websocket first choice transport * fix lint --------- Co-authored-by: Vikhyath Mondreti <vikhyathmondreti@vikhyaths-air.lan>
1 parent 00334e5 commit 02cecd5

4 files changed

Lines changed: 331 additions & 5 deletions

File tree

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/**
2+
* @vitest-environment jsdom
3+
*/
4+
5+
import { act, renderHook, waitFor } from '@testing-library/react'
6+
import { io } from 'socket.io-client'
7+
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
8+
import { SocketProvider, useSocket } from './socket-context'
9+
10+
vi.mock('socket.io-client')
11+
const mockIo = vi.mocked(io)
12+
13+
global.fetch = vi.fn()
14+
const mockFetch = vi.mocked(fetch)
15+
16+
vi.mock('@/lib/logs/console-logger', () => ({
17+
createLogger: () => ({
18+
info: vi.fn(),
19+
error: vi.fn(),
20+
warn: vi.fn(),
21+
debug: vi.fn(),
22+
}),
23+
}))
24+
25+
describe('SocketContext Token Refresh', () => {
26+
let mockSocket: any
27+
let eventHandlers: Record<string, any>
28+
29+
beforeEach(() => {
30+
eventHandlers = {}
31+
mockSocket = {
32+
id: 'test-socket-id',
33+
connected: true,
34+
io: { engine: { transport: { name: 'websocket' } } },
35+
auth: { token: 'initial-token' },
36+
on: vi.fn((event, handler) => {
37+
eventHandlers[event] = handler
38+
}),
39+
connect: vi.fn(),
40+
disconnect: vi.fn(),
41+
emit: vi.fn(),
42+
close: vi.fn(),
43+
}
44+
45+
mockIo.mockReturnValue(mockSocket)
46+
47+
mockFetch.mockResolvedValue({
48+
ok: true,
49+
json: async () => ({ token: 'fresh-token' }),
50+
} as Response)
51+
})
52+
53+
afterEach(() => {
54+
vi.clearAllMocks()
55+
})
56+
57+
const renderSocketProvider = async (user = { id: 'test-user', name: 'Test User' }) => {
58+
const result = renderHook(() => useSocket(), {
59+
wrapper: ({ children }) => <SocketProvider user={user}>{children}</SocketProvider>,
60+
})
61+
62+
await waitFor(() => {
63+
expect(mockSocket.on).toHaveBeenCalledWith('connect_error', expect.any(Function))
64+
})
65+
66+
vi.clearAllMocks()
67+
68+
mockFetch.mockResolvedValue({
69+
ok: true,
70+
json: async () => ({ token: 'fresh-token' }),
71+
} as Response)
72+
73+
return result
74+
}
75+
76+
describe('Token Refresh on Connection Error', () => {
77+
it('should refresh token on authentication failure', async () => {
78+
const { result } = await renderSocketProvider()
79+
80+
const error = { message: 'Token validation failed' }
81+
82+
await act(async () => {
83+
await eventHandlers.connect_error(error)
84+
})
85+
86+
expect(mockFetch).toHaveBeenCalledWith('/api/auth/socket-token', {
87+
method: 'POST',
88+
credentials: 'include',
89+
})
90+
91+
// Should update socket auth and reconnect
92+
expect(mockSocket.auth.token).toBe('fresh-token')
93+
expect(mockSocket.connect).toHaveBeenCalled()
94+
})
95+
96+
it('should limit token refresh attempts to 3', async () => {
97+
const { result } = await renderSocketProvider()
98+
99+
const error = { message: 'Token validation failed' }
100+
101+
for (let i = 0; i < 4; i++) {
102+
await act(async () => {
103+
await eventHandlers.connect_error(error)
104+
})
105+
}
106+
107+
// Should only call fetch 3 times (max attempts)
108+
expect(mockFetch).toHaveBeenCalledTimes(3)
109+
expect(mockSocket.connect).toHaveBeenCalledTimes(3)
110+
})
111+
112+
it('should prevent concurrent token refresh attempts', async () => {
113+
const { result } = await renderSocketProvider()
114+
115+
let resolveTokenFetch!: (value: {
116+
ok: boolean
117+
json: () => Promise<{ token: string }>
118+
}) => void
119+
const slowTokenPromise = new Promise((resolve) => {
120+
resolveTokenFetch = resolve
121+
})
122+
123+
mockFetch.mockReturnValue(slowTokenPromise as any)
124+
125+
const error = { message: 'Authentication failed' }
126+
127+
// Start two concurrent refresh attempts
128+
const promise1 = act(async () => {
129+
await eventHandlers.connect_error(error)
130+
})
131+
132+
const promise2 = act(async () => {
133+
await eventHandlers.connect_error(error)
134+
})
135+
136+
// Resolve the slow fetch
137+
resolveTokenFetch({
138+
ok: true,
139+
json: async () => ({ token: 'fresh-token' }),
140+
})
141+
142+
await Promise.all([promise1, promise2])
143+
144+
// Should only call fetch once (concurrent protection)
145+
expect(mockFetch).toHaveBeenCalledTimes(1)
146+
})
147+
148+
it('should reset retry counter on successful connection', async () => {
149+
const { result } = await renderSocketProvider()
150+
151+
const error = { message: 'Token validation failed' }
152+
153+
// Use up 2 retry attempts
154+
await act(async () => {
155+
await eventHandlers.connect_error(error)
156+
})
157+
await act(async () => {
158+
await eventHandlers.connect_error(error)
159+
})
160+
161+
expect(mockFetch).toHaveBeenCalledTimes(2)
162+
163+
// Simulate successful connection (resets counter)
164+
await act(async () => {
165+
eventHandlers.connect()
166+
})
167+
168+
// Should be able to retry again (counter reset)
169+
await act(async () => {
170+
await eventHandlers.connect_error(error)
171+
})
172+
173+
expect(mockFetch).toHaveBeenCalledTimes(3)
174+
})
175+
176+
it('should handle token refresh failure gracefully', async () => {
177+
const { result } = await renderSocketProvider()
178+
179+
// Mock failed token refresh after initialization
180+
mockFetch.mockResolvedValue({
181+
ok: false,
182+
status: 401,
183+
} as Response)
184+
185+
const error = { message: 'Token validation failed' }
186+
187+
await act(async () => {
188+
await eventHandlers.connect_error(error)
189+
})
190+
191+
// Should attempt refresh but not update auth or reconnect
192+
expect(mockFetch).toHaveBeenCalled()
193+
expect(mockSocket.auth.token).toBe('initial-token') // unchanged
194+
expect(mockSocket.connect).not.toHaveBeenCalled()
195+
})
196+
197+
it('should handle fetch errors gracefully', async () => {
198+
const { result } = await renderSocketProvider()
199+
200+
// Mock fetch error after initialization
201+
mockFetch.mockRejectedValue(new Error('Network error'))
202+
203+
const error = { message: 'Authentication failed' }
204+
205+
// Should not throw error
206+
await act(async () => {
207+
await eventHandlers.connect_error(error)
208+
})
209+
210+
expect(mockFetch).toHaveBeenCalled()
211+
expect(mockSocket.connect).not.toHaveBeenCalled()
212+
})
213+
214+
it('should only refresh token on authentication-related errors', async () => {
215+
const { result } = await renderSocketProvider()
216+
217+
// Non-authentication error
218+
const networkError = { message: 'Network timeout' }
219+
220+
await act(async () => {
221+
await eventHandlers.connect_error(networkError)
222+
})
223+
224+
// Should not attempt token refresh
225+
expect(mockFetch).not.toHaveBeenCalled()
226+
expect(mockSocket.connect).not.toHaveBeenCalled()
227+
})
228+
})
229+
230+
describe('Interaction with Socket.IO Reconnection', () => {
231+
it('should work with Socket.IO built-in reconnection attempts', async () => {
232+
const { result } = await renderSocketProvider()
233+
234+
// Simulate Socket.IO reconnection cycle
235+
await act(async () => {
236+
// Reconnection attempt starts
237+
eventHandlers.reconnect_attempt(1)
238+
})
239+
240+
await act(async () => {
241+
// Fails with auth error
242+
await eventHandlers.connect_error({ message: 'Token validation failed' })
243+
})
244+
245+
// Should refresh token and attempt reconnection
246+
expect(mockFetch).toHaveBeenCalled()
247+
expect(mockSocket.connect).toHaveBeenCalled()
248+
})
249+
250+
it('should reset counters on successful reconnect', async () => {
251+
const { result } = await renderSocketProvider()
252+
253+
// Use up retry attempts
254+
const error = { message: 'Authentication failed' }
255+
await act(async () => {
256+
await eventHandlers.connect_error(error)
257+
})
258+
259+
await act(async () => {
260+
await eventHandlers.connect_error(error)
261+
})
262+
263+
expect(mockFetch).toHaveBeenCalledTimes(2)
264+
265+
// Simulate successful reconnection
266+
await act(async () => {
267+
eventHandlers.reconnect(1)
268+
})
269+
270+
// Should reset and allow new attempts
271+
await act(async () => {
272+
await eventHandlers.connect_error(error)
273+
})
274+
275+
expect(mockFetch).toHaveBeenCalledTimes(3)
276+
})
277+
})
278+
})

apps/sim/contexts/socket-context.tsx

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
8787

8888
// Connection state tracking
8989
const reconnectCount = useRef(0)
90+
const tokenRefreshAttempts = useRef(0)
91+
const isRefreshingToken = useRef(false)
9092

9193
// Use refs to store event handlers to avoid stale closures
9294
const eventHandlers = useRef<{
@@ -138,7 +140,9 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
138140
const socketInstance = io(socketUrl, {
139141
transports: ['websocket', 'polling'], // Keep polling fallback for reliability
140142
withCredentials: true,
141-
reconnectionAttempts: 5, // Back to original conservative setting
143+
reconnectionAttempts: 5, // Socket.IO handles base reconnection
144+
reconnectionDelay: 1000, // Start with 1 second delay
145+
reconnectionDelayMax: 5000, // Max 5 second delay
142146
timeout: 10000, // Back to original timeout
143147
auth: {
144148
token, // Send one-time token for authentication
@@ -150,6 +154,7 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
150154
setIsConnected(true)
151155
setIsConnecting(false)
152156
reconnectCount.current = 0
157+
tokenRefreshAttempts.current = 0
153158

154159
logger.info('Socket connected successfully', {
155160
socketId: socketInstance.id,
@@ -172,7 +177,7 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
172177
setPresenceUsers([])
173178
})
174179

175-
socketInstance.on('connect_error', (error: any) => {
180+
socketInstance.on('connect_error', async (error: any) => {
176181
setIsConnecting(false)
177182
logger.error('Socket connection error:', {
178183
message: error.message,
@@ -181,11 +186,54 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
181186
type: error.type,
182187
transport: error.transport,
183188
})
189+
190+
if (
191+
error.message?.includes('Token validation failed') ||
192+
error.message?.includes('Authentication failed')
193+
) {
194+
// Prevent infinite loops - limit refresh attempts
195+
if (tokenRefreshAttempts.current >= 3) {
196+
logger.warn('Max token refresh attempts reached - user needs to refresh page')
197+
return
198+
}
199+
200+
// Prevent concurrent refresh attempts
201+
if (isRefreshingToken.current) {
202+
logger.info('Token refresh already in progress, skipping...')
203+
return
204+
}
205+
206+
isRefreshingToken.current = true
207+
tokenRefreshAttempts.current++
208+
logger.info(`Token expired, attempting refresh (${tokenRefreshAttempts.current}/3)...`)
209+
210+
try {
211+
const tokenResponse = await fetch('/api/auth/socket-token', {
212+
method: 'POST',
213+
credentials: 'include',
214+
})
215+
216+
if (tokenResponse.ok) {
217+
const { token } = await tokenResponse.json()
218+
socketInstance.auth = { ...socketInstance.auth, token }
219+
logger.info('Token refreshed successfully, reconnecting...')
220+
socketInstance.connect()
221+
} else {
222+
logger.warn('Failed to refresh token - user may need to refresh page')
223+
}
224+
} catch (refreshError) {
225+
logger.error('Token refresh failed:', refreshError)
226+
} finally {
227+
isRefreshingToken.current = false
228+
}
229+
}
184230
})
185231

186232
// Add reconnection logging
187233
socketInstance.on('reconnect', (attemptNumber) => {
188234
reconnectCount.current = attemptNumber
235+
// Reset token refresh attempts on successful reconnection
236+
tokenRefreshAttempts.current = 0
189237
logger.info('Socket reconnected', {
190238
attemptNumber,
191239
})

apps/sim/lib/auth.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ export const auth = betterAuth({
187187
plugins: [
188188
nextCookies(),
189189
oneTimeToken({
190-
expiresIn: 10, // 10 minutes - enough time for socket connection
190+
expiresIn: 30, // 30 minutes - covers typical work sessions
191191
}),
192192
emailOTP({
193193
sendVerificationOTP: async (data: {

apps/sim/socket-server/config/socket.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ export function createSocketIOServer(httpServer: HttpServer): Server {
3636
allowedHeaders: ['Content-Type', 'Authorization', 'Cookie', 'socket.io'],
3737
credentials: true, // Enable credentials to accept cookies
3838
},
39-
transports: ['polling', 'websocket'], // Keep both transports for reliability
39+
transports: ['websocket', 'polling'], // WebSocket first, polling as fallback
4040
allowEIO3: true, // Keep legacy support for compatibility
4141
pingTimeout: 60000, // Back to original conservative setting
4242
pingInterval: 25000, // Back to original interval
@@ -52,7 +52,7 @@ export function createSocketIOServer(httpServer: HttpServer): Server {
5252

5353
logger.info('Socket.IO server configured with:', {
5454
allowedOrigins: allowedOrigins.length,
55-
transports: ['polling', 'websocket'],
55+
transports: ['websocket', 'polling'],
5656
pingTimeout: 60000,
5757
pingInterval: 25000,
5858
maxHttpBufferSize: 1e6,

0 commit comments

Comments
 (0)