diff --git a/conn.go b/conn.go index fe1c7dd..f317eff 100644 --- a/conn.go +++ b/conn.go @@ -93,6 +93,11 @@ func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, return con } +// 返回标准库的net.Conn +func (c *Conn) NetConn() net.Conn { + return c.c +} + func (c *Conn) writeErrAndOnClose(code StatusCode, userErr error) error { defer c.Callback.OnClose(c, userErr) if err := c.WriteTimeout(opcode.Close, statusCodeToBytes(code), 2*time.Second); err != nil { diff --git a/conn_test.go b/conn_test.go index 82864f1..f3e63c9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -623,6 +623,43 @@ func Test_WriteControl(t *testing.T) { }) } +func Test_API(t *testing.T) { + t.Run("NetConn", func(t *testing.T) { + var shandler testPingPongCloseHandler + shandler.data = make(chan string, 1) + upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerCallback(&shandler)) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrade.Upgrade(w, r) + if err != nil { + t.Error(err) + } + if c.NetConn() != c.c { + t.Error("server.not equal") + } + c.StartReadLoop() + })) + + defer ts.Close() + + url := strings.ReplaceAll(ts.URL, "http", "ws") + con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { + })) + if err != nil { + t.Error(err) + } + defer con.Close() + if con.NetConn() != con.c { + t.Error("client.not equal") + } + + err = con.WriteControl(Close, bytes.Repeat([]byte{1}, 126)) + // 这里必须要报错 + if err == nil { + t.Error("not error") + } + }) +} + // 测试ping pong close control信息 func TestPingPongClose(t *testing.T) { // 写一个超过maxControlFrameSize的消息