From 247d1ac4a83d00cb262f286fd9a57ae45f4ae4b0 Mon Sep 17 00:00:00 2001 From: Dave Voutila Date: Fri, 1 Dec 2023 09:55:14 -0500 Subject: [PATCH] refactor to address underwrites/underreads --- client_test.c | 5 +- dws.c | 262 +++++++++++++++++++++++++++++++--------------- dws.h | 2 +- go-test/server.go | 1 + 4 files changed, 184 insertions(+), 86 deletions(-) diff --git a/client_test.c b/client_test.c index 5e3080d..5d1a860 100644 --- a/client_test.c +++ b/client_test.c @@ -98,12 +98,11 @@ main(int argc, char **argv) printf("received payload of " SSIZE_T_PARAM " bytes:\n---\n%s\n---\n", len, out); - dumb_close(&ws); - //assert(0 == dumb_close(&ws)); + assert(0 == dumb_close(&ws)); printf("sent a CLOSE frame!\n"); // Our socket should be closed now - assert(-2 == dumb_recv(&ws, buf, sizeof(buf))); + assert(-1 == dumb_recv(&ws, buf, sizeof(buf))); printf("socket looks closed!\n"); return 0; diff --git a/dws.c b/dws.c index 9a0e7e7..c9f03fd 100644 --- a/dws.c +++ b/dws.c @@ -42,6 +42,8 @@ #include "dws.h" +#define MIN(a,b) (((a)<(b))?(a):(b)) + // It's ludicrous to think we'd have a server handshake response larger #define HANDSHAKE_BUF_SIZE 1024 @@ -63,6 +65,10 @@ static int rng_initialized = 0; static const char B64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static ssize_t ws_read(struct websocket *, void *, size_t); +static ssize_t ws_read_txt(struct websocket *, void *, size_t); + + static void __attribute__((noreturn)) crap(int code, const char *fmt, ...) { @@ -163,66 +169,136 @@ dumb_mask(uint8_t mask[4]) } /* - * Safely read as much as we can into a given buffer since it may take - * multiple read(2)/tls_read(3) calls. + * Safely read at most `n` bytes into the given buffer. */ static ssize_t ws_read(struct websocket *ws, void *buf, size_t buflen) { - ssize_t ret = 0; - char *_buf = (char*) buf; + ssize_t _buflen, sz, len = 0; + char *_buf; if (buflen > INT_MAX) crap(1, "ws_read: buflen too large"); - if (ws->ctx) { - do { - ret = tls_read(ws->ctx, _buf, buflen); - } while (ret == TLS_WANT_POLLIN); - } else { - ret = read(ws->s, _buf, buflen); + _buf = (char*) buf; + _buflen = (ssize_t) buflen; + + while (_buflen > 0) { + if (ws->ctx) { + sz = tls_read(ws->ctx, _buf, _buflen); + if (sz == TLS_WANT_POLLIN) + continue; + if (sz < 0) + crap(1, "tls_read: %s", tls_error(ws->ctx)); + } else { + sz = read(ws->s, _buf, _buflen); + if (sz == -1 && errno == EAGAIN) + continue; + else if (sz == -1) + return -1; + } + + _buf += sz; + _buflen -= sz; + len += sz; } // TODO: figure out how we want to handle errors... // win32 spits out a different error than posix systems, btw. - return ret; + return len; +} + +/* + * Read up to buflen bytes into buf, looking for `\r\n` terminators. + */ +static ssize_t +ws_read_txt(struct websocket *ws, void *buf, size_t buflen) +{ + ssize_t _buflen, sz, len = 0; + char *_buf, *end = NULL; + + if (buflen > INT_MAX) + crap(1, "ws_read: buflen too large"); + + _buf = (char*) buf; + _buflen = (ssize_t) buflen; + + while (_buflen > 0) { + if (ws->ctx) { + sz = tls_read(ws->ctx, _buf, _buflen); + if (sz == TLS_WANT_POLLIN) + continue; + if (sz < 0) + crap(1, "tls_read: %s", tls_error(ws->ctx)); + } else { + sz = read(ws->s, _buf, _buflen); + if (sz == -1 && errno == EAGAIN) + continue; + else if (sz == -1) + return -1; + } + + _buf += sz; + _buflen -= sz; + len += sz; + + if (len >= 4) { + // Look for terminator pattern. + end = _buf - 4; + if (memcmp(end, "\r\n\r\n", 4) == 0) + break; + } + } + + // TODO: figure out how we want to handle errors... + // win32 spits out a different error than posix systems, btw. + + return len; } /* * Safely write the given buf up to buflen via the socket. * + * Will write the entirety of the given buffer. Does not currently use any + * poll like functionality, so will busy poll the socket! + * * XXX: for now failures to write are fatal :X */ static ssize_t -ws_write(struct websocket *ws, void *buf, size_t buflen) +ws_write(struct websocket *ws, const void *buf, size_t buflen) { - ssize_t _buflen, ret, len = 0; + ssize_t _buflen, sz, len = 0; char *_buf; if (buflen > INT_MAX) crap(1, "%s: buflen too large", __func__); + if (buflen == 0) + return 0; _buf = (char *)buf; _buflen = (ssize_t) buflen; - while (_buflen > 0) { - if (ws->ctx) { - ret = tls_write(ws->ctx, _buf, (size_t) _buflen); - if (ret == TLS_WANT_POLLOUT) + while (_buflen > 0) { + if (ws->ctx) { + sz = tls_write(ws->ctx, _buf, (size_t) _buflen); + if (sz == TLS_WANT_POLLOUT) continue; - if (ret < 0) + if (sz < 0) crap(1, "tls_write: %s", tls_error(ws->ctx)); - } else { - ret = write(ws->s, _buf, (size_t) _buflen); - if (ret < 0) - crap(1, "write: %s", strerror(errno)); - } + } else { + sz = write(ws->s, _buf, (size_t) _buflen); + printf("%s: wrote %zd bytes\n", __func__, sz); + if (sz == -1 && errno == EAGAIN) + continue; + else if (sz == -1) + return -1; + } - _buf += ret; - _buflen -= ret; - len += ret; - } + _buf += sz; + _buflen -= sz; + len += sz; + } return len; } @@ -324,7 +400,7 @@ dump_frame(uint8_t *frame, size_t len) * */ static ssize_t -dumb_frame(uint8_t *frame, uint8_t *data, size_t len) +dumb_frame(uint8_t *frame, const uint8_t *data, size_t len) { int i; ssize_t header_len; @@ -372,19 +448,23 @@ dumb_handshake(struct websocket *ws, const char *host, const char *path, { int len, ret = 0; char key[25], buf[HANDSHAKE_BUF_SIZE]; + ssize_t sz = 0; memset(key, 0, sizeof(key)); dumb_key(key); len = snprintf(buf, sizeof(buf), HANDSHAKE_TEMPLATE, - path, host, key, proto); + path, host, key, proto); if (len < 1) return -1; - ws_write(ws, buf, (size_t) len); + // Send our upgrade request. + sz = ws_write(ws, buf, len); + if (sz != len) + crap(1, "dumb_handshake"); memset(buf, 0, sizeof(buf)); - ws_read(ws, buf, sizeof(buf)); + ws_read_txt(ws, buf, sizeof(buf)); /* XXX: If we gave a crap, we'd validate the returned key per the * requirements of RFC6455 sec. 4.1, but we don't. @@ -507,23 +587,19 @@ dumb_connect_tls(struct websocket *ws, char *host, char *port, int insecure) * Returns: * the amount of bytes sent, * -1 on failure to calloc(3) a buffer for the dumb websocket frame, - * or whatever send(2) might return on error (zero or a negative value) + * or whatever ws_write might return on error (zero or a negative value) */ ssize_t -dumb_send(struct websocket *ws, void *payload, size_t len) +dumb_send(struct websocket *ws, const void *payload, size_t len) { uint8_t *frame; - uint8_t mask[4]; ssize_t frame_len, n; // We need payload size + 14 bytes minimum, but pad a little extra - frame = calloc(sizeof(uint8_t), len + 16); + frame = calloc(1, len + 16); if (frame == NULL) return -1; - memset(mask, 0, sizeof(mask)); - dumb_mask(mask); - frame_len = dumb_frame(frame, payload, len); if (frame_len < 0) crap(1, "%s: invalid frame payload length", __func__); @@ -550,51 +626,46 @@ dumb_send(struct websocket *ws, void *payload, size_t len) * * Returns: * the number of bytes received in the payload (not including frame headers), - * -1 on failure to calloc(3) memory for a receive buffer, - * -2 on failure to read(2) data, - * -3 if the frame was sent fractured (unsupported right now!) + * -1 on failure to read(2) data, */ ssize_t -dumb_recv(struct websocket *ws, void *out, size_t len) +dumb_recv(struct websocket *ws, void *buf, size_t buflen) { - uint8_t *frame; + uint8_t frame[4] = { 0 }; ssize_t payload_len; - ssize_t offset = 0, n = 0; + ssize_t n = 0; - frame = calloc(sizeof(uint8_t), len + FRAME_MAX_HEADER_SIZE + 1); - if (frame == NULL) + // Read first 2 bytes to figure out the framing details. + n = ws_read(ws, frame, 2); + if (n < 2) return -1; - // TODO: handle under reads. - n = ws_read(ws, frame, len + FRAME_MAX_HEADER_SIZE + 1); - if (n < 1) { - free(frame); - return -2; - } - // Now to validate the frame... if (!(frame[0] & 0x80)) { // XXX: We don't currently fragmentation - free(frame); - return -3; + crap(1, "dumb_recv: fragmentation unsupported"); } payload_len = frame[1] & 0x7F; - if (payload_len < 126) { - offset = 2; - } else if (payload_len == 126) { - // arrives in network byte order + if (payload_len == 126) { + // Need the next two bytes to get the actual payload size, which + // arrives in network byte order. + n = ws_read(ws, frame + 2, 2); + if (n < 2) + return -1; payload_len = frame[2] << 8; payload_len += frame[3]; - offset = 4; - } else { - free(frame); + } else if (payload_len > 126) crap(1, "%s: unsupported payload size", __func__); - } - memcpy(out, frame + offset, (size_t) payload_len); + // We can now read the the payload. + payload_len = MIN(payload_len, buflen); + if (payload_len == 0) + return payload_len; + n = ws_read(ws, buf, (size_t) payload_len); + if (n < payload_len) + return -1; - free(frame); return payload_len; } @@ -616,9 +687,9 @@ dumb_recv(struct websocket *ws, void *out, size_t len) int dumb_ping(struct websocket *ws) { - ssize_t len; + ssize_t len, payload_len; uint8_t mask[4]; - uint8_t frame[64]; + uint8_t frame[128]; memset(frame, 0, sizeof(frame)); dumb_mask(mask); @@ -631,17 +702,26 @@ dumb_ping(struct websocket *ws) memset(frame, 0, sizeof(frame)); - len = ws_read(ws, frame, sizeof(frame)); - if (len < 1) + // Read first 2 bytes. + len = ws_read(ws, frame, 2); + if (len != 2) return -2; -#ifdef DEBUG - dump_frame(frame, len); -#endif - + // We should have a PONG reply. if (frame[0] != (0x80 + PONG)) return -3; + payload_len = frame[1] & 0x7F; + if (payload_len >= 126) + crap(1, "dumb_ping: unsupported pong payload size > 125"); + + // Dump the rest of the data on the floor. + if (payload_len > 0) { + len = ws_read(ws, frame + 2, MIN(payload_len, sizeof(frame) - 2)); + if (len < 1) + return -3; + } + return 0; } @@ -669,36 +749,54 @@ dumb_ping(struct websocket *ws) * -1 on failure to send(2) the close frame, * -2 on failure to read(2) a response, * -3 on a response being invalid (i.e. not a CLOSE), - * -4 on a failure to shutdown(2) the underlying socket */ int dumb_close(struct websocket *ws) { - ssize_t frame_len, len; + ssize_t len, payload_len; uint8_t mask[4]; uint8_t frame[128]; memset(frame, 0, sizeof(frame)); dumb_mask(mask); - frame_len = init_frame(frame, CLOSE, mask, 0); + len = init_frame(frame, CLOSE, mask, 0); - len = ws_write(ws, frame, (size_t) frame_len); - if (len < frame_len) + len = ws_write(ws, frame, (size_t) len); + if (len < 1) return -1; memset(frame, 0, sizeof(frame)); // A valid RFC6455 websocket server MUST send a Close frame in response - len = ws_read(ws, frame, sizeof(frame)); - if (len < 1) + // Read first 2 bytes. + len = ws_read(ws, frame, 2); + if (len != 2) + return -2; + + if (frame[0] != (0x80 + CLOSE)) return -3; + payload_len = frame[1] & 0x7F; + if (payload_len > 126) + crap(1, "dumb_close: unsupported close payload size > 125"); + + // Dump the rest of the data on the floor. + if (payload_len > 0) { + len = ws_read(ws, frame + 2, MIN(payload_len, sizeof(frame) - 2)); + if (len < 1) + return -3; + } + + // Now close/shutdown our socket. if (ws->ctx) tls_close(ws->ctx); - if (shutdown(ws->s, HOW)) - return -4; + // Don't care if shutdown fails. Other side may have closed some things first. + shutdown(ws->s, HOW); + + ws->ctx = NULL; // XXX does this leak anything? + ws->s = -1; return 0; } diff --git a/dws.h b/dws.h index c000918..828ba2e 100644 --- a/dws.h +++ b/dws.h @@ -56,7 +56,7 @@ int dumb_connect(struct websocket *ws, char*, char*); int dumb_connect_tls(struct websocket *ws, char*, char*, int); int dumb_handshake(struct websocket *s, const char*, const char*, const char*); -ssize_t dumb_send(struct websocket *ws, void*, size_t); +ssize_t dumb_send(struct websocket *ws, const void*, size_t); ssize_t dumb_recv(struct websocket *ws, void*, size_t); int dumb_ping(struct websocket *ws); int dumb_close(struct websocket *ws); diff --git a/go-test/server.go b/go-test/server.go index ed9819b..bdf012f 100644 --- a/go-test/server.go +++ b/go-test/server.go @@ -49,6 +49,7 @@ func handler(w http.ResponseWriter, r *http.Request) { if c.WriteMessage(ws.BinaryMessage, out) != nil { log.Fatal("WriteMessage: ", err) } + log.Printf("sent: %s", out) } default: log.Fatal("dumb message type: ", msgtype)