Skip to content

Commit

Permalink
vitest-pool-workers: Support AbortSignal in fetch-mock (#7032)
Browse files Browse the repository at this point in the history
* feat: Support AbortSignal in fetch-mock

* test: remove duplicate test
  • Loading branch information
Codex- authored Nov 27, 2024
1 parent c7361b1 commit a18ed4e
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .changeset/hip-penguins-kiss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@cloudflare/vitest-pool-workers": patch
---

Add support for AbortSignal to fetch-mock
100 changes: 97 additions & 3 deletions fixtures/vitest-pool-workers-examples/misc/test/fetch-mock.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import { fetchMock } from "cloudflare:test";
import { afterEach, beforeAll, expect, it } from "vitest";
import {
afterAll,
afterEach,
beforeAll,
beforeEach,
describe,
expect,
it,
vi,
} from "vitest";
import type { MockInstance } from "vitest";

beforeAll(() => fetchMock.activate());
afterEach(() => fetchMock.assertNoPendingInterceptors());
beforeEach(() => fetchMock.activate());
afterEach(() => {
fetchMock.assertNoPendingInterceptors();
fetchMock.deactivate();
});

it("falls through to global fetch() if unmatched", async () => {
fetchMock
Expand All @@ -18,3 +31,84 @@ it("falls through to global fetch() if unmatched", async () => {
expect(response.url).toEqual("https://example.com/bad");
expect(await response.text()).toBe("fallthrough:GET https://example.com/bad");
});

describe("AbortSignal", () => {
let abortSignalTimeoutMock: MockInstance;

beforeAll(() => {
// Fake Timers does not mock AbortSignal.timeout
abortSignalTimeoutMock = vi
.spyOn(AbortSignal, "timeout")
.mockImplementation((ms: number) => {
const controller = new AbortController();
setTimeout(() => {
controller.abort();
}, ms);
return controller.signal;
});
});

afterAll(() => abortSignalTimeoutMock.mockRestore());

beforeEach(() => vi.useFakeTimers());

afterEach(() => vi.useRealTimers());

it("aborts if an AbortSignal timeout is exceeded", async () => {
fetchMock
.get("https://example.com")
.intercept({ path: "/" })
.reply(200, async () => {
await new Promise((resolve) => setTimeout(resolve, 5000));
return "Delayed response";
});

const fetchPromise = fetch("https://example.com", {
signal: AbortSignal.timeout(2000),
});

vi.advanceTimersByTime(10_000);
await expect(fetchPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`[AbortError: The operation was aborted]`
);
});

it("does not abort if an AbortSignal timeout is not exceeded", async () => {
fetchMock
.get("https://example.com")
.intercept({ path: "/" })
.reply(200, async () => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return "Delayed response";
});

const fetchPromise = fetch("https://example.com", {
signal: AbortSignal.timeout(2000),
});

vi.advanceTimersByTime(1500);
const response = await fetchPromise;
expect(response.status).toStrictEqual(200);
expect(await response.text()).toMatchInlineSnapshot(`"Delayed response"`);
});

it("aborts if an AbortSignal is already aborted", async () => {
const controller = new AbortController();
controller.abort();

fetchMock
.get("https://example.com")
.intercept({ path: "/" })
.reply(200, async () => {
return "Delayed response";
});

const fetchPromise = fetch("https://example.com", {
signal: controller.signal,
});

await expect(fetchPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`[AbortError: The operation was aborted]`
);
});
});
37 changes: 36 additions & 1 deletion packages/vitest-pool-workers/src/worker/fetch-mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ import type { Dispatcher } from "undici";

const DECODER = new TextDecoder();

/**
* Mutate an Error instance so it passes either of the checks in isAbortError
*/
export function castAsAbortError(err: Error): Error {
(err as Error & { code: string }).code = "ABORT_ERR";
err.name = "AbortError";
return err;
}

// See public facing `cloudflare:test` types for docs
export const fetchMock = new MockAgent({ connections: 1 });

Expand Down Expand Up @@ -47,6 +56,13 @@ globalThis.fetch = async (input, init) => {
const request = new Request(input, init);
const url = new URL(request.url);

// Use a signal and the aborted value if provided
const abortSignal = init?.signal;
let abortSignalAborted = abortSignal?.aborted ?? false;
abortSignal?.addEventListener("abort", () => {
abortSignalAborted = true;
});

// Don't allow mocked `Upgrade` requests
if (request.headers.get("Upgrade") !== null) {
return originalFetch.call(globalThis, request);
Expand Down Expand Up @@ -101,7 +117,11 @@ globalThis.fetch = async (input, init) => {

// Dispatch the request through the mock agent
const dispatchHandlers: Dispatcher.DispatchHandlers = {
onConnect(_abort) {}, // (ignored)
onConnect(abort) {
if (abortSignalAborted) {
abort();
}
},
onError(error) {
responseReject(error);
},
Expand All @@ -110,6 +130,10 @@ globalThis.fetch = async (input, init) => {
},
// `onHeaders` and `onData` will only be called if the response was mocked
onHeaders(statusCode, headers, _resume, statusText) {
if (abortSignalAborted) {
return false;
}

responseStatusCode = statusCode;
responseStatusText = statusText;

Expand All @@ -123,10 +147,21 @@ globalThis.fetch = async (input, init) => {
return true;
},
onData(chunk) {
if (abortSignalAborted) {
return false;
}

responseChunks.push(chunk);
return true;
},
onComplete(_trailers) {
if (abortSignalAborted) {
responseReject(
castAsAbortError(new Error("The operation was aborted"))
);
return;
}

// `maybeResponse` will be `undefined` if we mocked the request
const maybeResponse = responses.get(dispatchOptions);
if (maybeResponse === undefined) {
Expand Down

0 comments on commit a18ed4e

Please sign in to comment.