From a18ed4ef14d3b2a23d6336091f37b35e59725cbc Mon Sep 17 00:00:00 2001 From: Alex Miller Date: Thu, 28 Nov 2024 02:46:40 +1300 Subject: [PATCH] vitest-pool-workers: Support AbortSignal in fetch-mock (#7032) * feat: Support AbortSignal in fetch-mock * test: remove duplicate test --- .changeset/hip-penguins-kiss.md | 5 + .../misc/test/fetch-mock.test.ts | 100 +++++++++++++++++- .../src/worker/fetch-mock.ts | 37 ++++++- 3 files changed, 138 insertions(+), 4 deletions(-) create mode 100644 .changeset/hip-penguins-kiss.md diff --git a/.changeset/hip-penguins-kiss.md b/.changeset/hip-penguins-kiss.md new file mode 100644 index 000000000000..a3b00ccd25cf --- /dev/null +++ b/.changeset/hip-penguins-kiss.md @@ -0,0 +1,5 @@ +--- +"@cloudflare/vitest-pool-workers": patch +--- + +Add support for AbortSignal to fetch-mock diff --git a/fixtures/vitest-pool-workers-examples/misc/test/fetch-mock.test.ts b/fixtures/vitest-pool-workers-examples/misc/test/fetch-mock.test.ts index 9f32a61485d4..9c181bcb2f50 100644 --- a/fixtures/vitest-pool-workers-examples/misc/test/fetch-mock.test.ts +++ b/fixtures/vitest-pool-workers-examples/misc/test/fetch-mock.test.ts @@ -1,8 +1,21 @@ import { fetchMock } from "cloudflare:test"; -import { afterEach, beforeAll, expect, it } from "vitest"; +import { + afterAll, + afterEach, + beforeAll, + beforeEach, + describe, + expect, + it, + vi, +} from "vitest"; +import type { MockInstance } from "vitest"; -beforeAll(() => fetchMock.activate()); -afterEach(() => fetchMock.assertNoPendingInterceptors()); +beforeEach(() => fetchMock.activate()); +afterEach(() => { + fetchMock.assertNoPendingInterceptors(); + fetchMock.deactivate(); +}); it("falls through to global fetch() if unmatched", async () => { fetchMock @@ -18,3 +31,84 @@ it("falls through to global fetch() if unmatched", async () => { expect(response.url).toEqual("https://example.com/bad"); expect(await response.text()).toBe("fallthrough:GET https://example.com/bad"); }); + +describe("AbortSignal", () => { + let abortSignalTimeoutMock: MockInstance; + + beforeAll(() => { + // Fake Timers does not mock AbortSignal.timeout + abortSignalTimeoutMock = vi + .spyOn(AbortSignal, "timeout") + .mockImplementation((ms: number) => { + const controller = new AbortController(); + setTimeout(() => { + controller.abort(); + }, ms); + return controller.signal; + }); + }); + + afterAll(() => abortSignalTimeoutMock.mockRestore()); + + beforeEach(() => vi.useFakeTimers()); + + afterEach(() => vi.useRealTimers()); + + it("aborts if an AbortSignal timeout is exceeded", async () => { + fetchMock + .get("https://example.com") + .intercept({ path: "/" }) + .reply(200, async () => { + await new Promise((resolve) => setTimeout(resolve, 5000)); + return "Delayed response"; + }); + + const fetchPromise = fetch("https://example.com", { + signal: AbortSignal.timeout(2000), + }); + + vi.advanceTimersByTime(10_000); + await expect(fetchPromise).rejects.toThrowErrorMatchingInlineSnapshot( + `[AbortError: The operation was aborted]` + ); + }); + + it("does not abort if an AbortSignal timeout is not exceeded", async () => { + fetchMock + .get("https://example.com") + .intercept({ path: "/" }) + .reply(200, async () => { + await new Promise((resolve) => setTimeout(resolve, 1000)); + return "Delayed response"; + }); + + const fetchPromise = fetch("https://example.com", { + signal: AbortSignal.timeout(2000), + }); + + vi.advanceTimersByTime(1500); + const response = await fetchPromise; + expect(response.status).toStrictEqual(200); + expect(await response.text()).toMatchInlineSnapshot(`"Delayed response"`); + }); + + it("aborts if an AbortSignal is already aborted", async () => { + const controller = new AbortController(); + controller.abort(); + + fetchMock + .get("https://example.com") + .intercept({ path: "/" }) + .reply(200, async () => { + return "Delayed response"; + }); + + const fetchPromise = fetch("https://example.com", { + signal: controller.signal, + }); + + await expect(fetchPromise).rejects.toThrowErrorMatchingInlineSnapshot( + `[AbortError: The operation was aborted]` + ); + }); +}); diff --git a/packages/vitest-pool-workers/src/worker/fetch-mock.ts b/packages/vitest-pool-workers/src/worker/fetch-mock.ts index 993523996091..e0b89b0b995a 100644 --- a/packages/vitest-pool-workers/src/worker/fetch-mock.ts +++ b/packages/vitest-pool-workers/src/worker/fetch-mock.ts @@ -5,6 +5,15 @@ import type { Dispatcher } from "undici"; const DECODER = new TextDecoder(); +/** + * Mutate an Error instance so it passes either of the checks in isAbortError + */ +export function castAsAbortError(err: Error): Error { + (err as Error & { code: string }).code = "ABORT_ERR"; + err.name = "AbortError"; + return err; +} + // See public facing `cloudflare:test` types for docs export const fetchMock = new MockAgent({ connections: 1 }); @@ -47,6 +56,13 @@ globalThis.fetch = async (input, init) => { const request = new Request(input, init); const url = new URL(request.url); + // Use a signal and the aborted value if provided + const abortSignal = init?.signal; + let abortSignalAborted = abortSignal?.aborted ?? false; + abortSignal?.addEventListener("abort", () => { + abortSignalAborted = true; + }); + // Don't allow mocked `Upgrade` requests if (request.headers.get("Upgrade") !== null) { return originalFetch.call(globalThis, request); @@ -101,7 +117,11 @@ globalThis.fetch = async (input, init) => { // Dispatch the request through the mock agent const dispatchHandlers: Dispatcher.DispatchHandlers = { - onConnect(_abort) {}, // (ignored) + onConnect(abort) { + if (abortSignalAborted) { + abort(); + } + }, onError(error) { responseReject(error); }, @@ -110,6 +130,10 @@ globalThis.fetch = async (input, init) => { }, // `onHeaders` and `onData` will only be called if the response was mocked onHeaders(statusCode, headers, _resume, statusText) { + if (abortSignalAborted) { + return false; + } + responseStatusCode = statusCode; responseStatusText = statusText; @@ -123,10 +147,21 @@ globalThis.fetch = async (input, init) => { return true; }, onData(chunk) { + if (abortSignalAborted) { + return false; + } + responseChunks.push(chunk); return true; }, onComplete(_trailers) { + if (abortSignalAborted) { + responseReject( + castAsAbortError(new Error("The operation was aborted")) + ); + return; + } + // `maybeResponse` will be `undefined` if we mocked the request const maybeResponse = responses.get(dispatchOptions); if (maybeResponse === undefined) {