Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting streaming API calls to callable functions #1629

Merged
merged 13 commits into from
Nov 7, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Add support for callable function to return streaming response (#1629)
93 changes: 74 additions & 19 deletions spec/common/providers/https.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,33 @@ async function runCallableTest(test: CallTest): Promise<any> {
cors: { origin: true, methods: "POST" },
...test.callableOption,
};
const callableFunctionV1 = https.onCallHandler(opts, (data, context) => {
expect(data).to.deep.equal(test.expectedData);
return test.callableFunction(data, context);
});
const callableFunctionV1 = https.onCallHandler(
opts,
(data, context) => {
expect(data).to.deep.equal(test.expectedData);
return test.callableFunction(data, context);
},
"gcfv1"
);

const responseV1 = await runHandler(callableFunctionV1, test.httpRequest);

expect(responseV1.body).to.deep.equal(test.expectedHttpResponse.body);
expect(responseV1.body).to.deep.equal(JSON.stringify(test.expectedHttpResponse.body));
expect(responseV1.headers).to.deep.equal(test.expectedHttpResponse.headers);
expect(responseV1.status).to.equal(test.expectedHttpResponse.status);

const callableFunctionV2 = https.onCallHandler(opts, (request) => {
expect(request.data).to.deep.equal(test.expectedData);
return test.callableFunction2(request);
});
const callableFunctionV2 = https.onCallHandler(
opts,
(request) => {
expect(request.data).to.deep.equal(test.expectedData);
return test.callableFunction2(request);
},
"gcfv2"
);

const responseV2 = await runHandler(callableFunctionV2, test.httpRequest);

expect(responseV2.body).to.deep.equal(test.expectedHttpResponse.body);
expect(responseV2.body).to.deep.equal(JSON.stringify(test.expectedHttpResponse.body));
expect(responseV2.headers).to.deep.equal(test.expectedHttpResponse.headers);
expect(responseV2.status).to.equal(test.expectedHttpResponse.status);
}
Expand Down Expand Up @@ -165,7 +173,7 @@ describe("onCallHandler", () => {
status: 400,
headers: expectedResponseHeaders,
body: {
error: { status: "INVALID_ARGUMENT", message: "Bad Request" },
error: { message: "Bad Request", status: "INVALID_ARGUMENT" },
},
},
});
Expand Down Expand Up @@ -203,7 +211,7 @@ describe("onCallHandler", () => {
status: 400,
headers: expectedResponseHeaders,
body: {
error: { status: "INVALID_ARGUMENT", message: "Bad Request" },
error: { message: "Bad Request", status: "INVALID_ARGUMENT" },
},
},
});
Expand All @@ -225,7 +233,7 @@ describe("onCallHandler", () => {
status: 400,
headers: expectedResponseHeaders,
body: {
error: { status: "INVALID_ARGUMENT", message: "Bad Request" },
error: { message: "Bad Request", status: "INVALID_ARGUMENT" },
},
},
});
Expand All @@ -244,7 +252,7 @@ describe("onCallHandler", () => {
expectedHttpResponse: {
status: 500,
headers: expectedResponseHeaders,
body: { error: { status: "INTERNAL", message: "INTERNAL" } },
body: { error: { message: "INTERNAL", status: "INTERNAL" } },
},
});
});
Expand All @@ -262,7 +270,7 @@ describe("onCallHandler", () => {
expectedHttpResponse: {
status: 500,
headers: expectedResponseHeaders,
body: { error: { status: "INTERNAL", message: "INTERNAL" } },
body: { error: { message: "INTERNAL", status: "INTERNAL" } },
},
});
});
Expand All @@ -280,7 +288,7 @@ describe("onCallHandler", () => {
expectedHttpResponse: {
status: 404,
headers: expectedResponseHeaders,
body: { error: { status: "NOT_FOUND", message: "i am error" } },
body: { error: { message: "i am error", status: "NOT_FOUND" } },
},
});
});
Expand Down Expand Up @@ -364,8 +372,8 @@ describe("onCallHandler", () => {
headers: expectedResponseHeaders,
body: {
error: {
status: "UNAUTHENTICATED",
message: "Unauthenticated",
status: "UNAUTHENTICATED",
},
},
},
Expand All @@ -391,8 +399,8 @@ describe("onCallHandler", () => {
headers: expectedResponseHeaders,
body: {
error: {
status: "UNAUTHENTICATED",
message: "Unauthenticated",
status: "UNAUTHENTICATED",
},
},
},
Expand Down Expand Up @@ -461,8 +469,8 @@ describe("onCallHandler", () => {
headers: expectedResponseHeaders,
body: {
error: {
status: "UNAUTHENTICATED",
message: "Unauthenticated",
status: "UNAUTHENTICATED",
},
},
},
Expand Down Expand Up @@ -748,6 +756,53 @@ describe("onCallHandler", () => {
});
});
});

