From 45656ee206a37b0180c952a7787da6abb61b7805 Mon Sep 17 00:00:00 2001 From: unkernet Date: Wed, 5 Jul 2023 16:41:15 +0700 Subject: [PATCH] UDP support --- bin/wstunnel.js | 18 +++--- lib/WstClient.js | 94 ++++++++++++++++++++++++------- lib/WstServer.js | 128 ++++++++++++++++++++++++++----------------- lib/bindStream.js | 39 ------------- lib/bindUdpStream.js | 15 +++++ package.json | 2 +- readme.md | 13 +++++ test/test.js | 92 ++++++++++++++++++++++++++++--- 8 files changed, 270 insertions(+), 131 deletions(-) create mode 100644 lib/bindUdpStream.js diff --git a/bin/wstunnel.js b/bin/wstunnel.js index b8ceb08..82e1c2e 100755 --- a/bin/wstunnel.js +++ b/bin/wstunnel.js @@ -33,6 +33,8 @@ module.exports = (Server, Client) => { .string('s') .string('t') .string('p') + .boolean('u') + .alias('u', 'udp') .alias('p', 'proxy') .alias('t', 'tunnel') .boolean('c') @@ -52,11 +54,12 @@ module.exports = (Server, Client) => { .describe('c', 'accept any certificates') .describe('http', 'force to use http tunnel').argv; + const proto = argv.u ? 'udp' : 'tcp'; if (argv.s) { let server; if (argv.t) { let [host, port] = argv.t.split(':'); - server = new Server(host, port); + server = new Server({ host, port, proto }); } else { server = new Server(); } @@ -123,11 +126,7 @@ module.exports = (Server, Client) => { remoteAddr = `${toks[2]}:${toks[3]}`; } else if (toks.length === 3) { remoteAddr = `${toks[1]}:${toks[2]}`; - if (toks[0] === 'stdio') { - localHost = toks[0]; - } else { - localPort = toks[0]; - } + localPort = toks[0]; } else if (toks.length === 1) { remoteAddr = ''; localPort = toks[0]; @@ -136,11 +135,10 @@ module.exports = (Server, Client) => { console.log(optimist.help()); process.exit(1); } - localPort = parseInt(localPort); - if (localHost === 'stdio') { - client.startStdio(wsHostUrl, remoteAddr, { 'x-wstclient': machineId }); + if (localPort === 'stdio') { + client.startStdio({ wsHostUrl, remoteAddr, proto }, { 'x-wstclient': machineId }); } else { - client.start(localHost, localPort, wsHostUrl, remoteAddr, { + client.start({ localHost, localPort: parseInt(localPort), wsHostUrl, remoteAddr, proto }, { 'x-wstclient': machineId, }); } diff --git a/lib/WstClient.js b/lib/WstClient.js index a7927a0..00e00c4 100644 --- a/lib/WstClient.js +++ b/lib/WstClient.js @@ -1,9 +1,11 @@ const net = require('net'); +const dgram = require('dgram'); const WsStream = require('./WsStream'); const url = require('url'); const log = require('lawg'); const ClientConn = require('./httptunnel/ClientConn'); const bindStream = require('./bindStream'); +const bindUdpStream = require('./bindUdpStream'); const createWsClient = () => new (require('websocket').client)(); module.exports = wst_client = class wst_client extends require('events') @@ -18,6 +20,7 @@ module.exports = wst_client = class wst_client extends require('events') constructor() { super(); this.tcpServer = net.createServer(); + this.udpServer = dgram.createSocket('udp4'); } verbose() { @@ -40,32 +43,26 @@ module.exports = wst_client = class wst_client extends require('events') setHttpOnly(httpOnly) { this.httpOnly = httpOnly; } - // example: start("localhost", 8081, "wss://ws.domain.com:454", "dst.domain.com:22") + // example: start({ + // localHost: "localhost", + // localPort: 8081, + // wsHostUrl: "wss://ws.domain.com:454", + // remoteAddr: "dst.domain.com:22", + // proto: "tcp" + // }); // meaning: tunnel localhost:8081 to remoteAddr by using websocket connection to wsHost // @wsHostUrl: ws:// denotes standard socket, wss:// denotes ssl socket // may be changed at any time to change websocket server info - start(localHost, localPort, wsHostUrl, remoteAddr, optionalHeaders, cb) { + start({ localHost, localPort, wsHostUrl, remoteAddr, proto }, optionalHeaders, cb) { this.wsHostUrl = wsHostUrl; - - this.tcpServer.listen(localPort, localHost, cb); - this.tcpServer.on('connection', (tcpConn) => { - const bind = (tcp, s) => { - bindStream(tcp, s); - this.emit('tunnel', tcp, s); - }; - this._connect( - this.wsHostUrl, - remoteAddr, - optionalHeaders, - (err, stream) => { - if (err) this.emit('connectFailed', err); - else bind(tcpConn, stream); - } - ); - }); + if (proto === 'udp') { + this.listenUdp(localPort, localHost, remoteAddr, optionalHeaders, cb); + } else { + this.listenTcp(localPort, localHost, remoteAddr, optionalHeaders, cb); + } } - startStdio(wsHostUrl, remoteAddr, optionalHeaders, cb) { + startStdio({ wsHostUrl, remoteAddr, proto }, optionalHeaders, cb) { this.wsHostUrl = wsHostUrl; const bind = (s) => { process.stdin.pipe(s); @@ -76,6 +73,7 @@ module.exports = wst_client = class wst_client extends require('events') this._connect( this.wsHostUrl, remoteAddr, + proto, optionalHeaders, (err, stream) => { if (err) this.emit('connectFailed', err); @@ -85,7 +83,61 @@ module.exports = wst_client = class wst_client extends require('events') ); } - _connect(wsHostUrl, remoteAddr, optionalHeaders, cb) { + listenUdp(localPort, localHost, remoteAddr, optionalHeaders, cb) { + const udpServer = dgram.createSocket('udp4'); + this.connections = new Set(); + udpServer.bind(localPort, localHost, cb); + udpServer.on('message', (data, rinfo) => { + const id = `${rinfo.address}:${rinfo.port}`; + if (!this.connections.has(id)) { + this.connections.add(id); + this._connect( + this.wsHostUrl, + remoteAddr, + 'udp', + optionalHeaders, + (err, stream) => { + if (err) { + this.emit('connectFailed', err); + this.connections.delete(id); + } else { + bindUdpStream(stream, udpServer, rinfo.address, rinfo.port, () => { + this.connections.delete(id); + }); + stream.write(data); + this.emit('tunnel', udpServer, stream); + } + } + ); + } + }); + } + + listenTcp(localPort, localHost, remoteAddr, optionalHeaders, cb) { + const tcpServer = net.createServer(); + tcpServer.listen(localPort, localHost, cb); + tcpServer.on('connection', (tcpConn) => { + this._connect( + this.wsHostUrl, + remoteAddr, + 'tcp', + optionalHeaders, + (err, stream) => { + if (err) { + this.emit('connectFailed', err); + } else { + bindStream(tcpConn, stream); + this.emit('tunnel', tcpConn, stream); + } + } + ); + }); + } + + _connect(wsHostUrl, remoteAddr, proto, optionalHeaders, cb) { + if (remoteAddr && proto) { + remoteAddr = `${remoteAddr}:${proto}`; + } if (this.httpOnly) { return this._httpConnect(wsHostUrl, remoteAddr, optionalHeaders, cb); } else { diff --git a/lib/WstServer.js b/lib/WstServer.js index 1009259..bfdcd88 100644 --- a/lib/WstServer.js +++ b/lib/WstServer.js @@ -2,19 +2,23 @@ const WebSocketServer = require('websocket').server; const http = require('http'); const url = require('url'); const net = require('net'); +const dgram = require('dgram'); const WsStream = require('./WsStream'); const log = require('lawg'); const HttpTunnelServer = require('./httptunnel/Server'); const HttpTunnelReq = require('./httptunnel/ConnRequest'); const ChainedWebApps = require('./ChainedWebApps'); - +const bindStream = require('./bindStream'); +const bindUdpStream = require('./bindUdpStream'); +const httpReqRemoteIp = require('./httpReqRemoteIp'); module.exports = wst_server = class wst_server { // if dstHost, dstPort are specified here, then all tunnel end points are at dstHost:dstPort, regardless what // client requests, for security option // webapp: customize webapp if any, you may use express app - constructor(dstHost, dstPort, webapp) { - this.dstHost = dstHost; - this.dstPort = dstPort; + constructor({ host, port, proto, webapp } = {}) { + this.dstHost = host; + this.dstPort = port; + this.dstProto = proto; this.httpServer = http.createServer(); this.wsServer = new WebSocketServer({ httpServer: this.httpServer, @@ -30,9 +34,61 @@ module.exports = wst_server = class wst_server { apps.bindToHttpServer(this.httpServer); } + accept(request, remote, connWrapperCb) { + let wsConn; + const ip = httpReqRemoteIp(request.httpRequest); + try { + wsConn = request.accept('tunnel-protocol', request.origin); + log( + `Client ${ip} established ${ + request instanceof HttpTunnelReq ? 'http' : 'ws' + } tunnel to ${remote}` + ); + } catch (e) { + log(`Client ${ip} rejected due to ${e.toString()}`); + return; + } + if (connWrapperCb) { + wsConn = connWrapperCb(wsConn); + } + return wsConn; + } + + connectUdp(request, connWrapperCb, host, port) { + const socket = dgram.createSocket('udp4'); + socket.bind(() => { + socket.removeAllListeners('error'); + const wsConn = this.accept(request, `${host}:${port}:udp`, connWrapperCb); + if (wsConn) { + bindUdpStream(wsConn, socket, host, port, () => { + socket.close(); + }); + } + }); + socket.on('error', (err) => + request.reject(500, JSON.stringify(`Tunnel connect error to ${host}:${port}:udp: ` + err)) + ); + } + + connectTcp(request, connWrapperCb, host, port) { + const tcpConn = net.connect( + { host, port, allowHalfOpen: false }, + () => { + tcpConn.removeAllListeners('error'); + const wsConn = this.accept(request, `${host}:${port}`, connWrapperCb); + if (wsConn) { + bindStream(wsConn, tcpConn); + } + } + ); + tcpConn.on('error', (err) => + request.reject(500, JSON.stringify(`Tunnel connect error to ${host}:${port}:tcp: ` + err)) + ); + } + // localAddr: [addr:]port, the local address to listen at, i.e. localhost:8888, 8888, 0.0.0.0:8888 start(localAddr, cb) { - const [localHost, localPort] = Array.from(this._parseAddr(localAddr)); + const [localHost, localPort] = this._parseAddr(localAddr); return this.httpServer.listen(localPort, localHost, (err) => { if (cb) { cb(err); @@ -42,47 +98,16 @@ module.exports = wst_server = class wst_server { const { httpRequest } = request; return this.authenticate( httpRequest, - (rejectReason, target, monitor) => { + (rejectReason, target) => { if (rejectReason) { return request.reject(500, JSON.stringify(rejectReason)); } - const { host, port } = target; - var tcpConn = net.connect( - { host, port, allowHalfOpen: false }, - () => { - tcpConn.removeAllListeners('error'); - const ip = require('./httpReqRemoteIp')(httpRequest); - let wsConn = null; - try { - wsConn = request.accept('tunnel-protocol', request.origin); - log( - `Client ${ip} established ${ - request instanceof HttpTunnelReq ? 'http' : 'ws' - } tunnel to ${host}:${port}` - ); - } catch (e) { - log(`Client ${ip} rejected due to ${e.toString()}`); - tcpConn.end(); - return; - } - if (connWrapperCb) { - wsConn = connWrapperCb(wsConn); - } - require('./bindStream')(wsConn, tcpConn); - if (monitor) { - return monitor.bind(wsConn, tcpConn); - } - } - ); - - tcpConn.on('error', (err) => - request.reject( - 500, - JSON.stringify( - `Tunnel connect error to ${host}:${port}: ` + err - ) - ) - ); + const { host, port, proto } = target; + if (proto === 'udp') { + this.connectUdp(request, connWrapperCb, host, port); + } else { + this.connectTcp(request, connWrapperCb, host, port); + } } ); }; @@ -101,20 +126,21 @@ module.exports = wst_server = class wst_server { }); } - // authCb(rejectReason, {host, port}, monitor) + // authCb(rejectReason, {host, port}) authenticate(httpRequest, authCb) { - let host, port; + let host, port, proto; if (this.dstHost && this.dstPort) { - [host, port] = Array.from([this.dstHost, this.dstPort]); + [host, port, proto] = [this.dstHost, this.dstPort, this.dstProto]; } else { const dst = this.parseUrlDst(httpRequest.url); if (!dst) { return authCb('Unable to determine tunnel target'); } else { - ({ host, port } = dst); + ({ host, port, proto } = dst); } } - return authCb(null, { host, port }); // allow by default + port = parseInt(port); + return authCb(null, { host, port, proto }); // allow by default } // returns {host, port} or undefined @@ -123,8 +149,8 @@ module.exports = wst_server = class wst_server { if (!uri.query.dst) { return undefined; } else { - const [host, port] = Array.from(uri.query.dst.split(':')); - return { host, port }; + const [host, port, proto] = uri.query.dst.split(':'); + return { host, port, proto }; } } @@ -134,7 +160,7 @@ module.exports = wst_server = class wst_server { if (typeof localAddr === 'number') { localPort = localAddr; } else { - [localHost, localPort] = Array.from(localAddr.split(':')); + [localHost, localPort] = localAddr.split(':'); if (/^\d+$/.test(localHost)) { localPort = localHost; localHost = null; diff --git a/lib/bindStream.js b/lib/bindStream.js index fb735c7..c755482 100644 --- a/lib/bindStream.js +++ b/lib/bindStream.js @@ -40,44 +40,6 @@ module.exports = function (s1, s2) { return s1.stop(); }); - const manualPipe = function () { - s1.on('data', function (data) { - if (!s2._stop) { - return s2.write(data); - } - }); - s2.on('data', function (data) { - if (!s1._stop) { - return s1.write(data); - } - }); - - s1.on('finish', function () { - dlog(s1, 'finish'); - return s2.stop(); - }); - s1.on('end', function () { - dlog(s1, 'end'); - return s2.stop(); - }); - s1.on('close', function () { - dlog(s1, 'close'); - return s2.stop(); - }); - s2.on('finish', function () { - dlog(s2, 'finish'); - return s1.stop(); - }); - s2.on('end', function () { - dlog(s2, 'end'); - return s1.stop(); - }); - return s2.on('close', function () { - dlog(s2, 'close'); - return s1.stop(); - }); - }; - const autoPipe = function () { s1.on('close', function () { dlog(s1, 'close'); @@ -91,7 +53,6 @@ module.exports = function (s1, s2) { return s1.pipe(s2, { end }).pipe(s1, { end }); }; - // manualPipe() autoPipe(); class SpeedMeter { diff --git a/lib/bindUdpStream.js b/lib/bindUdpStream.js new file mode 100644 index 0000000..4ee3c7e --- /dev/null +++ b/lib/bindUdpStream.js @@ -0,0 +1,15 @@ +module.exports = function (ws, socket, host, port, onClose) { + const onMessage = (data, rinfo) => { + if (rinfo.address === host && rinfo.port === port) { + ws.write(data); + } + }; + socket.on('message', onMessage); + ws.on('data', (data) => { + socket.send(data, port, host); + }); + ws.on('close', () => { + socket.off('message', onMessage); + onClose(); + }); +} diff --git a/package.json b/package.json index e67828c..b23c0ea 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "wstunnel", - "version": "1.4.0", + "version": "1.5.0", "description": "tunnel over websocket", "main": "./lib/wst.js", "scripts": { diff --git a/readme.md b/readme.md index 0189e7e..8f2daf0 100755 --- a/readme.md +++ b/readme.md @@ -30,6 +30,19 @@ To tell client to connect via http proxy, do: wstunnel -t 33:2.2.2.2:33 -p http://[user:pass@]proxyhost:proxyport wss://server:443 +To pass the UDP traffic instead of TCP use '-u' option for the client or for the server: + + Client: + wstunnel -u -t 33:2.2.2.2:33 ws://host:8080 + + or + + Client: + wstunnel -u -t 33 ws://server:8080 + + Server: + wstunnel -s 0.0.0.0:8080 -u -t 2.2.2.2:33 + For dev/test purpose, client can set '-c' option to disable ssl certificate check. This also makes you vulnerable to MITM attack, so use with caution. diff --git a/test/test.js b/test/test.js index d2a5350..a177f47 100644 --- a/test/test.js +++ b/test/test.js @@ -2,6 +2,7 @@ const { spawn } = require('child_process'); const path = require('path'); const wst = require('../lib/wst'); const net = require('net'); +const dgram = require('dgram'); const log = require('lawg'); const assert = require('assert'); @@ -31,11 +32,12 @@ describe('wstunnel', () => { // setup ws server server.start(config.ws_port, function (err) { if (err) done(err); - return client.start( - 'localhost', - config.s_port, - `ws://localhost:${config.ws_port}`, - `localhost:${config.t_port}`, + return client.start({ + localHost: 'localhost', + localPort: config.s_port, + wsHostUrl: `ws://localhost:${config.ws_port}`, + remoteAddr: `localhost:${config.t_port}`, + }, {}, function (err) { if (err) done(err); @@ -69,14 +71,15 @@ describe('wstunnel', () => { }); }); - const makeBuf = function (size) { + const makeBuf = function (size, seed) { + seed = seed || 0; const b = Buffer.alloc(size); for ( let i = 0, end = size / 4, asc = end >= 0; asc ? i < end : i > end; asc ? i++ : i-- ) { - b.writeInt32LE(i + 1, i * 4); + b.writeInt32LE(i + 1 + seed, i * 4); } return b; }; @@ -121,11 +124,11 @@ describe('wstunnel', () => { it('test echo stream via http tunnel', function (done) { const { authenticate } = server; server.authenticate = (httpRequest, authCb) => - authenticate.call(server, httpRequest, function (err, { host, port }) { + authenticate.call(server, httpRequest, function (err, { host, port, proto }) { if (!('x-htundir' in httpRequest.headers)) { return authCb('reject websocket intentionally'); } else { - return authCb(err, { host, port }); + return authCb(err, { host, port, proto }); } }); @@ -138,6 +141,77 @@ describe('wstunnel', () => { }); }); + it('setup udp echo server', function (done) { + echo_server.close(); + echo_server = dgram.createSocket('udp4'); + echo_server.on('message', (data, rinfo) => { + echo_server.send(data, rinfo.port, rinfo.address); + }); + echo_server.bind(config.t_port, () => done()); + }); + + it('setup udp tunnel', (done) => { + const client = new wst.client(); + client.start({ + localHost: '127.0.0.1', // only ipv4 + localPort: config.s_port, + wsHostUrl: `ws://localhost:${config.ws_port}`, + remoteAddr: `127.0.0.1:${config.t_port}`, + proto: 'udp', + }, + {}, + function (err) { + if (err) done(err); + done(); + }); + }); + + function recvUdpEcho(host, port, sendData, doneCb) { + const size = sendData.length; + const rb = Buffer.alloc(size); + let rbi = 0; + const chunk = 1024; + const socket = dgram.createSocket('udp4'); + socket.on('message', (data, rinfo) => { + assert.equal(host, rinfo.address); + assert.equal(port, rinfo.port); + data.copy(rb, rbi); + rbi += data.length; + if (rbi >= size) { + assert.equal(isSameBuf(rb, sendData), true); + doneCb(); + } else { + socket.send(sendData, rbi, Math.min(chunk, size - rbi), port, host); + } + }); + socket.send(sendData, 0, Math.min(chunk, size), port, host); + } + + it('test udp echo', (done) => { + const data = makeBuf(987648); + recvUdpEcho('127.0.0.1', config.t_port, data, done); + }); + + it('test udp tunnel', (done) => { + const data = makeBuf(987648); + recvUdpEcho('127.0.0.1', config.s_port, data, done); + }); + + it('test multiple udp tunnels', (done) => { + Promise.all([ + new Promise((resolve) => { + const data = makeBuf(987648, 1); + recvUdpEcho('127.0.0.1', config.s_port, data, resolve); + }), + new Promise((resolve) => { + const data = makeBuf(987648, 2); + recvUdpEcho('127.0.0.1', config.s_port, data, resolve); + }), + ]) + .then(() => done()) + .catch(e => done(e)); + }); + it('test end', function (done) { done(); return setTimeout(() => process.exit(0), 100);