Skip to content

Commit

Permalink
Fixes #14345 (#14374)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarred-Sumner authored Oct 5, 2024
1 parent 6ca68ca commit 65a6803
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/bun.js/bindings/webcore/JSWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ static inline JSC::EncodedJSValue constructJSWebSocket3(JSGlobalObject* lexicalG
int rejectUnauthorized = -1;
auto headersInit = std::optional<Converter<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>::ReturnType>();
if (JSC::JSObject* options = optionsObjectValue.getObject()) {
if (JSValue headersValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "headers"_s)))) {
const auto& builtinnames = WebCore::builtinNames(vm);
if (JSValue headersValue = options->getIfPropertyExists(globalObject, builtinnames.headersPublicName())) {
if (!headersValue.isUndefinedOrNull()) {
headersInit = convert<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>(*lexicalGlobalObject, headersValue);
RETURN_IF_EXCEPTION(throwScope, {});
Expand Down
100 changes: 98 additions & 2 deletions src/js/thirdparty/ws.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class BunWebSocket extends EventEmitter {
#paused = false;
#fragments = false;
#binaryType = "nodebuffer";
static [Symbol.toStringTag] = "WebSocket";

// Bitset to track whether event handlers are set.
#eventId = 0;
Expand All @@ -69,9 +70,104 @@ class BunWebSocket extends EventEmitter {
if (!WebSocket) {
WebSocket = $cpp("JSWebSocket.cpp", "getWebSocketConstructor");
}
let ws = (this.#ws = new WebSocket(url, protocols));

if (protocols === undefined) {
protocols = [];
} else if (!Array.isArray(protocols)) {
if (typeof protocols === "object" && protocols !== null) {
options = protocols;
protocols = [];
} else {
protocols = [protocols];
}
}

let headers;
let method = "GET";
// https://github.com/websockets/ws/blob/0d1b5e6c4acad16a6b1a1904426eb266a5ba2f72/lib/websocket.js#L741-L747
if ($isObject(options)) {
headers = options?.headers;
}

const finishRequest = options?.finishRequest;
if ($isCallable(finishRequest)) {
if (headers) {
headers = {
__proto__: null,
...headers,
};
}
let lazyRawHeaders;
let didCallEnd = false;
const nodeHttpClientRequestSimulated = {
__proto__: Object.create(EventEmitter.prototype),
setHeader: function (name, value) {
if (!headers) headers = Object.create(null);
headers[name.toLowerCase()] = value;
},
getHeader: function (name) {
return headers ? headers[name.toLowerCase()] : undefined;
},
removeHeader: function (name) {
if (headers) delete headers[name.toLowerCase()];
},
getHeaders: function () {
return { ...headers };
},
hasHeader: function (name) {
return headers ? name.toLowerCase() in headers : false;
},
headersSent: false,
method: method,
path: url,
abort: function () {
// No-op for now, as we don't have a real request to abort
},
end: () => {
if (!didCallEnd) {
didCallEnd = true;
this.#createWebSocket(url, protocols, headers, method);
}
},
write() {},
writeHead() {},
[Symbol.toStringTag]: "ClientRequest",
get rawHeaders() {
if (lazyRawHeaders === undefined) {
lazyRawHeaders = [];
for (const key in headers) {
lazyRawHeaders.push(key, headers[key]);
}
}
return lazyRawHeaders;
},
set rawHeaders(value) {
lazyRawHeaders = value;
},
rawTrailers: [],
trailers: null,
finished: false,
socket: undefined,
_header: null,
_headerSent: false,
_last: null,
};
EventEmitter.$call(nodeHttpClientRequestSimulated);
finishRequest(nodeHttpClientRequestSimulated);
if (!didCallEnd) {
this.#createWebSocket(url, protocols, headers, method);
}
return;
}

this.#createWebSocket(url, protocols, headers, method);
}

#createWebSocket(url, protocols, headers, method) {
let ws = (this.#ws = new WebSocket(url, headers ? { headers, method, protocols } : protocols));
ws.binaryType = "nodebuffer";
// TODO: options

return ws;
}

#onOrOnce(event, listener, once) {
Expand Down
41 changes: 41 additions & 0 deletions test/js/first_party/ws/ws.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,47 @@ it("close event", async () => {
wss.close();
});

// https://github.com/oven-sh/bun/issues/14345
it("WebSocket finishRequest mocked", async () => {
const { promise, resolve, reject } = Promise.withResolvers();

using server = Bun.serve({
port: 0,
websocket: {
open() {},
close() {},
message() {},
},
fetch(req, server) {
expect(req.headers.get("X-Custom-Header")).toBe("CustomValue");
expect(req.headers.get("Another-Header")).toBe("AnotherValue");
return server.upgrade(req);
},
});

const customHeaders = {
"X-Custom-Header": "CustomValue",
"Another-Header": "AnotherValue",
};

const ws = new WebSocket(server.url, [], {
finishRequest: req => {
Object.entries(customHeaders).forEach(([key, value]) => {
req.setHeader(key, value);
});
req.end();
},
});

ws.once("open", () => {
ws.send("Hello");
ws.close();
resolve();
});

await promise;
});

function test(label: string, fn: (ws: WebSocket, done: (err?: unknown) => void) => void, timeout?: number) {
it(
label,
Expand Down

0 comments on commit 65a6803

Please sign in to comment.