describe("Streaming callables", () => {
it("returns data in SSE format for requests Accept: text/event-stream header", async () => {
const mockReq = mockRequest(
{ message: "hello streaming" },
"application/json",
{},
{ accept: "text/event-stream" }
) as any;
const fn = https.onCallHandler(
{
cors: { origin: true, methods: "POST" },
},
(req, resp) => {
resp.write("hello");
return "world";
},
"gcfv2"
);

const resp = await runHandler(fn, mockReq);
const data = [`data: {"message":"hello"}`, `data: {"result":"world"}`];
expect(resp.body).to.equal([...data, ""].join("\n"));
});

it("returns error in SSE format", async () => {
const mockReq = mockRequest(
{ message: "hello streaming" },
"application/json",
{},
{ accept: "text/event-stream" }
) as any;
const fn = https.onCallHandler(
{
cors: { origin: true, methods: "POST" },
},
() => {
throw new Error("BOOM");
},
"gcfv2"
);

const resp = await runHandler(fn, mockReq);
const data = [`data: {"error":{"message":"INTERNAL","status":"INTERNAL"}}`];
expect(resp.body).to.equal([...data, ""].join("\n"));
});
});
});

describe("encoding/decoding", () => {
Expand Down
10 changes: 9 additions & 1 deletion spec/helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export function runHandler(
// MockResponse mocks an express.Response.
// This class lives here so it can reference resolve and reject.
class MockResponse {
private sentBody = "";
private statusCode = 0;
private headers: { [name: string]: string } = {};
private callback: () => void;
Expand All @@ -65,7 +66,10 @@ export function runHandler(
return this.headers[name];
}

public send(body: any) {
public send(sendBody: any) {
const toSend = typeof sendBody === "object" ? JSON.stringify(sendBody) : sendBody;
const body = this.sentBody ? this.sentBody + ((toSend as string) || "") : toSend;

resolve({
status: this.statusCode,
headers: this.headers,
Expand All @@ -76,6 +80,10 @@ export function runHandler(
}
}

public write(writeBody: any) {
this.sentBody += typeof writeBody === "object" ? JSON.stringify(writeBody) : writeBody;
}

public end() {
this.send(undefined);
}
Expand Down
2 changes: 1 addition & 1 deletion spec/v1/providers/https.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ describe("callable CORS", () => {
const response = await runHandler(func, req as any);

expect(response.status).to.equal(200);
expect(response.body).to.be.deep.equal({ result: 42 });
expect(response.body).to.be.deep.equal(JSON.stringify({ result: 42 }));
expect(response.headers).to.deep.equal(expectedResponseHeaders);
});
});
4 changes: 2 additions & 2 deletions spec/v2/providers/https.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ describe("onCall", () => {
req.method = "POST";

const resp = await runHandler(func, req as any);
expect(resp.body).to.deep.equal({ result: 42 });
expect(resp.body).to.deep.equal(JSON.stringify({ result: 42 }));
});

it("should enforce CORS options", async () => {
Expand Down Expand Up @@ -496,7 +496,7 @@ describe("onCall", () => {
const response = await runHandler(func, req as any);

expect(response.status).to.equal(200);
expect(response.body).to.be.deep.equal({ result: 42 });
expect(response.body).to.be.deep.equal(JSON.stringify({ result: 42 }));
expect(response.headers).to.deep.equal(expectedResponseHeaders);
});

Expand Down
64 changes: 53 additions & 11 deletions src/common/providers/https.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ export interface CallableRequest<T = any> {
rawRequest: Request;
}

/**
* CallableProxyResponse exposes subset of express.Response object
* to allow writing partial, streaming responses back to the client.
*/
export interface CallableProxyResponse {
write: express.Response["write"];
acceptsStreaming: boolean;
}

/**
* The set of Firebase Functions status codes. The codes are the same at the
* ones exposed by {@link https://github.com/grpc/grpc/blob/master/doc/statuscodes.md | gRPC}.
Expand Down Expand Up @@ -673,7 +682,10 @@ async function checkAppCheckToken(
}

type v1CallableHandler = (data: any, context: CallableContext) => any | Promise<any>;
type v2CallableHandler<Req, Res> = (request: CallableRequest<Req>) => Res;
type v2CallableHandler<Req, Res> = (
request: CallableRequest<Req>,
response?: CallableProxyResponse
) => Res;

/** @internal **/
export interface CallableOptions {
Expand All @@ -685,9 +697,10 @@ export interface CallableOptions {
/** @internal */
export function onCallHandler<Req = any, Res = any>(
options: CallableOptions,
handler: v1CallableHandler | v2CallableHandler<Req, Res>
handler: v1CallableHandler | v2CallableHandler<Req, Res>,
version: "gcfv1" | "gcfv2"
): (req: Request, res: express.Response) => Promise<void> {
const wrapped = wrapOnCallHandler(options, handler);
const wrapped = wrapOnCallHandler(options, handler, version);
return (req: Request, res: express.Response) => {
return new Promise((resolve) => {
res.on("finish", resolve);
Expand All @@ -698,10 +711,15 @@ export function onCallHandler<Req = any, Res = any>(
};
}

function encodeSSE(data: unknown): string {
return `data: ${JSON.stringify(data)}\n`;
}

/** @internal */
function wrapOnCallHandler<Req = any, Res = any>(
options: CallableOptions,
handler: v1CallableHandler | v2CallableHandler<Req, Res>
handler: v1CallableHandler | v2CallableHandler<Req, Res>,
version: "gcfv1" | "gcfv2"
): (req: Request, res: express.Response) => Promise<void> {
return async (req: Request, res: express.Response): Promise<void> => {
try {
Expand All @@ -719,7 +737,7 @@ function wrapOnCallHandler<Req = any, Res = any>(
// The original monkey-patched code lived in the functionsEmulatorRuntime
// (link: https://github.com/firebase/firebase-tools/blob/accea7abda3cc9fa6bb91368e4895faf95281c60/src/emulator/functionsEmulatorRuntime.ts#L480)
// and was not compatible with how monorepos separate out packages (see https://github.com/firebase/firebase-tools/issues/5210).
if (isDebugFeatureEnabled("skipTokenVerification") && handler.length === 2) {
if (isDebugFeatureEnabled("skipTokenVerification") && version === "gcfv1") {
const authContext = context.rawRequest.header(CALLABLE_AUTH_HEADER);
if (authContext) {
logger.debug("Callable functions auth override", {
Expand Down Expand Up @@ -763,26 +781,47 @@ function wrapOnCallHandler<Req = any, Res = any>(
context.instanceIdToken = req.header("Firebase-Instance-ID-Token");
}

const acceptsStreaming = req.header("accept") === "text/event-stream";
const data: Req = decode(req.body.data);
let result: Res;
if (handler.length === 2) {
result = await handler(data, context);
if (version === "gcfv1") {
result = await (handler as v1CallableHandler)(data, context);
} else {
const arg: CallableRequest<Req> = {
...context,
data,
};
// TODO: set up optional heartbeat
const responseProxy: CallableProxyResponse = {
write(chunk): boolean {
if (acceptsStreaming) {
const formattedData = encodeSSE({ message: chunk });
return res.write(formattedData);
}
// if client doesn't accept sse-protocol, response.write() is no-op.
},
acceptsStreaming,
};
if (acceptsStreaming) {
// SSE always responds with 200
res.status(200);
}
// For some reason the type system isn't picking up that the handler
// is a one argument function.
result = await (handler as any)(arg);
result = await (handler as any)(arg, responseProxy);
}

// Encode the result as JSON to preserve types like Dates.
result = encode(result);

// If there was some result, encode it in the body.
const responseBody: HttpResponseBody = { result };
res.status(200).send(responseBody);
if (acceptsStreaming) {
res.write(encodeSSE(responseBody));
res.end();
} else {
res.status(200).send(responseBody);
}
} catch (err) {
let httpErr = err;
if (!(err instanceof HttpsError)) {
Expand All @@ -793,8 +832,11 @@ function wrapOnCallHandler<Req = any, Res = any>(

const { status } = httpErr.httpErrorCode;
const body = { error: httpErr.toJSON() };

res.status(status).send(body);
if (req.header("accept") === "text/event-stream") {
res.send(encodeSSE(body));
} else {
res.status(status).send(body);
}
}
};
}
8 changes: 4 additions & 4 deletions src/v1/providers/https.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ export function _onCallWithOptions(
handler: (data: any, context: CallableContext) => any | Promise<any>,
options: DeploymentOptions
): HttpsFunction & Runnable<any> {
// onCallHandler sniffs the function length of the passed-in callback
// and the user could have only tried to listen to data. Wrap their handler
// in another handler to avoid accidentally triggering the v2 API
// fix the length of handler to make the call to handler consistent
// in the onCallHandler
const fixedLen = (data: any, context: CallableContext) => {
return withInit(handler)(data, context);
};
Expand All @@ -115,7 +114,8 @@ export function _onCallWithOptions(
consumeAppCheckToken: options.consumeAppCheckToken,
cors: { origin: true, methods: "POST" },
},
fixedLen
fixedLen,
"gcfv1"
)
);

Expand Down
Loading
Loading