diff --git a/.claude/rules/sim-testing.md b/.claude/rules/sim-testing.md index 36b19dc0d6..4e28169449 100644 --- a/.claude/rules/sim-testing.md +++ b/.claude/rules/sim-testing.md @@ -217,13 +217,20 @@ it('reads a row', async () => { ``` **Default chains supported:** -- `select()/selectDistinct()/selectDistinctOn() → from() → where()/innerJoin()/leftJoin() → where() → limit()/orderBy()/returning()/groupBy()` +- `select()/selectDistinct()/selectDistinctOn() → from() → where()/innerJoin()/leftJoin() → where() → limit()/orderBy()/returning()/groupBy()/for()` - `insert() → values() → returning()/onConflictDoUpdate()/onConflictDoNothing()` -- `update() → set() → where() → limit()/orderBy()/returning()` -- `delete() → where() → limit()/orderBy()/returning()` +- `update() → set() → where() → limit()/orderBy()/returning()/for()` +- `delete() → where() → limit()/orderBy()/returning()/for()` - `db.execute()` resolves `[]` - `db.transaction(cb)` calls cb with `dbChainMock.db` +`.for('update')` (Postgres row-level locking) is supported on `where` +builders. It returns a thenable with `.limit` / `.orderBy` / `.returning` / +`.groupBy` attached, so both `await .where().for('update')` (terminal) and +`await .where().for('update').limit(1)` (chained) work. Override the terminal +result with `dbChainMockFns.for.mockResolvedValueOnce([...])`; for the chained +form, mock the downstream terminal (e.g. `dbChainMockFns.limit.mockResolvedValueOnce([...])`). + All terminals default to `Promise.resolve([])`. Override per-test with `dbChainMockFns..mockResolvedValueOnce(...)`. Use `resetDbChainMock()` in `beforeEach` only when tests replace wiring with `.mockReturnValue` / `.mockResolvedValue` (permanent). Tests using only `...Once` variants don't need it. diff --git a/apps/sim/app/api/users/me/subscription/[id]/transfer/route.test.ts b/apps/sim/app/api/users/me/subscription/[id]/transfer/route.test.ts index 43182d88d2..d9c464661b 100644 --- a/apps/sim/app/api/users/me/subscription/[id]/transfer/route.test.ts +++ b/apps/sim/app/api/users/me/subscription/[id]/transfer/route.test.ts @@ -1,185 +1,138 @@ /** * @vitest-environment node */ -import { createSession, loggerMock } from '@sim/testing' +import { + authMock, + authMockFns, + createSession, + dbChainMock, + dbChainMockFns, + resetDbChainMock, +} from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockDbState, mockGetSession, mockHasPaidSubscription } = vi.hoisted(() => ({ - mockDbState: { - selectResults: [] as any[], - updateCalls: [] as Array<{ table: unknown; values: Record }>, - }, - mockGetSession: vi.fn(), - mockHasPaidSubscription: vi.fn(), -})) - -vi.mock('@sim/db', () => ({ - db: { - select: vi.fn().mockImplementation(() => { - const chain: any = {} - chain.from = vi.fn().mockReturnValue(chain) - chain.where = vi.fn().mockReturnValue(chain) - chain.then = vi - .fn() - .mockImplementation((callback: (rows: any[]) => any) => - Promise.resolve(callback(mockDbState.selectResults.shift() ?? [])) - ) - return chain - }), - update: vi.fn().mockImplementation((table: unknown) => ({ - set: vi.fn().mockImplementation((values: Record) => { - mockDbState.updateCalls.push({ table, values }) - return { - where: vi.fn().mockResolvedValue(undefined), - } - }), - })), - }, -})) - -vi.mock('@sim/db/schema', () => ({ - member: { - userId: 'member.userId', - organizationId: 'member.organizationId', - }, - organization: { - id: 'organization.id', - }, - subscription: { - id: 'subscription.id', - referenceId: 'subscription.referenceId', - }, -})) - -vi.mock('drizzle-orm', () => ({ - and: vi.fn((...conditions: unknown[]) => ({ type: 'and', conditions })), - eq: vi.fn((field: unknown, value: unknown) => ({ field, value })), -})) - -vi.mock('@sim/logger', () => loggerMock) - -vi.mock('@/lib/auth', () => ({ - getSession: mockGetSession, -})) - -vi.mock('@/lib/billing', () => ({ - hasPaidSubscription: mockHasPaidSubscription, -})) +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@/lib/auth', () => authMock) vi.mock('@/lib/billing/plan-helpers', () => ({ isOrgPlan: (plan: string) => plan === 'team' || plan === 'enterprise', })) vi.mock('@/lib/billing/subscriptions/utils', () => ({ + ENTITLED_SUBSCRIPTION_STATUSES: ['active', 'past_due'], hasPaidSubscriptionStatus: (status: string) => status === 'active' || status === 'past_due', })) import { POST } from '@/app/api/users/me/subscription/[id]/transfer/route' +function makeRequest(body: unknown, id = 'sub-1') { + return POST( + new Request(`http://localhost/api/users/me/subscription/${id}/transfer`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + }) as any, + { params: Promise.resolve({ id }) } + ) +} + describe('POST /api/users/me/subscription/[id]/transfer', () => { beforeEach(() => { vi.clearAllMocks() - mockDbState.selectResults = [] - mockDbState.updateCalls = [] - mockHasPaidSubscription.mockResolvedValue(false) - }) - - it('rejects transfers for non-organization subscriptions', async () => { - mockGetSession.mockResolvedValue( + resetDbChainMock() + authMockFns.mockGetSession.mockResolvedValue( createSession({ userId: 'user-1', email: 'owner@example.com', name: 'Owner', }) ) - mockDbState.selectResults = [ - [{ id: 'sub-1', referenceId: 'user-1', plan: 'pro', status: 'active' }], - ] - - const response = await POST( - new Request('http://localhost/api/users/me/subscription/sub-1/transfer', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ organizationId: 'org-1' }), - }) as any, - { params: Promise.resolve({ id: 'sub-1' }) } - ) + }) + + it('rejects transfers for non-organization subscriptions', async () => { + dbChainMockFns.for.mockResolvedValueOnce([ + { id: 'sub-1', referenceId: 'user-1', plan: 'pro', status: 'active' }, + ]) + + const response = await makeRequest({ organizationId: 'org-1' }) expect(response.status).toBe(400) await expect(response.json()).resolves.toEqual({ error: 'Only active Team or Enterprise subscriptions can be transferred to an organization.', }) - expect(mockDbState.updateCalls).toEqual([]) + expect(dbChainMockFns.update).not.toHaveBeenCalled() }) it('transfers an active organization subscription to an admin-owned organization', async () => { - mockGetSession.mockResolvedValue( - createSession({ - userId: 'user-1', - email: 'owner@example.com', - name: 'Owner', - }) - ) - mockDbState.selectResults = [ - [{ id: 'sub-1', referenceId: 'user-1', plan: 'team', status: 'active' }], - [{ id: 'org-1' }], - [{ id: 'member-1', role: 'owner' }], - ] - - const response = await POST( - new Request('http://localhost/api/users/me/subscription/sub-1/transfer', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ organizationId: 'org-1' }), - }) as any, - { params: Promise.resolve({ id: 'sub-1' }) } - ) + dbChainMockFns.for + .mockResolvedValueOnce([ + { id: 'sub-1', referenceId: 'user-1', plan: 'team', status: 'active' }, + ]) + .mockResolvedValueOnce([{ id: 'org-1' }]) + dbChainMockFns.limit.mockResolvedValueOnce([{ role: 'owner' }]).mockResolvedValueOnce([]) + + const response = await makeRequest({ organizationId: 'org-1' }) expect(response.status).toBe(200) await expect(response.json()).resolves.toEqual({ success: true, message: 'Subscription transferred successfully', }) - expect(mockDbState.updateCalls).toEqual([ - { - table: expect.objectContaining({ - id: 'subscription.id', - referenceId: 'subscription.referenceId', - }), - values: { referenceId: 'org-1' }, - }, - ]) + expect(dbChainMockFns.update).toHaveBeenCalled() + expect(dbChainMockFns.set).toHaveBeenCalledWith({ referenceId: 'org-1' }) }) it('treats an already-transferred organization subscription as a successful no-op', async () => { - mockGetSession.mockResolvedValue( - createSession({ - userId: 'user-1', - email: 'owner@example.com', - name: 'Owner', - }) - ) - mockDbState.selectResults = [ - [{ id: 'sub-1', referenceId: 'org-1', plan: 'team', status: 'active' }], - [{ id: 'org-1' }], - [{ id: 'member-1', role: 'owner' }], - ] - - const response = await POST( - new Request('http://localhost/api/users/me/subscription/sub-1/transfer', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ organizationId: 'org-1' }), - }) as any, - { params: Promise.resolve({ id: 'sub-1' }) } - ) + dbChainMockFns.for + .mockResolvedValueOnce([ + { id: 'sub-1', referenceId: 'org-1', plan: 'team', status: 'active' }, + ]) + .mockResolvedValueOnce([{ id: 'org-1' }]) + dbChainMockFns.limit.mockResolvedValueOnce([{ role: 'owner' }]) + + const response = await makeRequest({ organizationId: 'org-1' }) expect(response.status).toBe(200) await expect(response.json()).resolves.toEqual({ success: true, message: 'Subscription already belongs to this organization', }) - expect(mockDbState.updateCalls).toEqual([]) - expect(mockHasPaidSubscription).not.toHaveBeenCalled() + expect(dbChainMockFns.update).not.toHaveBeenCalled() + }) + + it('rejects the noop probe when the requester is not a member of the target organization', async () => { + dbChainMockFns.for + .mockResolvedValueOnce([ + { id: 'sub-1', referenceId: 'org-1', plan: 'team', status: 'active' }, + ]) + .mockResolvedValueOnce([{ id: 'org-1' }]) + dbChainMockFns.limit.mockResolvedValueOnce([]) + + const response = await makeRequest({ organizationId: 'org-1' }) + + expect(response.status).toBe(403) + await expect(response.json()).resolves.toEqual({ + error: 'Unauthorized - user is not admin of organization', + }) + expect(dbChainMockFns.update).not.toHaveBeenCalled() + }) + + it('rejects the transfer when the target organization already has an active subscription', async () => { + dbChainMockFns.for + .mockResolvedValueOnce([ + { id: 'sub-1', referenceId: 'user-1', plan: 'team', status: 'active' }, + ]) + .mockResolvedValueOnce([{ id: 'org-1' }]) + dbChainMockFns.limit + .mockResolvedValueOnce([{ role: 'owner' }]) + .mockResolvedValueOnce([{ id: 'existing-sub' }]) + + const response = await makeRequest({ organizationId: 'org-1' }) + + expect(response.status).toBe(409) + await expect(response.json()).resolves.toEqual({ + error: 'Organization already has an active subscription', + }) + expect(dbChainMockFns.update).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/app/api/users/me/subscription/[id]/transfer/route.ts b/apps/sim/app/api/users/me/subscription/[id]/transfer/route.ts index df3dddd1af..46b798b493 100644 --- a/apps/sim/app/api/users/me/subscription/[id]/transfer/route.ts +++ b/apps/sim/app/api/users/me/subscription/[id]/transfer/route.ts @@ -2,13 +2,15 @@ import { db } from '@sim/db' import { member, organization, subscription } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' -import { and, eq } from 'drizzle-orm' +import { and, eq, inArray } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' -import { hasPaidSubscription } from '@/lib/billing' import { isOrgPlan } from '@/lib/billing/plan-helpers' -import { hasPaidSubscriptionStatus } from '@/lib/billing/subscriptions/utils' +import { + ENTITLED_SUBSCRIPTION_STATUSES, + hasPaidSubscriptionStatus, +} from '@/lib/billing/subscriptions/utils' const logger = createLogger('SubscriptionTransferAPI') @@ -16,6 +18,11 @@ const transferSubscriptionSchema = z.object({ organizationId: z.string().min(1), }) +type TransferOutcome = + | { kind: 'error'; status: number; error: string } + | { kind: 'noop'; message: string } + | { kind: 'success'; message: string } + export async function POST(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { try { const subscriptionId = (await params).id @@ -50,88 +57,105 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ } const { organizationId } = validationResult.data + const userId = session.user.id logger.info('Processing subscription transfer', { subscriptionId, organizationId }) - const sub = await db - .select() - .from(subscription) - .where(eq(subscription.id, subscriptionId)) - .then((rows) => rows[0]) - - if (!sub) { - return NextResponse.json({ error: 'Subscription not found' }, { status: 404 }) - } - - if (!isOrgPlan(sub.plan) || !hasPaidSubscriptionStatus(sub.status)) { - return NextResponse.json( - { + const outcome = await db.transaction(async (tx): Promise => { + const [sub] = await tx + .select() + .from(subscription) + .where(eq(subscription.id, subscriptionId)) + .for('update') + + if (!sub) { + return { kind: 'error', status: 404, error: 'Subscription not found' } + } + + if (!isOrgPlan(sub.plan) || !hasPaidSubscriptionStatus(sub.status)) { + return { + kind: 'error', + status: 400, error: 'Only active Team or Enterprise subscriptions can be transferred to an organization.', - }, - { status: 400 } - ) - } - - const org = await db - .select() - .from(organization) - .where(eq(organization.id, organizationId)) - .then((rows) => rows[0]) + } + } + + const [org] = await tx + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, organizationId)) + .for('update') + + if (!org) { + return { kind: 'error', status: 404, error: 'Organization not found' } + } + + const [mem] = await tx + .select({ role: member.role }) + .from(member) + .where(and(eq(member.userId, userId), eq(member.organizationId, organizationId))) + .limit(1) + + if (!mem || (mem.role !== 'owner' && mem.role !== 'admin')) { + return { + kind: 'error', + status: 403, + error: 'Unauthorized - user is not admin of organization', + } + } + + if (sub.referenceId === organizationId) { + return { kind: 'noop', message: 'Subscription already belongs to this organization' } + } + + if (sub.referenceId !== userId) { + return { + kind: 'error', + status: 403, + error: 'Unauthorized - subscription does not belong to user', + } + } + + const [existingOrgSub] = await tx + .select({ id: subscription.id }) + .from(subscription) + .where( + and( + eq(subscription.referenceId, organizationId), + inArray(subscription.status, ENTITLED_SUBSCRIPTION_STATUSES) + ) + ) + .limit(1) + + if (existingOrgSub) { + return { + kind: 'error', + status: 409, + error: 'Organization already has an active subscription', + } + } + + await tx + .update(subscription) + .set({ referenceId: organizationId }) + .where(eq(subscription.id, subscriptionId)) + + return { kind: 'success', message: 'Subscription transferred successfully' } + }) - if (!org) { - return NextResponse.json({ error: 'Organization not found' }, { status: 404 }) + if (outcome.kind === 'error') { + return NextResponse.json({ error: outcome.error }, { status: outcome.status }) } - const mem = await db - .select() - .from(member) - .where(and(eq(member.userId, session.user.id), eq(member.organizationId, organizationId))) - .then((rows) => rows[0]) - - if (!mem || (mem.role !== 'owner' && mem.role !== 'admin')) { - return NextResponse.json( - { error: 'Unauthorized - user is not admin of organization' }, - { status: 403 } - ) - } - - if (sub.referenceId === organizationId) { - return NextResponse.json({ - success: true, - message: 'Subscription already belongs to this organization', + if (outcome.kind === 'success') { + logger.info('Subscription transfer completed', { + subscriptionId, + organizationId, + userId, }) } - if (sub.referenceId !== session.user.id) { - return NextResponse.json( - { error: 'Unauthorized - subscription does not belong to user' }, - { status: 403 } - ) - } - - // Check if org already has an active subscription (prevent duplicates) - if (await hasPaidSubscription(organizationId)) { - return NextResponse.json( - { error: 'Organization already has an active subscription' }, - { status: 409 } - ) - } - - await db - .update(subscription) - .set({ referenceId: organizationId }) - .where(eq(subscription.id, subscriptionId)) - - logger.info('Subscription transfer completed', { - subscriptionId, - organizationId, - userId: session.user.id, - }) - - return NextResponse.json({ - success: true, - message: 'Subscription transferred successfully', - }) + return NextResponse.json({ success: true, message: outcome.message }) } catch (error) { logger.error('Error transferring subscription', { error: toError(error).message, diff --git a/apps/sim/lib/billing/core/subscription.test.ts b/apps/sim/lib/billing/core/subscription.test.ts index ada13963d4..2f49aa6ba2 100644 --- a/apps/sim/lib/billing/core/subscription.test.ts +++ b/apps/sim/lib/billing/core/subscription.test.ts @@ -1,37 +1,10 @@ /** * @vitest-environment node */ -import { schemaMock } from '@sim/testing' +import { dbChainMock, dbChainMockFns, urlsMock } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockDbResults } = vi.hoisted(() => ({ - mockDbResults: { value: [] as Array }, -})) - -vi.mock('@sim/db', () => ({ - db: { - select: vi.fn().mockImplementation(() => { - const chain: any = {} - chain.from = vi.fn().mockReturnValue(chain) - chain.where = vi.fn().mockReturnValue(chain) - chain.limit = vi.fn().mockImplementation(async () => { - const result = mockDbResults.value.shift() - if (result instanceof Error) { - throw result - } - return result ?? [] - }) - return chain - }), - update: vi.fn().mockReturnValue({ - set: vi.fn().mockReturnValue({ - where: vi.fn().mockResolvedValue(undefined), - }), - }), - }, -})) - -vi.mock('@sim/db/schema', () => schemaMock) +vi.mock('@sim/db', () => dbChainMock) vi.mock('@/lib/billing/core/access', () => ({ getEffectiveBillingStatus: vi.fn(), @@ -66,9 +39,7 @@ vi.mock('@/lib/core/config/feature-flags', () => ({ isSsoEnabled: false, })) -vi.mock('@/lib/core/utils/urls', () => ({ - getBaseUrl: vi.fn().mockReturnValue('https://test.sim.ai'), -})) +vi.mock('@/lib/core/utils/urls', () => urlsMock) import { getOrganizationIdForSubscriptionReference, @@ -78,29 +49,28 @@ import { describe('hasPaidSubscription', () => { beforeEach(() => { vi.clearAllMocks() - mockDbResults.value = [] }) it('returns true when an entitled subscription exists', async () => { - mockDbResults.value = [[{ id: 'sub-1' }]] + dbChainMockFns.limit.mockResolvedValueOnce([{ id: 'sub-1' }]) await expect(hasPaidSubscription('org-1')).resolves.toBe(true) }) it('returns false when no entitled subscription exists', async () => { - mockDbResults.value = [[]] + dbChainMockFns.limit.mockResolvedValueOnce([]) await expect(hasPaidSubscription('org-1')).resolves.toBe(false) }) it('fails closed by default when the lookup errors', async () => { - mockDbResults.value = [new Error('db unavailable')] + dbChainMockFns.limit.mockRejectedValueOnce(new Error('db unavailable')) await expect(hasPaidSubscription('org-1')).resolves.toBe(true) }) it('throws when requested so callers can retry instead of skipping cleanup', async () => { - mockDbResults.value = [new Error('db unavailable')] + dbChainMockFns.limit.mockRejectedValueOnce(new Error('db unavailable')) await expect(hasPaidSubscription('org-1', { onError: 'throw' })).rejects.toThrow( 'db unavailable' @@ -111,25 +81,18 @@ describe('hasPaidSubscription', () => { describe('getOrganizationIdForSubscriptionReference', () => { beforeEach(() => { vi.clearAllMocks() - mockDbResults.value = [] }) it('returns an organization id directly when the reference already points to one', async () => { - mockDbResults.value = [[{ id: 'org-1' }]] + dbChainMockFns.limit.mockResolvedValueOnce([{ id: 'org-1' }]) await expect(getOrganizationIdForSubscriptionReference('org-1')).resolves.toBe('org-1') }) it('falls back to the admin-owned organization when the reference is still user-scoped', async () => { - mockDbResults.value = [ - [], - [ - { - organizationId: 'org-1', - role: 'owner', - }, - ], - ] + dbChainMockFns.limit + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([{ organizationId: 'org-1', role: 'owner' }]) await expect(getOrganizationIdForSubscriptionReference('user-1')).resolves.toBe('org-1') }) diff --git a/apps/sim/lib/billing/credits/daily-refresh.test.ts b/apps/sim/lib/billing/credits/daily-refresh.test.ts index 1a74ae9a0c..932f487fe9 100644 --- a/apps/sim/lib/billing/credits/daily-refresh.test.ts +++ b/apps/sim/lib/billing/credits/daily-refresh.test.ts @@ -1,34 +1,19 @@ /** * @vitest-environment node */ +import { dbChainMock, dbChainMockFns, drizzleOrmMock } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const mockDbSelect = vi.fn() - -vi.mock('@sim/db', () => ({ - db: { - select: () => ({ - from: () => ({ - where: () => ({ - groupBy: mockDbSelect, - }), - }), - }), - }, -})) +vi.mock('@sim/db', () => dbChainMock) vi.mock('drizzle-orm', () => { const sqlTag = () => { - const obj = { as: () => obj } + const obj: { as: () => typeof obj } = { as: () => obj } return obj } - sqlTag.as = sqlTag return { + ...drizzleOrmMock, sql: Object.assign(sqlTag, { raw: sqlTag }), - and: vi.fn(), - gte: vi.fn(), - lt: vi.fn(), - inArray: vi.fn(), sum: () => ({ as: () => 'sum' }), } }) @@ -37,7 +22,10 @@ vi.mock('@/lib/billing/constants', () => ({ DAILY_REFRESH_RATE: 0.01, })) -import { computeDailyRefreshConsumed, getDailyRefreshDollars } from './daily-refresh' +import { + computeDailyRefreshConsumed, + getDailyRefreshDollars, +} from '@/lib/billing/credits/daily-refresh' describe('computeDailyRefreshConsumed', () => { beforeEach(() => { @@ -51,7 +39,7 @@ describe('computeDailyRefreshConsumed', () => { planDollars: 0, }) expect(result).toBe(0) - expect(mockDbSelect).not.toHaveBeenCalled() + expect(dbChainMockFns.groupBy).not.toHaveBeenCalled() }) it('returns 0 when userIds is empty', async () => { @@ -61,7 +49,7 @@ describe('computeDailyRefreshConsumed', () => { planDollars: 25, }) expect(result).toBe(0) - expect(mockDbSelect).not.toHaveBeenCalled() + expect(dbChainMockFns.groupBy).not.toHaveBeenCalled() }) it('returns 0 when periodEnd is before periodStart', async () => { @@ -75,7 +63,7 @@ describe('computeDailyRefreshConsumed', () => { }) it('caps each day at the daily refresh allowance', async () => { - mockDbSelect.mockResolvedValue([ + dbChainMockFns.groupBy.mockResolvedValueOnce([ { dayIndex: 0, dayTotal: '0.50' }, { dayIndex: 1, dayTotal: '0.10' }, { dayIndex: 2, dayTotal: '1.00' }, @@ -97,7 +85,7 @@ describe('computeDailyRefreshConsumed', () => { }) it('returns 0 when no usage rows exist', async () => { - mockDbSelect.mockResolvedValue([]) + dbChainMockFns.groupBy.mockResolvedValueOnce([]) const result = await computeDailyRefreshConsumed({ userIds: ['user-1'], @@ -110,7 +98,7 @@ describe('computeDailyRefreshConsumed', () => { }) it('multiplies daily refresh by seats', async () => { - mockDbSelect.mockResolvedValue([{ dayIndex: 0, dayTotal: '2.00' }]) + dbChainMockFns.groupBy.mockResolvedValueOnce([{ dayIndex: 0, dayTotal: '2.00' }]) const result = await computeDailyRefreshConsumed({ userIds: ['user-1', 'user-2', 'user-3'], @@ -126,7 +114,7 @@ describe('computeDailyRefreshConsumed', () => { }) it('caps at refresh even with high usage and multiple seats', async () => { - mockDbSelect.mockResolvedValue([{ dayIndex: 0, dayTotal: '50.00' }]) + dbChainMockFns.groupBy.mockResolvedValueOnce([{ dayIndex: 0, dayTotal: '50.00' }]) const result = await computeDailyRefreshConsumed({ userIds: ['user-1', 'user-2'], @@ -142,7 +130,7 @@ describe('computeDailyRefreshConsumed', () => { }) it('handles null dayTotal gracefully', async () => { - mockDbSelect.mockResolvedValue([{ dayIndex: 0, dayTotal: null }]) + dbChainMockFns.groupBy.mockResolvedValueOnce([{ dayIndex: 0, dayTotal: null }]) const result = await computeDailyRefreshConsumed({ userIds: ['user-1'], diff --git a/apps/sim/lib/billing/organization.test.ts b/apps/sim/lib/billing/organization.test.ts index c65aba2eb4..46c962df92 100644 --- a/apps/sim/lib/billing/organization.test.ts +++ b/apps/sim/lib/billing/organization.test.ts @@ -1,48 +1,20 @@ /** * @vitest-environment node */ -import { schemaMock } from '@sim/testing' +import { dbChainMock } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' const { - mockDbState, mockCreateOrganizationWithOwner, mockAttachOwnedWorkspacesToOrganization, mockGetOrganizationIdForSubscriptionReference, } = vi.hoisted(() => ({ - mockDbState: { - selectResults: [] as any[], - }, mockCreateOrganizationWithOwner: vi.fn(), mockAttachOwnedWorkspacesToOrganization: vi.fn(), mockGetOrganizationIdForSubscriptionReference: vi.fn(), })) -vi.mock('@sim/db', () => ({ - db: { - select: vi.fn().mockImplementation(() => { - const chain: any = {} - chain.from = vi.fn().mockReturnValue(chain) - chain.where = vi.fn().mockReturnValue(chain) - chain.limit = vi - .fn() - .mockImplementation(() => Promise.resolve(mockDbState.selectResults.shift() ?? [])) - chain.then = vi - .fn() - .mockImplementation((callback: (rows: any[]) => any) => - Promise.resolve(callback(mockDbState.selectResults.shift() ?? [])) - ) - return chain - }), - update: vi.fn(), - }, -})) - -vi.mock('@sim/db/schema', () => schemaMock) - -vi.mock('@/lib/billing', () => ({ - hasPaidSubscription: vi.fn(), -})) +vi.mock('@sim/db', () => dbChainMock) vi.mock('@/lib/billing/core/billing', () => ({ getPlanPricing: vi.fn(), @@ -74,7 +46,6 @@ import { ensureOrganizationForTeamSubscription } from '@/lib/billing/organizatio describe('ensureOrganizationForTeamSubscription', () => { beforeEach(() => { vi.clearAllMocks() - mockDbState.selectResults = [] mockGetOrganizationIdForSubscriptionReference.mockResolvedValue(null) }) diff --git a/apps/sim/lib/billing/organization.ts b/apps/sim/lib/billing/organization.ts index 9593e14d70..e87968dd0a 100644 --- a/apps/sim/lib/billing/organization.ts +++ b/apps/sim/lib/billing/organization.ts @@ -8,12 +8,12 @@ import { } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq, inArray, sql } from 'drizzle-orm' -import { hasPaidSubscription } from '@/lib/billing' import { getPlanPricing } from '@/lib/billing/core/billing' import { getOrganizationIdForSubscriptionReference } from '@/lib/billing/core/subscription' import { syncUsageLimitsFromSubscription } from '@/lib/billing/core/usage' import { createOrganizationWithOwner } from '@/lib/billing/organizations/create-organization' import { isEnterprise, isOrgPlan, isPaid } from '@/lib/billing/plan-helpers' +import { ENTITLED_SUBSCRIPTION_STATUSES } from '@/lib/billing/subscriptions/utils' import { toDecimal, toNumber } from '@/lib/billing/utils/decimal' import { attachOwnedWorkspacesToOrganization } from '@/lib/workspaces/organization-workspaces' @@ -134,26 +134,75 @@ export async function ensureOrganizationForTeamSubscription( if (existingMembership.length > 0) { const membership = existingMembership[0] if (membership.role === 'owner' || membership.role === 'admin') { - // Check if org already has an active subscription (prevent duplicates) - if (await hasPaidSubscription(membership.organizationId)) { - logger.error('Organization already has an active subscription', { - userId, - organizationId: membership.organizationId, - newSubscriptionId: subscription.id, - }) - throw new Error('Organization already has an active subscription') - } + /** + * Atomic duplicate-subscription check + referenceId transfer. + * + * Row-level locks (`FOR UPDATE`) on the subscription and target + * organization rows prevent a TOCTOU race between the "org has no + * paid subscription" check and the transfer write — which could + * otherwise let two concurrent webhook deliveries or org-creation + * flows both pass the check and attach two subscriptions to the + * same organization. + */ + await db.transaction(async (tx) => { + const [lockedSub] = await tx + .select({ + id: subscriptionTable.id, + referenceId: subscriptionTable.referenceId, + }) + .from(subscriptionTable) + .where(eq(subscriptionTable.id, subscription.id)) + .for('update') + + if (!lockedSub) { + throw new Error(`Subscription ${subscription.id} not found during transfer`) + } + + if (lockedSub.referenceId === membership.organizationId) { + return + } + + const [lockedOrg] = await tx + .select({ id: organization.id }) + .from(organization) + .where(eq(organization.id, membership.organizationId)) + .for('update') + + if (!lockedOrg) { + throw new Error(`Organization ${membership.organizationId} not found during transfer`) + } + + const [existingOrgSub] = await tx + .select({ id: subscriptionTable.id }) + .from(subscriptionTable) + .where( + and( + eq(subscriptionTable.referenceId, membership.organizationId), + inArray(subscriptionTable.status, ENTITLED_SUBSCRIPTION_STATUSES) + ) + ) + .limit(1) + + if (existingOrgSub) { + logger.error('Organization already has an active subscription', { + userId, + organizationId: membership.organizationId, + newSubscriptionId: subscription.id, + }) + throw new Error('Organization already has an active subscription') + } + + await tx + .update(subscriptionTable) + .set({ referenceId: membership.organizationId }) + .where(eq(subscriptionTable.id, subscription.id)) + }) logger.info('User already owns/admins an org, using it', { userId, organizationId: membership.organizationId, }) - await db - .update(subscriptionTable) - .set({ referenceId: membership.organizationId }) - .where(eq(subscriptionTable.id, subscription.id)) - await attachOwnedWorkspacesToOrganization({ ownerUserId: userId, organizationId: membership.organizationId, diff --git a/apps/sim/lib/billing/organizations/create-organization.test.ts b/apps/sim/lib/billing/organizations/create-organization.test.ts index 3d58df25c4..a3fb3bfa74 100644 --- a/apps/sim/lib/billing/organizations/create-organization.test.ts +++ b/apps/sim/lib/billing/organizations/create-organization.test.ts @@ -1,49 +1,14 @@ /** * @vitest-environment node */ -import { schemaMock } from '@sim/testing' +import { dbChainMock, dbChainMockFns } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockDbState, mockGenerateId } = vi.hoisted(() => ({ - mockDbState: { - selectResults: [] as any[], - insertedOrganizations: [] as any[], - insertedMembers: [] as any[], - }, +const { mockGenerateId } = vi.hoisted(() => ({ mockGenerateId: vi.fn(), })) -vi.mock('@sim/db', () => ({ - db: { - transaction: vi.fn(async (callback: any) => { - const tx = { - select: vi.fn().mockReturnValue({ - from: vi.fn().mockReturnValue({ - where: vi.fn().mockReturnValue({ - limit: vi - .fn() - .mockImplementation(() => Promise.resolve(mockDbState.selectResults.shift() ?? [])), - }), - }), - }), - insert: vi.fn().mockReturnValue({ - values: vi.fn().mockImplementation(async (values: Record) => { - if ('slug' in values) { - mockDbState.insertedOrganizations.push(values) - return - } - - mockDbState.insertedMembers.push(values) - }), - }), - } - - return callback(tx) - }), - }, -})) - -vi.mock('@sim/db/schema', () => schemaMock) +vi.mock('@sim/db', () => dbChainMock) vi.mock('@sim/utils/id', () => ({ generateId: mockGenerateId, @@ -56,17 +21,20 @@ import { validateOrganizationSlugOrThrow, } from '@/lib/billing/organizations/create-organization' +function insertedValuesFor(predicate: (values: Record) => boolean) { + return dbChainMockFns.values.mock.calls + .map((call) => call[0] as Record) + .filter(predicate) +} + describe('createOrganizationWithOwner', () => { beforeEach(() => { vi.clearAllMocks() - mockDbState.selectResults = [] - mockDbState.insertedOrganizations = [] - mockDbState.insertedMembers = [] }) it('creates an organization with a Better Auth-compatible id prefix', async () => { mockGenerateId.mockReturnValueOnce('abc123').mockReturnValueOnce('member456') - mockDbState.selectResults = [[]] + dbChainMockFns.limit.mockResolvedValueOnce([]) const result = await createOrganizationWithOwner({ ownerUserId: 'user-1', @@ -79,7 +47,7 @@ describe('createOrganizationWithOwner', () => { organizationId: 'org_abc123', memberId: 'member456', }) - expect(mockDbState.insertedOrganizations).toEqual([ + expect(insertedValuesFor((v) => 'slug' in v)).toEqual([ expect.objectContaining({ id: 'org_abc123', name: 'My Org', @@ -87,7 +55,7 @@ describe('createOrganizationWithOwner', () => { metadata: { source: 'test' }, }), ]) - expect(mockDbState.insertedMembers).toEqual([ + expect(insertedValuesFor((v) => !('slug' in v))).toEqual([ expect.objectContaining({ id: 'member456', userId: 'user-1', @@ -99,7 +67,7 @@ describe('createOrganizationWithOwner', () => { it('throws a typed error when the organization slug is already taken', async () => { mockGenerateId.mockReturnValueOnce('abc123').mockReturnValueOnce('member456') - mockDbState.selectResults = [[{ id: 'existing-org' }]] + dbChainMockFns.limit.mockResolvedValueOnce([{ id: 'existing-org' }]) await expect( createOrganizationWithOwner({ @@ -109,8 +77,7 @@ describe('createOrganizationWithOwner', () => { }) ).rejects.toBeInstanceOf(OrganizationSlugTakenError) - expect(mockDbState.insertedOrganizations).toEqual([]) - expect(mockDbState.insertedMembers).toEqual([]) + expect(insertedValuesFor(() => true)).toEqual([]) }) it('rejects invalid organization slugs before writing anything', () => { diff --git a/apps/sim/lib/billing/validation/seat-management.test.ts b/apps/sim/lib/billing/validation/seat-management.test.ts index 3035051a58..305cc401b1 100644 --- a/apps/sim/lib/billing/validation/seat-management.test.ts +++ b/apps/sim/lib/billing/validation/seat-management.test.ts @@ -1,34 +1,15 @@ /** * @vitest-environment node */ -import { schemaMock } from '@sim/testing' +import { dbChainMock, dbChainMockFns, resetDbChainMock } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockDbResults, mockFeatureFlags, mockGetOrganizationSubscription } = vi.hoisted(() => ({ - mockDbResults: { value: [] as any[] }, +const { mockFeatureFlags, mockGetOrganizationSubscription } = vi.hoisted(() => ({ mockFeatureFlags: { isBillingEnabled: false }, mockGetOrganizationSubscription: vi.fn(), })) -vi.mock('@sim/db', () => ({ - db: { - select: vi.fn().mockImplementation(() => { - const chain: any = {} - chain.from = vi.fn().mockReturnValue(chain) - chain.where = vi.fn().mockReturnValue(chain) - chain.limit = vi - .fn() - .mockImplementation(() => Promise.resolve(mockDbResults.value.shift() ?? [])) - chain.then = vi.fn().mockImplementation((callback: (rows: any[]) => unknown) => { - const rows = mockDbResults.value.shift() ?? [] - return Promise.resolve(callback(rows)) - }) - return chain - }), - }, -})) - -vi.mock('@sim/db/schema', () => schemaMock) +vi.mock('@sim/db', () => dbChainMock) vi.mock('@/lib/billing/core/billing', () => ({ getOrganizationSubscription: mockGetOrganizationSubscription, @@ -56,16 +37,36 @@ vi.mock('@/lib/messaging/email/validation', () => ({ import { getOrganizationSeatInfo } from '@/lib/billing/validation/seat-management' +/** + * Queues the next N responses for `db.select().from(...).where(...)` calls, + * supporting both `.limit(1)` and directly-awaited `where` chains. + */ +function queueSelectResponses(responses: unknown[][]) { + const queue = [...responses] + dbChainMockFns.where.mockImplementation(() => { + const result = queue.shift() ?? [] + const thenable = { + limit: vi.fn(() => Promise.resolve(result)), + orderBy: vi.fn(() => Promise.resolve(result)), + returning: vi.fn(() => Promise.resolve(result)), + groupBy: vi.fn(() => Promise.resolve(result)), + then: (onFulfilled: (rows: unknown) => unknown, onRejected?: (reason: unknown) => unknown) => + Promise.resolve(result).then(onFulfilled, onRejected), + } + return thenable as unknown as ReturnType + }) +} + describe('getOrganizationSeatInfo', () => { beforeEach(() => { vi.clearAllMocks() - mockDbResults.value = [] + resetDbChainMock() mockFeatureFlags.isBillingEnabled = false mockGetOrganizationSubscription.mockResolvedValue(null) }) it('returns unlimited seat info when billing is disabled', async () => { - mockDbResults.value = [[{ id: 'org-1', name: 'Acme' }], [{ count: 3 }], [{ count: 2 }]] + queueSelectResponses([[{ id: 'org-1', name: 'Acme' }], [{ count: 3 }], [{ count: 2 }]]) const result = await getOrganizationSeatInfo('org-1') diff --git a/apps/sim/lib/billing/webhooks/invoices.test.ts b/apps/sim/lib/billing/webhooks/invoices.test.ts index 0642bf47b9..9d601c33a8 100644 --- a/apps/sim/lib/billing/webhooks/invoices.test.ts +++ b/apps/sim/lib/billing/webhooks/invoices.test.ts @@ -1,57 +1,27 @@ /** * @vitest-environment node */ -import { urlsMock, urlsMockFns } from '@sim/testing' +import { + createMockStripeEvent, + dbChainMock, + dbChainMockFns, + drizzleOrmMock, + resetDbChainMock, + stripeClientMock, + stripePaymentMethodMock, + urlsMock, + urlsMockFns, +} from '@sim/testing' import type Stripe from 'stripe' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockBlockOrgMembers, mockDbSelect, mockUnblockOrgMembers, selectResponses } = vi.hoisted( - () => { - const selectResponses: Array<{ limitResult?: unknown; whereResult?: unknown }> = [] - const mockDbSelect = vi.fn(() => { - const nextResponse = selectResponses.shift() - - if (!nextResponse) { - throw new Error('No queued db.select response') - } - - const builder = { - from: vi.fn(() => builder), - where: vi.fn(() => builder), - limit: vi.fn(async () => nextResponse.limitResult ?? nextResponse.whereResult ?? []), - then: (resolve: (value: unknown) => unknown, reject?: (reason: unknown) => unknown) => - Promise.resolve(nextResponse.whereResult ?? nextResponse.limitResult ?? []).then( - resolve, - reject - ), - } - - return builder - }) - - return { - mockBlockOrgMembers: vi.fn(), - mockDbSelect, - mockUnblockOrgMembers: vi.fn(), - selectResponses, - } - } -) - -vi.mock('@sim/db', () => ({ - db: { - select: mockDbSelect, - }, +const { mockBlockOrgMembers, mockUnblockOrgMembers } = vi.hoisted(() => ({ + mockBlockOrgMembers: vi.fn(), + mockUnblockOrgMembers: vi.fn(), })) -vi.mock('drizzle-orm', () => ({ - and: vi.fn(() => 'and'), - eq: vi.fn(() => 'eq'), - inArray: vi.fn(() => 'inArray'), - isNull: vi.fn(() => 'isNull'), - ne: vi.fn(() => 'ne'), - or: vi.fn(() => 'or'), -})) +vi.mock('@sim/db', () => dbChainMock) +vi.mock('drizzle-orm', () => drizzleOrmMock) vi.mock('@/components/emails', () => ({ PaymentFailedEmail: vi.fn(), @@ -85,18 +55,8 @@ vi.mock('@/lib/billing/plan-helpers', () => ({ isTeam: vi.fn((plan: string | null | undefined) => Boolean(plan?.startsWith('team'))), })) -vi.mock('@/lib/billing/stripe-client', () => ({ - requireStripeClient: vi.fn(), -})) - -vi.mock('@/lib/billing/stripe-payment-method', () => ({ - resolveDefaultPaymentMethod: vi.fn(async () => ({ - paymentMethodId: undefined, - collectionMethod: 'charge_automatically', - })), - getPaymentMethodId: vi.fn(), - getCustomerId: vi.fn(), -})) +vi.mock('@/lib/billing/stripe-client', () => stripeClientMock) +vi.mock('@/lib/billing/stripe-payment-method', () => stripePaymentMethodMock) vi.mock('@/lib/billing/subscriptions/utils', () => ({ ENTITLED_SUBSCRIPTION_STATUSES: ['active', 'trialing', 'past_due'], @@ -140,29 +100,57 @@ vi.mock('@react-email/render', () => ({ render: vi.fn(), })) -import { handleInvoicePaymentFailed, handleInvoicePaymentSucceeded } from './invoices' +import { + handleInvoicePaymentFailed, + handleInvoicePaymentSucceeded, +} from '@/lib/billing/webhooks/invoices' -function queueSelectResponse(response: { limitResult?: unknown; whereResult?: unknown }) { +interface SelectResponse { + limitResult?: unknown + whereResult?: unknown +} + +const selectResponses: SelectResponse[] = [] + +function queueSelectResponse(response: SelectResponse) { selectResponses.push(response) } +/** + * Override `where` so that each select-then-where chain pops the next queued + * response. Supports both `.limit(1)` terminals and directly-awaited `where()`. + */ +function installSelectResponseQueue() { + dbChainMockFns.where.mockImplementation(() => { + const next = selectResponses.shift() + if (!next) { + throw new Error('No queued db.select response') + } + const builder = { + limit: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), + orderBy: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), + returning: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), + groupBy: vi.fn(async () => next.limitResult ?? next.whereResult ?? []), + then: (resolve: (value: unknown) => unknown, reject?: (reason: unknown) => unknown) => + Promise.resolve(next.whereResult ?? next.limitResult ?? []).then(resolve, reject), + } + return builder as unknown as ReturnType + }) +} + function createInvoiceEvent( type: 'invoice.payment_failed' | 'invoice.payment_succeeded', invoice: Partial ): Stripe.Event { - return { - data: { - object: invoice as Stripe.Invoice, - }, - id: `evt_${type}`, - type, - } as Stripe.Event + return createMockStripeEvent(type, invoice) } describe('invoice billing recovery', () => { beforeEach(() => { vi.clearAllMocks() + resetDbChainMock() selectResponses.length = 0 + installSelectResponseQueue() urlsMockFns.mockGetBaseUrl.mockReturnValue('https://sim.test') mockBlockOrgMembers.mockResolvedValue(2) mockUnblockOrgMembers.mockResolvedValue(2) diff --git a/bun.lock b/bun.lock index e38f683870..3847a5c931 100644 --- a/bun.lock +++ b/bun.lock @@ -1,5 +1,6 @@ { "lockfileVersion": 1, + "configVersion": 0, "workspaces": { "": { "name": "simstudio", diff --git a/packages/testing/src/mocks/database.mock.ts b/packages/testing/src/mocks/database.mock.ts index e22c98bd29..1aafba0a54 100644 --- a/packages/testing/src/mocks/database.mock.ts +++ b/packages/testing/src/mocks/database.mock.ts @@ -64,13 +64,19 @@ export function createMockSqlOperators() { * are wired at module load time: * * - `select().from().where()` → returns a builder with `.limit` / `.orderBy` / - * `.returning` / `.groupBy` terminals + * `.returning` / `.groupBy` / `.for` terminals * - `select().from().innerJoin()|leftJoin()` → returns the same where-builder * - `insert().values().returning()` / `update().set().where()` / `delete().where()` * - * Terminals (`limit`, `orderBy`, `returning`, `groupBy`, `values`) default to - * resolving `[]` (or `undefined` for `values`). Override per-test with - * `dbChainMockFns.limit.mockResolvedValueOnce([...])`. + * Terminals (`limit`, `orderBy`, `returning`, `groupBy`, `for`, `values`) + * default to resolving `[]` (or `undefined` for `values`). Override per-test + * with `dbChainMockFns.limit.mockResolvedValueOnce([...])`. `for` mirrors + * drizzle's `.for('update')` — it returns a Promise with `.limit` / `.orderBy` + * / `.returning` / `.groupBy` attached, so both `await .where().for('update')` + * (terminal) and `await .where().for('update').limit(1)` (chained) work. + * Override the terminal result with `dbChainMockFns.for.mockResolvedValueOnce( + * [...])`; override the chained result by mocking the downstream terminal + * (e.g. `dbChainMockFns.limit.mockResolvedValueOnce([...])`). * * `vi.clearAllMocks()` clears call history but preserves default wiring. Tests * that replace a wiring with `mockReturnValue(...)` (not `...Once`) must re-wire @@ -94,10 +100,20 @@ const returning = vi.fn(() => Promise.resolve([] as unknown[])) const groupBy = vi.fn(() => Promise.resolve([] as unknown[])) const execute = vi.fn(() => Promise.resolve([] as unknown[])) +const forBuilder = () => { + const thenable: any = Promise.resolve([] as unknown[]) + thenable.limit = limit + thenable.orderBy = orderBy + thenable.returning = returning + thenable.groupBy = groupBy + return thenable +} +const forClause = vi.fn(forBuilder) + const onConflictDoUpdate = vi.fn(() => ({ returning }) as unknown as Promise) const onConflictDoNothing = vi.fn(() => ({ returning }) as unknown as Promise) -const whereBuilder = () => ({ limit, orderBy, returning, groupBy }) +const whereBuilder = () => ({ limit, orderBy, returning, groupBy, for: forClause }) const where = vi.fn(whereBuilder) const joinBuilder = (): { where: typeof where; innerJoin: any; leftJoin: any } => ({ @@ -134,6 +150,7 @@ export const dbChainMockFns = { leftJoin, groupBy, execute, + for: forClause, insert, values, onConflictDoUpdate, @@ -173,6 +190,7 @@ export function resetDbChainMock(): void { returning.mockImplementation(() => Promise.resolve([] as unknown[])) groupBy.mockImplementation(() => Promise.resolve([] as unknown[])) execute.mockImplementation(() => Promise.resolve([] as unknown[])) + forClause.mockImplementation(forBuilder) transaction.mockImplementation(async (cb: (tx: typeof dbChainMock.db) => unknown) => cb(dbChainMock.db) ) diff --git a/packages/testing/src/mocks/index.ts b/packages/testing/src/mocks/index.ts index 9767d44e7f..c3ab30078c 100644 --- a/packages/testing/src/mocks/index.ts +++ b/packages/testing/src/mocks/index.ts @@ -108,6 +108,14 @@ export { } from './socket.mock' // Storage mocks export { clearStorageMocks, createMockStorage, setupGlobalStorageMocks } from './storage.mock' +// Stripe mocks +export { + createMockStripeEvent, + stripeClientMock, + stripeClientMockFns, + stripePaymentMethodMock, + stripePaymentMethodMockFns, +} from './stripe.mock' // Telemetry mocks export { telemetryMock } from './telemetry.mock' // URL mocks diff --git a/packages/testing/src/mocks/stripe.mock.ts b/packages/testing/src/mocks/stripe.mock.ts new file mode 100644 index 0000000000..8786dbe1a1 --- /dev/null +++ b/packages/testing/src/mocks/stripe.mock.ts @@ -0,0 +1,71 @@ +import type Stripe from 'stripe' +import { vi } from 'vitest' + +/** + * Mock for `@/lib/billing/stripe-client`. + * + * @example + * ```ts + * import { stripeClientMock, stripeClientMockFns } from '@sim/testing' + * vi.mock('@/lib/billing/stripe-client', () => stripeClientMock) + * + * stripeClientMockFns.mockRequireStripeClient.mockReturnValue(fakeStripe) + * ``` + */ +export const stripeClientMockFns = { + mockRequireStripeClient: vi.fn(), + mockGetStripeClient: vi.fn(), + mockHasValidStripeCredentials: vi.fn(() => true), +} + +export const stripeClientMock = { + requireStripeClient: stripeClientMockFns.mockRequireStripeClient, + getStripeClient: stripeClientMockFns.mockGetStripeClient, + hasValidStripeCredentials: stripeClientMockFns.mockHasValidStripeCredentials, +} + +/** + * Mock for `@/lib/billing/stripe-payment-method`. + * + * @example + * ```ts + * import { stripePaymentMethodMock, stripePaymentMethodMockFns } from '@sim/testing' + * vi.mock('@/lib/billing/stripe-payment-method', () => stripePaymentMethodMock) + * ``` + */ +export const stripePaymentMethodMockFns = { + mockResolveDefaultPaymentMethod: vi.fn(async () => ({ + paymentMethodId: undefined as string | undefined, + collectionMethod: 'charge_automatically' as 'charge_automatically' | 'send_invoice' | null, + })), + mockGetCustomerId: vi.fn(), +} + +export const stripePaymentMethodMock = { + resolveDefaultPaymentMethod: stripePaymentMethodMockFns.mockResolveDefaultPaymentMethod, + getCustomerId: stripePaymentMethodMockFns.mockGetCustomerId, +} + +/** + * Build a minimal `Stripe.Event` with the given type and object payload. + * Fills in a deterministic `id` (`evt_${type}`) and nests `object` under + * `data.object` as Stripe does. + */ +export function createMockStripeEvent( + type: string, + object: T, + overrides: Partial = {} +): Stripe.Event { + return { + id: `evt_${type}`, + object: 'event', + api_version: '2024-06-20', + created: Math.floor(Date.now() / 1000), + livemode: false, + pending_webhooks: 0, + request: null, + type, + data: { object: object as unknown as Stripe.Event.Data.Object }, + ...overrides, + } as Stripe.Event +}