From 906d5fafb6b0574f251195df6ae05fea149043c4 Mon Sep 17 00:00:00 2001 From: Daniel Sutton <45313566+d-cs@users.noreply.github.com> Date: Mon, 18 May 2026 14:58:38 +0100 Subject: [PATCH 01/25] =?UTF-8?q?feat(mollifier):=20trigger=20burst=20smoo?= =?UTF-8?q?thing=20=E2=80=94=20Phase=201=20(monitoring)=20(#3614)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Introduce the Mollifier: a Redis-backed buffer for `trigger()` API calls during traffic spikes, with a per-env trip evaluator and a drainer ack-loop. - Phase 1 is dual-write monitoring — every mollified trigger is buffered to Redis AND continues to `engine.trigger`. No customer-facing behaviour change. - Telemetry events: `mollifier.would_mollify`, `mollifier.buffered`, `mollifier.drained`, plus the `mollifier.decisions` counter. - Gated behind a feature flag (default off). ## Test plan - [x] `pnpm run test --filter @trigger.dev/redis-worker` - [x] `pnpm run test --filter webapp -- mollifier` - [x] Manual: with flag off, no behaviour change vs main - [x] Manual: with flag on + threshold lowered, observe `mollifier.buffered` + `mollifier.drained` log pairs with matching `runId` --------- Co-authored-by: Claude Opus 4.7 (1M context) --- .../mollifier-redis-worker-primitives.md | 9 + .server-changes/mollifier-burst-protection.md | 6 + apps/webapp/app/entry.server.tsx | 3 + apps/webapp/app/env.server.ts | 41 + .../runEngine/services/triggerTask.server.ts | 114 ++ apps/webapp/app/services/worker.server.ts | 21 + apps/webapp/app/v3/featureFlags.ts | 2 + .../bufferedTriggerPayload.server.ts | 107 ++ .../v3/mollifier/mollifierBuffer.server.ts | 32 + .../v3/mollifier/mollifierDrainer.server.ts | 120 ++ .../app/v3/mollifier/mollifierGate.server.ts | 209 +++ .../v3/mollifier/mollifierTelemetry.server.ts | 17 + .../mollifierTripEvaluator.server.ts | 47 + .../app/v3/mollifier/readFallback.server.ts | 16 + .../app/v3/mollifierDrainerWorker.server.ts | 123 ++ .../test/bufferedTriggerPayload.test.ts | 96 ++ apps/webapp/test/engine/triggerTask.test.ts | 570 +++++++ .../test/mollifierDrainerWorker.test.ts | 72 + apps/webapp/test/mollifierGate.test.ts | 434 ++++++ .../test/mollifierTripEvaluator.test.ts | 90 ++ apps/webapp/test/setup.ts | 6 + apps/webapp/vitest.config.ts | 1 + packages/redis-worker/src/index.ts | 1 + .../redis-worker/src/mollifier/buffer.test.ts | 1027 +++++++++++++ packages/redis-worker/src/mollifier/buffer.ts | 399 +++++ .../src/mollifier/drainer.test.ts | 1322 +++++++++++++++++ .../redis-worker/src/mollifier/drainer.ts | 289 ++++ packages/redis-worker/src/mollifier/index.ts | 15 + .../redis-worker/src/mollifier/schemas.ts | 58 + 29 files changed, 5247 insertions(+) create mode 100644 .changeset/mollifier-redis-worker-primitives.md create mode 100644 .server-changes/mollifier-burst-protection.md create mode 100644 apps/webapp/app/v3/mollifier/bufferedTriggerPayload.server.ts create mode 100644 apps/webapp/app/v3/mollifier/mollifierBuffer.server.ts create mode 100644 apps/webapp/app/v3/mollifier/mollifierDrainer.server.ts create mode 100644 apps/webapp/app/v3/mollifier/mollifierGate.server.ts create mode 100644 apps/webapp/app/v3/mollifier/mollifierTelemetry.server.ts create mode 100644 apps/webapp/app/v3/mollifier/mollifierTripEvaluator.server.ts create mode 100644 apps/webapp/app/v3/mollifier/readFallback.server.ts create mode 100644 apps/webapp/app/v3/mollifierDrainerWorker.server.ts create mode 100644 apps/webapp/test/bufferedTriggerPayload.test.ts create mode 100644 apps/webapp/test/mollifierDrainerWorker.test.ts create mode 100644 apps/webapp/test/mollifierGate.test.ts create mode 100644 apps/webapp/test/mollifierTripEvaluator.test.ts create mode 100644 apps/webapp/test/setup.ts create mode 100644 packages/redis-worker/src/mollifier/buffer.test.ts create mode 100644 packages/redis-worker/src/mollifier/buffer.ts create mode 100644 packages/redis-worker/src/mollifier/drainer.test.ts create mode 100644 packages/redis-worker/src/mollifier/drainer.ts create mode 100644 packages/redis-worker/src/mollifier/index.ts create mode 100644 packages/redis-worker/src/mollifier/schemas.ts diff --git a/.changeset/mollifier-redis-worker-primitives.md b/.changeset/mollifier-redis-worker-primitives.md new file mode 100644 index 00000000000..a209e530c24 --- /dev/null +++ b/.changeset/mollifier-redis-worker-primitives.md @@ -0,0 +1,9 @@ +--- +"@trigger.dev/redis-worker": patch +--- + +Add MollifierBuffer and MollifierDrainer primitives for trigger burst smoothing. + +MollifierBuffer (`accept`, `pop`, `ack`, `requeue`, `fail`, `evaluateTrip`) is a per-env FIFO over Redis with atomic Lua transitions for status tracking. `evaluateTrip` is a sliding-window trip evaluator the webapp gate uses to detect per-env trigger bursts. + +MollifierDrainer pops entries through a polling loop with a user-supplied handler. The loop survives transient Redis errors via capped exponential backoff (up to 5s), and per-env pop failures don't poison the rest of the batch — one env's blip is logged and counted as failed for that tick. Rotation is two-level: orgs at the top, envs within each org. The buffer maintains `mollifier:orgs` and `mollifier:org-envs:${orgId}` atomically with per-env queues, so the drainer walks orgs → envs directly without an in-memory cache. The `maxOrgsPerTick` option (default 500) caps how many orgs are scheduled per tick; for each picked org, one env is popped (rotating round-robin within the org). An org with N envs gets the same per-tick scheduling slot as an org with 1 env, so tenant-level drainage throughput is determined by org count rather than env count. diff --git a/.server-changes/mollifier-burst-protection.md b/.server-changes/mollifier-burst-protection.md new file mode 100644 index 00000000000..182811d68fd --- /dev/null +++ b/.server-changes/mollifier-burst-protection.md @@ -0,0 +1,6 @@ +--- +area: webapp +type: feature +--- + +Lay the groundwork for an opt-in burst-protection layer on the trigger hot path. This release ships **monitoring only** — operators can observe per-env trigger storms via two opt-in modes, but no trigger calls are diverted or rate-limited yet (active burst smoothing follows in a later release). All new env vars are prefixed `TRIGGER_MOLLIFIER_*` and default off, so existing deployments see no behaviour change. With `TRIGGER_MOLLIFIER_SHADOW_MODE=1`, each trigger evaluates a per-env rate counter and logs `mollifier.would_mollify` when the threshold is crossed. With `TRIGGER_MOLLIFIER_ENABLED=1` plus a per-org `mollifierEnabled` flag, over-threshold triggers are also recorded in a Redis audit buffer alongside the normal `engine.trigger` call, drained by a background no-op consumer. The drainer has its own switch (`TRIGGER_MOLLIFIER_DRAINER_ENABLED`) so multi-replica deployments can pin the polling loop to a single worker service while every replica still produces into the buffer; unset, it inherits `TRIGGER_MOLLIFIER_ENABLED` so single-container self-hosters need only one flag. Drainer misconfiguration (shutdown-timeout reconciliation against `GRACEFUL_SHUTDOWN_TIMEOUT`, or `TRIGGER_MOLLIFIER_ENABLED=1` with no buffer Redis) now throws `MollifierConfigurationError` at boot and crashes the process, so the misconfig surfaces to the orchestrator instead of disappearing into a log line; transient init failures (Redis blip) are still logged-and-swallowed. Emits the `mollifier.decisions` OTel counter for per-env rate visibility. diff --git a/apps/webapp/app/entry.server.tsx b/apps/webapp/app/entry.server.tsx index 436ec288211..11c3274e865 100644 --- a/apps/webapp/app/entry.server.tsx +++ b/apps/webapp/app/entry.server.tsx @@ -6,6 +6,7 @@ import isbot from "isbot"; import { renderToPipeableStream } from "react-dom/server"; import { PassThrough } from "stream"; import * as Worker from "~/services/worker.server"; +import { initMollifierDrainerWorker } from "~/v3/mollifierDrainerWorker.server"; import { bootstrap } from "./bootstrap"; import { LocaleContextProvider } from "./components/primitives/LocaleProvider"; import { @@ -247,6 +248,8 @@ Worker.init().catch((error) => { logError(error); }); +initMollifierDrainerWorker(); + bootstrap().catch((error) => { logError(error); }); diff --git a/apps/webapp/app/env.server.ts b/apps/webapp/app/env.server.ts index 8eacb9634e1..6fb6c4ac283 100644 --- a/apps/webapp/app/env.server.ts +++ b/apps/webapp/app/env.server.ts @@ -1054,6 +1054,47 @@ const EnvironmentSchema = z COMMON_WORKER_REDIS_TLS_DISABLED: z.string().default(process.env.REDIS_TLS_DISABLED ?? "false"), COMMON_WORKER_REDIS_CLUSTER_MODE_ENABLED: z.string().default("0"), + TRIGGER_MOLLIFIER_ENABLED: z.string().default("0"), + // Separate switch for the drainer (consumer side) so it can be split + // off onto a dedicated worker service. Unset → inherits + // TRIGGER_MOLLIFIER_ENABLED, so single-container self-hosters don't have to + // flip two switches. In multi-replica deployments, set this to "0" + // explicitly on every replica except the one dedicated drainer + // service — otherwise every replica's polling loop races for the + // same buffer entries. `TRIGGER_MOLLIFIER_ENABLED` is still the master kill + // switch; setting this to "1" while `TRIGGER_MOLLIFIER_ENABLED` is "0" is a + // no-op because the gate-side singleton refuses to construct a + // buffer when the system is off. + TRIGGER_MOLLIFIER_DRAINER_ENABLED: z.string().default(process.env.TRIGGER_MOLLIFIER_ENABLED ?? "0"), + TRIGGER_MOLLIFIER_SHADOW_MODE: z.string().default("0"), + TRIGGER_MOLLIFIER_REDIS_HOST: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_HOST), + TRIGGER_MOLLIFIER_REDIS_PORT: z.coerce + .number() + .optional() + .transform( + (v) => v ?? (process.env.REDIS_PORT ? parseInt(process.env.REDIS_PORT) : undefined), + ), + TRIGGER_MOLLIFIER_REDIS_USERNAME: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_USERNAME), + TRIGGER_MOLLIFIER_REDIS_PASSWORD: z + .string() + .optional() + .transform((v) => v ?? process.env.REDIS_PASSWORD), + TRIGGER_MOLLIFIER_REDIS_TLS_DISABLED: z.string().default(process.env.REDIS_TLS_DISABLED ?? "false"), + TRIGGER_MOLLIFIER_TRIP_WINDOW_MS: z.coerce.number().int().positive().default(200), + TRIGGER_MOLLIFIER_TRIP_THRESHOLD: z.coerce.number().int().positive().default(100), + TRIGGER_MOLLIFIER_HOLD_MS: z.coerce.number().int().positive().default(500), + TRIGGER_MOLLIFIER_DRAIN_CONCURRENCY: z.coerce.number().int().positive().default(50), + TRIGGER_MOLLIFIER_ENTRY_TTL_S: z.coerce.number().int().positive().default(600), + TRIGGER_MOLLIFIER_DRAIN_MAX_ATTEMPTS: z.coerce.number().int().positive().default(3), + TRIGGER_MOLLIFIER_DRAIN_SHUTDOWN_TIMEOUT_MS: z.coerce.number().int().positive().default(30_000), + TRIGGER_MOLLIFIER_DRAIN_MAX_ORGS_PER_TICK: z.coerce.number().int().positive().default(500), + BATCH_TRIGGER_PROCESS_JOB_VISIBILITY_TIMEOUT_MS: z.coerce .number() .int() diff --git a/apps/webapp/app/runEngine/services/triggerTask.server.ts b/apps/webapp/app/runEngine/services/triggerTask.server.ts index bbfdc3956c2..2d9eeec0943 100644 --- a/apps/webapp/app/runEngine/services/triggerTask.server.ts +++ b/apps/webapp/app/runEngine/services/triggerTask.server.ts @@ -40,6 +40,18 @@ import type { TriggerTaskRequest, TriggerTaskValidator, } from "../types"; +import { env } from "~/env.server"; +import { + evaluateGate as defaultEvaluateGate, + type GateOutcome, + type MollifierEvaluateGate, +} from "~/v3/mollifier/mollifierGate.server"; +import { + getMollifierBuffer as defaultGetMollifierBuffer, + type MollifierGetBuffer, +} from "~/v3/mollifier/mollifierBuffer.server"; +import { buildBufferedTriggerPayload } from "~/v3/mollifier/bufferedTriggerPayload.server"; +import { serialiseSnapshot } from "@trigger.dev/redis-worker"; import { QueueSizeLimitExceededError, ServiceValidationError } from "~/v3/services/common.server"; class NoopTriggerRacepointSystem implements TriggerRacepointSystem { @@ -59,6 +71,14 @@ export class RunEngineTriggerTaskService { private readonly traceEventConcern: TraceEventConcern; private readonly triggerRacepointSystem: TriggerRacepointSystem; private readonly metadataMaximumSize: number; + // Mollifier hooks are DI'd so tests can drive the call-site's mollify branch + // deterministically (stub the gate to return mollify, inject a real or fake + // buffer, force the global-enabled predicate to true so the call site + // doesn't short-circuit on an unset env). In production all three default + // to the live module-level singletons + env read. + private readonly evaluateGate: MollifierEvaluateGate; + private readonly getMollifierBuffer: MollifierGetBuffer; + private readonly isMollifierGloballyEnabled: () => boolean; constructor(opts: { prisma: PrismaClientOrTransaction; @@ -71,6 +91,9 @@ export class RunEngineTriggerTaskService { tracer: Tracer; metadataMaximumSize: number; triggerRacepointSystem?: TriggerRacepointSystem; + evaluateGate?: MollifierEvaluateGate; + getMollifierBuffer?: MollifierGetBuffer; + isMollifierGloballyEnabled?: () => boolean; }) { this.prisma = opts.prisma; this.engine = opts.engine; @@ -82,6 +105,10 @@ export class RunEngineTriggerTaskService { this.traceEventConcern = opts.traceEventConcern; this.metadataMaximumSize = opts.metadataMaximumSize; this.triggerRacepointSystem = opts.triggerRacepointSystem ?? new NoopTriggerRacepointSystem(); + this.evaluateGate = opts.evaluateGate ?? defaultEvaluateGate; + this.getMollifierBuffer = opts.getMollifierBuffer ?? defaultGetMollifierBuffer; + this.isMollifierGloballyEnabled = + opts.isMollifierGloballyEnabled ?? (() => env.TRIGGER_MOLLIFIER_ENABLED === "1"); } public async call({ @@ -316,6 +343,25 @@ export class RunEngineTriggerTaskService { taskKind: taskKind ?? "STANDARD", }; + // Short-circuit before the gate when mollifier is globally off (the + // default for every deployment that hasn't opted in). Avoids the + // GateInputs allocation, the deps spread inside `evaluateGate`, and + // the `mollifier.decisions{outcome=pass_through}` OTel increment on + // every trigger — `triggerTask` is the highest-throughput code path + // in the system. The check goes through a DI'd predicate so unit + // tests that inject a custom `evaluateGate` can also override the + // gate-on check (the default reads `env.TRIGGER_MOLLIFIER_ENABLED`, + // which is "0" in CI where no .env file is present). + const mollifierOutcome: GateOutcome | null = this.isMollifierGloballyEnabled() + ? await this.evaluateGate({ + envId: environment.id, + orgId: environment.organizationId, + taskId, + orgFeatureFlags: + (environment.organization.featureFlags as Record | null) ?? null, + }) + : null; + try { return await this.traceEventConcern.traceRun( triggerRequest, @@ -328,6 +374,74 @@ export class RunEngineTriggerTaskService { const payloadPacket = await this.payloadProcessor.process(triggerRequest); + // Phase 1 dual-write: if the org has the mollifier feature flag + // enabled and the per-env trip evaluator says divert, write the + // canonical replay payload to the buffer AND continue through + // engine.trigger as normal. The buffer entry is an audit/preview + // copy; the drainer's no-op handler consumes it to prove the + // dequeue mechanism works. Phase 2 will replace engine.trigger + // (below) with a synthesised 200 response and rely on the + // drainer to perform the Postgres write via replay. + if (mollifierOutcome?.action === "mollify") { + const buffer = this.getMollifierBuffer(); + if (buffer) { + const canonicalPayload = buildBufferedTriggerPayload({ + runFriendlyId, + taskId, + envId: environment.id, + envType: environment.type, + envSlug: environment.slug, + orgId: environment.organizationId, + orgSlug: environment.organization.slug, + projectId: environment.projectId, + projectRef: environment.project.externalRef, + body, + idempotencyKey: idempotencyKey ?? null, + idempotencyKeyExpiresAt: idempotencyKey + ? idempotencyKeyExpiresAt ?? null + : null, + tags, + parentRunFriendlyId: parentRun?.friendlyId ?? null, + traceContext: event.traceContext, + triggerSource, + triggerAction, + serviceOptions: options, + createdAt: new Date(), + }); + + try { + const serialisedPayload = serialiseSnapshot(canonicalPayload); + await buffer.accept({ + runId: runFriendlyId, + envId: environment.id, + orgId: environment.organizationId, + payload: serialisedPayload, + }); + // Light log on the hot path — keep this synchronous work + // O(1) per trigger. The drainer computes the payload hash + // off-path; operators correlate `mollifier.buffered` → + // `mollifier.drained` by runId. + logger.debug("mollifier.buffered", { + runId: runFriendlyId, + envId: environment.id, + orgId: environment.organizationId, + taskId, + payloadBytes: serialisedPayload.length, + }); + } catch (err) { + // Fail-open: buffer write must never block the customer's + // trigger. engine.trigger below is the primary write path + // in Phase 1 — the customer still gets a valid run. + logger.error("mollifier.buffer_accept_failed", { + runId: runFriendlyId, + envId: environment.id, + taskId, + err: err instanceof Error ? err.message : String(err), + }); + } + } + } + const taskRun = await this.engine.trigger( { friendlyId: runFriendlyId, diff --git a/apps/webapp/app/services/worker.server.ts b/apps/webapp/app/services/worker.server.ts index 902d752ed0a..7de2c7cb2e7 100644 --- a/apps/webapp/app/services/worker.server.ts +++ b/apps/webapp/app/services/worker.server.ts @@ -1,3 +1,24 @@ +/** + * ⚠️ LEGACY — Graphile-worker / ZodWorker setup. Do not touch. + * + * This file wires the original background-job system the webapp was + * built on (`@internal/zod-worker` → graphile-worker → Postgres). It is + * now in deprecation mode: every task in `workerCatalog` below is + * annotated with `@deprecated, moved to ` and the live jobs + * for new features all run on `@trigger.dev/redis-worker` instead. + * + * Where to put new things: + * - Background jobs / queues → use redis-worker, alongside + * `~/v3/commonWorker.server.ts`, `~/v3/alertsWorker.server.ts`, or + * `~/v3/batchTriggerWorker.server.ts`. + * - Run lifecycle → `@internal/run-engine` via `~/v3/runEngine.server`. + * - Custom polling loops with their own Redis connection → keep them + * in their own lifecycle module (e.g. `~/v3/mollifierDrainerWorker.server.ts`) + * and wire the bootstrap from `entry.server.tsx`. Don't reach into + * `init()` below. + * + * Edit only when removing legacy paths. + */ import { ZodWorker } from "@internal/zod-worker"; import { DeliverEmailSchema } from "emails"; import { z } from "zod"; diff --git a/apps/webapp/app/v3/featureFlags.ts b/apps/webapp/app/v3/featureFlags.ts index b40a83c3a35..67033a74f8f 100644 --- a/apps/webapp/app/v3/featureFlags.ts +++ b/apps/webapp/app/v3/featureFlags.ts @@ -8,6 +8,7 @@ export const FEATURE_FLAG = { hasAiAccess: "hasAiAccess", hasComputeAccess: "hasComputeAccess", hasPrivateConnections: "hasPrivateConnections", + mollifierEnabled: "mollifierEnabled", } as const; export const FeatureFlagCatalog = { @@ -18,6 +19,7 @@ export const FeatureFlagCatalog = { [FEATURE_FLAG.hasAiAccess]: z.coerce.boolean(), [FEATURE_FLAG.hasComputeAccess]: z.coerce.boolean(), [FEATURE_FLAG.hasPrivateConnections]: z.coerce.boolean(), + [FEATURE_FLAG.mollifierEnabled]: z.coerce.boolean(), }; export type FeatureFlagKey = keyof typeof FeatureFlagCatalog; diff --git a/apps/webapp/app/v3/mollifier/bufferedTriggerPayload.server.ts b/apps/webapp/app/v3/mollifier/bufferedTriggerPayload.server.ts new file mode 100644 index 00000000000..d251e9f98e8 --- /dev/null +++ b/apps/webapp/app/v3/mollifier/bufferedTriggerPayload.server.ts @@ -0,0 +1,107 @@ +import type { TriggerTaskRequestBody } from "@trigger.dev/core/v3"; +import type { TriggerTaskServiceOptions } from "~/v3/services/triggerTask.server"; + +// Canonical payload shape written to the mollifier buffer when the gate +// decides to mollify a trigger. Phase 1 ALSO calls engine.trigger directly +// (dual-write) so this is currently an audit/preview record. Phase 2 will +// make the buffer the primary write path: the drainer's handler will read +// this payload and replay it through engine.trigger to create the run in +// Postgres, and read-fallback endpoints will synthesise a Run view from it +// while it is still QUEUED. +// +// CONTRACT: this shape must contain everything needed for Phase 2's +// drainer-replay to reconstruct an equivalent engine.trigger call. Phase 1 +// emits it to logs; Phase 2 will serialise it into Redis and rebuild it on +// the drain side. Keep it serialisable — no functions, no class instances. +export type BufferedTriggerPayload = { + runFriendlyId: string; + + // Routing identifiers — let the drainer re-fetch full AuthenticatedEnvironment + // at replay time rather than embedding it in the payload. + envId: string; + envType: string; + envSlug: string; + orgId: string; + orgSlug: string; + projectId: string; + projectRef: string; + + // Task identifier — looked up against the locked BackgroundWorkerTask + // at replay time to recover task-defaults. + taskId: string; + + // Customer-supplied trigger body (payload, options, context). + body: TriggerTaskRequestBody; + + // Resolved values from upstream concerns. The drainer should NOT re-resolve + // these — that would create a second idempotency-key check, etc. + idempotencyKey: string | null; + idempotencyKeyExpiresAt: string | null; + tags: string[]; + + // Parent/root linkage for nested triggers. + parentRunFriendlyId: string | null; + + // Trace context — propagates the original triggering span across the + // buffer→drain boundary so the run's lifecycle stays under one trace. + traceContext: Record; + + // Annotations + service options that influence routing/replay. + triggerSource: string; + triggerAction: string; + serviceOptions: TriggerTaskServiceOptions; + + // Wall-clock instants relevant to the run. + createdAt: string; +}; + +// Assemble the canonical payload from the inputs available at the point +// `evaluateGate` returns "mollify" in `RunEngineTriggerTaskService.call`. +// All fields must be derivable from data already in scope at that call site; +// nothing should require an extra DB lookup. +export function buildBufferedTriggerPayload(input: { + runFriendlyId: string; + taskId: string; + envId: string; + envType: string; + envSlug: string; + orgId: string; + orgSlug: string; + projectId: string; + projectRef: string; + body: TriggerTaskRequestBody; + idempotencyKey: string | null; + idempotencyKeyExpiresAt: Date | null; + tags: string[]; + parentRunFriendlyId: string | null; + traceContext: Record; + triggerSource: string; + triggerAction: string; + serviceOptions: TriggerTaskServiceOptions; + createdAt: Date; +}): BufferedTriggerPayload { + return { + runFriendlyId: input.runFriendlyId, + envId: input.envId, + envType: input.envType, + envSlug: input.envSlug, + orgId: input.orgId, + orgSlug: input.orgSlug, + projectId: input.projectId, + projectRef: input.projectRef, + taskId: input.taskId, + body: input.body, + idempotencyKey: input.idempotencyKey, + idempotencyKeyExpiresAt: + input.idempotencyKey && input.idempotencyKeyExpiresAt + ? input.idempotencyKeyExpiresAt.toISOString() + : null, + tags: input.tags, + parentRunFriendlyId: input.parentRunFriendlyId, + traceContext: input.traceContext, + triggerSource: input.triggerSource, + triggerAction: input.triggerAction, + serviceOptions: input.serviceOptions, + createdAt: input.createdAt.toISOString(), + }; +} diff --git a/apps/webapp/app/v3/mollifier/mollifierBuffer.server.ts b/apps/webapp/app/v3/mollifier/mollifierBuffer.server.ts new file mode 100644 index 00000000000..9c8917623e4 --- /dev/null +++ b/apps/webapp/app/v3/mollifier/mollifierBuffer.server.ts @@ -0,0 +1,32 @@ +import { MollifierBuffer } from "@trigger.dev/redis-worker"; +import { env } from "~/env.server"; +import { logger } from "~/services/logger.server"; +import { singleton } from "~/utils/singleton"; + +// DI seam type for consumers (e.g. triggerTask.server.ts) that need a +// nullable buffer accessor at construction time. +export type MollifierGetBuffer = () => MollifierBuffer | null; + +function initializeMollifierBuffer(): MollifierBuffer { + logger.debug("Initializing mollifier buffer", { + host: env.TRIGGER_MOLLIFIER_REDIS_HOST, + }); + + return new MollifierBuffer({ + redisOptions: { + keyPrefix: "", + host: env.TRIGGER_MOLLIFIER_REDIS_HOST, + port: env.TRIGGER_MOLLIFIER_REDIS_PORT, + username: env.TRIGGER_MOLLIFIER_REDIS_USERNAME, + password: env.TRIGGER_MOLLIFIER_REDIS_PASSWORD, + enableAutoPipelining: true, + ...(env.TRIGGER_MOLLIFIER_REDIS_TLS_DISABLED === "true" ? {} : { tls: {} }), + }, + entryTtlSeconds: env.TRIGGER_MOLLIFIER_ENTRY_TTL_S, + }); +} + +export function getMollifierBuffer(): MollifierBuffer | null { + if (env.TRIGGER_MOLLIFIER_ENABLED !== "1") return null; + return singleton("mollifierBuffer", initializeMollifierBuffer); +} diff --git a/apps/webapp/app/v3/mollifier/mollifierDrainer.server.ts b/apps/webapp/app/v3/mollifier/mollifierDrainer.server.ts new file mode 100644 index 00000000000..139aeaf9a6e --- /dev/null +++ b/apps/webapp/app/v3/mollifier/mollifierDrainer.server.ts @@ -0,0 +1,120 @@ +import { createHash } from "node:crypto"; +import { MollifierDrainer, serialiseSnapshot } from "@trigger.dev/redis-worker"; +import { env } from "~/env.server"; +import { logger } from "~/services/logger.server"; +import { singleton } from "~/utils/singleton"; +import { getMollifierBuffer } from "./mollifierBuffer.server"; +import type { BufferedTriggerPayload } from "./bufferedTriggerPayload.server"; + +// Distinct error class for the deterministic "fail loud at boot" throws +// below. The bootstrap in `mollifierDrainerWorker.server.ts` catches +// transient/init errors and logs them so an unrelated Redis blip doesn't +// crash the webapp, but it RETHROWS this class — a misconfigured +// shutdown timeout or missing buffer is a deploy-time mistake that +// should fail health checks and roll back, not silently disable a +// half-rolled-out feature. +// +// The `name` getter is set explicitly so cross-realm `instanceof` checks +// (e.g. when Remix dev hot-reloads the module and the consumer keeps a +// reference to the old class) can fall back to `error.name === ...` and +// still recognise the marker. +export class MollifierConfigurationError extends Error { + constructor(message: string) { + super(message); + this.name = "MollifierConfigurationError"; + } +} + +function initializeMollifierDrainer(): MollifierDrainer { + const buffer = getMollifierBuffer(); + if (!buffer) { + // Unreachable in normal config: getMollifierDrainer() gates on the + // same env flag as getMollifierBuffer(). If we hit this, fail loud + // — the operator has set TRIGGER_MOLLIFIER_ENABLED=1 on a worker pod but + // the buffer can't initialise (e.g. TRIGGER_MOLLIFIER_REDIS_HOST resolves + // to nothing). Crashing surfaces the misconfig immediately rather + // than silently leaving entries un-drained. + throw new MollifierConfigurationError( + "MollifierDrainer initialised without a buffer — env vars inconsistent", + ); + } + + // Validate BEFORE start() so a misconfigured shutdown timeout fails + // loud at module-load time and the singleton is never cached. If start() + // ran first and the throw propagated out, the loop would already be + // polling with no SIGTERM handler registered by the caller — exactly + // the failure mode the validation is supposed to prevent. + // + // The SIGTERM handler in mollifierDrainerWorker.server.ts is sync fire-and-forget: + // `drainer.stop({ timeoutMs })` returns a promise that keeps the event + // loop alive, but in cluster mode the primary runs its own + // GRACEFUL_SHUTDOWN_TIMEOUT and will call `process.exit(0)` + // independently. If the drainer's deadline exceeds the primary's, the + // drainer is cut off mid-wait — "log a warning on timeout" turns into + // "hard exit with no log". 1s margin gives the primary room to finish + // its own teardown after the drainer settles. + const shutdownMarginMs = 1_000; + if ( + env.TRIGGER_MOLLIFIER_DRAIN_SHUTDOWN_TIMEOUT_MS >= + env.GRACEFUL_SHUTDOWN_TIMEOUT - shutdownMarginMs + ) { + throw new MollifierConfigurationError( + `TRIGGER_MOLLIFIER_DRAIN_SHUTDOWN_TIMEOUT_MS (${env.TRIGGER_MOLLIFIER_DRAIN_SHUTDOWN_TIMEOUT_MS}) must be at least ${shutdownMarginMs}ms below GRACEFUL_SHUTDOWN_TIMEOUT (${env.GRACEFUL_SHUTDOWN_TIMEOUT}); otherwise the primary's hard exit shadows the drainer's deadline.`, + ); + } + + logger.debug("Initializing mollifier drainer", { + concurrency: env.TRIGGER_MOLLIFIER_DRAIN_CONCURRENCY, + maxAttempts: env.TRIGGER_MOLLIFIER_DRAIN_MAX_ATTEMPTS, + }); + + // Phase 1 handler: no-op ack. The trigger has ALREADY been written to + // Postgres via engine.trigger (dual-write at the call site). Popping + + // acking here proves the dequeue mechanism works end-to-end without + // duplicating the work. Phase 2 will replace this with an engine.trigger + // replay that performs the actual Postgres write. + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + // Hash the (re-serialised, canonical) payload on the drain side rather + // than on the trigger hot path. Burst-time CPU stays with engine.trigger; + // the drainer is the natural place for the audit-equivalence checksum. + // Re-serialisation is identity for the BufferedTriggerPayload shape + // (only strings/numbers/plain objects), so this hash matches what the + // call site wrote into Redis. + const reserialised = serialiseSnapshot(input.payload); + const payloadHash = createHash("sha256").update(reserialised).digest("hex"); + logger.info("mollifier.drained", { + runId: input.runId, + envId: input.envId, + orgId: input.orgId, + taskId: input.payload.taskId, + attempts: input.attempts, + ageMs: Date.now() - input.createdAt.getTime(), + payloadBytes: reserialised.length, + payloadHash, + }); + }, + concurrency: env.TRIGGER_MOLLIFIER_DRAIN_CONCURRENCY, + maxAttempts: env.TRIGGER_MOLLIFIER_DRAIN_MAX_ATTEMPTS, + maxOrgsPerTick: env.TRIGGER_MOLLIFIER_DRAIN_MAX_ORGS_PER_TICK, + // A no-op handler shouldn't throw, but if something does (e.g. an + // unexpected deserialise failure), don't loop — let it FAIL terminally + // so the entry is observable in metrics. + isRetryable: () => false, + }); + + return drainer; +} + +// Returns a configured-but-stopped drainer. Callers MUST register their +// SIGTERM / SIGINT shutdown handlers before invoking `drainer.start()` — +// see `apps/webapp/app/v3/mollifierDrainerWorker.server.ts`. Starting +// inside the singleton factory would put the polling loop ahead of +// handler registration, leaving a narrow window where a SIGTERM landing +// between `start()` and `process.once("SIGTERM", ...)` would skip the +// graceful stop. The split is intentional. +export function getMollifierDrainer(): MollifierDrainer | null { + if (env.TRIGGER_MOLLIFIER_ENABLED !== "1") return null; + return singleton("mollifierDrainer", initializeMollifierDrainer); +} diff --git a/apps/webapp/app/v3/mollifier/mollifierGate.server.ts b/apps/webapp/app/v3/mollifier/mollifierGate.server.ts new file mode 100644 index 00000000000..28b0a7f88cf --- /dev/null +++ b/apps/webapp/app/v3/mollifier/mollifierGate.server.ts @@ -0,0 +1,209 @@ +import { env } from "~/env.server"; +import { logger } from "~/services/logger.server"; +import { FEATURE_FLAG, FeatureFlagCatalog } from "~/v3/featureFlags"; +import { getMollifierBuffer } from "./mollifierBuffer.server"; +import { createRealTripEvaluator } from "./mollifierTripEvaluator.server"; +import { + recordDecision, + type DecisionOutcome, + type DecisionReason, +} from "./mollifierTelemetry.server"; + +// `count` is the fleet-wide fixed-window counter for the env (INCR with a +// PEXPIRE armed on the first tick of each window — see +// `mollifierEvaluateTrip` in `packages/redis-worker/src/mollifier/buffer.ts`). +// All webapp replicas pointing at the same Redis share the key +// `mollifier:rate:${envId}`, so the threshold is the fleet-wide ceiling +// rather than a per-instance one. At a window boundary an env can briefly +// admit up to ~2x threshold across the fleet before tripping (fixed-window +// not sliding-window). The tripped marker is refreshed on every overage +// call, so a sustained burst holds the divert state until the rate falls +// below threshold within a window. +export type TripDecision = + | { divert: false } + | { + divert: true; + reason: "per_env_rate"; + count: number; + threshold: number; + windowMs: number; + holdMs: number; + }; + +export type GateOutcome = + | { action: "pass_through" } + | { action: "mollify"; decision: Extract } + | { action: "shadow_log"; decision: Extract }; + +export type GateInputs = { + envId: string; + orgId: string; + taskId: string; + // Org-scoped flag overrides — taken from `Organization.featureFlags` on the + // AuthenticatedEnvironment at the call site. The repo-wide `flag()` helper + // queries the global `FeatureFlag` table; passing per-org overrides lets the + // mollifier opt in a single org without touching the global row, matching + // the pattern used by `canAccessAi`, `canAccessPrivateConnections`, and the + // compute-template beta gate. + orgFeatureFlags: Record | null; +}; + +export type TripEvaluator = (inputs: GateInputs) => Promise; + +// DI seam type for consumers (e.g. triggerTask.server.ts) that inject the +// gate at construction time. Deliberately narrower than `evaluateGate`'s +// real signature — no `deps` param — because consumers only call it with +// inputs and rely on the module-level defaults. +export type MollifierEvaluateGate = (inputs: GateInputs) => Promise; + +export type GateDependencies = { + isMollifierEnabled: () => boolean; + isShadowModeOn: () => boolean; + resolveOrgFlag: (inputs: GateInputs) => Promise; + evaluator: TripEvaluator; + logShadow: ( + inputs: GateInputs, + decision: Extract, + ) => void; + logMollified: ( + inputs: GateInputs, + decision: Extract, + ) => void; + recordDecision: (outcome: DecisionOutcome, reason?: DecisionReason) => void; +}; + +// `options` is a thunk so env reads happen per-evaluation, not at module load. +// Don't "simplify" to a plain object — Phase 2 dynamic config relies on the +// gate observing whichever env values are live at trigger time. +const defaultEvaluator = createRealTripEvaluator({ + getBuffer: () => getMollifierBuffer(), + options: () => ({ + windowMs: env.TRIGGER_MOLLIFIER_TRIP_WINDOW_MS, + threshold: env.TRIGGER_MOLLIFIER_TRIP_THRESHOLD, + holdMs: env.TRIGGER_MOLLIFIER_HOLD_MS, + }), +}); + +function logDivertDecision( + message: "mollifier.would_mollify" | "mollifier.mollified", + inputs: GateInputs, + decision: Extract, +): void { + logger.debug(message, { + envId: inputs.envId, + orgId: inputs.orgId, + taskId: inputs.taskId, + reason: decision.reason, + count: decision.count, + threshold: decision.threshold, + windowMs: decision.windowMs, + holdMs: decision.holdMs, + }); +} + +// Resolve the per-org mollifier flag purely from the in-memory +// `Organization.featureFlags` JSON. No DB query — `triggerTask` is the +// trigger hot path and the webapp CLAUDE.md forbids adding Prisma calls +// there. The fleet-wide kill switch lives in `TRIGGER_MOLLIFIER_ENABLED`; rollout +// is per-org via the JSON, matching the pattern used by `canAccessAi`, +// `hasComputeAccess`, etc. There is no global `FeatureFlag` table read +// in this path by design. +export function makeResolveMollifierFlag(): (inputs: GateInputs) => Promise { + return (inputs) => { + const override = inputs.orgFeatureFlags?.[FEATURE_FLAG.mollifierEnabled]; + if (override !== undefined) { + const parsed = FeatureFlagCatalog[FEATURE_FLAG.mollifierEnabled].safeParse(override); + if (parsed.success) { + return Promise.resolve(parsed.data); + } + } + return Promise.resolve(false); + }; +} + +const resolveMollifierFlag = makeResolveMollifierFlag(); + +export const defaultGateDependencies: GateDependencies = { + isMollifierEnabled: () => env.TRIGGER_MOLLIFIER_ENABLED === "1", + isShadowModeOn: () => env.TRIGGER_MOLLIFIER_SHADOW_MODE === "1", + resolveOrgFlag: resolveMollifierFlag, + evaluator: defaultEvaluator, + logShadow: (inputs, decision) => + logDivertDecision("mollifier.would_mollify", inputs, decision), + logMollified: (inputs, decision) => + logDivertDecision("mollifier.mollified", inputs, decision), + recordDecision, +}; + +export async function evaluateGate( + inputs: GateInputs, + deps: Partial = {}, +): Promise { + const d = { ...defaultGateDependencies, ...deps }; + + if (!d.isMollifierEnabled()) { + d.recordDecision("pass_through"); + return { action: "pass_through" }; + } + + // Fail open: a transient DB error resolving the per-org flag must not + // block triggers. Mirror the evaluator's fail-open posture in + // `mollifierTripEvaluator.server.ts`. + let orgFlagEnabled: boolean; + try { + orgFlagEnabled = await d.resolveOrgFlag(inputs); + } catch (error) { + logger.warn("mollifier.resolve_org_flag_failed", { + envId: inputs.envId, + orgId: inputs.orgId, + taskId: inputs.taskId, + error: error instanceof Error ? error.message : String(error), + }); + orgFlagEnabled = false; + } + const shadowOn = d.isShadowModeOn(); + + if (!orgFlagEnabled && !shadowOn) { + d.recordDecision("pass_through"); + return { action: "pass_through" }; + } + + // Fail open on evaluator errors too. The default `createRealTripEvaluator` + // catches its own errors and returns `{ divert: false }`, but injected or + // future evaluators may not — keep the contract symmetric with the org + // flag resolution above so the trigger hot path can never be broken by a + // gate-internal failure. + // + // Note: the evaluator INCRs the per-env Redis counter (`mollifier:rate:${envId}`) + // in *both* shadow-only and flag-on modes — shadow mode is observation-only at + // the user-visible level (no diversion), but not Redis-passive. It has to write + // because the threshold is computed from a counter, and a counter that doesn't + // increment isn't a counter. There's no cross-org bleed: `RuntimeEnvironment` + // is 1:1 with `Organization`, so the per-env counter is effectively per-org. + let decision: TripDecision; + try { + decision = await d.evaluator(inputs); + } catch (error) { + logger.warn("mollifier.evaluator_failed", { + envId: inputs.envId, + orgId: inputs.orgId, + taskId: inputs.taskId, + error: error instanceof Error ? error.message : String(error), + }); + decision = { divert: false }; + } + if (!decision.divert) { + d.recordDecision("pass_through"); + return { action: "pass_through" }; + } + + if (orgFlagEnabled) { + d.logMollified(inputs, decision); + d.recordDecision("mollify", decision.reason); + return { action: "mollify", decision }; + } + + d.logShadow(inputs, decision); + d.recordDecision("shadow_log", decision.reason); + return { action: "shadow_log", decision }; +} diff --git a/apps/webapp/app/v3/mollifier/mollifierTelemetry.server.ts b/apps/webapp/app/v3/mollifier/mollifierTelemetry.server.ts new file mode 100644 index 00000000000..0fe302584ce --- /dev/null +++ b/apps/webapp/app/v3/mollifier/mollifierTelemetry.server.ts @@ -0,0 +1,17 @@ +import { getMeter } from "@internal/tracing"; + +const meter = getMeter("mollifier"); + +export const mollifierDecisionsCounter = meter.createCounter("mollifier.decisions", { + description: "Count of mollifier gate decisions by outcome", +}); + +export type DecisionOutcome = "pass_through" | "shadow_log" | "mollify"; +export type DecisionReason = "per_env_rate"; + +export function recordDecision(outcome: DecisionOutcome, reason?: DecisionReason): void { + mollifierDecisionsCounter.add(1, { + outcome, + ...(reason ? { reason } : {}), + }); +} diff --git a/apps/webapp/app/v3/mollifier/mollifierTripEvaluator.server.ts b/apps/webapp/app/v3/mollifier/mollifierTripEvaluator.server.ts new file mode 100644 index 00000000000..4bd9a34d412 --- /dev/null +++ b/apps/webapp/app/v3/mollifier/mollifierTripEvaluator.server.ts @@ -0,0 +1,47 @@ +import type { MollifierBuffer } from "@trigger.dev/redis-worker"; +import { logger } from "~/services/logger.server"; +import type { GateInputs, TripDecision, TripEvaluator } from "./mollifierGate.server"; + +export type TripEvaluatorOptions = { + windowMs: number; + threshold: number; + holdMs: number; +}; + +export type CreateRealTripEvaluatorDeps = { + getBuffer: () => MollifierBuffer | null; + options: () => TripEvaluatorOptions; +}; + +export function createRealTripEvaluator(deps: CreateRealTripEvaluatorDeps): TripEvaluator { + return async (inputs: GateInputs): Promise => { + const buffer = deps.getBuffer(); + if (!buffer) return { divert: false }; + + const opts = deps.options(); + + try { + const { tripped, count } = await buffer.evaluateTrip(inputs.envId, opts); + if (!tripped) return { divert: false }; + + return { + divert: true, + reason: "per_env_rate", + count, + threshold: opts.threshold, + windowMs: opts.windowMs, + holdMs: opts.holdMs, + }; + } catch (err) { + // Deliberate: no error counter here. Shadow mode means a silent miss is + // harmless — fail-open is the safe direction. The error log + Sentry + // capture is sufficient operability for Phase 1. Revisit in Phase 2 + // when buffer writes are the primary path and a missed evaluation has cost. + logger.error("mollifier trip evaluator: fail-open on error", { + envId: inputs.envId, + err: err instanceof Error ? err.message : String(err), + }); + return { divert: false }; + } + }; +} diff --git a/apps/webapp/app/v3/mollifier/readFallback.server.ts b/apps/webapp/app/v3/mollifier/readFallback.server.ts new file mode 100644 index 00000000000..34a8b48f970 --- /dev/null +++ b/apps/webapp/app/v3/mollifier/readFallback.server.ts @@ -0,0 +1,16 @@ +import { logger } from "~/services/logger.server"; + +export type ReadFallbackInput = { + runId: string; + environmentId: string; + organizationId: string; +}; + +export async function findRunByIdWithMollifierFallback( + input: ReadFallbackInput, +): Promise { + logger.debug("mollifier read-fallback called (phase 1 stub)", { + runId: input.runId, + }); + return null; +} diff --git a/apps/webapp/app/v3/mollifierDrainerWorker.server.ts b/apps/webapp/app/v3/mollifierDrainerWorker.server.ts new file mode 100644 index 00000000000..313e9af6719 --- /dev/null +++ b/apps/webapp/app/v3/mollifierDrainerWorker.server.ts @@ -0,0 +1,123 @@ +import { env } from "~/env.server"; +import { logger } from "~/services/logger.server"; +import { signalsEmitter } from "~/services/signals.server"; +import { + getMollifierDrainer, + MollifierConfigurationError, +} from "./mollifier/mollifierDrainer.server"; + +declare global { + // eslint-disable-next-line no-var + var __mollifierShutdownRegistered__: boolean | undefined; +} + +/** + * Bootstraps the mollifier drainer. + * + * Two-step lifecycle: + * 1. Construct the drainer via the gated singleton in + * `mollifierDrainer.server.ts`. That factory validates the + * shutdown-timeout reconciliation against `GRACEFUL_SHUTDOWN_TIMEOUT` + * and throws BEFORE returning if it's misconfigured; the returned + * drainer is configured-but-stopped. + * 2. Register SIGTERM/SIGINT shutdown handlers, then call + * `drainer.start()`. Doing this in the bootstrap (and not in the + * factory) guarantees a signal landing during boot can never find + * the polling loop running without a graceful-stop path. + * + * The drainer is intentionally NOT wired through `~/services/worker.server` + * — that file is the legacy ZodWorker / graphile-worker setup. The + * mollifier drainer is a custom polling loop over `MollifierBuffer`, not + * a graphile-worker job, so it gets its own lifecycle file alongside the + * redis-worker workers (`commonWorker`, `alertsWorker`, + * `batchTriggerWorker`). + * + * Gating order: + * - `TRIGGER_MOLLIFIER_DRAINER_ENABLED !== "1"` → early return. Unset defaults + * to `TRIGGER_MOLLIFIER_ENABLED`, so single-container self-hosters still get + * the drainer for free with one flag. In multi-replica deployments, + * set this to "0" explicitly on every replica except the dedicated + * drainer service so the polling loop doesn't race across replicas. + * - `TRIGGER_MOLLIFIER_ENABLED !== "1"` → `getMollifierDrainer()` returns null + * and the bootstrap is a no-op. `TRIGGER_MOLLIFIER_ENABLED` remains the + * master kill switch; the new flag only controls WHICH replicas + * run the drainer when the system is on. + */ +export function initMollifierDrainerWorker( + opts: { + // Test seams. Production callers pass nothing; the defaults read the + // live env and resolve the live singleton. Tests inject overrides so + // the misconfig-rethrow / transient-swallow branches can be driven + // without manipulating module-level env state. + isEnabled?: () => boolean; + getDrainer?: typeof getMollifierDrainer; + } = {}, +): void { + const isEnabled = opts.isEnabled ?? (() => env.TRIGGER_MOLLIFIER_DRAINER_ENABLED === "1"); + const getDrainer = opts.getDrainer ?? getMollifierDrainer; + + if (!isEnabled()) { + return; + } + + try { + const drainer = getDrainer(); + if (drainer && !global.__mollifierShutdownRegistered__) { + // `__mollifierShutdownRegistered__` guards against double-register + // on dev hot-reloads (this bootstrap is called from + // entry.server.tsx, which Remix dev re-evaluates on every change). + // Same guard owns both the handler registration and the start() + // call so the two never get out of sync. + // + // Registers through `signalsEmitter` (the webapp-wide singleton in + // `~/services/signals.server`) rather than `process.once` directly: + // - matches the codebase convention (runsReplicationInstance, + // llmPricingRegistry, dynamicFlushScheduler etc. all listen on + // the same emitter); + // - `.on` (not `.once`) means a second SIGTERM still reaches us if + // the orchestrator delivers more than one signal before SIGKILL; + // - if SIGTERM lands in the gap between this listener attaching + // and `drainer.start()` below, the first invocation no-ops + // (stop() returns early because the drainer isn't running yet) + // but the listener stays attached for a subsequent signal, + // rather than being consumed by `once`. + const stopDrainer = () => { + drainer + .stop({ timeoutMs: env.TRIGGER_MOLLIFIER_DRAIN_SHUTDOWN_TIMEOUT_MS }) + .catch((error) => { + logger.error("Failed to stop mollifier drainer", { error }); + }); + }; + signalsEmitter.on("SIGTERM", stopDrainer); + signalsEmitter.on("SIGINT", stopDrainer); + global.__mollifierShutdownRegistered__ = true; + drainer.start(); + } + } catch (error) { + // Deterministic misconfig (shutdown-timeout vs GRACEFUL_SHUTDOWN_TIMEOUT, + // missing buffer client) is a deploy-time mistake the operator must + // see immediately — rethrow so the process crashes, health checks + // fail, and the orchestrator rolls the deploy back. Phase 1 is + // monitoring-only and the silent-fallback was tempting, but Phase 2/3 + // make the drainer the source of truth for diverted triggers, where a + // silently-disabled drainer means data loss. Better to fail loud now + // than retrofit later. + // + // We accept both `instanceof` and `error.name === ...` so Remix dev + // hot-reload (where the consumer can hold a stale class reference) + // still recognises the marker. + if ( + error instanceof MollifierConfigurationError || + (error instanceof Error && error.name === "MollifierConfigurationError") + ) { + logger.error("Mollifier drainer misconfiguration — failing loud", { + error: error.message, + }); + throw error; + } + // Anything else (transient Redis blip, unexpected runtime error) is + // logged but kept non-fatal — the rest of the webapp shouldn't go + // down because the buffer's Redis cluster is briefly unreachable. + logger.error("Failed to initialise mollifier drainer", { error }); + } +} diff --git a/apps/webapp/test/bufferedTriggerPayload.test.ts b/apps/webapp/test/bufferedTriggerPayload.test.ts new file mode 100644 index 00000000000..6280acd4c63 --- /dev/null +++ b/apps/webapp/test/bufferedTriggerPayload.test.ts @@ -0,0 +1,96 @@ +import { describe, expect, it } from "vitest"; +import { buildBufferedTriggerPayload } from "~/v3/mollifier/bufferedTriggerPayload.server"; + +describe("buildBufferedTriggerPayload", () => { + const baseInput = { + runFriendlyId: "run_abc", + taskId: "my-task", + envId: "env_1", + envType: "DEVELOPMENT", + envSlug: "dev", + orgId: "org_1", + orgSlug: "acme", + projectId: "proj_db_id", + projectRef: "proj_xyz", + body: { payload: { hello: "world" }, options: { tags: ["t1"] } } as any, + idempotencyKey: null, + idempotencyKeyExpiresAt: null, + tags: ["t1"], + parentRunFriendlyId: null, + traceContext: { traceparent: "00-abc-def-01" }, + triggerSource: "api" as const, + triggerAction: "trigger" as const, + serviceOptions: {} as any, + createdAt: new Date("2026-05-13T09:00:00.000Z"), + }; + + it("captures all routing identifiers without losing data", () => { + const payload = buildBufferedTriggerPayload(baseInput); + + expect(payload.runFriendlyId).toBe("run_abc"); + expect(payload.envId).toBe("env_1"); + expect(payload.envType).toBe("DEVELOPMENT"); + expect(payload.envSlug).toBe("dev"); + expect(payload.orgId).toBe("org_1"); + expect(payload.orgSlug).toBe("acme"); + expect(payload.projectId).toBe("proj_db_id"); + expect(payload.projectRef).toBe("proj_xyz"); + expect(payload.taskId).toBe("my-task"); + }); + + it("serialises idempotencyKeyExpiresAt to ISO string only when key is present", () => { + const withKey = buildBufferedTriggerPayload({ + ...baseInput, + idempotencyKey: "ik_1", + idempotencyKeyExpiresAt: new Date("2026-05-13T10:00:00.000Z"), + }); + expect(withKey.idempotencyKey).toBe("ik_1"); + expect(withKey.idempotencyKeyExpiresAt).toBe("2026-05-13T10:00:00.000Z"); + + const noKey = buildBufferedTriggerPayload(baseInput); + expect(noKey.idempotencyKey).toBeNull(); + expect(noKey.idempotencyKeyExpiresAt).toBeNull(); + + // Defensive: an expiresAt without an accompanying key is an impossible + // idempotency state — drop the expiresAt rather than serialise it. + const orphanExpiry = buildBufferedTriggerPayload({ + ...baseInput, + idempotencyKey: null, + idempotencyKeyExpiresAt: new Date("2026-05-13T10:00:00.000Z"), + }); + expect(orphanExpiry.idempotencyKey).toBeNull(); + expect(orphanExpiry.idempotencyKeyExpiresAt).toBeNull(); + }); + + it("preserves customer body byte-equivalent (drainer replay must match Postgres)", () => { + const body = { + payload: { quotes: 'a"b', newline: "x\ny", unicode: "🚀", nested: { n: 1 } }, + options: { tags: ["a"], maxAttempts: 3, machine: "small-1x" }, + } as any; + const payload = buildBufferedTriggerPayload({ ...baseInput, body }); + expect(payload.body).toEqual(body); + + // JSON round-trip is the storage path; verify no information loss. + const roundtripped = JSON.parse(JSON.stringify(payload.body)); + expect(roundtripped).toEqual(body); + }); + + it("createdAt is serialised to ISO 8601", () => { + const payload = buildBufferedTriggerPayload(baseInput); + expect(payload.createdAt).toBe("2026-05-13T09:00:00.000Z"); + }); + + it("preserves traceContext (OTel continuity across buffer→drain boundary)", () => { + const traceContext = { traceparent: "00-x-y-01", tracestate: "vendor=foo" }; + const payload = buildBufferedTriggerPayload({ ...baseInput, traceContext }); + expect(payload.traceContext).toEqual(traceContext); + }); + + it("nullable parentRunFriendlyId — present and absent", () => { + expect(buildBufferedTriggerPayload(baseInput).parentRunFriendlyId).toBeNull(); + expect( + buildBufferedTriggerPayload({ ...baseInput, parentRunFriendlyId: "run_parent" }) + .parentRunFriendlyId, + ).toBe("run_parent"); + }); +}); diff --git a/apps/webapp/test/engine/triggerTask.test.ts b/apps/webapp/test/engine/triggerTask.test.ts index 798e39e0601..d07909d2907 100644 --- a/apps/webapp/test/engine/triggerTask.test.ts +++ b/apps/webapp/test/engine/triggerTask.test.ts @@ -1174,6 +1174,576 @@ describe("RunEngineTriggerTaskService", () => { await engine.quit(); } ); + + // ─── Mollifier integration ────────────────────────────────────────────────── + // + // These tests pin the call-site behaviour of the mollifier hooks inside + // RunEngineTriggerTaskService.call. They use the optional DI ports + // (`evaluateGate`, `getMollifierBuffer`) added on the service constructor — + // production wiring is unchanged (defaults to the live module-level imports). + // Each test's regression intent lives in its own setup comment. + + class CapturingMollifierBuffer { + public accepted: Array<{ runId: string; envId: string; orgId: string; payload: string }> = []; + async accept(input: { runId: string; envId: string; orgId: string; payload: string }) { + this.accepted.push(input); + return true; + } + async pop() { return null; } + async ack() {} + async requeue() {} + async fail() { return false; } + async getEntry() { return null; } + async listEnvs(): Promise { return []; } + async getEntryTtlSeconds(): Promise { return -1; } + async evaluateTrip() { return { tripped: false, count: 0 }; } + async close() {} + } + + containerTest( + "mollifier · validation throws before the gate is consulted; no buffer write", + async ({ prisma, redisOptions }) => { + const engine = new RunEngine({ + prisma, + worker: { redis: redisOptions, workers: 1, tasksPerWorker: 10, pollIntervalMs: 100 }, + queue: { redis: redisOptions }, + runLock: { redis: redisOptions }, + machines: { + defaultMachine: "small-1x", + machines: { "small-1x": { name: "small-1x" as const, cpu: 0.5, memory: 0.5, centsPerMs: 0.0001 } }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + + // Validator that fails on maxAttempts. Any validation throw must abort + // the call BEFORE the gate runs — otherwise the gate could leak a + // buffer write for an invalid request. + class FailingMaxAttemptsValidator extends MockTriggerTaskValidator { + validateMaxAttempts(): ValidationResult { + return { ok: false, error: new Error("synthetic max-attempts failure") }; + } + } + + const buffer = new CapturingMollifierBuffer(); + const evaluateGateSpy = vi.fn(async () => ({ action: "mollify" as const, decision: { + divert: true as const, reason: "per_env_rate" as const, count: 99, threshold: 1, windowMs: 200, holdMs: 500, + } })); + + const triggerTaskService = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern: new IdempotencyKeyConcern(prisma, engine, new MockTraceEventConcern()), + validator: new FailingMaxAttemptsValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + evaluateGate: evaluateGateSpy, + getMollifierBuffer: () => buffer as never, + isMollifierGloballyEnabled: () => true, + }); + + await expect( + triggerTaskService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { payload: { test: "x" } }, + }), + ).rejects.toThrow(/synthetic max-attempts failure/); + + // Critical: the gate must NEVER be consulted when validation fails. + // If this assertion fires, validation has been re-ordered after the + // mollifier gate — a regression that would let invalid triggers land + // in the buffer. + expect(evaluateGateSpy).not.toHaveBeenCalled(); + expect(buffer.accepted).toHaveLength(0); + + await engine.quit(); + }, + ); + + containerTest( + "mollifier · mollify action triggers dual-write (buffer.accept + engine.trigger)", + async ({ prisma, redisOptions }) => { + const engine = new RunEngine({ + prisma, + worker: { redis: redisOptions, workers: 1, tasksPerWorker: 10, pollIntervalMs: 100 }, + queue: { redis: redisOptions }, + runLock: { redis: redisOptions }, + machines: { + defaultMachine: "small-1x", + machines: { "small-1x": { name: "small-1x" as const, cpu: 0.5, memory: 0.5, centsPerMs: 0.0001 } }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + + const buffer = new CapturingMollifierBuffer(); + const trippedDecision = { + divert: true as const, + reason: "per_env_rate" as const, + count: 150, + threshold: 100, + windowMs: 200, + holdMs: 500, + }; + + const triggerTaskService = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern: new IdempotencyKeyConcern(prisma, engine, new MockTraceEventConcern()), + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + evaluateGate: async () => ({ action: "mollify", decision: trippedDecision }), + getMollifierBuffer: () => buffer as never, + isMollifierGloballyEnabled: () => true, + }); + + const result = await triggerTaskService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { payload: { hello: "world" } }, + }); + + // engine.trigger ran — Postgres has the run + expect(result).toBeDefined(); + expect(result?.run.friendlyId).toBeDefined(); + const pgRun = await prisma.taskRun.findFirst({ where: { id: result!.run.id } }); + expect(pgRun).not.toBeNull(); + expect(pgRun!.friendlyId).toBe(result!.run.friendlyId); + + // buffer.accept ran — Redis has the audit copy under the same friendlyId + expect(buffer.accepted).toHaveLength(1); + expect(buffer.accepted[0]!.runId).toBe(result!.run.friendlyId); + expect(buffer.accepted[0]!.envId).toBe(authenticatedEnvironment.id); + expect(buffer.accepted[0]!.orgId).toBe(authenticatedEnvironment.organizationId); + + // payload is the canonical replay shape + const payload = JSON.parse(buffer.accepted[0]!.payload); + expect(payload.runFriendlyId).toBe(result!.run.friendlyId); + expect(payload.taskId).toBe(taskIdentifier); + expect(payload.envId).toBe(authenticatedEnvironment.id); + expect(payload.body).toEqual({ payload: { hello: "world" } }); + + await engine.quit(); + }, + ); + + containerTest( + "mollifier · pass_through action does NOT call buffer.accept", + async ({ prisma, redisOptions }) => { + const engine = new RunEngine({ + prisma, + worker: { redis: redisOptions, workers: 1, tasksPerWorker: 10, pollIntervalMs: 100 }, + queue: { redis: redisOptions }, + runLock: { redis: redisOptions }, + machines: { + defaultMachine: "small-1x", + machines: { "small-1x": { name: "small-1x" as const, cpu: 0.5, memory: 0.5, centsPerMs: 0.0001 } }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + + const buffer = new CapturingMollifierBuffer(); + const getBufferSpy = vi.fn(() => buffer as never); + + const triggerTaskService = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern: new IdempotencyKeyConcern(prisma, engine, new MockTraceEventConcern()), + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + evaluateGate: async () => ({ action: "pass_through" }), + getMollifierBuffer: getBufferSpy, + isMollifierGloballyEnabled: () => true, + }); + + const result = await triggerTaskService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { payload: { test: "x" } }, + }); + + expect(result).toBeDefined(); + // Postgres has the run, no buffer side-effects + expect(buffer.accepted).toHaveLength(0); + // getMollifierBuffer must not be called either — the call site short-circuits + // before touching the singleton when the gate says pass_through. + expect(getBufferSpy).not.toHaveBeenCalled(); + + await engine.quit(); + }, + ); + + containerTest( + "mollifier · engine.trigger throwing AFTER buffer.accept leaves an orphan entry (documented behaviour)", + async ({ prisma, redisOptions }) => { + // SCENARIO: dual-write where buffer.accept succeeds but engine.trigger + // throws. The throw propagates to the caller (correct: customer sees + // the same 4xx as today), and the buffer entry remains as an "orphan" + // — Phase 1's no-op drainer will pop+ack it on its next poll, so the + // orphan is bounded (~drainer pollIntervalMs) but observable in the + // audit trail (mollifier.buffered with no matching TaskRun). + // + // Why engine.trigger can throw post-buffer: + // - RunDuplicateIdempotencyKeyError (Prisma P2002 on idempotencyKey): + // a concurrent non-mollified trigger with the same idempotencyKey + // wins the DB UNIQUE constraint between IdempotencyKeyConcern's + // pre-check and engine.trigger's INSERT. + // - RunOneTimeUseTokenError (Prisma P2002 on oneTimeUseToken). + // - Transient Prisma errors (FK constraint, connection drop, etc.). + // + // Why we don't "fix" this race in Phase 1: + // The customer correctly gets the error. State eventually converges + // (drainer pops the orphan). The audit-trail explicitly surfaces + // "buffered without TaskRun" entries to operators. A real fix is + // Phase 2's responsibility once the buffer becomes the primary write + // — at that point we add the mollifier-specific idempotency index. + // + // This test pins the current ordering: buffer.accept fires synchronously + // BEFORE engine.trigger, and engine.trigger failure does NOT roll back + // the buffer write. Any future change that reverses the order or adds + // a silent rollback will fail this assertion and force a design + // decision rather than a silent behaviour change. + + const engine = new RunEngine({ + prisma, + worker: { redis: redisOptions, workers: 1, tasksPerWorker: 10, pollIntervalMs: 100 }, + queue: { redis: redisOptions }, + runLock: { redis: redisOptions }, + machines: { + defaultMachine: "small-1x", + machines: { "small-1x": { name: "small-1x" as const, cpu: 0.5, memory: 0.5, centsPerMs: 0.0001 } }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + + const buffer = new CapturingMollifierBuffer(); + + // Force engine.trigger to throw on this single call. We spy AFTER + // setupBackgroundWorker so the worker setup still uses the real + // engine.trigger (which has its own engine.trigger-ish calls for + // worker bootstrap — though in practice setupBackgroundWorker doesn't + // call trigger). + const simulatedFailure = new Error("simulated engine.trigger failure post-buffer"); + vi.spyOn(engine, "trigger").mockRejectedValueOnce(simulatedFailure); + + const triggerTaskService = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern: new IdempotencyKeyConcern(prisma, engine, new MockTraceEventConcern()), + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + evaluateGate: async () => ({ + action: "mollify", + decision: { + divert: true, + reason: "per_env_rate", + count: 150, + threshold: 100, + windowMs: 200, + holdMs: 500, + }, + }), + getMollifierBuffer: () => buffer as never, + isMollifierGloballyEnabled: () => true, + }); + + await expect( + triggerTaskService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { payload: { test: "x" } }, + }), + ).rejects.toThrow(/simulated engine.trigger failure post-buffer/); + + // The buffer write happened BEFORE engine.trigger threw. The orphan + // remains; the audit-trail will surface it (mollifier.buffered with + // no matching TaskRun row). Phase 1's no-op drainer cleans it up. + expect(buffer.accepted).toHaveLength(1); + const orphanPayload = JSON.parse(buffer.accepted[0]!.payload); + expect(orphanPayload.taskId).toBe(taskIdentifier); + + await engine.quit(); + }, + ); + + containerTest( + "mollifier · idempotency-key match short-circuits BEFORE the gate is consulted", + async ({ prisma, redisOptions }) => { + // SCENARIO: a trigger arrives with an idempotency key matching an + // already-created run. `IdempotencyKeyConcern.handleTriggerRequest` + // (line 236 of triggerTask.server.ts) detects the match BEFORE the + // mollifier gate runs and returns `{ isCached: true, run }`. The + // service early-returns. The gate is never consulted, buffer.accept + // never fires, no orphan entry is created. + // + // Regression intent: if IdempotencyKeyConcern were re-ordered to run + // AFTER evaluateGate, every idempotent retry on a flagged org would + // produce an orphan buffer entry — the audit-trail invariant ("every + // buffered runId has a matching TaskRun") would silently start failing + // for retries. This test pins the current order. + + const engine = new RunEngine({ + prisma, + worker: { redis: redisOptions, workers: 1, tasksPerWorker: 10, pollIntervalMs: 100 }, + queue: { redis: redisOptions }, + runLock: { redis: redisOptions }, + machines: { + defaultMachine: "small-1x", + machines: { "small-1x": { name: "small-1x" as const, cpu: 0.5, memory: 0.5, centsPerMs: 0.0001 } }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + + const idempotencyKeyConcern = new IdempotencyKeyConcern( + prisma, + engine, + new MockTraceEventConcern(), + ); + + // Setup: normal trigger to create the cached run (no mollifier). + const baseline = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern, + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + }); + const first = await baseline.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { payload: { test: "x" }, options: { idempotencyKey: "regression-key-5" } }, + }); + expect(first?.isCached).toBe(false); + + // Action: same idempotency key, with a mollify-stub gate that WOULD + // create an orphan if reached. The concern must short-circuit first. + const buffer = new CapturingMollifierBuffer(); + const evaluateGateSpy = vi.fn(async () => ({ + action: "mollify" as const, + decision: { + divert: true as const, + reason: "per_env_rate" as const, + count: 150, + threshold: 100, + windowMs: 200, + holdMs: 500, + }, + })); + + const mollifierService = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern, + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + evaluateGate: evaluateGateSpy, + getMollifierBuffer: () => buffer as never, + isMollifierGloballyEnabled: () => true, + }); + + const cached = await mollifierService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { payload: { test: "x" }, options: { idempotencyKey: "regression-key-5" } }, + }); + + // Customer sees the cached run, isCached=true + expect(cached).toBeDefined(); + expect(cached?.isCached).toBe(true); + expect(cached?.run.friendlyId).toBe(first?.run.friendlyId); + + // Critical: the gate must NEVER be consulted on a cached-idempotency replay. + expect(evaluateGateSpy).not.toHaveBeenCalled(); + expect(buffer.accepted).toHaveLength(0); + + await engine.quit(); + }, + ); + + containerTest( + "mollifier · debounce match produces an orphan buffer entry (documented behaviour)", + async ({ prisma, redisOptions }) => { + // SCENARIO: a trigger with a debounce key arrives while a matching + // debounced run already exists. `debounceSystem.handleDebounce` runs + // INSIDE `engine.trigger` (line ~514 of run-engine/src/engine/index.ts), + // AFTER buffer.accept has already written the new friendlyId. The + // service correctly returns the existing run id to the customer, but + // the buffer is left with an orphan entry for the new friendlyId. + // + // Why this is acceptable in Phase 1: + // - Customer-facing behaviour is unchanged from today: they receive + // the existing run id, same as the non-mollified path. + // - The orphan is bounded — the drainer's no-op-ack handler pops + // and acks it on its next poll. + // - The audit-trail surfaces it: a `mollifier.buffered` log line + // with `runId` that has no matching TaskRun in Postgres. + // + // Why Phase 2 cares: + // - When the buffer becomes the primary write path, debounce can + // no longer be allowed to run AFTER buffer.accept. The drainer's + // engine.trigger replay would observe "existing" and skip the + // persist — the customer's synthesised 200 (with the new + // friendlyId) would never get a TaskRun, and the audit-trail + // divergence becomes a real data-loss bug. + // - Phase 2 must lift `handleDebounce` into the call site BEFORE + // buffer.accept: + // 1. handleDebounce → if existing, return existing run; do NOT + // touch the buffer. + // 2. Otherwise, accept with `claimId` threaded into the + // canonical payload so the drainer's replay can + // `registerDebouncedRun` after persisting. + // + // This test pins the current ordering. A future change that "fixes" + // it by lifting handleDebounce upfront will fail the orphan + // assertion below and force an explicit choice (update the test, + // remove this scenario, or stage the lift behind a flag). + + const engine = new RunEngine({ + prisma, + worker: { redis: redisOptions, workers: 1, tasksPerWorker: 10, pollIntervalMs: 100 }, + queue: { redis: redisOptions }, + runLock: { redis: redisOptions }, + machines: { + defaultMachine: "small-1x", + machines: { "small-1x": { name: "small-1x" as const, cpu: 0.5, memory: 0.5, centsPerMs: 0.0001 } }, + baseCostInCents: 0.0005, + }, + tracer: trace.getTracer("test", "0.0.0"), + }); + + const authenticatedEnvironment = await setupAuthenticatedEnvironment(prisma, "PRODUCTION"); + const taskIdentifier = "test-task"; + await setupBackgroundWorker(engine, authenticatedEnvironment, taskIdentifier); + + const idempotencyKeyConcern = new IdempotencyKeyConcern( + prisma, + engine, + new MockTraceEventConcern(), + ); + + // Setup: trigger with debounce — creates the existing run + Redis claim. + const baseline = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern, + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + }); + const first = await baseline.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { + payload: { test: "x" }, + options: { debounce: { key: "regression-debounce-6", delay: "30s" } }, + }, + }); + expect(first?.run.friendlyId).toBeDefined(); + + // Action: same debounce key, mollify-stub gate. + const buffer = new CapturingMollifierBuffer(); + const mollifierService = new RunEngineTriggerTaskService({ + engine, + prisma, + payloadProcessor: new MockPayloadProcessor(), + queueConcern: new DefaultQueueManager(prisma, engine), + idempotencyKeyConcern, + validator: new MockTriggerTaskValidator(), + traceEventConcern: new MockTraceEventConcern(), + tracer: trace.getTracer("test", "0.0.0"), + metadataMaximumSize: 1024 * 1024, + evaluateGate: async () => ({ + action: "mollify", + decision: { + divert: true, + reason: "per_env_rate", + count: 150, + threshold: 100, + windowMs: 200, + holdMs: 500, + }, + }), + getMollifierBuffer: () => buffer as never, + isMollifierGloballyEnabled: () => true, + }); + + const debounced = await mollifierService.call({ + taskId: taskIdentifier, + environment: authenticatedEnvironment, + body: { + payload: { test: "x" }, + options: { debounce: { key: "regression-debounce-6", delay: "30s" } }, + }, + }); + + // Customer-facing behaviour: the existing run is returned (correct). + expect(debounced).toBeDefined(); + expect(debounced?.run.friendlyId).toBe(first?.run.friendlyId); + + // Orphan: buffer.accept fired with the new friendlyId we generated + // upfront, and that friendlyId has no matching TaskRun in Postgres + // because engine.trigger returned the existing run via debounce. + expect(buffer.accepted).toHaveLength(1); + expect(buffer.accepted[0]!.runId).not.toBe(first?.run.friendlyId); + const orphanFriendlyId = buffer.accepted[0]!.runId; + const orphanRow = await prisma.taskRun.findFirst({ + where: { friendlyId: orphanFriendlyId }, + }); + expect(orphanRow).toBeNull(); + + await engine.quit(); + }, + ); }); describe("DefaultQueueManager task metadata cache", () => { diff --git a/apps/webapp/test/mollifierDrainerWorker.test.ts b/apps/webapp/test/mollifierDrainerWorker.test.ts new file mode 100644 index 00000000000..e5f38229d8f --- /dev/null +++ b/apps/webapp/test/mollifierDrainerWorker.test.ts @@ -0,0 +1,72 @@ +import { describe, expect, it } from "vitest"; +import { MollifierConfigurationError } from "~/v3/mollifier/mollifierDrainer.server"; +import { initMollifierDrainerWorker } from "~/v3/mollifierDrainerWorker.server"; + +// Pins the error-classification policy inside the bootstrap's catch: +// deterministic misconfig errors propagate (so a deploy fails loud +// rather than silently disabling the drainer), and anything else is +// logged-and-swallowed (so a transient Redis blip during boot doesn't +// take the whole webapp down). The corresponding production-path +// integration is the call at `entry.server.tsx`: a sync throw out of +// `initMollifierDrainerWorker` propagates to the module top level +// BEFORE `process.on("uncaughtException", ...)` is registered, so Node +// crashes with a stack trace and exit code 1 — which is exactly what we +// want from the orchestrator's health-check perspective. +describe("initMollifierDrainerWorker error classification", () => { + it("rethrows MollifierConfigurationError so the process can crash on misconfig", () => { + const misconfig = new MollifierConfigurationError( + "TRIGGER_MOLLIFIER_DRAIN_SHUTDOWN_TIMEOUT_MS must be at least 1000ms below GRACEFUL_SHUTDOWN_TIMEOUT", + ); + + expect(() => + initMollifierDrainerWorker({ + isEnabled: () => true, + getDrainer: () => { + throw misconfig; + }, + }), + ).toThrow(MollifierConfigurationError); + }); + + it("rethrows when the error carries the marker name even if instanceof fails (dev-realm hot-reload fallback)", () => { + // Simulate the cross-realm case where the consumer's instanceof + // check sees a different class instance from the one the throw + // site used. The bootstrap's `.name === "MollifierConfigurationError"` + // fallback must catch this so dev hot-reload doesn't silently + // suppress misconfig errors. + const cousin = new Error("buffer not initialised"); + cousin.name = "MollifierConfigurationError"; + + expect(() => + initMollifierDrainerWorker({ + isEnabled: () => true, + getDrainer: () => { + throw cousin; + }, + }), + ).toThrow(cousin); + }); + + it("swallows non-configuration errors so transient init failures don't take the webapp down", () => { + expect(() => + initMollifierDrainerWorker({ + isEnabled: () => true, + getDrainer: () => { + throw new Error("transient redis blip during buffer init"); + }, + }), + ).not.toThrow(); + }); + + it("is a no-op when the drainer is disabled for this replica", () => { + let factoryCalled = false; + initMollifierDrainerWorker({ + isEnabled: () => false, + getDrainer: () => { + factoryCalled = true; + return null; + }, + }); + expect(factoryCalled).toBe(false); + }); +}); diff --git a/apps/webapp/test/mollifierGate.test.ts b/apps/webapp/test/mollifierGate.test.ts new file mode 100644 index 00000000000..b81df7f0c5b --- /dev/null +++ b/apps/webapp/test/mollifierGate.test.ts @@ -0,0 +1,434 @@ +import { describe, expect, it, vi } from "vitest"; + +// Stub `~/db.server` before importing anything that transitively imports it. +// The real module eagerly calls `prisma.$connect()` at singleton construction +// (db.server.ts), so loading it under vitest tries to reach localhost:5432 +// and surfaces as an unhandled rejection that fails the whole shard — even +// though no test in this file actually uses the default prisma client. +vi.mock("~/db.server", () => ({ + prisma: {}, + $replica: {}, +})); + +import { + evaluateGate, + makeResolveMollifierFlag, + type GateDependencies, + type GateInputs, + type TripDecision, +} from "~/v3/mollifier/mollifierGate.server"; +import type { DecisionOutcome, DecisionReason } from "~/v3/mollifier/mollifierTelemetry.server"; + +// We deliberately don't use vi.fn here. Per repo policy tests shouldn't lean on +// mock frameworks for behaviours that are pure functions of the inputs — the +// gate is pure decision logic, so a hand-rolled "deps + spy log" wired with +// plain closures gives exactly the assertions we need without the indirection. +type Spies = { + evaluatorCalls: number; + logShadowCalls: Array<{ inputs: GateInputs; decision: Extract }>; + logMollifiedCalls: Array<{ inputs: GateInputs; decision: Extract }>; + recordDecisionCalls: Array<{ outcome: DecisionOutcome; reason?: DecisionReason }>; +}; + +type Toggles = { + enabled: boolean; + shadow: boolean; + flag: boolean; + decision: TripDecision; +}; + +function makeDeps(toggles: Toggles): { deps: GateDependencies; spies: Spies } { + const spies: Spies = { + evaluatorCalls: 0, + logShadowCalls: [], + logMollifiedCalls: [], + recordDecisionCalls: [], + }; + const deps: GateDependencies = { + isMollifierEnabled: () => toggles.enabled, + isShadowModeOn: () => toggles.shadow, + resolveOrgFlag: async () => toggles.flag, + evaluator: async () => { + spies.evaluatorCalls += 1; + return toggles.decision; + }, + logShadow: (inputs, decision) => { + spies.logShadowCalls.push({ inputs, decision }); + }, + logMollified: (inputs, decision) => { + spies.logMollifiedCalls.push({ inputs, decision }); + }, + recordDecision: (outcome, reason) => { + spies.recordDecisionCalls.push({ outcome, reason }); + }, + }; + return { deps, spies }; +} + +const trippedDecision = { + divert: true as const, + reason: "per_env_rate" as const, + count: 150, + threshold: 100, + windowMs: 200, + holdMs: 500, +}; + +const passDecision: TripDecision = { divert: false }; + +const inputs: GateInputs = { + envId: "e1", + orgId: "o1", + taskId: "t1", + orgFeatureFlags: null, +}; + +// Cascade truth table. Every combination of (enabled, shadow, flag, divert) is +// enumerated. `evaluatorCalls` is the expected count, not arbitrary: the gate +// short-circuits before the evaluator if `!enabled` or (`!flag && !shadow`). +// `expectedReason` is the optional second arg to `recordDecision` — only +// divert-true paths attach a reason. +type Row = { + id: number; + enabled: boolean; + shadow: boolean; + flag: boolean; + divert: boolean; + expected: { + action: "pass_through" | "shadow_log" | "mollify"; + evaluatorCalls: 0 | 1; + logShadowCalls: 0 | 1; + logMollifiedCalls: 0 | 1; + recordedOutcome: "pass_through" | "shadow_log" | "mollify"; + expectedReason: "per_env_rate" | undefined; + }; +}; + +// 16 rows = 2^4 input combinations. Comment column shows which gate branch +// each row exercises so reviewers can map row → code at a glance. +const cascade: Row[] = [ + // enabled=F → kill-switch wins; evaluator+flag never consulted (rows 1-8) + { id: 1, enabled: false, shadow: false, flag: false, divert: false, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 2, enabled: false, shadow: false, flag: false, divert: true, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 3, enabled: false, shadow: false, flag: true, divert: false, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 4, enabled: false, shadow: false, flag: true, divert: true, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 5, enabled: false, shadow: true, flag: false, divert: false, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 6, enabled: false, shadow: true, flag: false, divert: true, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 7, enabled: false, shadow: true, flag: true, divert: false, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 8, enabled: false, shadow: true, flag: true, divert: true, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + // enabled=T, flag=F, shadow=F → both opt-ins off; evaluator never called (rows 9-10) + { id: 9, enabled: true, shadow: false, flag: false, divert: false, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 10, enabled: true, shadow: false, flag: false, divert: true, expected: { action: "pass_through", evaluatorCalls: 0, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + // enabled=T, flag=F, shadow=T → shadow path; divert routes outcome (rows 11-12) + { id: 11, enabled: true, shadow: true, flag: false, divert: false, expected: { action: "pass_through", evaluatorCalls: 1, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 12, enabled: true, shadow: true, flag: false, divert: true, expected: { action: "shadow_log", evaluatorCalls: 1, logShadowCalls: 1, logMollifiedCalls: 0, recordedOutcome: "shadow_log", expectedReason: "per_env_rate" } }, + // enabled=T, flag=T, shadow=F → mollify path (rows 13-14) + { id: 13, enabled: true, shadow: false, flag: true, divert: false, expected: { action: "pass_through", evaluatorCalls: 1, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 14, enabled: true, shadow: false, flag: true, divert: true, expected: { action: "mollify", evaluatorCalls: 1, logShadowCalls: 0, logMollifiedCalls: 1, recordedOutcome: "mollify", expectedReason: "per_env_rate" } }, + // enabled=T, flag=T, shadow=T → flag wins over shadow (rows 15-16) + { id: 15, enabled: true, shadow: true, flag: true, divert: false, expected: { action: "pass_through", evaluatorCalls: 1, logShadowCalls: 0, logMollifiedCalls: 0, recordedOutcome: "pass_through", expectedReason: undefined } }, + { id: 16, enabled: true, shadow: true, flag: true, divert: true, expected: { action: "mollify", evaluatorCalls: 1, logShadowCalls: 0, logMollifiedCalls: 1, recordedOutcome: "mollify", expectedReason: "per_env_rate" } }, +]; + +describe("evaluateGate cascade — exhaustive truth table", () => { + it.each(cascade)( + "row $id: enabled=$enabled shadow=$shadow flag=$flag divert=$divert → action=$expected.action", + async (row) => { + const { deps, spies } = makeDeps({ + enabled: row.enabled, + shadow: row.shadow, + flag: row.flag, + decision: row.divert ? trippedDecision : passDecision, + }); + + const outcome = await evaluateGate(inputs, deps); + + expect(outcome.action).toBe(row.expected.action); + expect(spies.evaluatorCalls).toBe(row.expected.evaluatorCalls); + expect(spies.logShadowCalls).toHaveLength(row.expected.logShadowCalls); + expect(spies.logMollifiedCalls).toHaveLength(row.expected.logMollifiedCalls); + + // Every evaluation records exactly one decision. + expect(spies.recordDecisionCalls).toHaveLength(1); + expect(spies.recordDecisionCalls[0].outcome).toBe(row.expected.recordedOutcome); + expect(spies.recordDecisionCalls[0].reason).toBe(row.expected.expectedReason); + }, + ); + + it("divert log carries the full decision (envId, orgId, taskId, reason, count, threshold, windowMs, holdMs)", async () => { + const { deps, spies } = makeDeps({ + enabled: true, + shadow: true, + flag: false, + decision: trippedDecision, + }); + + await evaluateGate(inputs, deps); + + expect(spies.logShadowCalls).toEqual([{ inputs, decision: trippedDecision }]); + }); + + it("mollify log carries the full decision (mirrors shadow log)", async () => { + const { deps, spies } = makeDeps({ + enabled: true, + shadow: false, + flag: true, + decision: trippedDecision, + }); + + await evaluateGate(inputs, deps); + + expect(spies.logMollifiedCalls).toEqual([{ inputs, decision: trippedDecision }]); + }); +}); + +// Hot-path guard: `triggerTask.server.ts` calls `evaluateGate` on every +// trigger when `TRIGGER_MOLLIFIER_ENABLED=1`. The per-org override path must resolve +// without a Prisma round-trip — otherwise the gate adds a DB query to the +// highest-throughput code path in the system (see apps/webapp/CLAUDE.md). +describe("resolveMollifierFlag — hot path", () => { + it("returns the per-org override when it's set", async () => { + const resolve = makeResolveMollifierFlag(); + + const enabled = await resolve({ + envId: "e", + orgId: "o", + taskId: "t", + orgFeatureFlags: { mollifierEnabled: true }, + }); + const disabled = await resolve({ + envId: "e", + orgId: "o", + taskId: "t", + orgFeatureFlags: { mollifierEnabled: false }, + }); + + expect(enabled).toBe(true); + expect(disabled).toBe(false); + }); + + it("returns false when the org has no override for the key — no DB query, ever", async () => { + // Regression intent: the resolver MUST NOT call `flag()` (which would + // query `FeatureFlag` via Prisma) on the trigger hot path. Per-org + // rollout via `Organization.featureFlags` JSON is the only enable + // path; the fleet-wide kill switch is `TRIGGER_MOLLIFIER_ENABLED`. + const resolve = makeResolveMollifierFlag(); + + const fromNull = await resolve({ + envId: "e", + orgId: "o", + taskId: "t", + orgFeatureFlags: null, + }); + const fromUnrelatedKeys = await resolve({ + envId: "e", + orgId: "o", + taskId: "t", + orgFeatureFlags: { hasAiAccess: true }, + }); + + expect(fromNull).toBe(false); + expect(fromUnrelatedKeys).toBe(false); + }); +}); + +describe("evaluateGate — fail open on evaluator error", () => { + it("treats a throwing evaluator as no-divert (pass_through), and never blocks the trigger", async () => { + const spies: Spies = { + evaluatorCalls: 0, + logShadowCalls: [], + logMollifiedCalls: [], + recordDecisionCalls: [], + }; + const deps: Partial = { + isMollifierEnabled: () => true, + isShadowModeOn: () => false, + resolveOrgFlag: async () => true, + evaluator: async () => { + spies.evaluatorCalls += 1; + throw new Error("simulated evaluator failure"); + }, + logShadow: (inputs, decision) => { + spies.logShadowCalls.push({ inputs, decision }); + }, + logMollified: (inputs, decision) => { + spies.logMollifiedCalls.push({ inputs, decision }); + }, + recordDecision: (outcome, reason) => { + spies.recordDecisionCalls.push({ outcome, reason }); + }, + }; + + const outcome = await evaluateGate(inputs, deps); + + expect(outcome.action).toBe("pass_through"); + expect(spies.evaluatorCalls).toBe(1); + expect(spies.logMollifiedCalls).toHaveLength(0); + expect(spies.logShadowCalls).toHaveLength(0); + expect(spies.recordDecisionCalls).toEqual([{ outcome: "pass_through", reason: undefined }]); + }); +}); + +describe("evaluateGate — fail open on resolveOrgFlag error", () => { + it("treats org flag as false when resolveOrgFlag throws, and does not block triggers", async () => { + const spies: Spies = { + evaluatorCalls: 0, + logShadowCalls: [], + logMollifiedCalls: [], + recordDecisionCalls: [], + }; + const deps: Partial = { + isMollifierEnabled: () => true, + isShadowModeOn: () => false, + resolveOrgFlag: async () => { + throw new Error("simulated prisma timeout"); + }, + evaluator: async () => { + spies.evaluatorCalls += 1; + return trippedDecision; + }, + logShadow: (inputs, decision) => { + spies.logShadowCalls.push({ inputs, decision }); + }, + logMollified: (inputs, decision) => { + spies.logMollifiedCalls.push({ inputs, decision }); + }, + recordDecision: (outcome, reason) => { + spies.recordDecisionCalls.push({ outcome, reason }); + }, + }; + + const outcome = await evaluateGate(inputs, deps); + + expect(outcome.action).toBe("pass_through"); + expect(spies.evaluatorCalls).toBe(0); + expect(spies.recordDecisionCalls).toEqual([{ outcome: "pass_through", reason: undefined }]); + }); +}); + +describe("evaluateGate — per-org isolation via Organization.featureFlags", () => { + function makeIsolationDeps( + resolveOrgFlag: GateDependencies["resolveOrgFlag"], + ): { deps: Partial; spies: Spies } { + const spies: Spies = { + evaluatorCalls: 0, + logShadowCalls: [], + logMollifiedCalls: [], + recordDecisionCalls: [], + }; + // Override lifecycle bits and inject the production resolveOrgFlag. + // Evaluator returns a fixed tripped decision so the outcome is purely a + // function of the flag resolution (which is what we're isolating on). + const deps: Partial = { + isMollifierEnabled: () => true, + isShadowModeOn: () => false, + resolveOrgFlag, + evaluator: async () => { + spies.evaluatorCalls += 1; + return trippedDecision; + }, + logShadow: (inputs, decision) => { + spies.logShadowCalls.push({ inputs, decision }); + }, + logMollified: (inputs, decision) => { + spies.logMollifiedCalls.push({ inputs, decision }); + }, + recordDecision: (outcome, reason) => { + spies.recordDecisionCalls.push({ outcome, reason }); + }, + }; + return { deps, spies }; + } + + // The production resolver — purely in-memory, no Prisma. Mirrors + // `defaultGateDependencies.resolveOrgFlag` exactly. + const resolve = makeResolveMollifierFlag(); + + it("opts in only the org whose featureFlags has mollifierEnabled=true", async () => { + const orgA = { ...inputs, orgId: "org_a", orgFeatureFlags: { mollifierEnabled: true } }; + const orgB = { ...inputs, orgId: "org_b", orgFeatureFlags: { mollifierEnabled: false } }; + const orgC = { ...inputs, orgId: "org_c", orgFeatureFlags: null }; + + const a = makeIsolationDeps(resolve); + const b = makeIsolationDeps(resolve); + const c = makeIsolationDeps(resolve); + + const [outcomeA, outcomeB, outcomeC] = await Promise.all([ + evaluateGate(orgA, a.deps), + evaluateGate(orgB, b.deps), + evaluateGate(orgC, c.deps), + ]); + + // Only org A's flag is on → only org A mollifies. Orgs B and C never + // reach the evaluator because both flag and shadow-mode are off. + expect(outcomeA.action).toBe("mollify"); + expect(outcomeB.action).toBe("pass_through"); + expect(outcomeC.action).toBe("pass_through"); + + expect(a.spies.evaluatorCalls).toBe(1); + expect(b.spies.evaluatorCalls).toBe(0); + expect(c.spies.evaluatorCalls).toBe(0); + + expect(a.spies.logMollifiedCalls).toHaveLength(1); + expect(b.spies.logMollifiedCalls).toHaveLength(0); + expect(c.spies.logMollifiedCalls).toHaveLength(0); + }); + + it("another org's beta flags must not opt them into mollifier", async () => { + // Org A has mollifier on (plus an unrelated beta). + const orgA = { + ...inputs, + orgId: "org_a", + orgFeatureFlags: { mollifierEnabled: true, hasComputeAccess: true }, + }; + // Org B has *other* betas on but mollifier remains off — keys that gate + // compute/AI/query must not bleed across into the mollifier decision. + const orgB = { + ...inputs, + orgId: "org_b", + orgFeatureFlags: { hasComputeAccess: true, hasAiAccess: true }, + }; + + const a = makeIsolationDeps(resolve); + const b = makeIsolationDeps(resolve); + + const outcomeA = await evaluateGate(orgA, a.deps); + const outcomeB = await evaluateGate(orgB, b.deps); + + expect(outcomeA.action).toBe("mollify"); + expect(outcomeB.action).toBe("pass_through"); + }); + + it("orgs without an explicit override stay off — no global FeatureFlag fallback", async () => { + // Regression intent: the resolver MUST NOT consult the global + // `FeatureFlag` table on the hot path. An org with `orgFeatureFlags` + // unset (the default for almost every org during rollout) gets + // pass_through, period. The fleet-wide kill switch lives in + // `TRIGGER_MOLLIFIER_ENABLED`, not the FeatureFlag table. + const orgInherits = { ...inputs, orgId: "org_inherits", orgFeatureFlags: null }; + const orgEmpty = { ...inputs, orgId: "org_empty", orgFeatureFlags: {} }; + const orgUnrelated = { + ...inputs, + orgId: "org_unrelated", + orgFeatureFlags: { hasAiAccess: true }, + }; + + const inheritsDeps = makeIsolationDeps(resolve); + const emptyDeps = makeIsolationDeps(resolve); + const unrelatedDeps = makeIsolationDeps(resolve); + + const [outInherits, outEmpty, outUnrelated] = await Promise.all([ + evaluateGate(orgInherits, inheritsDeps.deps), + evaluateGate(orgEmpty, emptyDeps.deps), + evaluateGate(orgUnrelated, unrelatedDeps.deps), + ]); + + expect(outInherits.action).toBe("pass_through"); + expect(outEmpty.action).toBe("pass_through"); + expect(outUnrelated.action).toBe("pass_through"); + // None of these reached the evaluator (flag off, shadow off). + expect(inheritsDeps.spies.evaluatorCalls).toBe(0); + expect(emptyDeps.spies.evaluatorCalls).toBe(0); + expect(unrelatedDeps.spies.evaluatorCalls).toBe(0); + }); +}); diff --git a/apps/webapp/test/mollifierTripEvaluator.test.ts b/apps/webapp/test/mollifierTripEvaluator.test.ts new file mode 100644 index 00000000000..b9a9bf8c94a --- /dev/null +++ b/apps/webapp/test/mollifierTripEvaluator.test.ts @@ -0,0 +1,90 @@ +import { redisTest } from "@internal/testcontainers"; +import { MollifierBuffer } from "@trigger.dev/redis-worker"; +import { describe, expect, vi } from "vitest"; +import { createRealTripEvaluator } from "~/v3/mollifier/mollifierTripEvaluator.server"; + +vi.setConfig({ testTimeout: 30_000 }); + +// Use a real MollifierBuffer backed by a Redis testcontainer — repo policy +// is no mocks for Redis. Per-test envIds keep keys disjoint without explicit +// cleanup. We close() the buffer in a finally to release the client. +const inputs = { envId: "env_a", orgId: "org_1", taskId: "t1" } as const; + +describe("createRealTripEvaluator", () => { + redisTest( + "returns divert=false when the sliding window stays under threshold", + async ({ redisOptions }) => { + const buffer = new MollifierBuffer({ redisOptions, entryTtlSeconds: 600 }); + try { + const evaluator = createRealTripEvaluator({ + getBuffer: () => buffer, + options: () => ({ windowMs: 1000, threshold: 100, holdMs: 500 }), + }); + + const decision = await evaluator({ ...inputs, envId: "env_under" }); + expect(decision).toEqual({ divert: false }); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "returns divert=true with reason per_env_rate once the window trips", + async ({ redisOptions }) => { + const buffer = new MollifierBuffer({ redisOptions, entryTtlSeconds: 600 }); + try { + // threshold=2 → the 3rd call within windowMs is the first that trips. + const options = { windowMs: 5000, threshold: 2, holdMs: 5000 } as const; + const evaluator = createRealTripEvaluator({ + getBuffer: () => buffer, + options: () => options, + }); + + const envId = "env_trip"; + await evaluator({ ...inputs, envId }); + await evaluator({ ...inputs, envId }); + const decision = await evaluator({ ...inputs, envId }); + + expect(decision.divert).toBe(true); + if (decision.divert) { + expect(decision.reason).toBe("per_env_rate"); + expect(decision.threshold).toBe(options.threshold); + expect(decision.windowMs).toBe(options.windowMs); + expect(decision.holdMs).toBe(options.holdMs); + expect(decision.count).toBeGreaterThan(options.threshold); + } + } finally { + await buffer.close(); + } + }, + ); + + redisTest("returns divert=false when getBuffer returns null (fail-open)", async () => { + const evaluator = createRealTripEvaluator({ + getBuffer: () => null, + options: () => ({ windowMs: 200, threshold: 100, holdMs: 500 }), + }); + + const decision = await evaluator(inputs); + expect(decision).toEqual({ divert: false }); + }); + + redisTest( + "returns divert=false when buffer throws (fail-open)", + async ({ redisOptions }) => { + const buffer = new MollifierBuffer({ redisOptions, entryTtlSeconds: 600 }); + // Closing the client up front means evaluateTrip will throw on the first + // Redis command — a real failure mode, not a stub. + await buffer.close(); + + const evaluator = createRealTripEvaluator({ + getBuffer: () => buffer, + options: () => ({ windowMs: 200, threshold: 100, holdMs: 500 }), + }); + + const decision = await evaluator(inputs); + expect(decision).toEqual({ divert: false }); + }, + ); +}); diff --git a/apps/webapp/test/setup.ts b/apps/webapp/test/setup.ts new file mode 100644 index 00000000000..607ad78f3a9 --- /dev/null +++ b/apps/webapp/test/setup.ts @@ -0,0 +1,6 @@ +// Load apps/webapp/.env into process.env so env.server's top-level +// EnvironmentSchema.parse(process.env) succeeds in vitest workers. +import { config } from "dotenv"; +import path from "node:path"; + +config({ path: path.resolve(__dirname, "../.env") }); diff --git a/apps/webapp/vitest.config.ts b/apps/webapp/vitest.config.ts index 66f697706a5..6a6b550fc64 100644 --- a/apps/webapp/vitest.config.ts +++ b/apps/webapp/vitest.config.ts @@ -10,6 +10,7 @@ export default defineConfig({ exclude: ["test/**/*.e2e.test.ts", "test/**/*.e2e.full.test.ts"], globals: true, pool: "forks", + setupFiles: ["./test/setup.ts"], // load apps/webapp/.env }, // @ts-ignore plugins: [tsconfigPaths({ projects: ["./tsconfig.json"] })], diff --git a/packages/redis-worker/src/index.ts b/packages/redis-worker/src/index.ts index 1c5147ea48d..e5e3db32f12 100644 --- a/packages/redis-worker/src/index.ts +++ b/packages/redis-worker/src/index.ts @@ -4,3 +4,4 @@ export * from "./utils.js"; // Fair Queue System export * from "./fair-queue/index.js"; +export * from "./mollifier/index.js"; diff --git a/packages/redis-worker/src/mollifier/buffer.test.ts b/packages/redis-worker/src/mollifier/buffer.test.ts new file mode 100644 index 00000000000..c8f7b95c97a --- /dev/null +++ b/packages/redis-worker/src/mollifier/buffer.test.ts @@ -0,0 +1,1027 @@ +import { describe, expect, it } from "vitest"; +import { BufferEntrySchema, serialiseSnapshot, deserialiseSnapshot } from "./schemas.js"; +import { redisTest } from "@internal/testcontainers"; +import { Logger } from "@trigger.dev/core/logger"; +import { MollifierBuffer } from "./buffer.js"; + +describe("schemas", () => { + it("serialiseSnapshot then deserialiseSnapshot is identity for plain objects", () => { + const snapshot = { taskId: "my-task", payload: { foo: 42, bar: "baz" } }; + const round = deserialiseSnapshot(serialiseSnapshot(snapshot)); + expect(round).toEqual(snapshot); + }); + + it("BufferEntrySchema parses a complete entry", () => { + const raw = { + runId: "run_abc", + envId: "env_1", + orgId: "org_1", + payload: serialiseSnapshot({ taskId: "t" }), + status: "QUEUED", + attempts: "0", + createdAt: "2026-05-11T10:00:00.000Z", + }; + const parsed = BufferEntrySchema.parse(raw); + expect(parsed.runId).toBe("run_abc"); + expect(parsed.status).toBe("QUEUED"); + expect(parsed.attempts).toBe(0); + expect(parsed.createdAt).toBeInstanceOf(Date); + }); + + it("BufferEntrySchema parses a FAILED entry with lastError", () => { + const raw = { + runId: "run_abc", + envId: "env_1", + orgId: "org_1", + payload: serialiseSnapshot({}), + status: "FAILED", + attempts: "3", + createdAt: "2026-05-11T10:00:00.000Z", + lastError: JSON.stringify({ code: "P2024", message: "connection lost" }), + }; + const parsed = BufferEntrySchema.parse(raw); + expect(parsed.lastError).toEqual({ code: "P2024", message: "connection lost" }); + }); +}); + +describe("MollifierBuffer construction", () => { + redisTest("constructs and closes cleanly", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + await buffer.close(); + }); +}); + +describe("MollifierBuffer.accept", () => { + redisTest("accept writes entry, enqueues, and tracks env", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ + runId: "run_1", + envId: "env_a", + orgId: "org_1", + payload: serialiseSnapshot({ taskId: "t" }), + }); + + const entry = await buffer.getEntry("run_1"); + expect(entry).not.toBeNull(); + expect(entry!.runId).toBe("run_1"); + expect(entry!.envId).toBe("env_a"); + expect(entry!.orgId).toBe("org_1"); + expect(entry!.status).toBe("QUEUED"); + expect(entry!.attempts).toBe(0); + expect(entry!.createdAt).toBeInstanceOf(Date); + + const envs = await buffer.listEnvsForOrg("org_1"); + expect(envs).toContain("env_a"); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierBuffer.pop", () => { + redisTest("pop returns next QUEUED entry and transitions to DRAINING", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_1", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.accept({ runId: "run_2", envId: "env_a", orgId: "org_1", payload: "{}" }); + + const popped = await buffer.pop("env_a"); + expect(popped).not.toBeNull(); + expect(popped!.runId).toBe("run_1"); + expect(popped!.status).toBe("DRAINING"); + + const stored = await buffer.getEntry("run_1"); + expect(stored!.status).toBe("DRAINING"); + } finally { + await buffer.close(); + } + }); + + redisTest("pop returns null when env queue is empty", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const popped = await buffer.pop("env_nonexistent"); + expect(popped).toBeNull(); + } finally { + await buffer.close(); + } + }); + + redisTest("atomic RPOP across two parallel pops on the same env", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "only", envId: "env_a", orgId: "org_1", payload: "{}" }); + + const [a, b] = await Promise.all([buffer.pop("env_a"), buffer.pop("env_a")]); + const winners = [a, b].filter((x) => x !== null); + expect(winners).toHaveLength(1); + expect(winners[0]!.runId).toBe("only"); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierBuffer.ack", () => { + redisTest("ack deletes the entry", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_x", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.pop("env_a"); + await buffer.ack("run_x"); + + const after = await buffer.getEntry("run_x"); + expect(after).toBeNull(); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierBuffer.pop orphan handling", () => { + redisTest( + "pop skips orphan queue references (runId in queue but entry hash expired)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + // Simulate a TTL-expired orphan: queue ref exists, entry hash does not. + await buffer["redis"].lpush("mollifier:queue:env_a", "run_orphan"); + + const popped = await buffer.pop("env_a"); + expect(popped).toBeNull(); + + // Critical: no partial hash was created for the orphan. + const raw = await buffer["redis"].hgetall("mollifier:entries:run_orphan"); + expect(Object.keys(raw)).toHaveLength(0); + + // Queue is drained — the loop pops orphans until empty. + const qLen = await buffer["redis"].llen("mollifier:queue:env_a"); + expect(qLen).toBe(0); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "pop skips orphans then returns the first valid entry behind them", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + // Layout (oldest-first, since RPOP takes from tail): orphan, valid, orphan. + // LPUSH puts items at the head, so to get RPOP order [orphan_a, valid, orphan_b] + // we LPUSH in reverse: orphan_b first, then valid, then orphan_a. + await buffer["redis"].lpush("mollifier:queue:env_a", "orphan_b"); + await buffer.accept({ runId: "valid", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer["redis"].lpush("mollifier:queue:env_a", "orphan_a"); + + const popped = await buffer.pop("env_a"); + expect(popped).not.toBeNull(); + expect(popped!.runId).toBe("valid"); + expect(popped!.status).toBe("DRAINING"); + + // The trailing orphan_b is still in the queue (single pop call). + const remaining = await buffer["redis"].llen("mollifier:queue:env_a"); + expect(remaining).toBe(1); + + // A second pop drains the trailing orphan_b. The queue is now + // empty. NOTE: the pop's no-runId branch can't read orgId from + // a popped entry (it never got one), so it doesn't prune the + // org-envs SET. env_a remains in `mollifier:org-envs:org_1` as + // a stale entry until the next accept-or-success-pop cycle + // recovers it. This is the deliberate trade-off documented in + // popAndMarkDraining's Lua. + const second = await buffer.pop("env_a"); + expect(second).toBeNull(); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer.requeue", () => { + redisTest("requeue increments attempts, restores QUEUED, re-LPUSHes", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_r", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.pop("env_a"); + await buffer.requeue("run_r"); + + const entry = await buffer.getEntry("run_r"); + expect(entry!.status).toBe("QUEUED"); + expect(entry!.attempts).toBe(1); + + const popped = await buffer.pop("env_a"); + expect(popped!.runId).toBe("run_r"); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierBuffer.fail", () => { + redisTest("fail transitions to FAILED and stores lastError", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_f", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.pop("env_a"); + const failed = await buffer.fail("run_f", { code: "VALIDATION", message: "boom" }); + expect(failed).toBe(true); + + const entry = await buffer.getEntry("run_f"); + expect(entry!.status).toBe("FAILED"); + expect(entry!.lastError).toEqual({ code: "VALIDATION", message: "boom" }); + } finally { + await buffer.close(); + } + }); + + redisTest( + "fail on missing entry is a no-op (returns false; no partial hash created)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const result = await buffer.fail("run_ghost", { code: "VALIDATION", message: "boom" }); + expect(result).toBe(false); + + // Critical: no partial entry hash was created. + const stored = await buffer.getEntry("run_ghost"); + expect(stored).toBeNull(); + const raw = await buffer["redis"].hgetall("mollifier:entries:run_ghost"); + expect(Object.keys(raw)).toHaveLength(0); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer TTL", () => { + redisTest("entry has TTL applied on accept", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_t", envId: "env_a", orgId: "org_1", payload: "{}" }); + + const ttl = await buffer.getEntryTtlSeconds("run_t"); + expect(ttl).toBeGreaterThan(0); + expect(ttl).toBeLessThanOrEqual(600); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierBuffer payload encoding", () => { + redisTest( + "pop round-trips payloads with quotes, backslashes, control chars, unicode", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + const tricky = { + quotes: 'a"b\'c', + backslash: "x\\y\\z", + newlines: "line1\nline2\r\nline3", + tab: "col1\tcol2", + unicode: "héllo 🦀 世界", + lineSep: "before
after
end", + nested: { arr: ["a", "b", 1, true, null], n: 3.14 }, + }; + const payload = serialiseSnapshot(tricky); + + try { + await buffer.accept({ runId: "tricky", envId: "env_a", orgId: "org_1", payload }); + + const popped = await buffer.pop("env_a"); + expect(popped).not.toBeNull(); + expect(popped!.payload).toBe(payload); + + const decoded = JSON.parse(popped!.payload); + expect(decoded).toEqual(tricky); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer.requeue on missing entry", () => { + redisTest( + "requeue on a non-existent runId is a no-op (Lua returns 0; no queue push)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.requeue("run_does_not_exist"); + + // Critical: no queue keys were created from this no-op requeue. + const queueKeys = await buffer["redis"].keys("mollifier:queue:*"); + expect(queueKeys).toHaveLength(0); + const envs = await buffer.listEnvsForOrg("org_1"); + expect(envs).toHaveLength(0); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer.requeue ordering", () => { + redisTest( + "requeued entry is popped AFTER other queued entries on the same env (FIFO retry)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "a", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.accept({ runId: "b", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.accept({ runId: "c", envId: "env_a", orgId: "org_1", payload: "{}" }); + + const first = await buffer.pop("env_a"); + expect(first!.runId).toBe("a"); + + await buffer.requeue("a"); + + const next = await buffer.pop("env_a"); + expect(next!.runId).toBe("b"); + const after = await buffer.pop("env_a"); + expect(after!.runId).toBe("c"); + const last = await buffer.pop("env_a"); + expect(last!.runId).toBe("a"); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer.evaluateTrip", () => { + const tripOptions = { + windowMs: 200, + threshold: 5, + holdMs: 100, + }; + + redisTest("under threshold: not tripped, count increments", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const r1 = await buffer.evaluateTrip("env_a", tripOptions); + expect(r1).toEqual({ tripped: false, count: 1 }); + + const r2 = await buffer.evaluateTrip("env_a", tripOptions); + expect(r2).toEqual({ tripped: false, count: 2 }); + } finally { + await buffer.close(); + } + }); + + redisTest("crossing threshold sets the tripped marker", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + for (let i = 0; i < 5; i++) { + const r = await buffer.evaluateTrip("env_a", tripOptions); + expect(r.tripped).toBe(false); + } + + const after = await buffer.evaluateTrip("env_a", tripOptions); + expect(after).toEqual({ tripped: true, count: 6 }); + + const sticky = await buffer.evaluateTrip("env_a", tripOptions); + expect(sticky.tripped).toBe(true); + } finally { + await buffer.close(); + } + }); + + redisTest("hold-down marker expires after holdMs and env resets", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const fastWindow = { windowMs: 100, threshold: 2, holdMs: 100 }; + await buffer.evaluateTrip("env_a", fastWindow); + await buffer.evaluateTrip("env_a", fastWindow); + const tripped = await buffer.evaluateTrip("env_a", fastWindow); + expect(tripped.tripped).toBe(true); + + // Wait past windowMs AND holdMs so both rate counter and tripped marker expire + await new Promise((r) => setTimeout(r, 220)); + + const recovered = await buffer.evaluateTrip("env_a", fastWindow); + expect(recovered).toEqual({ tripped: false, count: 1 }); + } finally { + await buffer.close(); + } + }); + + redisTest("env isolation: tripping env_a does not affect env_b", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + for (let i = 0; i < 6; i++) { + await buffer.evaluateTrip("env_a", tripOptions); + } + const aTripped = await buffer.evaluateTrip("env_a", tripOptions); + expect(aTripped.tripped).toBe(true); + + const b = await buffer.evaluateTrip("env_b", tripOptions); + expect(b).toEqual({ tripped: false, count: 1 }); + } finally { + await buffer.close(); + } + }); + + redisTest("window expires and counter resets when no traffic", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const fastWindow = { windowMs: 100, threshold: 100, holdMs: 100 }; + await buffer.evaluateTrip("env_x", fastWindow); + await buffer.evaluateTrip("env_x", fastWindow); + // both incremented within a fresh window — count should be 2 + + await new Promise((r) => setTimeout(r, 150)); + const fresh = await buffer.evaluateTrip("env_x", fastWindow); + expect(fresh.count).toBe(1); + } finally { + await buffer.close(); + } + }); + + redisTest( + "tripped marker outlives the rate counter window", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const opts = { windowMs: 50, threshold: 2, holdMs: 1000 }; + await buffer.evaluateTrip("env_a", opts); + await buffer.evaluateTrip("env_a", opts); + const tripped = await buffer.evaluateTrip("env_a", opts); + expect(tripped.tripped).toBe(true); + + // Wait past windowMs (rate counter expires) but well inside holdMs (marker persists). + await new Promise((r) => setTimeout(r, 120)); + + const after = await buffer.evaluateTrip("env_a", opts); + expect(after.tripped).toBe(true); + expect(after.count).toBeLessThanOrEqual(2); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "INCR is atomic under 100 concurrent calls (no lost increments)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + // Wide window so all 100 calls land in the same window. High threshold + // so trip semantics don't interfere with the count assertion. + const opts = { windowMs: 5000, threshold: 1_000_000, holdMs: 100 }; + const results = await Promise.all( + Array.from({ length: 100 }, () => buffer.evaluateTrip("env_atomic", opts)), + ); + + // Every return value is unique (no two callers saw the same INCR result). + const counts = results.map((r) => r.count).sort((a, b) => a - b); + expect(counts).toEqual(Array.from({ length: 100 }, (_, i) => i + 1)); + + // No call tripped (we set threshold absurdly high). + expect(results.every((r) => !r.tripped)).toBe(true); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer entry lifecycle invariants", () => { + redisTest( + "entry TTL is preserved across pop (DRAINING entries don't lose their TTL)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_ttl", envId: "env_a", orgId: "org_1", payload: "{}" }); + const beforeTtl = await buffer.getEntryTtlSeconds("run_ttl"); + expect(beforeTtl).toBeGreaterThan(0); + + await buffer.pop("env_a"); + const afterTtl = await buffer.getEntryTtlSeconds("run_ttl"); + + // TTL must still be present (>0). Redis returns -1 if the key has no + // TTL — that's the leak shape we're guarding against. + expect(afterTtl).toBeGreaterThan(0); + expect(afterTtl).toBeLessThanOrEqual(beforeTtl); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "envs set membership tracks queue+DRAINING presence across the full lifecycle", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + // Empty start + expect(await buffer.listEnvsForOrg("org_1")).not.toContain("env_lc"); + + // accept → SADD + await buffer.accept({ runId: "r1", envId: "env_lc", orgId: "org_1", payload: "{}" }); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_lc"); + + // second accept (different runId) → still SADD (idempotent) + await buffer.accept({ runId: "r2", envId: "env_lc", orgId: "org_1", payload: "{}" }); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_lc"); + + // pop r1 → queue still has r2 → env stays + await buffer.pop("env_lc"); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_lc"); + + // ack r1 → no queue change, env still tracked (r2 still queued) + await buffer.ack("r1"); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_lc"); + + // pop r2 → queue empties → SREM + await buffer.pop("env_lc"); + expect(await buffer.listEnvsForOrg("org_1")).not.toContain("env_lc"); + + // requeue r2 → SADD back + await buffer.requeue("r2"); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_lc"); + + // fail r2 → entry FAILED but queue empty → next pop should SREM + await buffer.pop("env_lc"); + await buffer.fail("r2", { code: "X", message: "boom" }); + const afterFailEnvs = await buffer.listEnvsForOrg("org_1"); + // Queue is empty, env was SREM'd by the pop above. + expect(afterFailEnvs).not.toContain("env_lc"); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer.accept idempotency", () => { + redisTest( + "duplicate runId is refused; queue not double-LPUSHed; existing entry not overwritten", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const first = await buffer.accept({ + runId: "run_dup", + envId: "env_a", + orgId: "org_1", + payload: serialiseSnapshot({ first: true }), + }); + const second = await buffer.accept({ + runId: "run_dup", + envId: "env_a", + orgId: "org_1", + payload: serialiseSnapshot({ first: false }), + }); + + expect(first).toBe(true); + expect(second).toBe(false); + + // First payload preserved; second was a no-op. + const stored = await buffer.getEntry("run_dup"); + expect(stored).not.toBeNull(); + const decoded = JSON.parse(stored!.payload); + expect(decoded).toEqual({ first: true }); + + // Exactly one queue entry, not two. + const popped1 = await buffer.pop("env_a"); + expect(popped1).not.toBeNull(); + expect(popped1!.runId).toBe("run_dup"); + const popped2 = await buffer.pop("env_a"); + expect(popped2).toBeNull(); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "accept refused while existing entry is DRAINING", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_dr", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.pop("env_a"); // now DRAINING + const stored = await buffer.getEntry("run_dr"); + expect(stored!.status).toBe("DRAINING"); + + const dup = await buffer.accept({ runId: "run_dr", envId: "env_a", orgId: "org_1", payload: "{}" }); + expect(dup).toBe(false); + + const afterDup = await buffer.getEntry("run_dr"); + expect(afterDup!.status).toBe("DRAINING"); // unchanged + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "accept refused while existing entry is FAILED", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "run_fl", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.pop("env_a"); + await buffer.fail("run_fl", { code: "VALIDATION", message: "boom" }); + const stored = await buffer.getEntry("run_fl"); + expect(stored!.status).toBe("FAILED"); + + const dup = await buffer.accept({ runId: "run_fl", envId: "env_a", orgId: "org_1", payload: "{}" }); + expect(dup).toBe(false); + + const afterDup = await buffer.getEntry("run_fl"); + expect(afterDup!.status).toBe("FAILED"); // unchanged + expect(afterDup!.lastError).toEqual({ code: "VALIDATION", message: "boom" }); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "re-accept after ack works (terminal entry can be re-accepted)", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + const first = await buffer.accept({ + runId: "run_x", + envId: "env_a", + orgId: "org_1", + payload: "{}", + }); + await buffer.pop("env_a"); + await buffer.ack("run_x"); + + // Entry is gone — re-accept should succeed. + const reAccept = await buffer.accept({ + runId: "run_x", + envId: "env_a", + orgId: "org_1", + payload: "{}", + }); + + expect(first).toBe(true); + expect(reAccept).toBe(true); + } finally { + await buffer.close(); + } + }, + ); +}); + +describe("MollifierBuffer envs set lifecycle", () => { + redisTest( + "pop SREMs envId when it drains the queue", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "r1", envId: "env_a", orgId: "org_1", payload: "{}" }); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_a"); + + await buffer.pop("env_a"); + expect(await buffer.listEnvsForOrg("org_1")).not.toContain("env_a"); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "pop keeps envId in set while items remain; SREMs only on the draining pop", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "r1", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.accept({ runId: "r2", envId: "env_a", orgId: "org_1", payload: "{}" }); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_a"); + + await buffer.pop("env_a"); + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_a"); + + await buffer.pop("env_a"); + expect(await buffer.listEnvsForOrg("org_1")).not.toContain("env_a"); + } finally { + await buffer.close(); + } + }, + ); + + redisTest( + "requeue re-SADDs the envId if pop had previously cleaned it", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + entryTtlSeconds: 600, + logger: new Logger("test", "log"), + }); + + try { + await buffer.accept({ runId: "r1", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.pop("env_a"); + // Queue drained → env_a SREM'd. + expect(await buffer.listEnvsForOrg("org_1")).not.toContain("env_a"); + + await buffer.requeue("r1"); + // requeue must put env_a back so the drainer notices the retry. + expect(await buffer.listEnvsForOrg("org_1")).toContain("env_a"); + } finally { + await buffer.close(); + } + }, + ); +}); diff --git a/packages/redis-worker/src/mollifier/buffer.ts b/packages/redis-worker/src/mollifier/buffer.ts new file mode 100644 index 00000000000..f739e3ff362 --- /dev/null +++ b/packages/redis-worker/src/mollifier/buffer.ts @@ -0,0 +1,399 @@ +import { + createRedisClient, + type Callback, + type Redis, + type RedisOptions, + type Result, +} from "@internal/redis"; +import { Logger } from "@trigger.dev/core/logger"; +import { BufferEntry, BufferEntrySchema } from "./schemas.js"; + +export type MollifierBufferOptions = { + redisOptions: RedisOptions; + entryTtlSeconds: number; + logger?: Logger; +}; + +export class MollifierBuffer { + private readonly redis: Redis; + private readonly entryTtlSeconds: number; + private readonly logger: Logger; + + constructor(options: MollifierBufferOptions) { + this.entryTtlSeconds = options.entryTtlSeconds; + this.logger = options.logger ?? new Logger("MollifierBuffer", "debug"); + + this.redis = createRedisClient( + { + ...options.redisOptions, + retryStrategy(times) { + const delay = Math.min(times * 50, 1000); + return delay; + }, + maxRetriesPerRequest: 20, + }, + { + onError: (error) => { + this.logger.error("MollifierBuffer redis client error:", { error }); + }, + }, + ); + this.#registerCommands(); + } + + // Returns true if the entry was newly written; false if a duplicate runId + // was already buffered (idempotent no-op). Callers can use the boolean to + // record a duplicate-accept metric without affecting buffer state. + async accept(input: { + runId: string; + envId: string; + orgId: string; + payload: string; + }): Promise { + const entryKey = `mollifier:entries:${input.runId}`; + const queueKey = `mollifier:queue:${input.envId}`; + const orgsKey = "mollifier:orgs"; + const createdAt = new Date().toISOString(); + const result = await this.redis.acceptMollifierEntry( + entryKey, + queueKey, + orgsKey, + input.runId, + input.envId, + input.orgId, + input.payload, + createdAt, + String(this.entryTtlSeconds), + "mollifier:org-envs:", + ); + return result === 1; + } + + async pop(envId: string): Promise { + const queueKey = `mollifier:queue:${envId}`; + const orgsKey = "mollifier:orgs"; + const entryPrefix = "mollifier:entries:"; + const encoded = (await this.redis.popAndMarkDraining( + queueKey, + orgsKey, + entryPrefix, + envId, + "mollifier:org-envs:", + )) as string | null; + if (!encoded) return null; + + let raw: unknown; + try { + raw = JSON.parse(encoded); + } catch { + this.logger.error("MollifierBuffer.pop: failed to parse script result", { envId }); + return null; + } + + const parsed = BufferEntrySchema.safeParse(raw); + if (!parsed.success) { + this.logger.error("MollifierBuffer.pop: invalid entry shape", { + envId, + errors: parsed.error.flatten(), + }); + return null; + } + return parsed.data; + } + + async getEntry(runId: string): Promise { + const raw = await this.redis.hgetall(`mollifier:entries:${runId}`); + if (!raw || Object.keys(raw).length === 0) return null; + + const parsed = BufferEntrySchema.safeParse(raw); + if (!parsed.success) { + this.logger.error("MollifierBuffer.getEntry: invalid entry shape", { + runId, + errors: parsed.error.flatten(), + }); + return null; + } + return parsed.data; + } + + // Drainer walks these two methods to schedule pops with org-level + // fairness: one env per org per tick. The Lua scripts maintain both + // sets atomically with the per-env queues, so an org/env appears here + // exactly when at least one of its envs has a queued entry. + async listOrgs(): Promise { + return this.redis.smembers("mollifier:orgs"); + } + + async listEnvsForOrg(orgId: string): Promise { + return this.redis.smembers(`mollifier:org-envs:${orgId}`); + } + + async ack(runId: string): Promise { + await this.redis.del(`mollifier:entries:${runId}`); + } + + async requeue(runId: string): Promise { + await this.redis.requeueMollifierEntry( + `mollifier:entries:${runId}`, + "mollifier:orgs", + "mollifier:queue:", + runId, + "mollifier:org-envs:", + ); + } + + // Returns true if the entry transitioned to FAILED; false if the entry no + // longer exists (TTL expired between pop and fail). Caller can use the + // boolean to skip downstream FAILED handling for ghost entries. + async fail(runId: string, error: { code: string; message: string }): Promise { + const result = await this.redis.failMollifierEntry( + `mollifier:entries:${runId}`, + JSON.stringify(error), + ); + return result === 1; + } + + async getEntryTtlSeconds(runId: string): Promise { + return this.redis.ttl(`mollifier:entries:${runId}`); + } + + async evaluateTrip( + envId: string, + options: { windowMs: number; threshold: number; holdMs: number }, + ): Promise<{ tripped: boolean; count: number }> { + const rateKey = `mollifier:rate:${envId}`; + const trippedKey = `mollifier:tripped:${envId}`; + const result = (await this.redis.mollifierEvaluateTrip( + rateKey, + trippedKey, + String(options.windowMs), + String(options.threshold), + String(options.holdMs), + )) as [number, number]; + + return { count: result[0], tripped: result[1] === 1 }; + } + + async close(): Promise { + await this.redis.quit(); + } + + #registerCommands() { + this.redis.defineCommand("acceptMollifierEntry", { + numberOfKeys: 3, + lua: ` + local entryKey = KEYS[1] + local queueKey = KEYS[2] + local orgsKey = KEYS[3] + local runId = ARGV[1] + local envId = ARGV[2] + local orgId = ARGV[3] + local payload = ARGV[4] + local createdAt = ARGV[5] + local ttlSeconds = tonumber(ARGV[6]) + local orgEnvsPrefix = ARGV[7] + + -- Idempotent: refuse if an entry for this runId already exists in any + -- state. Caller-side dedup is also enforced via API idempotency keys, + -- but the buffer must not double-enqueue if a caller retries. + if redis.call('EXISTS', entryKey) == 1 then + return 0 + end + + redis.call('HSET', entryKey, + 'runId', runId, + 'envId', envId, + 'orgId', orgId, + 'payload', payload, + 'status', 'QUEUED', + 'attempts', '0', + 'createdAt', createdAt) + redis.call('EXPIRE', entryKey, ttlSeconds) + redis.call('LPUSH', queueKey, runId) + -- Org-level membership: maintained atomically with the per-env + -- queue so the drainer can walk orgs → envs-for-org and + -- schedule one env per org per tick. SADDs are idempotent if the + -- org/env are already tracked. + redis.call('SADD', orgsKey, orgId) + redis.call('SADD', orgEnvsPrefix .. orgId, envId) + return 1 + `, + }); + + this.redis.defineCommand("requeueMollifierEntry", { + numberOfKeys: 2, + lua: ` + local entryKey = KEYS[1] + local orgsKey = KEYS[2] + local queuePrefix = ARGV[1] + local runId = ARGV[2] + local orgEnvsPrefix = ARGV[3] + + local envId = redis.call('HGET', entryKey, 'envId') + local orgId = redis.call('HGET', entryKey, 'orgId') + if not envId then + return 0 + end + + local currentAttempts = redis.call('HGET', entryKey, 'attempts') + local nextAttempts = tonumber(currentAttempts or '0') + 1 + + redis.call('HSET', entryKey, 'status', 'QUEUED', 'attempts', tostring(nextAttempts)) + redis.call('LPUSH', queuePrefix .. envId, runId) + -- Re-track the org/env: pop may have SREM'd them when the queue + -- last emptied. SADDs are idempotent if the values are still + -- present. + if orgId then + redis.call('SADD', orgsKey, orgId) + redis.call('SADD', orgEnvsPrefix .. orgId, envId) + end + return 1 + `, + }); + + this.redis.defineCommand("popAndMarkDraining", { + numberOfKeys: 2, + lua: ` + local queueKey = KEYS[1] + local orgsKey = KEYS[2] + local entryPrefix = ARGV[1] + local envId = ARGV[2] + local orgEnvsPrefix = ARGV[3] + + -- Helper: prune org-level membership when an env's queue empties. + -- Called only from the success branch where we know orgId from the + -- popped entry. The no-runId branch below can't reach this because + -- it has no entry to read orgId from — accept any stale org-envs + -- entries that result (bounded by env count, recovered next accept). + local function pruneOrgMembership(orgId) + if not orgId then return end + local orgEnvsKey = orgEnvsPrefix .. orgId + redis.call('SREM', orgEnvsKey, envId) + if redis.call('SCARD', orgEnvsKey) == 0 then + redis.call('SREM', orgsKey, orgId) + end + end + + -- Loop to skip orphan queue references — runIds whose entry hash has + -- expired (TTL hit). HSET on a missing key would CREATE a partial + -- hash without a TTL, leaking memory. The loop is bounded by queue + -- length; entire Lua script remains atomic. + while true do + local runId = redis.call('RPOP', queueKey) + if not runId then + -- Queue is empty AND we have no entry to read orgId from, so + -- skip org-level cleanup. Stale org-envs entries are bounded + -- by env count and recovered on the next accept. + return nil + end + + local entryKey = entryPrefix .. runId + if redis.call('EXISTS', entryKey) == 1 then + redis.call('HSET', entryKey, 'status', 'DRAINING') + local raw = redis.call('HGETALL', entryKey) + local result = {} + for i = 1, #raw, 2 do + result[raw[i]] = raw[i + 1] + end + -- Prune org-level membership if this pop drained the queue. + -- Atomic with the RPOP above — a concurrent accept AFTER this + -- script will SADD both back along with its LPUSH. + if redis.call('LLEN', queueKey) == 0 then + pruneOrgMembership(result['orgId']) + end + return cjson.encode(result) + end + -- Orphan queue reference: entry TTL expired while runId was queued. + -- Discard the reference and loop to the next. + end + `, + }); + + this.redis.defineCommand("failMollifierEntry", { + numberOfKeys: 1, + lua: ` + local entryKey = KEYS[1] + local errorPayload = ARGV[1] + + -- Guard: never create a partial entry. If the hash expired between + -- pop and fail, the run is gone — nothing to mark FAILED. + if redis.call('EXISTS', entryKey) == 0 then + return 0 + end + + redis.call('HSET', entryKey, 'status', 'FAILED', 'lastError', errorPayload) + return 1 + `, + }); + + this.redis.defineCommand("mollifierEvaluateTrip", { + numberOfKeys: 2, + lua: ` + local rateKey = KEYS[1] + local trippedKey = KEYS[2] + local windowMs = tonumber(ARGV[1]) + local threshold = tonumber(ARGV[2]) + local holdMs = tonumber(ARGV[3]) + + local count = redis.call('INCR', rateKey) + if count == 1 then + redis.call('PEXPIRE', rateKey, windowMs) + end + + if count > threshold then + redis.call('PSETEX', trippedKey, holdMs, '1') + end + + local tripped = redis.call('EXISTS', trippedKey) + return {count, tripped} + `, + }); + } +} + +declare module "@internal/redis" { + interface RedisCommander { + acceptMollifierEntry( + entryKey: string, + queueKey: string, + orgsKey: string, + runId: string, + envId: string, + orgId: string, + payload: string, + createdAt: string, + ttlSeconds: string, + orgEnvsPrefix: string, + callback?: Callback, + ): Result; + popAndMarkDraining( + queueKey: string, + orgsKey: string, + entryPrefix: string, + envId: string, + orgEnvsPrefix: string, + callback?: Callback, + ): Result; + requeueMollifierEntry( + entryKey: string, + orgsKey: string, + queuePrefix: string, + runId: string, + orgEnvsPrefix: string, + callback?: Callback, + ): Result; + failMollifierEntry( + entryKey: string, + errorPayload: string, + callback?: Callback, + ): Result; + mollifierEvaluateTrip( + rateKey: string, + trippedKey: string, + windowMs: string, + threshold: string, + holdMs: string, + callback?: Callback<[number, number]>, + ): Result<[number, number], Context>; + } +} diff --git a/packages/redis-worker/src/mollifier/drainer.test.ts b/packages/redis-worker/src/mollifier/drainer.test.ts new file mode 100644 index 00000000000..c8f68977f69 --- /dev/null +++ b/packages/redis-worker/src/mollifier/drainer.test.ts @@ -0,0 +1,1322 @@ +import { redisTest } from "@internal/testcontainers"; +import { describe, expect, it } from "vitest"; +import { Logger } from "@trigger.dev/core/logger"; +import { MollifierBuffer } from "./buffer.js"; +import { MollifierDrainer } from "./drainer.js"; +import { serialiseSnapshot } from "./schemas.js"; + +const noopOptions = { + entryTtlSeconds: 600, + logger: new Logger("test", "log"), +}; + +// Module-scope stub helpers used by the unit tests below (no real Redis). +type StubBuffer = Partial & { [K in keyof MollifierBuffer]?: any }; + +function makeStubBuffer(overrides: StubBuffer): MollifierBuffer { + const base: StubBuffer = { + listOrgs: async () => [], + listEnvsForOrg: async () => [], + pop: async () => null, + ack: async () => {}, + requeue: async () => {}, + fail: async () => true, + getEntry: async () => null, + close: async () => {}, + }; + return { ...base, ...overrides } as unknown as MollifierBuffer; +} + +// Convenience for tests that don't care about org grouping: treat each +// env as its own org. `listOrgs` returns the env list verbatim; +// `listEnvsForOrg(envId)` returns `[envId]`. Spread into makeStubBuffer +// alongside the test's own `pop` override. +function eachEnvAsOwnOrg(envs: string[]): Partial { + return { + listOrgs: async () => envs, + listEnvsForOrg: async (orgId: string) => (envs.includes(orgId) ? [orgId] : []), + }; +} + +describe("MollifierDrainer.runOnce", () => { + redisTest("drains one queued entry through the handler and acks", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + const handlerCalls: Array<{ runId: string; envId: string; orgId: string; payload: unknown }> = + []; + const handler = async (input: { + runId: string; + envId: string; + orgId: string; + payload: unknown; + }) => { + handlerCalls.push(input); + }; + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + try { + await buffer.accept({ + runId: "run_1", + envId: "env_a", + orgId: "org_1", + payload: serialiseSnapshot({ foo: 1 }), + }); + + const result = await drainer.runOnce(); + expect(result.drained).toBe(1); + expect(result.failed).toBe(0); + expect(handlerCalls).toHaveLength(1); + expect(handlerCalls[0]).toMatchObject({ + runId: "run_1", + envId: "env_a", + orgId: "org_1", + payload: { foo: 1 }, + }); + + const entry = await buffer.getEntry("run_1"); + expect(entry).toBeNull(); + } finally { + await buffer.close(); + } + }); + + redisTest("runOnce with no entries does nothing", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + let handlerCalls = 0; + const handler = async () => { + handlerCalls++; + }; + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + try { + const result = await drainer.runOnce(); + expect(result.drained).toBe(0); + expect(result.failed).toBe(0); + expect(handlerCalls).toBe(0); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierDrainer error handling", () => { + redisTest("retryable error requeues and increments attempts", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + let calls = 0; + const handler = async () => { + calls++; + throw new Error("transient"); + }; + + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => true, + logger: new Logger("test-drainer", "log"), + }); + + try { + await buffer.accept({ runId: "run_r", envId: "env_a", orgId: "org_1", payload: "{}" }); + + await drainer.runOnce(); + const after1 = await buffer.getEntry("run_r"); + expect(after1!.status).toBe("QUEUED"); + expect(after1!.attempts).toBe(1); + + await drainer.runOnce(); + const after2 = await buffer.getEntry("run_r"); + expect(after2!.status).toBe("QUEUED"); + expect(after2!.attempts).toBe(2); + + await drainer.runOnce(); + const after3 = await buffer.getEntry("run_r"); + expect(after3!.status).toBe("FAILED"); + expect(calls).toBe(3); + } finally { + await buffer.close(); + } + }); + + redisTest("non-retryable error transitions directly to FAILED", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + const handler = async () => { + throw new Error("validation failure"); + }; + + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + try { + await buffer.accept({ runId: "run_nr", envId: "env_a", orgId: "org_1", payload: "{}" }); + + await drainer.runOnce(); + + const entry = await buffer.getEntry("run_nr"); + expect(entry!.status).toBe("FAILED"); + expect(entry!.lastError).toEqual({ code: "Error", message: "validation failure" }); + } finally { + await buffer.close(); + } + }); + + redisTest( + "multi-org round-robin: drains one item per org per runOnce", + { timeout: 20_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + const handled: string[] = []; + const handler = async (input: { runId: string }) => { + handled.push(input.runId); + }; + + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 10, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + try { + // org_A has two envs (env_a, env_b) → drainer picks one per tick + // via the per-org env cursor. org_B has one env (env_c) → it's + // always picked when org_B is in the slice. + await buffer.accept({ runId: "a1", envId: "env_a", orgId: "org_A", payload: "{}" }); + await buffer.accept({ runId: "b1", envId: "env_b", orgId: "org_A", payload: "{}" }); + await buffer.accept({ runId: "c1", envId: "env_c", orgId: "org_B", payload: "{}" }); + + // Tick 1: 2 orgs in slice → 2 pops, one from org_A's rotating env + // pick and one from org_B's only env. + const r1 = await drainer.runOnce(); + expect(r1.drained).toBe(2); + expect(handled).toContain("c1"); + // Org_A contributed exactly one of {a1, b1}. + const orgADrainedTick1 = handled.filter((h) => h === "a1" || h === "b1"); + expect(orgADrainedTick1).toHaveLength(1); + + handled.length = 0; + // Tick 2: org_B's queue is empty (only had 1 entry, drained tick 1). + // listOrgs returns [org_A] only. Drain the remaining org_A env. + const r2 = await drainer.runOnce(); + expect(r2.drained).toBe(1); + expect(handled).toHaveLength(1); + expect(["a1", "b1"]).toContain(handled[0]); + } finally { + await buffer.close(); + } + }, + ); +}); + +// Transient Redis errors used to permanently kill the loop because +// `processOneFromEnv` didn't catch `buffer.pop()` rejections — the error +// bubbled through `Promise.all` → `runOnce` → `loop`'s outer catch and +// left `isRunning = false`. These tests use a stubbed buffer (no Redis +// container) so we can deterministically inject failures from `listEnvs` +// and `pop` without racing against a real client. +describe("MollifierDrainer resilience to transient buffer errors", () => { + it("survives a transient listOrgs failure and resumes draining", async () => { + let listCalls = 0; + const popped: string[] = []; + const buffer = makeStubBuffer({ + listOrgs: async () => { + listCalls += 1; + if (listCalls === 1) { + throw new Error("simulated redis blip"); + } + return ["env_a"]; + }, + listEnvsForOrg: async (orgId: string) => (orgId === "env_a" ? ["env_a"] : []), + pop: async () => { + const runId = `run_${popped.length + 1}`; + if (popped.length >= 2) return null; + popped.push(runId); + return { + runId, + envId: "env_a", + orgId: "org_1", + payload: "{}", + attempts: 0, + createdAt: new Date(), + } as any; + }, + }); + + const handled: string[] = []; + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + handled.push(input.runId); + }, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + pollIntervalMs: 20, + logger: new Logger("test-drainer", "log"), + }); + + drainer.start(); + const deadline = Date.now() + 3_000; + while (handled.length < 2 && Date.now() < deadline) { + await new Promise((r) => setTimeout(r, 20)); + } + await drainer.stop({ timeoutMs: 1_000 }); + + expect(handled).toEqual(["run_1", "run_2"]); + expect(listCalls).toBeGreaterThan(1); + }); + + it("a pop failure for one env doesn't poison the rest of the batch", async () => { + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(["bad", "good"]), + pop: async (envId: string) => { + if (envId === "bad") { + throw new Error("simulated pop failure on bad env"); + } + return { + runId: "run_good", + envId: "good", + orgId: "org_1", + payload: "{}", + attempts: 0, + createdAt: new Date(), + } as any; + }, + }); + + const handled: string[] = []; + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + handled.push(input.runId); + }, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + const result = await drainer.runOnce(); + expect(result.drained).toBe(1); + expect(result.failed).toBe(1); + expect(handled).toEqual(["run_good"]); + }); + + it("a requeue failure during retry recovery doesn't poison the rest of the batch", async () => { + // Regression: handler throws a retryable error → processEntry calls + // buffer.requeue() inside its catch block. If requeue() itself throws + // (Redis blip during error recovery), the rejection used to escape + // processOneFromEnv unwrapped and reject the runOnce Promise.all, + // dropping handler results from sibling envs in the same tick. + const handled: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(["bad", "good"]), + pop: async (envId: string) => + ({ + runId: envId === "bad" ? "run_bad" : "run_good", + envId, + orgId: "org_1", + payload: "{}", + attempts: 0, + createdAt: new Date(), + }) as any, + requeue: async () => { + throw new Error("simulated requeue failure"); + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + handled.push(input.runId); + if (input.runId === "run_bad") throw new Error("transient"); + }, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => true, + logger: new Logger("test-drainer", "log"), + }); + + const result = await drainer.runOnce(); + // Two envs scheduled, one handler succeeded (drained), one handler threw + // and its recovery requeue threw too — counted as failed, batch not poisoned. + expect(result.drained).toBe(1); + expect(result.failed).toBe(1); + expect(new Set(handled)).toEqual(new Set(["run_bad", "run_good"])); + }); + + it("a fail() throw during terminal recovery doesn't poison the rest of the batch", async () => { + // Regression: handler throws a non-retryable error → processEntry calls + // buffer.fail() inside its catch block. If fail() itself throws, the + // rejection used to escape unwrapped and reject runOnce's Promise.all. + const handled: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(["bad", "good"]), + pop: async (envId: string) => + ({ + runId: envId === "bad" ? "run_bad" : "run_good", + envId, + orgId: "org_1", + payload: "{}", + attempts: 0, + createdAt: new Date(), + }) as any, + fail: async () => { + throw new Error("simulated fail() failure"); + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + handled.push(input.runId); + if (input.runId === "run_bad") throw new Error("terminal"); + }, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + const result = await drainer.runOnce(); + expect(result.drained).toBe(1); + expect(result.failed).toBe(1); + expect(new Set(handled)).toEqual(new Set(["run_bad", "run_good"])); + }); +}); + +describe("MollifierDrainer per-tick org cap", () => { + // Bounding fan-out prevents one runOnce from queuing thousands of + // processOneFromEnv jobs when the org set is unexpectedly large. + // These tests use a stub buffer so we can drive the org/env counts + // deterministically without provisioning a real Redis with thousands + // of envs. + + it("processes at most maxOrgsPerTick envs per runOnce", async () => { + const allEnvs = Array.from({ length: 20 }, (_, i) => `env_${i}`); + const popped: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(allEnvs), + pop: async (envId: string) => { + popped.push(envId); + return null; // empty queue — runOnce records this as "empty" + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 5, + logger: new Logger("test-drainer", "log"), + }); + + await drainer.runOnce(); + expect(popped).toHaveLength(5); + }); + + it("covers the full env set across `envs.length` ticks when sliced", async () => { + const allEnvs = Array.from({ length: 12 }, (_, i) => `env_${i}`); + const popped: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(allEnvs), + pop: async (envId: string) => { + popped.push(envId); + return null; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 4, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 4, + logger: new Logger("test-drainer", "log"), + }); + + // Cursor advances by 1 each tick. Over envs.length ticks every env + // appears in exactly `sliceSize` of them (slices overlap — intentional, + // see the head-of-line fairness test below). + for (let i = 0; i < allEnvs.length; i++) { + await drainer.runOnce(); + } + + expect(new Set(popped)).toEqual(new Set(allEnvs)); + expect(popped).toHaveLength(allEnvs.length * 4); // envs.length × sliceSize + const perEnvCounts = popped.reduce>((acc, e) => { + acc[e] = (acc[e] ?? 0) + 1; + return acc; + }, {}); + for (const env of allEnvs) { + expect(perEnvCounts[env]).toBe(4); + } + }); + + it("preserves head-of-line fairness when sliced: every env reaches every slice position", async () => { + // Regression test for the bias that advance-by-sliceSize would + // reintroduce. With fixed disjoint slices, env_0 would always be at + // position 0 (first into pLimit) and env_(sliceSize-1) would always + // be last. Advance-by-1 spreads each env across every slot. + const allEnvs = Array.from({ length: 8 }, (_, i) => `env_${i}`); + const sliceSize = 4; + const positionsByEnv = new Map>(); + for (const env of allEnvs) positionsByEnv.set(env, new Set()); + + let currentTick: string[] = []; + const popOrderBuffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(allEnvs), + pop: async (envId: string) => { + currentTick.push(envId); + return null; + }, + }); + + const drainer = new MollifierDrainer({ + buffer: popOrderBuffer, + handler: async () => {}, + // Concurrency >= sliceSize so pLimit doesn't reorder — pop call order + // matches the slice's scheduling order (i.e. the env's slot position). + concurrency: sliceSize, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: sliceSize, + logger: new Logger("test-drainer", "log"), + }); + + for (let tick = 0; tick < allEnvs.length; tick++) { + currentTick = []; + await drainer.runOnce(); + currentTick.forEach((env, position) => { + positionsByEnv.get(env)!.add(position); + }); + } + + // Each env should have occupied every slot 0..sliceSize-1 across the + // cycle. If we'd regressed to advance-by-sliceSize, env_0 would only + // ever be at position 0 and env_3 only at position 3. + for (const env of allEnvs) { + const positions = positionsByEnv.get(env)!; + expect(positions.size).toBe(sliceSize); + for (let p = 0; p < sliceSize; p++) { + expect(positions.has(p)).toBe(true); + } + } + }); + + it("takes all envs and rotates by 1 when the set fits within the cap", async () => { + const allEnvs = ["env_a", "env_b", "env_c"]; + const popsPerTick: string[][] = []; + let tick: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(allEnvs), + pop: async (envId: string) => { + tick.push(envId); + return null; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 3, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 100, // way above n + logger: new Logger("test-drainer", "log"), + }); + + for (let i = 0; i < 3; i++) { + tick = []; + await drainer.runOnce(); + popsPerTick.push(tick); + } + + // Every tick covers every env (because cap > n), but the head-of-line + // env rotates by 1 each tick — preserves the original fairness behaviour. + for (const popped of popsPerTick) { + expect(new Set(popped)).toEqual(new Set(allEnvs)); + } + const [tick0, tick1, tick2] = popsPerTick; + expect(tick0?.[0]).not.toEqual(tick1?.[0]); + expect(tick1?.[0]).not.toEqual(tick2?.[0]); + }); + + it("a light env is not starved behind heavy envs", async () => { + // The buffer's atomic Lua removes an env from `mollifier:envs` the + // moment its queue becomes empty, so a heavy env with thousands of + // pending entries stays in listEnvs and a light env with a single + // entry only stays until that one entry pops. Combined with the + // advance-by-1 cursor, this means the light env can't be parked + // behind heavy envs indefinitely — it gets popped within at most + // `envs.length - sliceSize + 1` ticks regardless of how many + // entries the heavy envs have queued. + const heavy = Array.from({ length: 6 }, (_, i) => `env_heavy_${i}`); + const light = "env_light"; + const queues = new Map(); + for (const h of heavy) { + queues.set( + h, + Array.from({ length: 100 }, (_, i) => `${h}_run_${i}`), + ); + } + queues.set(light, [`${light}_run_0`]); + + const activeEnvs = () => + [...queues.keys()].filter((k) => (queues.get(k)?.length ?? 0) > 0); + const buffer = makeStubBuffer({ + listOrgs: async () => activeEnvs(), + listEnvsForOrg: async (orgId: string) => + activeEnvs().includes(orgId) ? [orgId] : [], + pop: async (envId: string) => { + const q = queues.get(envId); + if (!q || q.length === 0) return null; + const runId = q.shift()!; + return { + runId, + envId, + orgId: "org_1", + payload: "{}", + status: "DRAINING", + attempts: 0, + createdAt: new Date(), + } as any; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 4, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 4, // < 7 envs so we exercise slicing + logger: new Logger("test-drainer", "log"), + }); + + // 7 envs, sliceSize=4 → worst-case wait for env_light is 4 ticks + // (it appears in the slice in exactly 4 of every 7 ticks). Run 7 to + // give the upper bound a wide margin. + const ticksUntilLightDrained = await (async () => { + for (let tick = 1; tick <= 7; tick++) { + await drainer.runOnce(); + if ((queues.get(light)?.length ?? 0) === 0) return tick; + } + return Infinity; + })(); + + expect(ticksUntilLightDrained).toBeLessThanOrEqual(4); + // Sanity: heavy envs are being worked on (not starved themselves) but + // are far from drained — confirms we measured the right property. + for (const h of heavy) { + const remaining = queues.get(h)!.length; + expect(remaining).toBeGreaterThan(0); + expect(remaining).toBeLessThan(100); + } + }); + + it("a light org is not starved behind a heavy org with many envs", async () => { + // Org-level no-starvation: org_B's single entry drains within ~1 + // tick because the drainer walks orgs at the top level. Org_A + // having many envs doesn't give it extra rotation slots. + const orgAEnvs = Array.from({ length: 6 }, (_, i) => `env_orgA_${i}`); + const orgBEnv = "env_orgB_only"; + const envOrg = new Map(); + for (const e of orgAEnvs) envOrg.set(e, "org_A"); + envOrg.set(orgBEnv, "org_B"); + const queues = new Map>(); + for (const e of orgAEnvs) { + queues.set( + e, + Array.from({ length: 100 }, (_, i) => ({ + runId: `${e}_run_${i}`, + orgId: "org_A", + })), + ); + } + queues.set(orgBEnv, [{ runId: `${orgBEnv}_run_0`, orgId: "org_B" }]); + + const drainedByOrg: Record = { org_A: 0, org_B: 0 }; + const buffer = makeStubBuffer({ + listOrgs: async () => { + const orgs = new Set(); + for (const [envId, items] of queues.entries()) { + if (items.length > 0) orgs.add(envOrg.get(envId)!); + } + return [...orgs]; + }, + listEnvsForOrg: async (orgId: string) => { + const envs: string[] = []; + for (const [envId, items] of queues.entries()) { + if (items.length > 0 && envOrg.get(envId) === orgId) envs.push(envId); + } + return envs; + }, + pop: async (envId: string) => { + const q = queues.get(envId); + if (!q || q.length === 0) return null; + const entry = q.shift()!; + return { + runId: entry.runId, + envId, + orgId: entry.orgId, + payload: "{}", + status: "DRAINING", + attempts: 0, + createdAt: new Date(), + } as any; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + drainedByOrg[input.orgId] = (drainedByOrg[input.orgId] ?? 0) + 1; + }, + concurrency: 4, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 4, + logger: new Logger("test-drainer", "log"), + }); + + // Only 2 orgs in play → both are in every tick's slice. Org_B's + // single env is popped on tick 1. + const ticksUntilOrgBDrained = await (async () => { + for (let tick = 1; tick <= 7; tick++) { + await drainer.runOnce(); + if ((drainedByOrg["org_B"] ?? 0) > 0) return tick; + } + return Infinity; + })(); + + expect(ticksUntilOrgBDrained).toBe(1); + // Sanity: org_A is being drained too (not starved itself) but its many + // envs are far from empty. + expect(drainedByOrg["org_A"]).toBeGreaterThan(0); + for (const e of orgAEnvs) { + expect(queues.get(e)!.length).toBeGreaterThan(0); + } + }); + + it("a heavy org with many envs gets ~1 slot per tick, not N slots", async () => { + // Hierarchical rotation property: an org with N envs gets the SAME + // per-tick scheduling slot as an org with 1 env, instead of N slots + // (which is what per-env rotation would give). Sustained-run drainage + // rate is therefore determined by org count, not env count. + // + // Org_A: 6 envs × 100 entries (a noisy tenant). + // Org_B: 1 env × 100 entries (a quiet tenant). + // Per-env rotation would drain org_A 6× faster than org_B. The org- + // level walk via listOrgs → listEnvsForOrg drains them at ~1:1 over + // a sustained window. + const orgAEnvs = Array.from({ length: 6 }, (_, i) => `env_orgA_${i}`); + const orgBEnv = "env_orgB_only"; + const envOrg = new Map(); + for (const e of orgAEnvs) envOrg.set(e, "org_A"); + envOrg.set(orgBEnv, "org_B"); + const queues = new Map>(); + for (const e of orgAEnvs) { + queues.set( + e, + Array.from({ length: 100 }, (_, i) => ({ + runId: `${e}_run_${i}`, + orgId: "org_A", + })), + ); + } + queues.set( + orgBEnv, + Array.from({ length: 100 }, (_, i) => ({ + runId: `${orgBEnv}_run_${i}`, + orgId: "org_B", + })), + ); + + const drainedByOrg: Record = { org_A: 0, org_B: 0 }; + const buffer = makeStubBuffer({ + listOrgs: async () => { + const orgs = new Set(); + for (const [envId, items] of queues.entries()) { + if (items.length > 0) orgs.add(envOrg.get(envId)!); + } + return [...orgs]; + }, + listEnvsForOrg: async (orgId: string) => { + const envs: string[] = []; + for (const [envId, items] of queues.entries()) { + if (items.length > 0 && envOrg.get(envId) === orgId) envs.push(envId); + } + return envs; + }, + pop: async (envId: string) => { + const q = queues.get(envId); + if (!q || q.length === 0) return null; + const entry = q.shift()!; + return { + runId: entry.runId, + envId, + orgId: entry.orgId, + payload: "{}", + status: "DRAINING", + attempts: 0, + createdAt: new Date(), + } as any; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async (input) => { + drainedByOrg[input.orgId] = (drainedByOrg[input.orgId] ?? 0) + 1; + }, + concurrency: 10, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 100, // unsliced — every org gets a slot every tick + logger: new Logger("test-drainer", "log"), + }); + + for (let i = 0; i < 20; i++) { + await drainer.runOnce(); + } + + // Under per-env rotation, drainedByOrg.org_A would be ~6× larger than + // drainedByOrg.org_B. Under hierarchical, the ratio is ~1. + expect(drainedByOrg["org_A"]).toBeGreaterThan(0); + expect(drainedByOrg["org_B"]).toBeGreaterThan(0); + const ratio = drainedByOrg["org_A"]! / drainedByOrg["org_B"]!; + expect(ratio).toBeGreaterThan(0.7); + expect(ratio).toBeLessThan(1.5); + }); + + it("within an org, envs are rotated round-robin across ticks", async () => { + // An org with N envs picks one env per tick, cycling through its + // envs via the per-org env cursor. Inner cursor advances by 1 per + // visit to the org (analogous to head-of-line fairness within a + // slice, but at the env-within-org layer). + const orgEnvs = ["env_x", "env_y", "env_z"]; + const orgId = "org_solo"; + const queues = new Map(); + for (const e of orgEnvs) queues.set(e, 100); + + const poppedSequence: string[] = []; + const buffer = makeStubBuffer({ + listOrgs: async () => { + const anyEnvActive = [...queues.values()].some((n) => n > 0); + return anyEnvActive ? [orgId] : []; + }, + listEnvsForOrg: async (org: string) => + org === orgId + ? [...queues.keys()].filter((k) => (queues.get(k) ?? 0) > 0) + : [], + pop: async (envId: string) => { + const remaining = queues.get(envId) ?? 0; + if (remaining === 0) return null; + queues.set(envId, remaining - 1); + poppedSequence.push(envId); + return { + runId: `${envId}_${remaining}`, + envId, + orgId, + payload: "{}", + status: "DRAINING", + attempts: 0, + createdAt: new Date(), + } as any; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 100, + logger: new Logger("test-drainer", "log"), + }); + + // 6 ticks × 1 env per tick = 6 pops, cycling x, y, z, x, y, z. Every + // env should be picked exactly twice across the 6 ticks. + for (let i = 0; i < 6; i++) { + await drainer.runOnce(); + } + + expect(poppedSequence).toHaveLength(6); + const counts = poppedSequence.reduce>((acc, e) => { + acc[e] = (acc[e] ?? 0) + 1; + return acc; + }, {}); + for (const env of orgEnvs) { + expect(counts[env]).toBe(2); + } + }); +}); + +describe("MollifierDrainer additional coverage", () => { + + it("a malformed payload is treated as a non-retryable handler error and goes terminal", async () => { + // The deserialise call lives inside processEntry's try, so a JSON parse + // failure is caught by the same handler-error branch. With + // isRetryable=false, the entry transitions directly to FAILED — the + // handler is never invoked because the throw happens before the + // handler call. + let handlerCalled = false; + const failedEntries: Array<{ runId: string; error: { code: string; message: string } }> = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(["env_a"]), + pop: async () => + ({ + runId: "run_malformed", + envId: "env_a", + orgId: "org_1", + payload: "not valid json {", + status: "DRAINING", + attempts: 0, + createdAt: new Date(), + }) as any, + fail: async (runId: string, error: { code: string; message: string }) => { + failedEntries.push({ runId, error }); + return true; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => { + handlerCalled = true; + }, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + const result = await drainer.runOnce(); + + expect(handlerCalled).toBe(false); + expect(result.failed).toBe(1); + expect(result.drained).toBe(0); + expect(failedEntries).toHaveLength(1); + expect(failedEntries[0]?.runId).toBe("run_malformed"); + }); + + it("an ack failure after a successful handler is currently treated as a handler error (documented behaviour)", async () => { + // CAVEAT: this pins a known behaviour gap, not the ideal behaviour. + // ack() lives inside the same try as the handler call, so if the + // handler succeeds but ack throws (e.g. transient Redis blip), the + // entry is routed through the retry/terminal path even though the + // handler-side work completed. Phase 2's engine-replay handler will + // need idempotency to absorb the re-execution this implies on retry, + // OR ack should be lifted out of the try block. + let handlerCalls = 0; + const failedEntries: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(["env_a"]), + pop: async () => + ({ + runId: "run_x", + envId: "env_a", + orgId: "org_1", + payload: "{}", + status: "DRAINING", + attempts: 0, + createdAt: new Date(), + }) as any, + ack: async () => { + throw new Error("simulated ack failure"); + }, + fail: async (runId: string) => { + failedEntries.push(runId); + return true; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => { + handlerCalls += 1; + }, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + await drainer.runOnce(); + + expect(handlerCalls).toBe(1); // handler did run + expect(failedEntries).toEqual(["run_x"]); // but entry was marked failed anyway + }); + + it("start() called twice does not spawn a second loop", async () => { + let listEnvsCalls = 0; + const buffer = makeStubBuffer({ + listOrgs: async () => { + listEnvsCalls += 1; + return []; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + pollIntervalMs: 50, + logger: new Logger("test-drainer", "log"), + }); + + drainer.start(); + drainer.start(); // no-op + await new Promise((r) => setTimeout(r, 150)); + await drainer.stop({ timeoutMs: 500 }); + + // One loop's worth of polling, not two. Allow a small fudge for timing — + // a doubled loop would produce ~2x the calls in the same window. + expect(listEnvsCalls).toBeLessThan(10); + }); + + it("stop() is idempotent and safe to call when never started", async () => { + const buffer = makeStubBuffer({}); + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + // Never started. + await expect(drainer.stop()).resolves.toBeUndefined(); + + // Started then stopped twice. + drainer.start(); + await expect(drainer.stop()).resolves.toBeUndefined(); + await expect(drainer.stop()).resolves.toBeUndefined(); + }); + + it("rotation cursors reset on start() so a stop+start cycle begins fresh", async () => { + const allEnvs = ["env_a", "env_b", "env_c", "env_d", "env_e", "env_f"]; + const popLog: string[] = []; + const buffer = makeStubBuffer({ + ...eachEnvAsOwnOrg(allEnvs), + pop: async (envId: string) => { + popLog.push(envId); + return null; + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 3, + maxAttempts: 3, + isRetryable: () => false, + maxOrgsPerTick: 3, + // Long sleep so the loop ticks exactly once between start() and stop(). + pollIntervalMs: 10_000, + logger: new Logger("test-drainer", "log"), + }); + + // Advance the cursor via runOnce so it's nonzero before start(). + await drainer.runOnce(); + await drainer.runOnce(); + popLog.length = 0; + + drainer.start(); + // Wait long enough for the loop's first tick to complete. + await new Promise((r) => setTimeout(r, 100)); + await drainer.stop({ timeoutMs: 1_000 }); + + // The first slice after start() should begin at envs[0] (cursor reset) + // — the slice is [env_a, env_b, env_c]. Without the reset, it would + // start at env_c (cursor was 2). + expect(popLog.slice(0, 3)).toEqual(["env_a", "env_b", "env_c"]); + }); + + it("loop backoff grows with consecutive runOnce failures and resets on success", async () => { + // The loop catches runOnce-level errors (e.g. listEnvs blip), increments + // `consecutiveErrors`, and delays for backoffMs(consecutiveErrors) — + // capped at 5s. This test pins the growth curve by failing N times in a + // row and observing increasing inter-tick gaps, then succeeding to + // verify the counter resets. + const tickTimestamps: number[] = []; + let listEnvsCalls = 0; + const buffer = makeStubBuffer({ + listOrgs: async () => { + listEnvsCalls += 1; + tickTimestamps.push(Date.now()); + if (listEnvsCalls <= 4) { + throw new Error("simulated sustained outage"); + } + return []; // success — resets consecutiveErrors + }, + }); + + const drainer = new MollifierDrainer({ + buffer, + handler: async () => {}, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + pollIntervalMs: 100, + logger: new Logger("test-drainer", "log"), + }); + + drainer.start(); + // Allow time for 4 failures + first success + a few subsequent successes. + // Backoff schedule on errors 1..4: 200ms, 400ms, 800ms, 1.6s ≈ 3s total + // worst case. Add headroom for jitter. + await new Promise((r) => setTimeout(r, 4_000)); + await drainer.stop({ timeoutMs: 1_000 }); + + expect(listEnvsCalls).toBeGreaterThanOrEqual(5); + // Inter-tick gaps during the failure run should grow (exponential). + const gap1 = tickTimestamps[1]! - tickTimestamps[0]!; + const gap2 = tickTimestamps[2]! - tickTimestamps[1]!; + const gap3 = tickTimestamps[3]! - tickTimestamps[2]!; + expect(gap2).toBeGreaterThan(gap1); + expect(gap3).toBeGreaterThan(gap2); + + // After the first success (tick 5), counter resets, so the gap between + // tick 5 and tick 6 should drop back to pollIntervalMs-ish — much + // smaller than gap3 (which was the longest backoff). + if (tickTimestamps.length >= 6) { + const postRecoveryGap = tickTimestamps[5]! - tickTimestamps[4]!; + expect(postRecoveryGap).toBeLessThan(gap3); + } + }); +}); + +describe("MollifierDrainer.start/stop", () => { + redisTest("start polls and processes, stop halts the loop", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + const handled: string[] = []; + const handler = async (input: { runId: string }) => { + handled.push(input.runId); + }; + + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 5, + maxAttempts: 3, + isRetryable: () => false, + pollIntervalMs: 20, + logger: new Logger("test-drainer", "log"), + }); + + try { + await buffer.accept({ runId: "live_1", envId: "env_a", orgId: "org_1", payload: "{}" }); + await buffer.accept({ runId: "live_2", envId: "env_a", orgId: "org_1", payload: "{}" }); + + drainer.start(); + + const deadline = Date.now() + 5_000; + while (handled.length < 2 && Date.now() < deadline) { + await new Promise((r) => setTimeout(r, 50)); + } + + await drainer.stop(); + + expect(new Set(handled)).toEqual(new Set(["live_1", "live_2"])); + } finally { + await buffer.close(); + } + }); + + redisTest("stop returns after timeoutMs even if a handler is hung", { timeout: 20_000 }, async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + let handlerStarted = false; + const handler = async () => { + handlerStarted = true; + await new Promise(() => {}); + }; + + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency: 1, + maxAttempts: 3, + isRetryable: () => false, + pollIntervalMs: 20, + logger: new Logger("test-drainer", "log"), + }); + + try { + await buffer.accept({ runId: "hung", envId: "env_a", orgId: "org_1", payload: "{}" }); + + drainer.start(); + + const deadline = Date.now() + 2_000; + while (!handlerStarted && Date.now() < deadline) { + await new Promise((r) => setTimeout(r, 25)); + } + expect(handlerStarted).toBe(true); + + const stopStart = Date.now(); + await drainer.stop({ timeoutMs: 500 }); + const stopElapsed = Date.now() - stopStart; + + // Allow a small jitter window below `timeoutMs` — Node's setTimeout can + // fire a millisecond or two early under CI load. The behaviour we're + // pinning is "stop honors the deadline instead of waiting for the hung + // handler indefinitely", not millisecond-precise timing. + expect(stopElapsed).toBeGreaterThanOrEqual(450); + expect(stopElapsed).toBeLessThan(2_000); + } finally { + await buffer.close(); + } + }); +}); + +describe("MollifierDrainer concurrency cap", () => { + redisTest( + "runOnce never exceeds configured concurrency in flight", + { timeout: 30_000 }, + async ({ redisContainer }) => { + const buffer = new MollifierBuffer({ + redisOptions: { + host: redisContainer.getHost(), + port: redisContainer.getPort(), + password: redisContainer.getPassword(), + }, + ...noopOptions, + }); + + const concurrency = 3; + const envCount = 12; + let inflight = 0; + let peak = 0; + let handlerCalls = 0; + const handler = async () => { + handlerCalls++; + inflight++; + if (inflight > peak) peak = inflight; + // Sleep long enough that handlers definitely overlap if scheduling + // allowed it — the assertion is meaningful only if multiple handlers + // would be running simultaneously without the cap. + await new Promise((r) => setTimeout(r, 75)); + inflight--; + }; + + const drainer = new MollifierDrainer({ + buffer, + handler, + concurrency, + maxAttempts: 1, + isRetryable: () => false, + logger: new Logger("test-drainer", "log"), + }); + + try { + // One entry per (env, org) so runOnce sees `envCount` distinct + // orgs as scheduling candidates and pLimits them through + // pLimit(concurrency). Spread across orgs (not envs in one org) + // because the drainer picks one env per org per tick — a single + // org with 12 envs would only see 1 pop per tick. + for (let i = 0; i < envCount; i++) { + await buffer.accept({ + runId: `run_${i}`, + envId: `env_${i}`, + orgId: `org_${i}`, + payload: "{}", + }); + } + + const result = await drainer.runOnce(); + expect(result.drained).toBe(envCount); + expect(handlerCalls).toBe(envCount); + expect(peak).toBeGreaterThan(1); // concurrency is real, not serialised + expect(peak).toBeLessThanOrEqual(concurrency); + } finally { + await buffer.close(); + } + }, + ); +}); diff --git a/packages/redis-worker/src/mollifier/drainer.ts b/packages/redis-worker/src/mollifier/drainer.ts new file mode 100644 index 00000000000..407b389e14e --- /dev/null +++ b/packages/redis-worker/src/mollifier/drainer.ts @@ -0,0 +1,289 @@ +import { Logger } from "@trigger.dev/core/logger"; +import pLimit from "p-limit"; +import { MollifierBuffer } from "./buffer.js"; +import { BufferEntry, deserialiseSnapshot } from "./schemas.js"; + +export type MollifierDrainerHandler = (input: { + runId: string; + envId: string; + orgId: string; + payload: TPayload; + attempts: number; + createdAt: Date; +}) => Promise; + +export type MollifierDrainerOptions = { + buffer: MollifierBuffer; + handler: MollifierDrainerHandler; + concurrency: number; + maxAttempts: number; + isRetryable: (err: unknown) => boolean; + pollIntervalMs?: number; + // Cap on how many ORGS `runOnce` processes per tick. The drainer rotates + // through orgs at the top level and picks one env per org per tick, so + // the actual per-tick pop count is at most `maxOrgsPerTick`. Tune for + // "typical orgs with pending entries" rather than total system org + // count. Defaults to 500. + // + // The buffer maintains `mollifier:orgs` and `mollifier:org-envs:${orgId}` + // atomically with per-env queues, so the drainer can walk orgs → envs + // directly. An org with N envs gets the same per-tick scheduling slot + // as an org with 1 env — tenant-level drainage throughput is determined + // by org count, not env count. + maxOrgsPerTick?: number; + logger?: Logger; +}; + +export type DrainResult = { + drained: number; + failed: number; +}; + +export class MollifierDrainer { + private readonly buffer: MollifierBuffer; + private readonly handler: MollifierDrainerHandler; + private readonly maxAttempts: number; + private readonly isRetryable: (err: unknown) => boolean; + private readonly pollIntervalMs: number; + private readonly maxOrgsPerTick: number; + private readonly logger: Logger; + private readonly limit: ReturnType; + // Rotation state. `orgCursor` advances through the active-orgs list. + // Each org has its own internal cursor in `perOrgEnvCursors` for + // cycling through that org's envs. Both reset on `start()`. + private orgCursor = 0; + private perOrgEnvCursors = new Map(); + private isRunning = false; + private stopping = false; + private loopPromise: Promise | null = null; + + constructor(options: MollifierDrainerOptions) { + this.buffer = options.buffer; + this.handler = options.handler; + this.maxAttempts = options.maxAttempts; + this.isRetryable = options.isRetryable; + this.pollIntervalMs = options.pollIntervalMs ?? 100; + this.maxOrgsPerTick = options.maxOrgsPerTick ?? 500; + this.logger = options.logger ?? new Logger("MollifierDrainer", "debug"); + this.limit = pLimit(options.concurrency); + } + + async runOnce(): Promise { + const orgs = await this.buffer.listOrgs(); + if (orgs.length === 0) return { drained: 0, failed: 0 }; + + const orgSlice = this.takeOrgSlice(orgs); + + // Fan the per-org SMEMBERS out in a single pipelined round-trip. Serial + // awaits would otherwise add `orgSlice.length × RTT` of dead time before + // pops start — at the default `maxOrgsPerTick=500` and a ~1ms ElastiCache + // RTT that's a ~500ms per-tick floor. ioredis auto-pipelines concurrent + // commands into one batch, so the burst is cheap; SMEMBERS on a small set + // is O(N) per org and trivial at this scale. `Promise.all` preserves + // order, so the org→envs pairing below stays deterministic. + const envsByOrg = await Promise.all( + orgSlice.map((orgId) => this.buffer.listEnvsForOrg(orgId)), + ); + const targets: string[] = []; + for (let i = 0; i < orgSlice.length; i++) { + const orgId = orgSlice[i]!; + const envsForOrg = envsByOrg[i]!; + if (envsForOrg.length === 0) continue; + const envId = this.pickEnvForOrg(orgId, envsForOrg); + targets.push(envId); + } + + const inflight: Promise<"drained" | "failed" | "empty">[] = []; + for (const envId of targets) { + inflight.push(this.limit(() => this.processOneFromEnv(envId))); + } + + const results = await Promise.all(inflight); + return { + drained: results.filter((r) => r === "drained").length, + failed: results.filter((r) => r === "failed").length, + }; + } + + start(): void { + if (this.isRunning) return; + this.isRunning = true; + this.stopping = false; + // Reset rotation state on each (re)start. A stop+start cycle means + // operator intent to "begin clean" — between-restart cursor drift + // would otherwise carry implicit state across what should look like + // a fresh boot. + this.orgCursor = 0; + this.perOrgEnvCursors = new Map(); + this.loopPromise = this.loop(); + } + + // Signal the loop to exit (`stopping = true`) and wait for it. With no + // timeout, wait indefinitely for the in-flight `runOnce` and its handlers + // to settle — same semantic as FairQueue / BatchQueue's `stop()`. With a + // timeout, race the loop promise against a deadline so a hung handler + // can't wedge the process past its termination grace period. + async stop(options: { timeoutMs?: number } = {}): Promise { + if (!this.isRunning || !this.loopPromise) return; + this.stopping = true; + if (options.timeoutMs == null) { + await this.loopPromise; + return; + } + // Hold the timer handle so we can clearTimeout() it after the race. + // Without this, when the loop wins the race, the discarded timer is + // still ref'd and pins the Node event loop for up to `timeoutMs`, + // delaying process shutdown by exactly the slack we were trying to + // bound. try/finally clears the handle in every exit path (loop-won, + // timeout-won, or exception). + const timeoutSentinel = Symbol("mollifier.stop.timeout"); + let timeoutHandle: ReturnType | undefined; + const timeoutPromise = new Promise((resolve) => { + timeoutHandle = setTimeout(() => resolve(timeoutSentinel), options.timeoutMs); + }); + try { + const winner = await Promise.race([ + this.loopPromise.then(() => "done" as const), + timeoutPromise, + ]); + if (winner === timeoutSentinel) { + this.logger.warn( + "MollifierDrainer.stop: deadline exceeded; returning while loop iteration is in flight", + { timeoutMs: options.timeoutMs }, + ); + } + } finally { + if (timeoutHandle) clearTimeout(timeoutHandle); + } + } + + // Transient Redis errors (e.g. a connection blip in `listOrgs` / + // `listEnvsForOrg` / `pop`) must not kill the polling loop permanently. + // We log each `runOnce` failure, back off so we don't spin tight on a + // sustained outage, and resume. The loop only exits when `stop()` flips + // `stopping`. + private async loop(): Promise { + try { + let consecutiveErrors = 0; + while (!this.stopping) { + try { + const result = await this.runOnce(); + consecutiveErrors = 0; + if (result.drained === 0 && result.failed === 0) { + await this.delay(this.pollIntervalMs); + } + } catch (err) { + consecutiveErrors += 1; + this.logger.error("MollifierDrainer.runOnce failed; backing off", { + err, + consecutiveErrors, + }); + await this.delay(this.backoffMs(consecutiveErrors)); + } + } + } finally { + this.isRunning = false; + } + } + + // Exponential backoff capped at 5s. Keeps the loop responsive after a + // brief blip while preventing a tight retry loop during a long Redis + // outage. 1 → 200ms, 2 → 400ms, 3 → 800ms, 4 → 1.6s, 5 → 3.2s, 6+ → 5s. + private backoffMs(consecutiveErrors: number): number { + const base = Math.max(this.pollIntervalMs, 100); + const capped = Math.min(base * 2 ** (consecutiveErrors - 1), 5_000); + return capped; + } + + private delay(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + // Take up to `maxOrgsPerTick` orgs starting at the current cursor, with + // wrap-around. Cursor advances by 1 each tick so every org reaches + // every slot position (0..sliceSize-1) over a full cycle — no + // head-of-line bias within the slice. Orgs are sorted before slicing + // so rotation is deterministic regardless of Redis SET iteration order. + private takeOrgSlice(orgs: string[]): string[] { + const sorted = [...orgs].sort(); + const n = sorted.length; + const sliceSize = Math.min(this.maxOrgsPerTick, n); + const start = this.orgCursor % n; + this.orgCursor = (this.orgCursor + 1) % Math.max(n, 1); + const end = start + sliceSize; + if (end <= n) return sorted.slice(start, end); + return [...sorted.slice(start), ...sorted.slice(0, end - n)]; + } + + // Pick one env from the org's active-envs list, rotating per org via + // the per-org cursor. Each org's cursor advances by 1 each visit, so + // an org with N envs cycles through them across N visits. + private pickEnvForOrg(orgId: string, envsForOrg: string[]): string { + const sorted = [...envsForOrg].sort(); + const cursor = this.perOrgEnvCursors.get(orgId) ?? 0; + const idx = cursor % sorted.length; + this.perOrgEnvCursors.set(orgId, (cursor + 1) % sorted.length); + return sorted[idx]!; + } + + // A failure for one env (e.g. a Redis hiccup mid-batch in `pop`, or in + // `requeue`/`fail` during error recovery inside `processEntry`) must not + // poison the rest of the batch — `Promise.all` would otherwise reject and + // bubble all the way to `loop()`. Catch both stages here so the failed env + // is just counted as "failed" for this tick and we move on. + private async processOneFromEnv(envId: string): Promise<"drained" | "failed" | "empty"> { + let entry: BufferEntry | null; + try { + entry = await this.buffer.pop(envId); + } catch (err) { + this.logger.error("MollifierDrainer.pop failed", { envId, err }); + return "failed"; + } + if (!entry) return "empty"; + try { + return await this.processEntry(entry); + } catch (err) { + this.logger.error("MollifierDrainer.processEntry failed", { + envId, + runId: entry.runId, + err, + }); + return "failed"; + } + } + + private async processEntry(entry: BufferEntry): Promise<"drained" | "failed"> { + try { + const payload = deserialiseSnapshot(entry.payload); + await this.handler({ + runId: entry.runId, + envId: entry.envId, + orgId: entry.orgId, + payload, + attempts: entry.attempts, + createdAt: entry.createdAt, + }); + await this.buffer.ack(entry.runId); + return "drained"; + } catch (err) { + const nextAttempts = entry.attempts + 1; + if (this.isRetryable(err) && nextAttempts < this.maxAttempts) { + await this.buffer.requeue(entry.runId); + this.logger.warn("MollifierDrainer: retryable error, requeued", { + runId: entry.runId, + attempts: nextAttempts, + }); + return "failed"; + } + const code = err instanceof Error ? err.name : "Unknown"; + const message = err instanceof Error ? err.message : String(err); + await this.buffer.fail(entry.runId, { code, message }); + this.logger.error("MollifierDrainer: terminal failure", { + runId: entry.runId, + code, + message, + }); + return "failed"; + } + } +} diff --git a/packages/redis-worker/src/mollifier/index.ts b/packages/redis-worker/src/mollifier/index.ts new file mode 100644 index 00000000000..5e6fe202e3d --- /dev/null +++ b/packages/redis-worker/src/mollifier/index.ts @@ -0,0 +1,15 @@ +export { MollifierBuffer, type MollifierBufferOptions } from "./buffer.js"; +export { + MollifierDrainer, + type MollifierDrainerOptions, + type MollifierDrainerHandler, + type DrainResult, +} from "./drainer.js"; +export { + BufferEntrySchema, + BufferEntryStatus, + BufferEntryError, + serialiseSnapshot, + deserialiseSnapshot, + type BufferEntry, +} from "./schemas.js"; diff --git a/packages/redis-worker/src/mollifier/schemas.ts b/packages/redis-worker/src/mollifier/schemas.ts new file mode 100644 index 00000000000..f93b0f0a3c3 --- /dev/null +++ b/packages/redis-worker/src/mollifier/schemas.ts @@ -0,0 +1,58 @@ +import { z } from "zod"; + +export const BufferEntryStatus = z.enum(["QUEUED", "DRAINING", "FAILED"]); +export type BufferEntryStatus = z.infer; + +export const BufferEntryError = z.object({ + code: z.string(), + message: z.string(), +}); +export type BufferEntryError = z.infer; + +const stringToInt = z.string().transform((v, ctx) => { + const n = Number(v); + if (!Number.isInteger(n) || n < 0) { + ctx.addIssue({ code: z.ZodIssueCode.custom, message: "expected non-negative integer string" }); + return z.NEVER; + } + return n; +}); + +const stringToDate = z.string().transform((v, ctx) => { + const d = new Date(v); + if (Number.isNaN(d.getTime())) { + ctx.addIssue({ code: z.ZodIssueCode.custom, message: "expected ISO date string" }); + return z.NEVER; + } + return d; +}); + +const stringToError = z.string().transform((v, ctx) => { + try { + return BufferEntryError.parse(JSON.parse(v)); + } catch { + ctx.addIssue({ code: z.ZodIssueCode.custom, message: "expected JSON-encoded BufferEntryError" }); + return z.NEVER; + } +}); + +export const BufferEntrySchema = z.object({ + runId: z.string().min(1), + envId: z.string().min(1), + orgId: z.string().min(1), + payload: z.string(), + status: BufferEntryStatus, + attempts: stringToInt, + createdAt: stringToDate, + lastError: stringToError.optional(), +}); + +export type BufferEntry = z.infer; + +export function serialiseSnapshot(snapshot: unknown): string { + return JSON.stringify(snapshot); +} + +export function deserialiseSnapshot(serialised: string): T { + return JSON.parse(serialised) as T; +} From 8b98e21b4bc8e188af3e78767aea8c8672bdc50b Mon Sep 17 00:00:00 2001 From: Eric Allam Date: Mon, 18 May 2026 16:41:58 +0100 Subject: [PATCH 02/25] fix(sdk,core): cache realtime-stream credentials per slot with refresh on writer failure (#3658) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Hot-loop writers — `streams.writer` / `streams.pipe` on the run-scoped side, `chat.response.write` / `chat.stream.*` on the session side — were issuing a fresh `PUT` to mint S2 credentials for every chunk. On run streams, each PUT also pushed the streamId onto `TaskRun.realtimeStreams`, so a chat-agent turn writing N chunks produced N PUTs and N duplicate array pushes against the same row. The SDK now caches the initialize response per cache slot: `(runId, key)` for run streams, the session id for session streams. First call PUTs as before; subsequent calls reuse the cached promise. Hot-loop writers do one PUT per slot for the lifetime of the cache. S2 access tokens have a 1-day TTL. If a writer's `wait()` rejects (auth error, expired token, network blip), the cache evicts the matching slot so the next call re-PUTs and mints fresh credentials, identity-checked so a concurrent caller's fresh promise isn't accidentally cleared. ## chat.agent guardrail `streams.pipe / writer / append / read` called inside a `chat.agent` run now logs a one-time warning pointing at `chat.response.write` / `chat.stream.*` — `streams.*` is run-scoped and isn't visible on the chat session. The ai-chat docs are updated to drop the old guidance toward run-scoped streams. --- packages/core/src/v3/realtime-streams-api.ts | 5 +- packages/core/src/v3/realtimeStreams/index.ts | 5 +- .../src/v3/realtimeStreams/manager.test.ts | 147 ++++++++++++++ .../core/src/v3/realtimeStreams/manager.ts | 83 +++++++- .../realtimeStreams/sessionStreamInstance.ts | 27 ++- .../src/v3/realtimeStreams/streamInstance.ts | 30 ++- packages/trigger-sdk/src/v3/ai.ts | 4 +- packages/trigger-sdk/src/v3/sessions.test.ts | 186 ++++++++++++++++++ packages/trigger-sdk/src/v3/sessions.ts | 73 ++++++- packages/trigger-sdk/src/v3/streams.ts | 37 ++++ 10 files changed, 576 insertions(+), 21 deletions(-) create mode 100644 packages/core/src/v3/realtimeStreams/manager.test.ts create mode 100644 packages/trigger-sdk/src/v3/sessions.test.ts diff --git a/packages/core/src/v3/realtime-streams-api.ts b/packages/core/src/v3/realtime-streams-api.ts index d9cd9ecfb45..728399bea6e 100644 --- a/packages/core/src/v3/realtime-streams-api.ts +++ b/packages/core/src/v3/realtime-streams-api.ts @@ -6,7 +6,10 @@ export const realtimeStreams = RealtimeStreamsAPI.getInstance(); export * from "./realtimeStreams/types.js"; export { SessionStreamInstance } from "./realtimeStreams/sessionStreamInstance.js"; -export type { SessionStreamInstanceOptions } from "./realtimeStreams/sessionStreamInstance.js"; +export type { + SessionStreamInstanceOptions, + InitializeSessionStreamResponseLike, +} from "./realtimeStreams/sessionStreamInstance.js"; export { trimSessionStream, writeSessionControlRecord, diff --git a/packages/core/src/v3/realtimeStreams/index.ts b/packages/core/src/v3/realtimeStreams/index.ts index b1c20735808..71854888ee5 100644 --- a/packages/core/src/v3/realtimeStreams/index.ts +++ b/packages/core/src/v3/realtimeStreams/index.ts @@ -10,7 +10,10 @@ import { // `SessionOutputChannel.pipe` / `.writer` can construct it without reaching // into the core package's internals. export { SessionStreamInstance } from "./sessionStreamInstance.js"; -export type { SessionStreamInstanceOptions } from "./sessionStreamInstance.js"; +export type { + SessionStreamInstanceOptions, + InitializeSessionStreamResponseLike, +} from "./sessionStreamInstance.js"; export { trimSessionStream, writeSessionControlRecord, diff --git a/packages/core/src/v3/realtimeStreams/manager.test.ts b/packages/core/src/v3/realtimeStreams/manager.test.ts new file mode 100644 index 00000000000..179754bc752 --- /dev/null +++ b/packages/core/src/v3/realtimeStreams/manager.test.ts @@ -0,0 +1,147 @@ +import { describe, expect, it, vi } from "vitest"; +import type { ApiClient } from "../apiClient/index.js"; +import { StandardRealtimeStreamsManager } from "./manager.js"; + +// The cache lives on a private method to keep `pipe()` callers from having +// to thread cache concerns. Tests exercise it via bracket-notation to keep +// the assertions tight on cache contracts and avoid spinning up real +// `StreamsWriterV1`/`StreamsWriterV2` infrastructure (HTTP requests, S2 +// connections) for what is purely an in-memory dedup check. +type GetCached = ( + runId: string, + key: string, + requestOptions?: undefined +) => Promise<{ version: string; headers?: Record }>; + +function getCached(manager: StandardRealtimeStreamsManager, runId: string, key: string) { + return (manager as unknown as { getCachedCreateStream: GetCached }).getCachedCreateStream( + runId, + key + ); +} + +function makeApiClient(impl: () => Promise<{ version: string; headers?: Record }>) { + const spy = vi.fn(impl); + const client = { createStream: spy } as unknown as ApiClient; + return { client, spy }; +} + +describe("StandardRealtimeStreamsManager createStream cache", () => { + it("dedupes repeated calls for the same (runId, key)", async () => { + const { client, spy } = makeApiClient(async () => ({ version: "v1", headers: {} })); + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + const p1 = getCached(manager, "run-1", "chat"); + const p2 = getCached(manager, "run-1", "chat"); + + expect(p1).toBe(p2); + expect(spy).toHaveBeenCalledTimes(1); + await Promise.all([p1, p2]); + expect(spy).toHaveBeenCalledTimes(1); + }); + + it("issues a separate PUT for each distinct stream key on the same run", async () => { + const { client, spy } = makeApiClient(async () => ({ version: "v1", headers: {} })); + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + await Promise.all([ + getCached(manager, "run-1", "chat"), + getCached(manager, "run-1", "tool-output"), + ]); + + expect(spy).toHaveBeenCalledTimes(2); + expect(spy).toHaveBeenNthCalledWith(1, "run-1", "self", "chat", undefined); + expect(spy).toHaveBeenNthCalledWith(2, "run-1", "self", "tool-output", undefined); + }); + + it("issues a separate PUT for each distinct run, even with the same key", async () => { + const { client, spy } = makeApiClient(async () => ({ version: "v1", headers: {} })); + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + await Promise.all([ + getCached(manager, "run-1", "chat"), + getCached(manager, "run-2", "chat"), + ]); + + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("evicts on failure so the next call retries instead of returning a poisoned entry", async () => { + const spy = vi + .fn() + .mockRejectedValueOnce(new Error("boom")) + .mockResolvedValueOnce({ version: "v1", headers: {} }); + const client = { createStream: spy } as unknown as ApiClient; + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + await expect(getCached(manager, "run-1", "chat")).rejects.toThrow("boom"); + + const retried = await getCached(manager, "run-1", "chat"); + + expect(retried).toEqual({ version: "v1", headers: {} }); + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("reset() clears cached entries so the next call re-PUTs", async () => { + const { client, spy } = makeApiClient(async () => ({ version: "v1", headers: {} })); + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + await getCached(manager, "run-1", "chat"); + expect(spy).toHaveBeenCalledTimes(1); + + manager.reset(); + + await getCached(manager, "run-1", "chat"); + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("evictCreateStreamIfStale clears the matching entry so the next call re-PUTs", async () => { + const { client, spy } = makeApiClient(async () => ({ version: "v1", headers: {} })); + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + // Prime the cache and capture which promise was stored. + const cachedPromise = getCached(manager, "run-1", "chat"); + await cachedPromise; + expect(spy).toHaveBeenCalledTimes(1); + + // Simulate the reactive invalidation path that `pipe()` runs when a + // writer's `wait()` rejects. + ( + manager as unknown as { + evictCreateStreamIfStale: ( + runId: string, + key: string, + expected: Promise + ) => void; + } + ).evictCreateStreamIfStale("run-1", "chat", cachedPromise); + + await getCached(manager, "run-1", "chat"); + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("evictCreateStreamIfStale is a no-op when the cache holds a different promise", async () => { + const { client, spy } = makeApiClient(async () => ({ version: "v1", headers: {} })); + const manager = new StandardRealtimeStreamsManager(client, "http://localhost"); + + const original = getCached(manager, "run-1", "chat"); + await original; + + // A different promise (e.g. from a concurrent caller that already + // refreshed) shouldn't trigger eviction. + const stalePromise = Promise.resolve({ version: "v1", headers: {} }); + ( + manager as unknown as { + evictCreateStreamIfStale: ( + runId: string, + key: string, + expected: Promise + ) => void; + } + ).evictCreateStreamIfStale("run-1", "chat", stalePromise); + + // Cache should still hold the original entry; next call is a hit. + await getCached(manager, "run-1", "chat"); + expect(spy).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/core/src/v3/realtimeStreams/manager.ts b/packages/core/src/v3/realtimeStreams/manager.ts index beda3535fb4..f4d915acc3f 100644 --- a/packages/core/src/v3/realtimeStreams/manager.ts +++ b/packages/core/src/v3/realtimeStreams/manager.ts @@ -1,7 +1,8 @@ import { ApiClient } from "../apiClient/index.js"; import { ensureAsyncIterable, ensureReadableStream } from "../streams/asyncIterableStream.js"; +import { AnyZodFetchOptions } from "../zodfetch.js"; import { taskContext } from "../task-context-api.js"; -import { StreamInstance } from "./streamInstance.js"; +import { CreateStreamResponseLike, StreamInstance } from "./streamInstance.js"; import { RealtimeStreamInstance, RealtimeStreamOperationOptions, @@ -21,8 +22,60 @@ export class StandardRealtimeStreamsManager implements RealtimeStreamsManager { abortController: AbortController; }>(); + // Cache of in-flight / resolved `createStream` responses, keyed by + // `${runId}:${key}`. S2 v2 access tokens are scoped to the org basin + // (default 1-day TTL server-side) so reusing them across repeated + // `pipe()` calls for the same `(runId, key)` is safe, and avoids the + // per-call PUT that pushes `streamId` onto `TaskRun.realtimeStreams`, + // which under chat-agent-style hot-loop writers caused row-lock + // contention on the writer DB. + private createStreamCache = new Map>(); + reset(): void { this.activeStreams.clear(); + this.createStreamCache.clear(); + } + + private getCachedCreateStream( + runId: string, + key: string, + requestOptions: AnyZodFetchOptions | undefined + ): Promise { + const cacheKey = `${runId}:${key}`; + const cached = this.createStreamCache.get(cacheKey); + if (cached) { + return cached; + } + + const promise = this.apiClient.createStream(runId, "self", key, requestOptions); + this.createStreamCache.set(cacheKey, promise); + // Evict on failure so the next call retries instead of returning a + // poisoned cache entry forever. + promise.catch((err) => { + if (this.createStreamCache.get(cacheKey) === promise) { + this.createStreamCache.delete(cacheKey); + } + }); + return promise; + } + + /** + * Reactive invalidation: a writer's `wait()` rejecting can mean the + * cached S2 credentials have gone stale (expired token, revoked + * access, basin retired), so evict the cached `createStream` response + * for `(runId, key)` and let the next `pipe()` re-PUT to mint fresh + * credentials. Compare by identity so a fresh promise installed by a + * concurrent caller isn't accidentally cleared. + */ + private evictCreateStreamIfStale( + runId: string, + key: string, + expected: Promise + ): void { + const cacheKey = `${runId}:${key}`; + if (this.createStreamCache.get(cacheKey) === expected) { + this.createStreamCache.delete(cacheKey); + } } public pipe( @@ -48,6 +101,15 @@ export class StandardRealtimeStreamsManager implements RealtimeStreamsManager { ? AbortSignal.any?.([options.signal, abortController.signal]) ?? abortController.signal : abortController.signal; + // Capture which cached promise this writer uses so reactive + // invalidation below evicts only if the cache still holds it (a + // concurrent caller may have already refreshed it). + const activeCreatePromise = this.getCachedCreateStream( + runId, + key, + options?.requestOptions + ); + const streamInstance = new StreamInstance({ apiClient: this.apiClient, baseUrl: this.baseUrl, @@ -58,14 +120,29 @@ export class StandardRealtimeStreamsManager implements RealtimeStreamsManager { requestOptions: options?.requestOptions, target: options?.target, debug: this.debug, + createStream: () => activeCreatePromise, }); // Register this stream const streamInfo = { wait: () => streamInstance.wait(), abortController }; this.activeStreams.add(streamInfo); - // Clean up when stream completes - streamInstance.wait().finally(() => this.activeStreams.delete(streamInfo)); + // Single internal chain that handles activeStreams cleanup AND + // reactive invalidation. On rejection we evict the cached + // `createStream` entry so the next pipe() for the same `(runId, key)` + // re-PUTs and recovers (e.g. when a cached S2 access token expired + // mid-process). Customer awaiters still observe the rejection via + // the returned `wait()`; this chain just keeps the cleanup path + // from surfacing as unhandled. + streamInstance.wait().then( + () => { + this.activeStreams.delete(streamInfo); + }, + (err) => { + this.evictCreateStreamIfStale(runId, key, activeCreatePromise); + this.activeStreams.delete(streamInfo); + } + ); return { wait: () => streamInstance.wait(), diff --git a/packages/core/src/v3/realtimeStreams/sessionStreamInstance.ts b/packages/core/src/v3/realtimeStreams/sessionStreamInstance.ts index 11eb7290edc..73bec591d9e 100644 --- a/packages/core/src/v3/realtimeStreams/sessionStreamInstance.ts +++ b/packages/core/src/v3/realtimeStreams/sessionStreamInstance.ts @@ -4,6 +4,10 @@ import { AnyZodFetchOptions } from "../zodfetch.js"; import { StreamsWriterV2 } from "./streamsWriterV2.js"; import { StreamsWriter, StreamWriteResult } from "./types.js"; +export type InitializeSessionStreamResponseLike = { + headers?: Record; +}; + export type SessionStreamInstanceOptions = { apiClient: ApiClient; baseUrl: string; @@ -13,6 +17,14 @@ export type SessionStreamInstanceOptions = { signal?: AbortSignal; requestOptions?: AnyZodFetchOptions; debug?: boolean; + /** + * Optional override for the initialize-session-stream call. Defaults to + * `apiClient.initializeSessionStream(sessionId, io, requestOptions)`. The + * channel passes a cached version so repeated `pipe()` / `writer()` + * calls for the same `(sessionId, io)` share a single PUT instead of + * hammering the server on every chunk. + */ + initializeSession?: () => Promise; }; /** @@ -31,11 +43,16 @@ export class SessionStreamInstance implements StreamsWriter { } private async initializeWriter(): Promise> { - const response = await this.options.apiClient.initializeSessionStream( - this.options.sessionId, - this.options.io, - this.options?.requestOptions - ); + const initializeFn = + this.options.initializeSession ?? + (() => + this.options.apiClient.initializeSessionStream( + this.options.sessionId, + this.options.io, + this.options?.requestOptions + )); + + const response = await initializeFn(); const headers = response.headers ?? {}; const accessToken = headers["x-s2-access-token"]; diff --git a/packages/core/src/v3/realtimeStreams/streamInstance.ts b/packages/core/src/v3/realtimeStreams/streamInstance.ts index 07ee0158bfb..e5cd3f84aea 100644 --- a/packages/core/src/v3/realtimeStreams/streamInstance.ts +++ b/packages/core/src/v3/realtimeStreams/streamInstance.ts @@ -5,6 +5,11 @@ import { StreamsWriterV1 } from "./streamsWriterV1.js"; import { StreamsWriterV2 } from "./streamsWriterV2.js"; import { StreamsWriter, StreamWriteResult } from "./types.js"; +export type CreateStreamResponseLike = { + version: string; + headers?: Record; +}; + export type StreamInstanceOptions = { apiClient: ApiClient; baseUrl: string; @@ -15,6 +20,14 @@ export type StreamInstanceOptions = { requestOptions?: AnyZodFetchOptions; target?: "self" | "parent" | "root" | string; debug?: boolean; + /** + * Optional override for the create-stream call. Defaults to + * `apiClient.createStream(runId, "self", key, requestOptions)`. The + * manager passes a cached version so repeated `pipe()` calls for the + * same `(runId, key)` share a single PUT instead of hammering the + * server on every chunk. + */ + createStream?: () => Promise; }; type StreamsWriterInstance = StreamsWriterV1 | StreamsWriterV2; @@ -27,12 +40,17 @@ export class StreamInstance implements StreamsWriter { } private async initializeWriter(): Promise> { - const { version, headers } = await this.options.apiClient.createStream( - this.options.runId, - "self", - this.options.key, - this.options?.requestOptions - ); + const createStreamFn = + this.options.createStream ?? + (() => + this.options.apiClient.createStream( + this.options.runId, + "self", + this.options.key, + this.options?.requestOptions + )); + + const { version, headers } = await createStreamFn(); const parsedResponse = parseCreateStreamResponse(version, headers); diff --git a/packages/trigger-sdk/src/v3/ai.ts b/packages/trigger-sdk/src/v3/ai.ts index 81994a03685..a7ac052b3f5 100644 --- a/packages/trigger-sdk/src/v3/ai.ts +++ b/packages/trigger-sdk/src/v3/ai.ts @@ -77,7 +77,7 @@ import type { ResolvedSkill } from "./skill.js"; // never touches `ai.ts`'s module graph, so the `node:*` builtins // pulled in transitively here never reach a client chunk. import { runBashInSkill, readFileInSkill } from "./agentSkillsRuntime.js"; -import { streams } from "./streams.js"; +import { streams, markChatAgentRunForStreamsWarning } from "./streams.js"; import { sessions, type SessionHandle, @@ -4495,6 +4495,7 @@ function chatCustomAgent< // No client-side upsert needed. locals.set(chatSessionHandleKey, sessions.open(payload.chatId)); locals.set(chatAgentRunContextKey, runOptions.ctx); + markChatAgentRunForStreamsWarning(); taskContext.setConversationId(payload.chatId); stampConversationIdOnActiveSpan(payload.chatId); return userRun(payload, runOptions); @@ -4591,6 +4592,7 @@ function chatAgent< // Mutable holder; advances in `writeTurnCompleteChunk` after each turn // and is the trim target for the NEXT turn's trim record. locals.set(lastTurnCompleteSeqNumKey, { value: undefined }); + markChatAgentRunForStreamsWarning(); taskContext.setConversationId(payload.chatId); // Stamp `gen_ai.conversation.id` on the run-level span. Every diff --git a/packages/trigger-sdk/src/v3/sessions.test.ts b/packages/trigger-sdk/src/v3/sessions.test.ts new file mode 100644 index 00000000000..abeccb0c12d --- /dev/null +++ b/packages/trigger-sdk/src/v3/sessions.test.ts @@ -0,0 +1,186 @@ +import { describe, expect, it, vi } from "vitest"; + +// Per-test override for the stubbed SessionStreamInstance's wait() so a +// test can simulate downstream writer failures (e.g. S2 auth error after +// initializeSessionStream returned a stale token). Reset at the top of +// each test that touches it. +let stubWaitImpl: (() => Promise<{ lastEventId?: string }>) | undefined; + +// Stub `SessionStreamInstance` so constructing a channel writer doesn't try +// to reach S2. The stub still invokes the `initializeSession` callback the +// channel passes in, which is the whole point: that's how the cache gets +// exercised. wait() resolves immediately by default; tests can override it +// via `stubWaitImpl` to verify reactive invalidation on writer failure. +vi.mock("@trigger.dev/core/v3", async (importActual) => { + const actual = (await importActual()) as Record; + class StubSessionStreamInstance { + private waitPromise: Promise<{ lastEventId?: string }>; + constructor(opts: { + source: ReadableStream; + initializeSession?: () => Promise<{ headers?: Record }>; + }) { + // Drain the source so the upstream tee doesn't backpressure-stall. + void (async () => { + const reader = opts.source.getReader(); + try { + while (true) { + const { done } = await reader.read(); + if (done) break; + } + } finally { + reader.releaseLock(); + } + })(); + // Trigger the initializeSession callback so the cache path runs. + opts.initializeSession?.().catch(() => { + // Failures are observed via the spy; swallow here so unhandled + // rejection warnings don't leak through the stub. + }); + // Capture the wait outcome once at construction (mirrors real + // SessionStreamInstance which kicks off initializeWriter from the + // ctor). All subsequent wait() calls return the same promise so + // a single failure is observable by every consumer in the channel + // (`.finally`, reactive `.catch`, and customer `waitUntilComplete`). + this.waitPromise = stubWaitImpl + ? stubWaitImpl() + : Promise.resolve({ lastEventId: undefined }); + // Claim any rejection so test runs don't surface as unhandled. + // Real awaiters still observe the rejection when they `await` it. + this.waitPromise.catch(() => {}); + } + async wait() { + return this.waitPromise; + } + get stream() { + return new ReadableStream({ start: (c) => c.close() }); + } + } + return { ...actual, SessionStreamInstance: StubSessionStreamInstance }; +}); + +import { SessionOutputChannel } from "./sessions.js"; +import { apiClientManager } from "@trigger.dev/core/v3"; + +type ApiClientStub = { + initializeSessionStream: ReturnType; +}; + +function installStubApiClient(impl: ApiClientStub["initializeSessionStream"]): ApiClientStub { + const stub: ApiClientStub = { initializeSessionStream: impl }; + // `apiClientManager.clientOrThrow()` is what `#pipeInternal` reaches for. + vi.spyOn(apiClientManager, "clientOrThrow").mockReturnValue( + stub as unknown as ReturnType + ); + return stub; +} + +function emptyStream(): ReadableStream { + return new ReadableStream({ start: (c) => c.close() }); +} + +describe("SessionOutputChannel initializeSessionStream cache", () => { + it("dedupes repeated pipe()/writer() calls for the same channel", async () => { + stubWaitImpl = undefined; + const spy = vi.fn(async () => ({ version: "v2", headers: {} })); + installStubApiClient(spy); + + const channel = new SessionOutputChannel("session-1"); + const p1 = channel.pipe(emptyStream()); + const p2 = channel.pipe(emptyStream()); + const p3 = channel.writer({ + execute: ({ write }) => { + write({ chunk: 1 }); + }, + }); + + await Promise.all([p1.waitUntilComplete(), p2.waitUntilComplete(), p3.waitUntilComplete()]); + + expect(spy).toHaveBeenCalledTimes(1); + expect(spy).toHaveBeenCalledWith("session-1", "out", undefined); + }); + + it("evicts on initialize failure so the next call retries instead of returning a poisoned entry", async () => { + stubWaitImpl = undefined; + const spy = vi + .fn() + .mockRejectedValueOnce(new Error("boom")) + .mockResolvedValueOnce({ version: "v2", headers: {} }); + installStubApiClient(spy); + + const channel = new SessionOutputChannel("session-1"); + const firstAttempt = channel.pipe(emptyStream()); + // First call fails — the stub swallows the rejection on the + // initializeSession callback, but the cache eviction handler still runs. + await firstAttempt.waitUntilComplete(); + // Settle pending microtasks so the .catch() eviction fires. + await new Promise((resolve) => setTimeout(resolve, 0)); + + const retried = channel.pipe(emptyStream()); + await retried.waitUntilComplete(); + + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("reset() clears cached entries so the next call re-PUTs", async () => { + stubWaitImpl = undefined; + const spy = vi.fn(async () => ({ version: "v2", headers: {} })); + installStubApiClient(spy); + + const channel = new SessionOutputChannel("session-1"); + await channel.pipe(emptyStream()).waitUntilComplete(); + expect(spy).toHaveBeenCalledTimes(1); + + channel.reset(); + + await channel.pipe(emptyStream()).waitUntilComplete(); + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("scopes the cache per channel instance", async () => { + stubWaitImpl = undefined; + const spy = vi.fn(async () => ({ version: "v2", headers: {} })); + installStubApiClient(spy); + + const channelA = new SessionOutputChannel("session-a"); + const channelB = new SessionOutputChannel("session-b"); + + await Promise.all([ + channelA.pipe(emptyStream()).waitUntilComplete(), + channelB.pipe(emptyStream()).waitUntilComplete(), + ]); + + expect(spy).toHaveBeenCalledTimes(2); + expect(spy).toHaveBeenCalledWith("session-a", "out", undefined); + expect(spy).toHaveBeenCalledWith("session-b", "out", undefined); + }); + + it("evicts the cache when a writer's wait() rejects (simulated stale-token failure)", async () => { + const spy = vi.fn(async () => ({ version: "v2", headers: {} })); + installStubApiClient(spy); + + // First writer's wait() rejects (e.g. S2 returned 401 after the cached + // token expired mid-process); subsequent writers' wait() resolve cleanly. + let waitCallCount = 0; + stubWaitImpl = async () => { + waitCallCount++; + if (waitCallCount === 1) throw new Error("S2 auth failed: token expired"); + return { lastEventId: undefined }; + }; + + const channel = new SessionOutputChannel("session-1"); + + const failed = channel.pipe(emptyStream()); + await expect(failed.waitUntilComplete()).rejects.toThrow(/token expired/); + + // Settle microtasks so the reactive .catch eviction handler fires. + await new Promise((resolve) => setTimeout(resolve, 0)); + + const recovered = channel.pipe(emptyStream()); + await recovered.waitUntilComplete(); + + // Cache evicted ⇒ second pipe() re-PUT ⇒ two distinct initialize calls. + expect(spy).toHaveBeenCalledTimes(2); + + stubWaitImpl = undefined; + }); +}); diff --git a/packages/trigger-sdk/src/v3/sessions.ts b/packages/trigger-sdk/src/v3/sessions.ts index 663dbbebc30..18023535f52 100644 --- a/packages/trigger-sdk/src/v3/sessions.ts +++ b/packages/trigger-sdk/src/v3/sessions.ts @@ -34,7 +34,11 @@ import { trimSessionStream, writeSessionControlRecord, } from "@trigger.dev/core/v3"; -import type { ControlEvent, StreamWriteResult } from "@trigger.dev/core/v3"; +import type { + ControlEvent, + InitializeSessionStreamResponseLike, + StreamWriteResult, +} from "@trigger.dev/core/v3"; import { conditionallyImportAndParsePacket } from "@trigger.dev/core/v3/utils/ioSerialization"; import { SpanStatusCode } from "@opentelemetry/api"; import { tracer } from "./tracer.js"; @@ -266,8 +270,30 @@ export type SessionPipeStreamOptions = Omit; * internally by `pipe`/`writer` — there's no public `initialize()`. */ export class SessionOutputChannel { + // Cache of the in-flight / resolved `initializeSessionStream` PUT for + // this channel. Every `pipe()` / `writer()` call needs the same S2 + // credentials, so we share a single promise instead of re-PUTing on + // every chunk. Hot-loop writers (per-chunk `chat.response.write` / + // direct `session.out.writer` calls) drop from N PUTs to 1 PUT for + // the lifetime of the channel. The S2 access token has a 1-day TTL + // server-side so reusing it across calls within a single run is safe. + // Evicts on failure (so the next call retries) and on `reset()`. + #initPromise?: Promise; + constructor(public readonly sessionId: string) {} + /** + * Drop the cached `initializeSessionStream` response. Surfaces for + * tests and lifecycle hooks that need the next write to re-mint S2 + * credentials. The cache also self-evicts on `initializeSession` + * rejection, so callers don't need to invoke this on failures. + * + * @internal + */ + reset(): void { + this.#initPromise = undefined; + } + /** * Append a single record. Routes through {@link writer} internally so * subscribers receive the same parsed-object shape as multi-record @@ -429,6 +455,28 @@ export class SessionOutputChannel { ? AbortSignal.any?.([options.signal, abortController.signal]) ?? abortController.signal : abortController.signal; + // Resolve the init promise eagerly so we can capture which one this + // writer uses for reactive invalidation below. + const writerInitPromise = ((): Promise => { + if (this.#initPromise) { + return this.#initPromise; + } + const fresh = apiClient.initializeSessionStream( + this.sessionId, + "out", + options?.requestOptions + ); + this.#initPromise = fresh; + // Evict on failure so the next call retries instead of returning a + // poisoned cache entry forever. + fresh.catch((err) => { + if (this.#initPromise === fresh) { + this.#initPromise = undefined; + } + }); + return fresh; + })(); + try { const instance = new SessionStreamInstance({ apiClient, @@ -438,11 +486,28 @@ export class SessionOutputChannel { source: readableStreamSource, signal: combinedSignal, requestOptions: options?.requestOptions, + initializeSession: () => writerInitPromise, }); - instance.wait().finally(() => { - span.end(); - }); + // Single internal chain that handles span lifecycle AND reactive + // invalidation. On rejection we evict the cached init promise so + // the next pipe()/writer() re-PUTs and recovers (e.g. when a + // cached S2 access token expired mid-process). Compare by identity + // so a concurrent caller's fresh promise isn't accidentally cleared. + // Customer awaiters still observe the rejection via the returned + // `waitUntilComplete()`; this chain just keeps the cleanup path + // from surfacing as unhandled. + instance.wait().then( + () => { + span.end(); + }, + () => { + if (this.#initPromise === writerInitPromise) { + this.#initPromise = undefined; + } + span.end(); + } + ); return { stream: instance.stream, diff --git a/packages/trigger-sdk/src/v3/streams.ts b/packages/trigger-sdk/src/v3/streams.ts index 6ccaea8891a..f987872d80a 100644 --- a/packages/trigger-sdk/src/v3/streams.ts +++ b/packages/trigger-sdk/src/v3/streams.ts @@ -19,6 +19,7 @@ import { ManualWaitpointPromise, WaitpointTimeoutError, runtime, + logger, type RealtimeDefinedInputStream, type InputStreamSubscription, type InputStreamOnceOptions, @@ -32,10 +33,43 @@ import { } from "@trigger.dev/core/v3"; import { conditionallyImportAndParsePacket } from "@trigger.dev/core/v3/utils/ioSerialization"; import { tracer } from "./tracer.js"; +import { locals } from "./locals.js"; import { SpanStatusCode } from "@opentelemetry/api"; const DEFAULT_STREAM_KEY = "default"; +// `chat.agent` sets this once at the top of every run via +// `markChatAgentRunForStreamsWarning`. The flag lives on the run's +// AsyncLocalStorage frame, so it naturally resets between runs and stays +// invisible to subtasks (where `streams.*` is a normal API). +const inChatAgentRunKey = locals.create("streams.inChatAgentRun"); +// Once-per-run dedup. `streams.*` callers inside a chat.agent run get the +// nudge on the first call and silence afterwards; a single tight loop +// won't spam the logs. +const chatAgentStreamsWarnedKey = locals.create("streams.chatAgentWarned"); + +/** + * Marks the current run as a `chat.agent` run so subsequent `streams.pipe` / + * `streams.append` / `streams.read` calls can warn the user that they're + * writing to a run-scoped stream rather than the chat's `session.out`. + * + * Called from inside the `chat.agent` task wrapper at the top of every run. + * + * @internal + */ +export function markChatAgentRunForStreamsWarning(): void { + locals.set(inChatAgentRunKey, true); +} + +function warnIfChatAgentStreamsMisuse(method: "pipe" | "append" | "read" | "writer"): void { + if (!locals.get(inChatAgentRunKey)) return; + if (locals.get(chatAgentStreamsWarnedKey)) return; + locals.set(chatAgentStreamsWarnedKey, true); + logger.warn( + `streams.${method}() was called inside a chat.agent run. This writes to a run-scoped realtime stream and is NOT visible on the chat session, so the chat UI will not see these chunks. For chat output use chat.response.write() or chat.stream.* instead. See https://trigger.dev/docs/ai-chat/patterns/large-payloads. (Logged once per run; subsequent streams.${method}() calls in this run are silent.)` + ); +} + /** * Pipes data to a realtime stream using the default stream key (`"default"`). * @@ -154,6 +188,7 @@ function pipeInternal( opts: PipeStreamOptions | undefined, spanName: string ): PipeStreamResult { + warnIfChatAgentStreamsMisuse(spanName === "streams.writer()" ? "writer" : "pipe"); const runId = getRunIdForOptions(opts); if (!runId) { @@ -325,6 +360,7 @@ async function readStreamImpl( key: string, options?: ReadStreamOptions ): Promise> { + warnIfChatAgentStreamsMisuse("read"); const apiClient = apiClientManager.clientOrThrow(); const span = tracer.startSpan("streams.read()", { @@ -403,6 +439,7 @@ async function appendInternal( part: TPart, options?: AppendStreamOptions ): Promise { + warnIfChatAgentStreamsMisuse("append"); const runId = getRunIdForOptions(options); if (!runId) { From 427d9e078a399ede5bb351f89ccae706b54553a7 Mon Sep 17 00:00:00 2001 From: Eric Allam Date: Mon, 18 May 2026 16:42:18 +0100 Subject: [PATCH 03/25] feat(sdk): functional baseURL and fetch override on chat transports (#3655) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary `TriggerChatTransport`, `AgentChat`, and `chat.createStartSessionAction` now accept a string-or-function `baseURL` so callers can route per endpoint — e.g. `.in/append` through a trusted edge proxy while keeping `.out` SSE direct. The same surfaces add a `fetch` override for header injection, custom retries, or proxy rewrites that go beyond URL routing. SSE GETs are covered too via a new `fetchClient` option on `SSEStreamSubscription`. ```ts // TriggerChatTransport / AgentChat — endpoints: "in" | "out" baseURL: ({ endpoint }) => endpoint === "out" ? DIRECT : PROXY, fetch: (url, init, ctx) => { init.headers = new Headers(init.headers); init.headers.set("traceparent", currentTraceparent()); return globalThis.fetch(url, init); }, // chat.createStartSessionAction — endpoints: "sessions" | "auth" chat.createStartSessionAction("my-agent", { baseURL: ({ endpoint }) => (endpoint === "sessions" ? PROXY : DIRECT), }); ``` `streamBaseURL` on `TriggerChatTransport` is kept as a backwards-compat alias and continues to win for the `"out"` endpoint when set. Plain-string `baseURL` still applies to every endpoint, matching prior behavior. --- packages/core/src/v3/apiClient/runStream.ts | 7 +- packages/trigger-sdk/src/v3/ai.ts | 177 ++++++++++++++++- packages/trigger-sdk/src/v3/chat-client.ts | 175 +++++++++++------ packages/trigger-sdk/src/v3/chat.test.ts | 88 +++++++++ packages/trigger-sdk/src/v3/chat.ts | 185 ++++++++++++++---- references/ai-chat/cf-worker/.gitignore | 3 + references/ai-chat/cf-worker/README.md | 33 ++++ references/ai-chat/cf-worker/package.json | 14 ++ references/ai-chat/cf-worker/src/index.ts | 124 ++++++++++++ references/ai-chat/cf-worker/tsconfig.json | 14 ++ references/ai-chat/cf-worker/wrangler.toml | 10 + references/ai-chat/src/app/actions.ts | 11 +- .../ai-chat/src/components/chat-sidebar.tsx | 1 + references/ai-chat/src/trigger/chat.ts | 38 ++++ 14 files changed, 772 insertions(+), 108 deletions(-) create mode 100644 references/ai-chat/cf-worker/.gitignore create mode 100644 references/ai-chat/cf-worker/README.md create mode 100644 references/ai-chat/cf-worker/package.json create mode 100644 references/ai-chat/cf-worker/src/index.ts create mode 100644 references/ai-chat/cf-worker/tsconfig.json create mode 100644 references/ai-chat/cf-worker/wrangler.toml diff --git a/packages/core/src/v3/apiClient/runStream.ts b/packages/core/src/v3/apiClient/runStream.ts index 217b7a51082..b52283eae9e 100644 --- a/packages/core/src/v3/apiClient/runStream.ts +++ b/packages/core/src/v3/apiClient/runStream.ts @@ -236,6 +236,10 @@ export class SSEStreamSubscription implements StreamSubscription { // permanently. `404` (stream gone) and `410` (session closed) // are sensible defaults; tune per-caller for other 4xx. nonRetryableStatuses?: readonly number[]; + // Optional fetch override. Used by transports that need to route + // the SSE connect through a custom path (proxy, custom headers, + // tracing). Defaults to global `fetch`. + fetchClient?: typeof fetch; } ) { this.lastEventId = options.lastEventId; @@ -331,7 +335,8 @@ export class SSEStreamSubscription implements StreamSubscription { headers["Timeout-Seconds"] = this.options.timeoutInSeconds.toString(); } - const response = await fetch(this.url, { + const fetchClient = this.options.fetchClient ?? fetch; + const response = await fetchClient(this.url, { headers, signal: this.internalAbort.signal, }); diff --git a/packages/trigger-sdk/src/v3/ai.ts b/packages/trigger-sdk/src/v3/ai.ts index a7ac052b3f5..d80e40b83d4 100644 --- a/packages/trigger-sdk/src/v3/ai.ts +++ b/packages/trigger-sdk/src/v3/ai.ts @@ -35,6 +35,7 @@ import { type TaskWithSchema, SESSION_IN_EVENT_ID_HEADER, TRIGGER_CONTROL_SUBTYPE, + generateJWT, type WriterStreamOptions, } from "@trigger.dev/core/v3"; import type { @@ -8413,6 +8414,32 @@ export type { InferChatClientData, InferChatUIMessage } from "./ai-shared.js"; /** * Options for {@link createChatStartSessionAction}. */ +/** + * Discriminator for per-endpoint `baseURL` / `fetch` callbacks on + * `createChatStartSessionAction`. + * + * - `"sessions"` — `POST /api/v1/sessions` (session create + first run trigger). + * - `"auth"` — `POST /api/v1/auth/jwt/claims` (only fired when + * `tokenTTL` is set; otherwise the publicAccessToken from session create + * is reused as-is). + */ +export type ChatStartSessionEndpoint = "sessions" | "auth"; + +export type ChatStartSessionEndpointContext = { + endpoint: ChatStartSessionEndpoint; + chatId: string; +}; + +export type ChatStartSessionBaseURLResolver = ( + ctx: ChatStartSessionEndpointContext +) => string; + +export type ChatStartSessionFetchOverride = ( + url: string, + init: RequestInit, + ctx: ChatStartSessionEndpointContext +) => Promise; + export type CreateChatStartSessionActionOptions = { /** TTL for the session-scoped public access token. @default "1h" */ tokenTTL?: string | number | Date; @@ -8421,6 +8448,21 @@ export type CreateChatStartSessionActionOptions = { * Per-call `params.triggerConfig` shallow-merges on top. */ triggerConfig?: Partial; + /** + * Override the Trigger.dev API base URL. String applies to both + * `/api/v1/sessions` and `/api/v1/auth/jwt/claims`; function picks per + * endpoint. When unset, falls back to `apiClientManager.baseURL` + * (typically the `TRIGGER_API_URL` env var). Set this to route session + * create through a trusted edge proxy that injects server-side signal + * into `basePayload.metadata` before forwarding upstream. + */ + baseURL?: string | ChatStartSessionBaseURLResolver; + /** + * Per-request fetch override. Receives the resolved URL, RequestInit, + * and endpoint context. Use for header injection, proxy routing, or + * custom retry. Applies to both session-create and JWT-claims POSTs. + */ + fetch?: ChatStartSessionFetchOverride; }; /** @@ -8544,13 +8586,26 @@ function createChatStartSessionAction( : {}), }; - const created = await sessions.start({ - type: "chat.agent", + const startBody = { + type: "chat.agent" as const, externalId: params.chatId, taskIdentifier: taskId, triggerConfig, metadata: params.metadata, - }); + }; + + const baseURLOption = options?.baseURL; + const fetchOverride = options?.fetch; + const hasOverride = baseURLOption !== undefined || fetchOverride !== undefined; + + const created: { id: string; runId: string; publicAccessToken: string } = hasOverride + ? await callSessionsCreateWithOverride({ + chatId: params.chatId, + body: startBody, + baseURLOption, + fetchOverride, + }) + : await sessions.start(startBody); // Session create returns a session PAT directly when called with a // start token, but when the SDK call goes via the secret key we still @@ -8558,13 +8613,20 @@ function createChatStartSessionAction( // re-minting here lets the customer override `tokenTTL`). const publicAccessToken = options?.tokenTTL !== undefined - ? await auth.createPublicToken({ - scopes: { - read: { sessions: params.chatId }, - write: { sessions: params.chatId }, - }, - expirationTime: options.tokenTTL, - }) + ? hasOverride + ? await mintPublicTokenWithOverride({ + chatId: params.chatId, + expirationTime: options.tokenTTL, + baseURLOption, + fetchOverride, + }) + : await auth.createPublicToken({ + scopes: { + read: { sessions: params.chatId }, + write: { sessions: params.chatId }, + }, + expirationTime: options.tokenTTL, + }) : created.publicAccessToken; return { @@ -8575,6 +8637,101 @@ function createChatStartSessionAction( }; } +function resolveChatStartBaseURL( + endpoint: ChatStartSessionEndpoint, + chatId: string, + option: string | ChatStartSessionBaseURLResolver | undefined +): string { + const fallback = apiClientManager.baseURL ?? "https://api.trigger.dev"; + const raw = + typeof option === "function" + ? option({ endpoint, chatId }) + : option ?? fallback; + return raw.replace(/\/$/, ""); +} + +function overrideRequestHeaders(accessToken: string): Record { + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + "x-trigger-source": "sdk", + }; + // Forward the preview-branch hint so override-mode requests land on the + // same env the standard ApiClient path would have routed to. Mirrors + // ApiClient.#getHeaders. Read from TRIGGER_PREVIEW_BRANCH / + // VERCEL_GIT_COMMIT_REF via apiClientManager.branchName. + if (apiClientManager.branchName) { + headers["x-trigger-branch"] = apiClientManager.branchName; + } + return headers; +} + +async function callSessionsCreateWithOverride(args: { + chatId: string; + body: { type: "chat.agent"; externalId: string; taskIdentifier: string; triggerConfig: SessionTriggerConfig; metadata?: Record }; + baseURLOption: string | ChatStartSessionBaseURLResolver | undefined; + fetchOverride: ChatStartSessionFetchOverride | undefined; +}): Promise<{ id: string; runId: string; publicAccessToken: string }> { + const accessToken = apiClientManager.accessToken; + if (!accessToken) { + throw new Error( + "chat.createStartSessionAction: no API access token configured. Set TRIGGER_SECRET_KEY or call apiClientManager.setGlobalAPIClientConfiguration before invoking the action." + ); + } + const ctx: ChatStartSessionEndpointContext = { endpoint: "sessions", chatId: args.chatId }; + const url = `${resolveChatStartBaseURL("sessions", args.chatId, args.baseURLOption)}/api/v1/sessions`; + const init: RequestInit = { + method: "POST", + headers: overrideRequestHeaders(accessToken), + body: JSON.stringify(args.body), + }; + const response = args.fetchOverride + ? await args.fetchOverride(url, init, ctx) + : await fetch(url, init); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error(`sessions.start failed: ${response.status} ${text}`); + } + const json = (await response.json()) as { id: string; runId: string; publicAccessToken: string }; + return json; +} + +async function mintPublicTokenWithOverride(args: { + chatId: string; + expirationTime: string | number | Date; + baseURLOption: string | ChatStartSessionBaseURLResolver | undefined; + fetchOverride: ChatStartSessionFetchOverride | undefined; +}): Promise { + const accessToken = apiClientManager.accessToken; + if (!accessToken) { + throw new Error( + "chat.createStartSessionAction: no API access token configured for JWT mint." + ); + } + const ctx: ChatStartSessionEndpointContext = { endpoint: "auth", chatId: args.chatId }; + const url = `${resolveChatStartBaseURL("auth", args.chatId, args.baseURLOption)}/api/v1/auth/jwt/claims`; + const init: RequestInit = { + method: "POST", + headers: overrideRequestHeaders(accessToken), + }; + const response = args.fetchOverride + ? await args.fetchOverride(url, init, ctx) + : await fetch(url, init); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error(`auth.createPublicToken failed: ${response.status} ${text}`); + } + const claims = (await response.json()) as Record; + return generateJWT({ + secretKey: accessToken, + payload: { + ...claims, + scopes: [`read:sessions:${args.chatId}`, `write:sessions:${args.chatId}`], + }, + expirationTime: args.expirationTime, + }); +} + export const chat = { /** Create a chat agent. See {@link chatAgent}. */ agent: chatAgent, diff --git a/packages/trigger-sdk/src/v3/chat-client.ts b/packages/trigger-sdk/src/v3/chat-client.ts index 40132a624e1..98380f1e8be 100644 --- a/packages/trigger-sdk/src/v3/chat-client.ts +++ b/packages/trigger-sdk/src/v3/chat-client.ts @@ -20,7 +20,6 @@ import type { SessionTriggerConfig, Task } from "@trigger.dev/core/v3"; import type { ModelMessage, UIMessage, UIMessageChunk } from "ai"; import { readUIMessageStream } from "ai"; import { - ApiClient, apiClientManager, controlSubtype, SSEStreamSubscription, @@ -53,6 +52,26 @@ export type ChatSession = { lastEventId?: string; }; +/** + * Discriminator passed to per-endpoint `baseURL` and `fetch` callbacks on + * `AgentChat`. Same shape as the type on `TriggerChatTransport` — these + * mirror so customers can share a single resolver between the two clients. + */ +export type AgentChatEndpoint = "in" | "out"; + +export type AgentChatEndpointContext = { + endpoint: AgentChatEndpoint; + chatId: string; +}; + +export type AgentChatBaseURLResolver = (ctx: AgentChatEndpointContext) => string; + +export type AgentChatFetchOverride = ( + url: string, + init: RequestInit, + ctx: AgentChatEndpointContext +) => Promise; + export type AgentChatOptions = { /** The agent task ID to trigger. */ agent: string; @@ -89,6 +108,26 @@ export type AgentChatOptions = { * chat. Folded into `sessions.start({...triggerConfig})` body. */ triggerConfig?: SessionTriggerConfig; + /** + * Override the Trigger.dev API base URL for the chat's `.in/append` and + * `.out` SSE endpoints. String form applies to both; pass a function to + * pick per endpoint. Defaults to `apiClientManager.baseURL` (whatever + * `@trigger.dev/sdk` was configured with — typically `TRIGGER_API_URL` + * env var). + * + * Session creation (`POST /api/v1/sessions`) and token mint + * (`POST /api/v1/auth/jwt/claims`) still flow through + * `apiClientManager` — pass equivalent options to + * `chat.createStartSessionAction` if you need those routed too. + */ + baseURL?: string | AgentChatBaseURLResolver; + /** + * Optional per-request fetch override. Receives the resolved URL, the + * RequestInit, and endpoint context. Use this for header injection + * (tracing), proxy routing, or custom retries. Applies to both the + * `.in/append` POSTs and the `.out` SSE GET. + */ + fetch?: AgentChatFetchOverride; }; // ─── ChatStream ──────────────────────────────────────────────────── @@ -272,6 +311,8 @@ export class AgentChat { private readonly triggerConfigDefault: SessionTriggerConfig | undefined; private readonly onTriggered: AgentChatOptions["onTriggered"]; private readonly onTurnComplete: AgentChatOptions["onTurnComplete"]; + private readonly baseURLResolver: AgentChatBaseURLResolver; + private readonly fetchOverride: AgentChatFetchOverride | undefined; private state: SessionState; @@ -283,6 +324,11 @@ export class AgentChat { this.triggerConfigDefault = options.triggerConfig; this.onTriggered = options.onTriggered; this.onTurnComplete = options.onTurnComplete; + const baseURLOption = options.baseURL; + this.baseURLResolver = typeof baseURLOption === "function" + ? baseURLOption + : () => baseURLOption ?? apiClientManager.baseURL ?? "https://api.trigger.dev"; + this.fetchOverride = options.fetch; // Hydration: a non-empty `session` means the caller knows the // session already exists (started in a previous request). Mark @@ -378,12 +424,7 @@ export class AgentChat { metadata: this.clientData, } as ChatTaskWirePayload; - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ kind: "message", payload }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "message", payload })); return this.subscribeToSessionStream(options?.abortSignal); } @@ -404,15 +445,7 @@ export class AgentChat { }; try { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ - kind: "message", - payload, - }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "message", payload })); return true; } catch { return false; @@ -424,14 +457,7 @@ export class AgentChat { if (!this.state.started) return; this.state.skipToTurnComplete = true; - const api = this.createApiClient(); - await api - .appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ kind: "stop" }) - ) - .catch(() => {}); + await this.appendInputChunk(serializeInputChunk({ kind: "stop" })).catch(() => {}); } /** @@ -459,10 +485,7 @@ export class AgentChat { */ isFinal: boolean; }): Promise { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", + await this.appendInputChunk( serializeInputChunk({ kind: "handover", partialAssistantMessage: args.partialAssistantMessage, @@ -481,12 +504,7 @@ export class AgentChat { * surface. */ async sendHandoverSkip(): Promise { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ kind: "handover-skip" }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "handover-skip" })); } /** @@ -531,15 +549,7 @@ export class AgentChat { }; try { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ - kind: "message", - payload, - }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "message", payload })); } catch { throw new Error("Failed to send action. The session may have ended."); } @@ -553,10 +563,7 @@ export class AgentChat { if (!this.state.started) return false; try { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", + await this.appendInputChunk( serializeInputChunk({ kind: "message", payload: { @@ -582,10 +589,41 @@ export class AgentChat { // ─── Private ─────────────────────────────────────────────────── - private createApiClient(): ApiClient { - const baseURL = apiClientManager.baseURL ?? "https://api.trigger.dev"; + private resolveBaseURL(endpoint: AgentChatEndpoint): string { + return this.baseURLResolver({ endpoint, chatId: this.chatId }).replace(/\/$/, ""); + } + + private async doFetch( + ctx: AgentChatEndpointContext, + url: string, + init: RequestInit + ): Promise { + return this.fetchOverride ? this.fetchOverride(url, init, ctx) : fetch(url, init); + } + + private async appendInputChunk(body: string): Promise { const accessToken = apiClientManager.accessToken ?? ""; - return new ApiClient(baseURL, accessToken); + const ctx: AgentChatEndpointContext = { endpoint: "in", chatId: this.chatId }; + const url = `${this.resolveBaseURL("in")}/realtime/v1/sessions/${encodeURIComponent(this.chatId)}/in/append`; + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + "x-trigger-source": "sdk", + }; + const response = await this.doFetch(ctx, url, { method: "POST", headers, body }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + // Match the error shape that ApiClient/zodfetch produced before the + // inline-POST refactor so callers inspecting `error.name === + // "TriggerApiError"` or `error.status` keep working. + const err = new Error(`appendToSessionStream failed: ${response.status} ${text}`) as Error & { + name: string; + status: number; + }; + err.name = "TriggerApiError"; + err.status = response.status; + throw err; + } } /** @@ -650,10 +688,33 @@ export class AgentChat { options?: { sendStopOnAbort?: boolean } ): ReadableStream { const state = this.state; - const baseURL = apiClientManager.baseURL ?? "https://api.trigger.dev"; const accessToken = apiClientManager.accessToken ?? ""; const onTurnComplete = this.onTurnComplete; const chatId = this.chatId; + const sseCtx: AgentChatEndpointContext = { endpoint: "out", chatId }; + const fetchOverride = this.fetchOverride; + const sseFetchClient: typeof fetch | undefined = fetchOverride + ? ((input, init) => { + if (typeof input === "string") { + return fetchOverride(input, init ?? {}, sseCtx); + } + if (input instanceof URL) { + return fetchOverride(input.toString(), init ?? {}, sseCtx); + } + // Request — preserve its url + intrinsic init, let any provided + // init override on top (matches fetch(Request, init) semantics). + return fetchOverride( + input.url, + { + method: input.method, + headers: input.headers, + signal: input.signal, + ...(init ?? {}), + }, + sseCtx + ); + }) as typeof fetch + : undefined; const internalAbort = new AbortController(); const combinedSignal = abortSignal @@ -666,14 +727,7 @@ export class AgentChat { () => { if (options?.sendStopOnAbort !== false) { state.skipToTurnComplete = true; - const api = new ApiClient(baseURL, accessToken); - api - .appendToSessionStream( - chatId, - "in", - serializeInputChunk({ kind: "stop" }) - ) - .catch(() => {}); + this.appendInputChunk(serializeInputChunk({ kind: "stop" })).catch(() => {}); } internalAbort.abort(); }, @@ -681,7 +735,7 @@ export class AgentChat { ); } - const streamUrl = `${baseURL}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; + const streamUrl = `${this.resolveBaseURL("out")}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; return new ReadableStream({ start: async (controller) => { @@ -693,6 +747,7 @@ export class AgentChat { signal: combinedSignal, timeoutInSeconds: this.streamTimeoutSeconds, lastEventId: state.lastEventId, + fetchClient: sseFetchClient, }); const sseStream = await subscription.subscribe(); const reader = sseStream.getReader(); diff --git a/packages/trigger-sdk/src/v3/chat.test.ts b/packages/trigger-sdk/src/v3/chat.test.ts index 5f50854ec41..6469f1ac86c 100644 --- a/packages/trigger-sdk/src/v3/chat.test.ts +++ b/packages/trigger-sdk/src/v3/chat.test.ts @@ -609,6 +609,94 @@ describe("TriggerChatTransport", () => { expect(subscribe!).toContain("/realtime/v1/sessions/chat-by-chatid/out"); }); + it("functional baseURL dispatches per endpoint (in vs out)", async () => { + const requests: Array<{ url: string; ctxEndpoint: string | undefined }> = []; + global.fetch = vi.fn().mockImplementation(async (url: string | URL) => { + const urlStr = typeof url === "string" ? url : url.toString(); + requests.push({ url: urlStr, ctxEndpoint: undefined }); + if (isSessionStreamAppendUrl(urlStr)) return defaultAppendResponse(); + if (isSessionOutSubscribeUrl(urlStr)) return defaultSseResponse(); + throw new Error(`Unexpected URL: ${urlStr}`); + }); + + const baseURLFn = vi.fn(({ endpoint }: { endpoint: "in" | "out"; chatId: string }) => + endpoint === "out" + ? "https://stream.example.com" + : "https://api.example.com" + ); + + const transport = new TriggerChatTransport({ + task: "my-chat-task", + accessToken: () => "pat", + baseURL: baseURLFn, + sessions: { "chat-fn": { publicAccessToken: "p" } }, + }); + + const stream = await transport.sendMessages({ + trigger: "submit-message", + chatId: "chat-fn", + messageId: undefined, + messages: [createUserMessage("Hi")], + abortSignal: undefined, + }); + await drainChunks(stream); + + const appendCalls = baseURLFn.mock.calls.filter((c) => c[0].endpoint === "in"); + const outCalls = baseURLFn.mock.calls.filter((c) => c[0].endpoint === "out"); + expect(appendCalls.length).toBeGreaterThanOrEqual(1); + expect(outCalls.length).toBeGreaterThanOrEqual(1); + expect(appendCalls[0]![0].chatId).toBe("chat-fn"); + expect(outCalls[0]![0].chatId).toBe("chat-fn"); + + const append = requests.find((r) => isSessionStreamAppendUrl(r.url)); + const subscribe = requests.find((r) => isSessionOutSubscribeUrl(r.url)); + expect(append!.url.startsWith("https://api.example.com/")).toBe(true); + expect(subscribe!.url.startsWith("https://stream.example.com/")).toBe(true); + }); + + it("fetch override is invoked for both .in/append and .out SSE with endpoint ctx", async () => { + const fetchCalls: Array<{ url: string; endpoint: string; chatId: string }> = []; + + const customFetch = vi.fn( + async ( + url: string, + init: RequestInit, + ctx: { endpoint: "in" | "out"; chatId: string } + ) => { + fetchCalls.push({ url, endpoint: ctx.endpoint, chatId: ctx.chatId }); + if (isSessionStreamAppendUrl(url)) return defaultAppendResponse(); + if (isSessionOutSubscribeUrl(url)) return defaultSseResponse(); + throw new Error(`Unexpected URL: ${url}`); + } + ); + + global.fetch = vi.fn().mockRejectedValue(new Error("global fetch should not be called")); + + const transport = new TriggerChatTransport({ + task: "my-chat-task", + accessToken: () => "pat", + baseURL: "https://api.test.trigger.dev", + fetch: customFetch, + sessions: { "chat-fetch": { publicAccessToken: "p" } }, + }); + + const stream = await transport.sendMessages({ + trigger: "submit-message", + chatId: "chat-fetch", + messageId: undefined, + messages: [createUserMessage("Hi")], + abortSignal: undefined, + }); + await drainChunks(stream); + + const inCalls = fetchCalls.filter((c) => c.endpoint === "in"); + const outCalls = fetchCalls.filter((c) => c.endpoint === "out"); + expect(inCalls.length).toBeGreaterThanOrEqual(1); + expect(outCalls.length).toBeGreaterThanOrEqual(1); + expect(inCalls[0]!.chatId).toBe("chat-fetch"); + expect(outCalls[0]!.chatId).toBe("chat-fetch"); + }); + it("routes .out SSE through streamBaseURL while appends stay on baseURL", async () => { const requests: string[] = []; global.fetch = vi.fn().mockImplementation(async (url: string | URL) => { diff --git a/packages/trigger-sdk/src/v3/chat.ts b/packages/trigger-sdk/src/v3/chat.ts index a979b8f2b11..2aefc2bb800 100644 --- a/packages/trigger-sdk/src/v3/chat.ts +++ b/packages/trigger-sdk/src/v3/chat.ts @@ -25,7 +25,6 @@ import type { ChatTransport, UIMessage, UIMessageChunk, ChatRequestOptions } from "ai"; import { - ApiClient, controlSubtype, headerValue, PUBLIC_ACCESS_TOKEN_HEADER, @@ -38,6 +37,43 @@ import type { ChatInputChunk, ChatTaskWirePayload } from "./ai-shared.js"; const DEFAULT_BASE_URL = "https://api.trigger.dev"; const DEFAULT_STREAM_TIMEOUT_SECONDS = 120; +/** + * Discriminator passed to per-endpoint `baseURL` and `fetch` callbacks. + * + * - `"in"` — `POST /realtime/v1/sessions/{chatId}/in/append` (user messages, + * stops, actions). + * - `"out"` — `GET /realtime/v1/sessions/{chatId}/out` (SSE response stream). + * + * Other endpoints (`/api/v1/sessions`, `/api/v1/auth/jwt/claims`) are reached + * from the server-side `chat.createStartSessionAction` and `accessToken` + * callback, not the transport — they accept the same callback shape on their + * own option objects. + */ +export type ChatTransportEndpoint = "in" | "out"; + +/** Context passed to `baseURL` and `fetch` callbacks. */ +export type ChatTransportEndpointContext = { + endpoint: ChatTransportEndpoint; + chatId: string; +}; + +/** Resolver form of `baseURL` — return the base for the given endpoint. */ +export type ChatBaseURLResolver = (ctx: ChatTransportEndpointContext) => string; + +/** + * Per-request fetch override. Receives the fully-resolved URL and the + * RequestInit the transport would have used, plus endpoint context for + * routing decisions. Customers can rewrite the URL, inject headers, or + * delegate to a custom transport (e.g. a Cloudflare worker fronting + * `api.trigger.dev`). Must return a `Response` semantically equivalent to + * what `globalThis.fetch(url, init)` would have returned. + */ +export type ChatFetchOverride = ( + url: string, + init: RequestInit, + ctx: ChatTransportEndpointContext +) => Promise; + /** * Detect 401/403 from realtime/input-stream calls without relying on `instanceof` * (Vitest can load duplicate `@trigger.dev/core` copies, which breaks subclass checks). @@ -229,18 +265,45 @@ export type TriggerChatTransportOptions = { > ) => Promise; - /** Base URL for the Trigger.dev API. @default "https://api.trigger.dev" */ - baseURL?: string; + /** + * Base URL for the Trigger.dev API. Either a single string applied to every + * endpoint, or a function called per request that picks a base URL from the + * endpoint discriminator and chat ID. @default "https://api.trigger.dev" + * + * @example Route appends through a proxy, SSE direct: + * ```ts + * baseURL: ({ endpoint }) => + * endpoint === "out" ? "https://api.trigger.dev" : "https://proxy.example.com", + * ``` + */ + baseURL?: string | ChatBaseURLResolver; /** * Base URL for the SSE stream subscription only (`GET .../sessions/{chatId}/out`). - * Falls back to `baseURL` when unset. Set this to route the long-lived - * stream through a custom proxy (e.g. a Cloudflare worker capturing JA4 - * fingerprints for bot detection) while keeping append POSTs direct to - * `baseURL` to avoid an extra hop on every user message. + * @deprecated Pass a function for `baseURL` instead and branch on + * `endpoint === "out"`. `streamBaseURL` continues to work for backwards + * compatibility and wins over `baseURL` for the SSE endpoint when both + * are set. */ streamBaseURL?: string; + /** + * Optional per-request fetch override. Called with the resolved URL and the + * RequestInit the transport built, plus endpoint context. Use this to + * inject custom headers (e.g. distributed tracing), redirect via a proxy, + * or wrap fetch with retries/logging. + * + * @example Add a tracing header to every chat request: + * ```ts + * fetch: (url, init, ctx) => { + * init.headers = new Headers(init.headers); + * init.headers.set("traceparent", currentTraceparent()); + * return globalThis.fetch(url, init); + * }, + * ``` + */ + fetch?: ChatFetchOverride; + /** Additional headers included in every API request. */ headers?: Record; @@ -361,8 +424,8 @@ export class TriggerChatTransport implements ChatTransport { private readonly resolveStartSession: | ((params: StartSessionParams>) => Promise) | undefined; - private readonly baseURL: string; - private readonly streamBaseURL: string; + private readonly resolveBaseURLFn: ChatBaseURLResolver; + private readonly fetchOverride: ChatFetchOverride | undefined; private readonly extraHeaders: Record; private readonly streamTimeoutSeconds: number; private defaultMetadata: Record | undefined; @@ -383,8 +446,12 @@ export class TriggerChatTransport implements ChatTransport { this.resolveStartSession = options.startSession as | ((params: StartSessionParams>) => Promise) | undefined; - this.baseURL = options.baseURL ?? DEFAULT_BASE_URL; - this.streamBaseURL = options.streamBaseURL ?? this.baseURL; + const baseURLOption = options.baseURL ?? DEFAULT_BASE_URL; + const streamOverride = options.streamBaseURL; + this.resolveBaseURLFn = typeof baseURLOption === "function" + ? (ctx) => (ctx.endpoint === "out" && streamOverride ? streamOverride : baseURLOption(ctx)) + : (ctx) => (ctx.endpoint === "out" && streamOverride ? streamOverride : baseURLOption); + this.fetchOverride = options.fetch; this.extraHeaders = options.headers ?? {}; this.streamTimeoutSeconds = options.streamTimeoutSeconds ?? DEFAULT_STREAM_TIMEOUT_SECONDS; this.defaultMetadata = options.clientData; @@ -528,10 +595,9 @@ export class TriggerChatTransport implements ChatTransport { const state = await this.ensureSessionState(chatId); const sendChatMessage = async (token: string) => { - const apiClient = new ApiClient(this.baseURL, token); - await apiClient.appendToSessionStream( + await this.appendInputChunk( chatId, - "in", + token, this.serializeInputChunk({ kind: "message", payload: wirePayload }) ); }; @@ -708,10 +774,9 @@ export class TriggerChatTransport implements ChatTransport { }; const send = async (token: string) => { - const apiClient = new ApiClient(this.baseURL, token); - await apiClient.appendToSessionStream( + await this.appendInputChunk( chatId, - "in", + token, this.serializeInputChunk({ kind: "message", payload: wirePayload }) ); }; @@ -768,12 +833,7 @@ export class TriggerChatTransport implements ChatTransport { if (!state) return false; const send = async (token: string) => { - const api = new ApiClient(this.baseURL, token); - await api.appendToSessionStream( - chatId, - "in", - this.serializeInputChunk({ kind: "stop" }) - ); + await this.appendInputChunk(chatId, token, this.serializeInputChunk({ kind: "stop" })); }; try { @@ -822,8 +882,7 @@ export class TriggerChatTransport implements ChatTransport { const body = this.serializeInputChunk({ kind: "message", payload: wirePayload }); const send = async (token: string) => { - const apiClient = new ApiClient(this.baseURL, token); - await apiClient.appendToSessionStream(chatId, "in", body); + await this.appendInputChunk(chatId, token, body); }; await this.callWithAuthRetry(chatId, state, send); @@ -978,6 +1037,41 @@ export class TriggerChatTransport implements ChatTransport { * Run `op` with the session's stored PAT. On 401/403, refresh the PAT * via `accessToken` and retry once. Surfaces non-auth errors as-is. */ + private resolveBaseURL(ctx: ChatTransportEndpointContext): string { + const raw = this.resolveBaseURLFn(ctx); + return raw.replace(/\/$/, ""); + } + + private async doFetch( + ctx: ChatTransportEndpointContext, + url: string, + init: RequestInit + ): Promise { + return this.fetchOverride ? this.fetchOverride(url, init, ctx) : fetch(url, init); + } + + private async appendInputChunk(chatId: string, token: string, body: string): Promise { + const ctx: ChatTransportEndpointContext = { endpoint: "in", chatId }; + const url = `${this.resolveBaseURL(ctx)}/realtime/v1/sessions/${encodeURIComponent(chatId)}/in/append`; + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + "x-trigger-source": "sdk", + ...this.extraHeaders, + }; + const response = await this.doFetch(ctx, url, { method: "POST", headers, body }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + const err = new Error(`appendToSessionStream failed: ${response.status} ${text}`) as Error & { + name: string; + status: number; + }; + err.name = "TriggerApiError"; + err.status = response.status; + throw err; + } + } + private async callWithAuthRetry( chatId: string, state: ChatSessionState, @@ -1026,14 +1120,11 @@ export class TriggerChatTransport implements ChatTransport { () => { if (options?.sendStopOnAbort !== false) { state.skipToTurnComplete = true; - const api = new ApiClient(this.baseURL, state.publicAccessToken); - api - .appendToSessionStream( - chatId, - "in", - this.serializeInputChunk({ kind: "stop" }) - ) - .catch(() => {}); + this.appendInputChunk( + chatId, + state.publicAccessToken, + this.serializeInputChunk({ kind: "stop" }) + ).catch(() => {}); } internalAbort.abort(); }, @@ -1041,7 +1132,7 @@ export class TriggerChatTransport implements ChatTransport { ); } - const streamUrl = `${this.streamBaseURL}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; + const streamUrl = `${this.resolveBaseURL({ endpoint: "out", chatId })}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; return new ReadableStream({ start: async (controller) => { @@ -1099,6 +1190,31 @@ export class TriggerChatTransport implements ChatTransport { })() : () => {}; + const sseCtx: ChatTransportEndpointContext = { endpoint: "out", chatId }; + const fetchOverride = this.fetchOverride; + const sseFetchClient: typeof fetch | undefined = fetchOverride + ? ((input, init) => { + if (typeof input === "string") { + return fetchOverride(input, init ?? {}, sseCtx); + } + if (input instanceof URL) { + return fetchOverride(input.toString(), init ?? {}, sseCtx); + } + // Request — preserve its url + intrinsic init, let any + // provided init override on top (matches fetch(Request, init) + // semantics). + return fetchOverride( + input.url, + { + method: input.method, + headers: input.headers, + signal: input.signal, + ...(init ?? {}), + }, + sseCtx + ); + }) as typeof fetch + : undefined; const connectSseOnce = async (token: string) => { const subscription = new SSEStreamSubscription(streamUrl, { headers: { @@ -1113,6 +1229,7 @@ export class TriggerChatTransport implements ChatTransport { // keepalive) arrives in 60s, force reconnect. Sized // generously over typical agent thinking pauses. stallTimeoutMs: 60_000, + fetchClient: sseFetchClient, }); currentSubscription = subscription; const sseStream = await subscription.subscribe(); diff --git a/references/ai-chat/cf-worker/.gitignore b/references/ai-chat/cf-worker/.gitignore new file mode 100644 index 00000000000..8619bbe6b27 --- /dev/null +++ b/references/ai-chat/cf-worker/.gitignore @@ -0,0 +1,3 @@ +.wrangler/ +node_modules/ +*.log diff --git a/references/ai-chat/cf-worker/README.md b/references/ai-chat/cf-worker/README.md new file mode 100644 index 00000000000..8c9a733d73a --- /dev/null +++ b/references/ai-chat/cf-worker/README.md @@ -0,0 +1,33 @@ +# cf-trust-test worker + +A minimal Cloudflare Worker that demonstrates the trusted-edge-signals pattern from [`docs/ai-chat/patterns/trusted-edge-signals`](../../../docs/ai-chat/patterns/trusted-edge-signals.mdx). The worker sits in front of the Trigger.dev API, intercepts the two body-write paths (`POST /api/v1/sessions` and `POST /realtime/v1/sessions/{id}/in/append`), and injects a server-trusted `__cf` namespace into the wire payload's `metadata` field. Everything else (SSE, auth, dashboard) passes through untouched. + +Pairs with the `cfTrustTestAgent` (task id `cf-trust-test`) defined in `src/trigger/chat.ts`, which declares the `__cf` namespace in its `clientDataSchema` and echoes the values back so the round-trip is visible in the streamed response. + +## Run it + +```bash +# In references/ai-chat/cf-worker +pnpm install +pnpm run dev # serves on http://localhost:8787, proxies to TRIGGER_API_UPSTREAM +``` + +Point the Next.js reference app at the worker by setting `TRIGGER_API_URL` and `NEXT_PUBLIC_TRIGGER_API_URL` to `http://localhost:8787` in `references/ai-chat/.env`. Then start trigger-dev and Next.js as usual. + +`wrangler dev` populates `request.cf` with the developer's real Cloudflare edge metadata even in local mode; the worker falls back to hardcoded sample values if `request.cf` is unset. + +## Wire-up for `.out` SSE direct (optional) + +By default the reference app routes every request through `NEXT_PUBLIC_TRIGGER_API_URL`, so SSE also flows through the worker. To skip the worker on the long-lived `.out` channel — which gives no body-mutation benefit and adds one extra edge hop per reconnect — switch the transport's `baseURL` to the function form: + +```ts +const transport = useTriggerChatTransport({ + // ... + baseURL: ({ endpoint }) => + endpoint === "out" + ? "https://api.trigger.dev" + : process.env.NEXT_PUBLIC_TRIGGER_API_URL!, +}); +``` + +See [`docs/ai-chat/patterns/trusted-edge-signals`](../../../docs/ai-chat/patterns/trusted-edge-signals.mdx) for the full design — threat model, agent-side schema, deploy considerations. diff --git a/references/ai-chat/cf-worker/package.json b/references/ai-chat/cf-worker/package.json new file mode 100644 index 00000000000..3e1f8debe99 --- /dev/null +++ b/references/ai-chat/cf-worker/package.json @@ -0,0 +1,14 @@ +{ + "name": "cf-trust-test-worker", + "version": "0.0.0", + "private": true, + "type": "module", + "scripts": { + "dev": "wrangler dev", + "deploy": "wrangler deploy" + }, + "devDependencies": { + "@cloudflare/workers-types": "4.20240909.0", + "wrangler": "3.78.0" + } +} diff --git a/references/ai-chat/cf-worker/src/index.ts b/references/ai-chat/cf-worker/src/index.ts new file mode 100644 index 00000000000..1e449153d1e --- /dev/null +++ b/references/ai-chat/cf-worker/src/index.ts @@ -0,0 +1,124 @@ +/** + * cf-trust-test proxy. Validates that a trusted edge proxy can inject a + * namespaced metadata field (`__cf`) into trigger.dev's chat session-create + * and follow-up message wire payloads — and that the trigger.dev server passes + * it through to the agent untouched. + * + * Local dev: `wrangler dev` exposes the worker on http://localhost:8787 and + * forwards to TRIGGER_API_UPSTREAM. With `wrangler dev --remote` the worker + * runs on the CF edge and `request.cf` is populated with real signals; the + * --local default leaves request.cf undefined, so we fall back to hardcoded + * trust values that prove the plumbing without depending on a real CF edge. + */ + +export interface Env { + TRIGGER_API_UPSTREAM: string; +} + +type CfTrustData = { + botScore: number; + ja4: string; + asn: number; + country: string; +}; + +function readCfTrustData(request: Request): CfTrustData { + const cf = (request as Request & { cf?: Record }).cf; + const bm = (cf?.botManagement as Record | undefined) ?? undefined; + return { + botScore: (bm?.score as number | undefined) ?? 95, + ja4: (bm?.ja4 as string | undefined) ?? "t13d1715h2_5b57614c22b0_5c2c4ed3e2d9", + asn: (cf?.asn as number | undefined) ?? 13335, + country: (cf?.country as string | undefined) ?? "US", + }; +} + +function withCors(response: Response, request: Request): Response { + const headers = new Headers(response.headers); + const origin = request.headers.get("origin") ?? "*"; + const reqHeaders = request.headers.get("access-control-request-headers"); + headers.set("Access-Control-Allow-Origin", origin); + headers.set("Vary", "Origin"); + headers.set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, PATCH, DELETE"); + if (reqHeaders) headers.set("Access-Control-Allow-Headers", reqHeaders); + headers.set("Access-Control-Expose-Headers", "*"); + headers.set("Access-Control-Allow-Credentials", "true"); + return new Response(response.body, { status: response.status, statusText: response.statusText, headers }); +} + +function handlePreflight(request: Request): Response { + return withCors(new Response(null, { status: 204 }), request); +} + +function setCfNamespace( + metadata: Record | undefined, + cf: CfTrustData +): Record { + const stripped: Record = { ...(metadata ?? {}) }; + delete stripped.__cf; + return { ...stripped, __cf: cf }; +} + +async function rewriteSessionCreateBody(body: string, cf: CfTrustData): Promise { + const parsed = JSON.parse(body) as Record; + const triggerConfig = (parsed.triggerConfig as Record | undefined) ?? {}; + const basePayload = (triggerConfig.basePayload as Record | undefined) ?? {}; + const metadata = basePayload.metadata as Record | undefined; + parsed.triggerConfig = { + ...triggerConfig, + basePayload: { ...basePayload, metadata: setCfNamespace(metadata, cf) }, + }; + return JSON.stringify(parsed); +} + +async function rewriteAppendBody(body: string, cf: CfTrustData): Promise { + let parsed: Record; + try { + parsed = JSON.parse(body) as Record; + } catch { + return body; + } + if (parsed.kind !== "message") return body; + const payload = (parsed.payload as Record | undefined) ?? {}; + const metadata = payload.metadata as Record | undefined; + parsed.payload = { ...payload, metadata: setCfNamespace(metadata, cf) }; + return JSON.stringify(parsed); +} + +export default { + async fetch(request: Request, env: Env): Promise { + if (request.method === "OPTIONS") return handlePreflight(request); + + const upstream = new URL(env.TRIGGER_API_UPSTREAM); + const incoming = new URL(request.url); + const target = new URL(incoming.pathname + incoming.search, upstream); + + const cf = readCfTrustData(request); + const isAppend = + request.method === "POST" && + /^\/realtime\/v1\/sessions\/[^/]+\/in\/append$/.test(incoming.pathname); + const isSessionsCreate = + request.method === "POST" && incoming.pathname === "/api/v1/sessions"; + + let body: BodyInit | null = null; + if (request.method !== "GET" && request.method !== "HEAD") { + const raw = await request.text(); + if (isSessionsCreate && raw) body = await rewriteSessionCreateBody(raw, cf); + else if (isAppend && raw) body = await rewriteAppendBody(raw, cf); + else body = raw; + } + + const headers = new Headers(request.headers); + headers.delete("host"); + headers.delete("content-length"); + + const upstreamResponse = await fetch(target.toString(), { + method: request.method, + headers, + body, + redirect: "manual", + }); + + return withCors(upstreamResponse, request); + }, +}; diff --git a/references/ai-chat/cf-worker/tsconfig.json b/references/ai-chat/cf-worker/tsconfig.json new file mode 100644 index 00000000000..7d45444baef --- /dev/null +++ b/references/ai-chat/cf-worker/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2022", + "module": "es2022", + "moduleResolution": "bundler", + "lib": ["es2022"], + "types": ["@cloudflare/workers-types"], + "strict": true, + "noEmit": true, + "esModuleInterop": true, + "skipLibCheck": true + }, + "include": ["src/**/*.ts"] +} diff --git a/references/ai-chat/cf-worker/wrangler.toml b/references/ai-chat/cf-worker/wrangler.toml new file mode 100644 index 00000000000..e62a10cc1bf --- /dev/null +++ b/references/ai-chat/cf-worker/wrangler.toml @@ -0,0 +1,10 @@ +name = "cf-trust-test-worker" +main = "src/index.ts" +compatibility_date = "2024-09-23" +compatibility_flags = ["nodejs_compat"] + +[vars] +TRIGGER_API_UPSTREAM = "https://api.trigger.dev" + +[dev] +port = 8787 diff --git a/references/ai-chat/src/app/actions.ts b/references/ai-chat/src/app/actions.ts index 0ef650cfc8c..4586aa3eeb3 100644 --- a/references/ai-chat/src/app/actions.ts +++ b/references/ai-chat/src/app/actions.ts @@ -8,6 +8,7 @@ import type { aiChatRaw, aiChatSession, upgradeTestAgent, + cfTrustTestAgent, } from "@/trigger/chat"; import type { ChatUiMessage } from "@/lib/chat-tools-schemas"; import { prisma } from "@/lib/prisma"; @@ -20,7 +21,8 @@ export type ChatReferenceTaskId = | "ai-chat-hydrated" | "ai-chat-raw" | "ai-chat-session" - | "upgrade-test"; + | "upgrade-test" + | "cf-trust-test"; function isChatReferenceTaskId(id: string): id is ChatReferenceTaskId { return ( @@ -28,7 +30,8 @@ function isChatReferenceTaskId(id: string): id is ChatReferenceTaskId { id === "ai-chat-hydrated" || id === "ai-chat-raw" || id === "ai-chat-session" || - id === "upgrade-test" + id === "upgrade-test" || + id === "cf-trust-test" ); } @@ -38,7 +41,8 @@ type TaskIdentifierForChat = | (typeof aiChatHydrated)["id"] | (typeof aiChatRaw)["id"] | (typeof aiChatSession)["id"] - | (typeof upgradeTestAgent)["id"]; + | (typeof upgradeTestAgent)["id"] + | (typeof cfTrustTestAgent)["id"]; /** * Server-mediated start: creates the Session row + triggers the first @@ -70,6 +74,7 @@ const startActionByTaskId: Record< "ai-chat-raw": startChatSessionFor("ai-chat-raw"), "ai-chat-session": startChatSessionFor("ai-chat-session"), "upgrade-test": startChatSessionFor("upgrade-test"), + "cf-trust-test": startChatSessionFor("cf-trust-test"), }; export async function startChatSession(input: { diff --git a/references/ai-chat/src/components/chat-sidebar.tsx b/references/ai-chat/src/components/chat-sidebar.tsx index e036eebc71a..9707b61ac36 100644 --- a/references/ai-chat/src/components/chat-sidebar.tsx +++ b/references/ai-chat/src/components/chat-sidebar.tsx @@ -118,6 +118,7 @@ export function ChatSidebar({ +