diff --git a/internal/persistence/message/message.go b/internal/persistence/message/message.go index 471d86d..9c039f1 100644 --- a/internal/persistence/message/message.go +++ b/internal/persistence/message/message.go @@ -127,16 +127,14 @@ func getVariableLength(l int) int { return 0 } -func ToPublish(msg *Message, version packet.Version) *packet.Publish { - pub := &packet.Publish{ - Dup: msg.Dup, - QoS: msg.QoS, - PacketId: msg.PacketId, - Retain: msg.Retained, - TopicName: []byte(msg.Topic), - Payload: msg.Payload, +func (m *Message) ToPublish(version packet.Version) *packet.Publish { + return &packet.Publish{ + Dup: m.Dup, + QoS: m.QoS, + PacketId: m.PacketId, + Retain: m.Retained, + TopicName: []byte(m.Topic), + Payload: m.Payload, Version: version, } - - return pub } diff --git a/internal/server/client.go b/internal/server/client.go index 1fdf7c2..c0c6b25 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -27,8 +27,8 @@ import ( "github.com/yunqi/lighthouse/internal/persistence/message" "github.com/yunqi/lighthouse/internal/persistence/queue" "github.com/yunqi/lighthouse/internal/persistence/subscription" + "github.com/yunqi/lighthouse/internal/persistence/unack" "github.com/yunqi/lighthouse/internal/session" - sub "github.com/yunqi/lighthouse/internal/subscription" "github.com/yunqi/lighthouse/internal/xerror" "github.com/yunqi/lighthouse/internal/xlog" "go.opentelemetry.io/otel/trace" @@ -62,7 +62,7 @@ type ( ConnectedAt() time.Time // Connection returns the raw net.Conn Connection() net.Conn - // Close closes the client connection. + // Close closes the client handleReceiveConnection. Close() error // Disconnect sends a disconnect packet to client, it is use to close v5 client. Disconnect(disconnect *packet.Disconnect) @@ -130,9 +130,11 @@ type ( wg sync.WaitGroup queueStore queue.Queue subscriptionStore subscription.Store + unackStore unack.Store limit *packetIdLimiter log *xlog.Log remoteAddr net.Addr + deliverMessage func(srcClientID string, msg *message.Message, options subscription.IterationOptions) (matched bool) } ) @@ -140,8 +142,9 @@ func (c *client) ClientOption() *ClientOption { panic("implement me") } -func (c *client) Deliver(message message.Message) error { - panic("implement me") +func (c *client) Deliver(ctx context.Context, message *message.Message) error { + c.write(ctx, message.ToPublish(c.version)) + return nil } func (c *client) ClientOptions() *ClientOption { @@ -204,6 +207,7 @@ func newClient(server *server, conn net.Conn) *client { log: xlog.LoggerModule("client"), remoteAddr: conn.RemoteAddr(), subscriptionStore: server.subscriptionStore, + deliverMessage: server.deliverMessage(), } return c } @@ -211,7 +215,7 @@ func newClient(server *server, conn net.Conn) *client { func (c *client) listen() { ctx, span := c.server.tracer.Start(context.Background(), "listen") logger := c.log.WithContext(ctx) - logger.Debug("create a new client connection", zap.Any("IP", c.remoteAddr.String())) + logger.Debug("create a new client handleReceiveConnection", zap.Any("IP", c.remoteAddr.String())) c.wg.Add(1) goroutine.Go(func() { @@ -237,9 +241,14 @@ func (c *client) listen() { // 拉取消息 c.wg.Add(1) goroutine.Go(func() { + defer c.wg.Done() c.pollMessageHandler() - c.wg.Done() }) + ////c.wg.Add(1) + //goroutine.Go(func() { + // defer c.wg.Done() + // //c.handleReceiveConnection() + //}) c.wg.Add(1) goroutine.Go(func() { @@ -291,113 +300,18 @@ func (c *client) auth(ctx context.Context) bool { logger.Debug("invalid package", zap.String("package", p.String())) _ = c.Close() - logger.Debug("close connection", zap.String("IP", c.remoteAddr.String())) + logger.Debug("close handleReceiveConnection", zap.String("IP", c.remoteAddr.String())) return false } -func (c *client) readConn() { - defer func() { - // 关闭 in 通道 - _ = c.Close() - close(c.in) - }() - go func() { - select { - case <-c.closed: - // 立即关闭 - _ = c.clientConn.SetReadDeadline(time.Now()) - return - } - }() - for { - var p packet.Packet - if c.IsConnected() { - if keepAlive := c.opt.KeepAlive; keepAlive != 0 { //KeepAlive - _ = c.clientConn.SetReadDeadline(time.Now().Add(time.Duration(keepAlive/2+keepAlive) * time.Second)) - } - } - p, err := c.packetReader.Read() - if err != nil { - if err != io.EOF && p != nil { - c.log.Error("read error", zap.String("packet_type", reflect.TypeOf(p).String())) - } - select { - case <-c.closed: - c.log.Debug("客户端退出,关闭连接") - default: - c.log.Debug("连接超时,自动关闭") - } - return - } - //if connect, ok := p.(*packet.Connect); ok { - // c.log.Debug("接收认证信息", zap.String("ClientId", string(connect.ClientId))) - //} else { - // //c.log.Debug("Rec data", zap.String("packet", p.String())) - //} - c.in <- p - - // 等待连接认证完成 - //c.waitConnection() - - } -} - -func (c *client) writeConn() { - - defer func() { - }() - for p := range c.out { - //c.log.Debug("Ret data", zap.String("packet", p.String())) - err := c.packetWriter.WritePacketAndFlush(p) - if err != nil { - return - } - } - c.log.Debug("写入操作退出") - -} func (c *client) write(ctx context.Context, packet packet.Packet) { c.log.WithContext(ctx).Debug("write packet", zap.String("packet", packet.String())) - c.out <- packet -} - -//func (c *client) waitConnection() { -// <-c.connected -//} - -func (c *client) connectionDone() { - close(c.connected) -} - -func (c *client) connection() (ok bool) { - defer func() { - c.connectionDone() - }() - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() - for { - select { - case p := <-c.in: - //c.log.Debug("从in通道中读出数据", zap.Any("packet", p)) - if p == nil { - return - } - - switch conn := p.(type) { - case *packet.Connect: - if conn == nil { - //err := xerror.ErrProtocol - break - } - return c.connectAuthentication(context.Background(), conn) - default: - } - case <-timeout.C: - return - } - + select { + case <-c.closed: + return + default: + c.out <- packet } - } // TODO 验证客户端连接 @@ -478,15 +392,15 @@ func (c *client) handleConn() { for p := range c.in { switch packetData := p.(type) { case *packet.Publish: - err = c.handlePublish(packetData) + err = c.handleReceivePublish(packetData) case *packet.Pingreq: - c.handlePingreq(packetData) + c.handleReceivePingreq(packetData) case *packet.Pubrel: - c.handlePubrel(packetData) + c.handleReceivePubrel(packetData) case *packet.Subscribe: - c.handleSubscribe(packetData) + c.handleReceiveSubscribe(packetData) case *packet.Unsubscribe: - c.handleUnsubscribe(packetData) + c.handleReceiveUnsubscribe(packetData) case *packet.Disconnect: break default: @@ -496,90 +410,12 @@ func (c *client) handleConn() { } } } + func (c *client) getTraceLog(spanName string) (context.Context, trace.Span, *zap.Logger) { ctx, span := c.server.tracer.Start(context.Background(), spanName) logger := c.log.WithContext(ctx) return ctx, span, logger } -func (c *client) handlePublish(publish *packet.Publish) *xerror.Error { - ctx, span, logger := c.getTraceLog("publish") - defer span.End() - logger.Debug("received publish packet", zap.String("packet", publish.String())) - var ackPacket packet.Packet - switch publish.QoS { - case packet.QoS1: - ackPacket = publish.CreatePuback() - case packet.QoS2: - ackPacket = publish.CreatePubrec() - } - - if ackPacket != nil { - // 返回响应 - c.write(ctx, ackPacket) - } - - return nil -} - -func (c *client) handlePingreq(pingreq *packet.Pingreq) { - ctx, span, logger := c.getTraceLog("ping request") - defer span.End() - logger.Debug("received ping request packet", zap.String("packet", pingreq.String())) - c.write(ctx, pingreq.CreatePingresp()) -} - -func (c *client) handlePubrel(pubrel *packet.Pubrel) { - ctx, span, logger := c.getTraceLog("publish release") - defer span.End() - - logger.Debug("received publish release packet", zap.String("packet", pubrel.String())) - c.write(ctx, pubrel.CreatePubcomp()) -} - -func (c *client) handleSubscribe(subscribe *packet.Subscribe) { - ctx, span, logger := c.getTraceLog("subscribe") - defer span.End() - - logger.Debug("received subscribe packet", zap.String("packet", subscribe.String())) - - var subs = make([]*sub.Subscription, 0, len(subscribe.Topics)) - - for _, topic := range subscribe.Topics { - subs = append(subs, &sub.Subscription{ - //ShareName: topic.Name, - TopicFilter: topic.Name, - //ID: subscribe.PacketId, - QoS: topic.QoS, - NoLocal: topic.NoLocal, - RetainAsPublished: topic.RetainAsPublished, - RetainHandling: topic.RetainHandling, - }) - } - subscribeResult, err := c.subscriptionStore.Subscribe(ctx, c.clientId, subs...) - if err != nil { - logger.Error("err", zap.Error(err)) - return - - } else { - logger.Info("", zap.Any("subscribeResult", subscribeResult)) - } - c.write(ctx, &packet.Suback{ - Version: subscribe.Version, - PacketId: subscribe.PacketId, - Payload: make([]code.Code, len(subscribe.Topics)), - }) -} - -func (c *client) handleUnsubscribe(unsubscribe *packet.Unsubscribe) { - ctx, span, logger := c.getTraceLog("unsubscribe") - defer span.End() - logger.Debug("received unsubscribe packet", zap.String("packet", unsubscribe.String())) - - c.write(ctx, &packet.Unsuback{ - Version: unsubscribe.Version, - PacketId: unsubscribe.PacketId, - }) -} func (c *client) pollMessageHandler() { var err error @@ -615,6 +451,7 @@ func (c *client) pollMessageHandler() { c.limit.batchRelease(ids) } } + func (c *client) pollNewMessages(ids []packet.Id) (unused []packet.Id, err error) { var elems []*queue.Element elems, err = c.queueStore.Read(context.Background(), ids) @@ -627,12 +464,13 @@ func (c *client) pollNewMessages(ids []packet.Id) (unused []packet.Id, err error if m.QoS != packet.QoS0 { ids = ids[1:] } - c.write(context.Background(), message.ToPublish(m.Message, c.version)) + c.write(context.Background(), m.Message.ToPublish(c.version)) case *queue.Pubrel: } } return ids, err } + func (c *client) pollInFlights() (bool, error) { var elems []*queue.Element elems, err := c.queueStore.ReadInflight(context.Background(), uint(c.opt.MaxInflight)) @@ -649,7 +487,7 @@ func (c *client) pollInFlights() (bool, error) { m.SubscriptionIdentifier = nil c.limit.markUsedLocked(id) - c.write(context.Background(), message.ToPublish(m.Message, c.version)) + c.write(context.Background(), m.Message.ToPublish(c.version)) case *queue.Pubrel: c.write(context.Background(), &packet.Pubrel{PacketId: id}) } @@ -660,3 +498,18 @@ func (c *client) pollInFlights() (bool, error) { func (c *client) newPacketIdLimiter(limit uint16) { c.limit = newPacketIDLimiter(limit) } + +func convertError(err error) *xerror.Error { + if err == nil { + return nil + } + if e, ok := err.(*xerror.Error); ok { + return e + } + return &xerror.Error{ + Code: code.UnspecifiedError, + ErrorDetails: xerror.ErrorDetails{ + ReasonString: []byte(err.Error()), + }, + } +} diff --git a/internal/server/client_receive.go b/internal/server/client_receive.go new file mode 100644 index 0000000..4b4d460 --- /dev/null +++ b/internal/server/client_receive.go @@ -0,0 +1,169 @@ +package server + +import ( + "github.com/yunqi/lighthouse/internal/code" + "github.com/yunqi/lighthouse/internal/packet" + "github.com/yunqi/lighthouse/internal/persistence/message" + sub "github.com/yunqi/lighthouse/internal/subscription" + "github.com/yunqi/lighthouse/internal/xerror" + "go.uber.org/zap" + "io" + "reflect" + "time" +) + +func (c *client) readConn() { + defer func() { + // 关闭 in 通道 + _ = c.Close() + close(c.in) + }() + go func() { + select { + case <-c.closed: + // 立即关闭 + _ = c.clientConn.SetReadDeadline(time.Now()) + return + } + }() + for { + var p packet.Packet + if c.IsConnected() { + if keepAlive := c.opt.KeepAlive; keepAlive != 0 { //KeepAlive + _ = c.clientConn.SetReadDeadline(time.Now().Add(time.Duration(keepAlive/2+keepAlive) * time.Second)) + } + } + p, err := c.packetReader.Read() + if err != nil { + if err != io.EOF && p != nil { + c.log.Error("read error", zap.String("packet_type", reflect.TypeOf(p).String())) + } + select { + case <-c.closed: + c.log.Debug("客户端退出,关闭连接") + default: + c.log.Debug("连接超时,自动关闭") + } + return + } + //if connect, ok := p.(*packet.Connect); ok { + // c.log.Debug("接收认证信息", zap.String("ClientId", string(connect.ClientId))) + //} else { + // //c.log.Debug("Rec data", zap.String("packet", p.String())) + //} + c.in <- p + + // 等待连接认证完成 + //c.waitConnection() + + } +} + +func (c *client) handleReceivePublish(publish *packet.Publish) *xerror.Error { + ctx, span, logger := c.getTraceLog("publish") + defer span.End() + logger.Debug("received publish packet", zap.String("packet", publish.String())) + + var ( + dup bool + ackPacket packet.Packet + ) + switch publish.QoS { + case packet.QoS1: + ackPacket = publish.CreatePuback() + case packet.QoS2: + exist, err := c.unackStore.Set(ctx, publish.PacketId) + if err != nil { + return convertError(err) + } + + if exist { + dup = true + } + ackPacket = publish.CreatePubrec() + } + + // 第一次收到数据 + if !dup { + // 分发数据 + topicName := string(publish.TopicName) + msg := message.FromPublish(publish) + + options := defaultIterateOptions(topicName) + _ = c.deliverMessage(c.clientId, msg, options) + + if publish.Retain { + if len(publish.Payload) == 0 { + c.server.retainedStore.Remove(topicName) + } else { + c.server.retainedStore.AddOrReplace(msg) + } + } + } + + if ackPacket != nil { + // 返回响应 + c.write(ctx, ackPacket) + } + + return nil +} + +func (c *client) handleReceivePingreq(pingreq *packet.Pingreq) { + ctx, span, logger := c.getTraceLog("ping request") + defer span.End() + logger.Debug("received ping request packet", zap.String("packet", pingreq.String())) + c.write(ctx, pingreq.CreatePingresp()) +} + +func (c *client) handleReceivePubrel(pubrel *packet.Pubrel) { + ctx, span, logger := c.getTraceLog("publish release") + defer span.End() + logger.Debug("received publish release packet", zap.String("packet", pubrel.String())) + c.write(ctx, pubrel.CreatePubcomp()) +} + +func (c *client) handleReceiveSubscribe(subscribe *packet.Subscribe) { + ctx, span, logger := c.getTraceLog("subscribe") + defer span.End() + + logger.Debug("received subscribe packet", zap.String("packet", subscribe.String())) + + var subs = make([]*sub.Subscription, 0, len(subscribe.Topics)) + + for _, topic := range subscribe.Topics { + subs = append(subs, &sub.Subscription{ + //ShareName: topic.Name, + TopicFilter: topic.Name, + //ID: subscribe.PacketId, + QoS: topic.QoS, + NoLocal: topic.NoLocal, + RetainAsPublished: topic.RetainAsPublished, + RetainHandling: topic.RetainHandling, + }) + } + subscribeResult, err := c.subscriptionStore.Subscribe(ctx, c.clientId, subs...) + if err != nil { + logger.Error("err", zap.Error(err)) + return + + } else { + logger.Info("", zap.Any("subscribeResult", subscribeResult)) + } + c.write(ctx, &packet.Suback{ + Version: subscribe.Version, + PacketId: subscribe.PacketId, + Payload: make([]code.Code, len(subscribe.Topics)), + }) +} + +func (c *client) handleReceiveUnsubscribe(unsubscribe *packet.Unsubscribe) { + ctx, span, logger := c.getTraceLog("unsubscribe") + defer span.End() + logger.Debug("received unsubscribe packet", zap.String("packet", unsubscribe.String())) + + c.write(ctx, &packet.Unsuback{ + Version: unsubscribe.Version, + PacketId: unsubscribe.PacketId, + }) +} diff --git a/internal/server/client_send.go b/internal/server/client_send.go new file mode 100644 index 0000000..fe7bba5 --- /dev/null +++ b/internal/server/client_send.go @@ -0,0 +1,11 @@ +package server + +func (c *client) writeConn() { + for p := range c.out { + err := c.packetWriter.WritePacketAndFlush(p) + if err != nil { + return + } + } + c.log.Debug("写入操作退出") +} diff --git a/internal/server/server.go b/internal/server/server.go index 14faae6..f8fee44 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -22,6 +22,9 @@ import ( "github.com/yunqi/lighthouse/config" "github.com/yunqi/lighthouse/internal/goroutine" "github.com/yunqi/lighthouse/internal/persistence" + "github.com/yunqi/lighthouse/internal/persistence/message" + "github.com/yunqi/lighthouse/internal/persistence/retained" + "github.com/yunqi/lighthouse/internal/persistence/retained/trie" "github.com/yunqi/lighthouse/internal/persistence/session" "github.com/yunqi/lighthouse/internal/persistence/subscription" "github.com/yunqi/lighthouse/internal/xlog" @@ -30,6 +33,7 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "net" + "sync" "time" ) @@ -52,6 +56,8 @@ type ( websocketListener *websocket.Conn sessionStore session.Store subscriptionStore subscription.Store + retainedStore retained.Store + clients sync.Map // log *xlog.Log tracer trace.Tracer } @@ -141,7 +147,6 @@ func (s *server) init(opts *Options) { if !ok { s.log.Panic("invalid session store") } - if store, err := sessionStore(&opts.persistence.Session); err != nil { s.log.Panic("session store", zap.Error(err)) } else { @@ -154,7 +159,6 @@ func (s *server) init(opts *Options) { if !ok { s.log.Panic("invalid subscriptionStore store") } - if subscriptionStore, err := subscriptionStoreFunc(&opts.persistence.Subscription); err != nil { s.log.Panic("subscriptionStore store", zap.Error(err)) } else { @@ -162,6 +166,7 @@ func (s *server) init(opts *Options) { s.log.Info("subscriptionStore store", zap.String("type", opts.persistence.Session.Type)) } + // tcp ln, err := net.Listen("tcp", s.tcpListen) if err != nil { s.log.Panic("start tcp error", zap.String("tcp", s.tcpListen), zap.Error(err)) @@ -169,4 +174,13 @@ func (s *server) init(opts *Options) { s.log.Info("start tcp", zap.String("TCP", s.tcpListen)) s.tcpListener = ln + // retain msg + s.retainedStore = trie.NewStore() +} + +func (s *server) deliverMessage() func(srcClientId string, msg *message.Message, options subscription.IterationOptions) (matched bool) { + return func(srcClientId string, msg *message.Message, options subscription.IterationOptions) (matched bool) { + + return true + } } diff --git a/internal/server/util.go b/internal/server/util.go new file mode 100644 index 0000000..6978f91 --- /dev/null +++ b/internal/server/util.go @@ -0,0 +1,11 @@ +package server + +import "github.com/yunqi/lighthouse/internal/persistence/subscription" + +func defaultIterateOptions(topicName string) subscription.IterationOptions { + return subscription.IterationOptions{ + Type: subscription.TypeAll, + TopicName: topicName, + MatchType: subscription.MatchFilter, + } +}