diff --git a/go.mod b/go.mod index d4ce1010..359c6cc0 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/go-chi/cors v1.2.1 github.com/gofrs/uuid v4.4.0+incompatible github.com/gorilla/websocket v1.5.0 + github.com/lxzan/gws v1.6.13 github.com/mailru/easyjson v0.7.7 github.com/stretchr/testify v1.8.4 github.com/subosito/gotenv v1.6.0 @@ -20,6 +21,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/compress v1.16.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 87ac1eb7..29ce0320 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,10 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= +github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/lxzan/gws v1.6.13 h1:85UaBsL5alQOiDao0tlupZYSJCl1Yp6u+un/VHLkVfY= +github.com/lxzan/gws v1.6.13/go.mod h1:dsC6S7kJNh+iWqqu2HiO8tnNCji04HwyJCYfTOS+6iY= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/api/v1.go b/internal/api/v1.go index a9fe78ea..0353482d 100644 --- a/internal/api/v1.go +++ b/internal/api/v1.go @@ -21,6 +21,7 @@ import ( var ErrLobbyNotExistent = errors.New("the requested lobby doesn't exist") +//easyjson:skip type V1Handler struct { cfg *config.Config } diff --git a/internal/api/ws.go b/internal/api/ws.go index 9ab64a08..1c08cec9 100644 --- a/internal/api/ws.go +++ b/internal/api/ws.go @@ -4,13 +4,13 @@ import ( "errors" "fmt" "log" - "net" "net/http" "runtime/debug" + "time" "github.com/go-chi/chi/v5" "github.com/gofrs/uuid" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" "github.com/mailru/easyjson" "github.com/scribble-rs/scribble.rs/internal/game" @@ -20,12 +20,11 @@ import ( var ( ErrPlayerNotConnected = errors.New("player not connected") - upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(_ *http.Request) bool { return true }, - EnableCompression: true, - } + upgrader = gws.NewUpgrader(&socketHandler{}, &gws.ServerOption{ + ReadAsyncEnabled: true, + CompressEnabled: true, + Recovery: gws.Recovery, + }) ) func (handler *V1Handler) websocketUpgrade(writer http.ResponseWriter, request *http.Request) { @@ -57,77 +56,118 @@ func (handler *V1Handler) websocketUpgrade(writer http.ResponseWriter, request * return } - socket, err := upgrader.Upgrade(writer, request, nil) + socket, err := upgrader.Upgrade(writer, request) if err != nil { http.Error(writer, err.Error(), http.StatusInternalServerError) return } log.Printf("%s(%s) has connected\n", player.Name, player.ID) - player.SetWebsocket(socket) + socket.Session().Store("player", player) + socket.Session().Store("lobby", lobby) lobby.OnPlayerConnectUnsynchronized(player) - socket.SetCloseHandler(func(code int, text string) error { - lobby.OnPlayerDisconnect(player) - return nil - }) - - go wsListen(lobby, player, socket) + go socket.ReadLoop() }) } -func wsListen(lobby *game.Lobby, player *game.Player, socket *websocket.Conn) { - // Workaround to prevent crash, since not all kind of - // disconnect errors are cleanly caught by gorilla websockets. +const ( + pingInterval = 10 * time.Second + pingWait = 5 * time.Second +) + +type socketHandler struct{} + +func (c *socketHandler) resetDeadline(socket *gws.Conn) { + if err := socket.SetDeadline(time.Now().Add(pingInterval + pingWait)); err != nil { + log.Printf("error resetting deadline: %s\n", err) + } +} + +func (c *socketHandler) OnOpen(socket *gws.Conn) { + c.resetDeadline(socket) +} + +func (c *socketHandler) OnClose(socket *gws.Conn, err error) { + val, ok := socket.Session().Load("player") + if ok { + if player, ok := val.(*game.Player); ok { + lobby, ok := socket.Session().Load("lobby") + if ok { + if lobby, ok := lobby.(*game.Lobby); ok { + lobby.OnPlayerDisconnect(player) + } + } + + player.SetWebsocket(nil) + } + } + socket.Session().Delete("player") + socket.Session().Delete("lobby") +} + +func (c *socketHandler) OnPing(socket *gws.Conn, _ []byte) { + c.resetDeadline(socket) + _ = socket.WritePong(nil) +} + +func (c *socketHandler) OnPong(socket *gws.Conn, _ []byte) { + c.resetDeadline(socket) +} + +func (c *socketHandler) OnMessage(socket *gws.Conn, message *gws.Message) { + defer message.Close() + defer c.resetDeadline(socket) + + val, ok := socket.Session().Load("player") + if !ok { + return + } + + player, ok := val.(*game.Player) + if !ok { + return + } + + val, ok = socket.Session().Load("lobby") + if !ok { + return + } + + lobby, ok := val.(*game.Lobby) + if !ok { + return + } + + bytes := message.Bytes() + message.Close() + wsListen(lobby, player, socket, bytes) +} + +func wsListen(lobby *game.Lobby, player *game.Player, socket *gws.Conn, data []byte) { defer func() { if err := recover(); err != nil { log.Printf("Error occurred in wsListen.\n\tError: %s\n\tPlayer: %s(%s)\nStack %s\n", err, player.Name, player.ID, string(debug.Stack())) - lobby.OnPlayerDisconnect(player) + // FIXME Should this lead to a disconnect? } }() var event game.EventTypeOnly - - for { - messageType, data, err := socket.ReadMessage() + if err := easyjson.Unmarshal(data, &event); err != nil { + log.Printf("Error unmarshalling message: %s\n", err) + err := WriteObject(player, game.Event{ + Type: game.EventTypeSystemMessage, + Data: fmt.Sprintf("error parsing message, please report this issue via Github: %s!", err), + }) if err != nil { - if websocket.IsCloseError(err) || websocket.IsUnexpectedCloseError(err) { - lobby.OnPlayerDisconnect(player) - return - } - - // This way, we should catch repeated reads on closed connections - // on both linux and windows. Previously we did this by searching - // for certain text in the error message, which was neither - // cross-platform nor translation aware. - if netErr, ok := err.(*net.OpError); ok && !netErr.Temporary() { - lobby.OnPlayerDisconnect(player) - return - } - - log.Printf("Error reading from socket: %s\n", err) - // If the error doesn't seem fatal we attempt listening for more messages. - continue + log.Printf("Error sending errormessage: %s\n", err) } + return + } - if messageType == websocket.TextMessage { - if err := easyjson.Unmarshal(data, &event); err != nil { - log.Printf("Error unmarshalling message: %s\n", err) - err := WriteObject(player, game.Event{ - Type: game.EventTypeSystemMessage, - Data: fmt.Sprintf("error parsing message, please report this issue via Github: %s!", err), - }) - if err != nil { - log.Printf("Error sending errormessage: %s\n", err) - } - continue - } - - if err := lobby.HandleEvent(event.Type, data, player); err != nil { - log.Printf("Error handling event: %s\n", err) - } - } + if err := lobby.HandleEvent(event.Type, data, player); err != nil { + log.Printf("Error handling event: %s\n", err) } } @@ -145,10 +185,12 @@ func WriteObject(player *game.Player, object easyjson.Marshaler) error { return fmt.Errorf("error marshalling payload: %w", err) } - return socket.WriteMessage(websocket.TextMessage, bytes) + // We write async, as broadcast always uses the queue. If we use write, the + // order will become messed up, potentially causing issues in the frontend. + return socket.WriteAsync(gws.OpcodeText, bytes) } -func WritePreparedMessage(player *game.Player, message *websocket.PreparedMessage) error { +func WritePreparedMessage(player *game.Player, message *gws.Broadcaster) error { player.GetWebsocketMutex().Lock() defer player.GetWebsocketMutex().Unlock() @@ -157,5 +199,5 @@ func WritePreparedMessage(player *game.Player, message *websocket.PreparedMessag return ErrPlayerNotConnected } - return socket.WritePreparedMessage(message) + return message.Broadcast(socket) } diff --git a/internal/game/data.go b/internal/game/data.go index 5b9b3de9..e6b2e8cc 100644 --- a/internal/game/data.go +++ b/internal/game/data.go @@ -7,7 +7,7 @@ import ( discordemojimap "github.com/Bios-Marcel/discordemojimap/v2" "github.com/gofrs/uuid" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" easyjson "github.com/mailru/easyjson" "golang.org/x/text/cases" ) @@ -91,7 +91,7 @@ type Lobby struct { mutex *sync.Mutex WriteObject func(*Player, easyjson.Marshaler) error - WritePreparedMessage func(*Player, *websocket.PreparedMessage) error + WritePreparedMessage func(*Player, *gws.Broadcaster) error } // MaxPlayerNameLength defines how long a string can be at max when used @@ -112,12 +112,12 @@ func (player *Player) SetLastKnownAddress(address string) { // GetWebsocket simply returns the players websocket connection. This method // exists to encapsulate the websocket field and prevent accidental sending // the websocket data via the network. -func (player *Player) GetWebsocket() *websocket.Conn { +func (player *Player) GetWebsocket() *gws.Conn { return player.ws } // SetWebsocket sets the given connection as the players websocket connection. -func (player *Player) SetWebsocket(socket *websocket.Conn) { +func (player *Player) SetWebsocket(socket *gws.Conn) { player.ws = socket } diff --git a/internal/game/lobby.go b/internal/game/lobby.go index 390d1552..5c1b6ec7 100644 --- a/internal/game/lobby.go +++ b/internal/game/lobby.go @@ -12,7 +12,7 @@ import ( "time" "unicode/utf8" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" "github.com/mailru/easyjson" "github.com/scribble-rs/scribble.rs/internal/config" "github.com/scribble-rs/scribble.rs/internal/sanitize" @@ -362,32 +362,28 @@ func (lobby *Lobby) Broadcast(data easyjson.Marshaler) { return } - message, err := websocket.NewPreparedMessage(websocket.TextMessage, bytes) - if err != nil { - log.Println("error preparing message", err) - return - } - + message := gws.NewBroadcaster(gws.OpcodeText, bytes) for _, player := range lobby.GetPlayers() { lobby.WritePreparedMessage(player, message) } } func (lobby *Lobby) broadcastConditional(data easyjson.Marshaler, condition func(*Player) bool) { - bytes, err := easyjson.Marshal(data) - if err != nil { - log.Println("error marshalling broadcastConditional message", err) - return - } - - message, err := websocket.NewPreparedMessage(websocket.TextMessage, bytes) - if err != nil { - log.Println("error preparing message", err) - return - } - + var message *gws.Broadcaster for _, player := range lobby.players { if condition(player) { + if message == nil { + bytes, err := easyjson.Marshal(data) + if err != nil { + log.Println("error marshalling broadcastConditional message", err) + return + } + + // Message is created lazily, since the conditional events could + // potentially not be sent at all. The cost of the nil-check is + // much lower than the cost of creating the message. + message = gws.NewBroadcaster(gws.OpcodeText, bytes) + } lobby.WritePreparedMessage(player, message) } } @@ -453,9 +449,7 @@ func handleKickVoteEvent(lobby *Lobby, player *Player, toKickID uuid.UUID) { func kickPlayer(lobby *Lobby, playerToKick *Player, playerToKickIndex int) { // Avoiding nilpointer in case playerToKick disconnects during this event unluckily. if playerToKickSocket := playerToKick.ws; playerToKickSocket != nil { - if err := playerToKickSocket.Close(); err != nil { - log.Printf("Error disconnecting kicked player:\n\t%s\n", err) - } + playerToKickSocket.WriteClose(1000, nil) } // Since the player is already kicked, we first clean up the kicking information related to that player diff --git a/internal/game/lobby_test.go b/internal/game/lobby_test.go index b166964b..ee89a2e0 100644 --- a/internal/game/lobby_test.go +++ b/internal/game/lobby_test.go @@ -8,7 +8,7 @@ import ( "unsafe" "github.com/gofrs/uuid" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" easyjson "github.com/mailru/easyjson" "github.com/scribble-rs/scribble.rs/internal/sanitize" ) @@ -33,7 +33,7 @@ func noOpWriteObject(_ *Player, _ easyjson.Marshaler) error { return nil } -func noOpWritePreparedMessage(_ *Player, _ *websocket.PreparedMessage) error { +func noOpWritePreparedMessage(_ *Player, _ *gws.Broadcaster) error { return nil } @@ -233,8 +233,8 @@ func Test_wordSelectionEvent(t *testing.T) { return nil } - lobby.WritePreparedMessage = func(player *Player, message *websocket.PreparedMessage) error { - data := getUnexportedField(reflect.ValueOf(message).Elem().FieldByName("data")).([]byte) + lobby.WritePreparedMessage = func(player *Player, message *gws.Broadcaster) error { + data := getUnexportedField(reflect.ValueOf(message).Elem().FieldByName("payload")).([]byte) type event struct { Type string `json:"type"` Data json.RawMessage `json:"data"` diff --git a/internal/game/shared.go b/internal/game/shared.go index 9ecb1a57..1b35f720 100644 --- a/internal/game/shared.go +++ b/internal/game/shared.go @@ -6,7 +6,7 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" ) // @@ -197,7 +197,7 @@ type Ready struct { type Player struct { // userSession uniquely identifies the player. userSession uuid.UUID - ws *websocket.Conn + ws *gws.Conn socketMutex *sync.Mutex // disconnectTime is used to kick a player in case the lobby doesn't have // space for new players. The player with the oldest disconnect.Time will