Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(push): push down message #30

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions internal/persistence/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,14 @@ func getVariableLength(l int) int {
return 0
}

func ToPublish(msg *Message, version packet.Version) *packet.Publish {
pub := &packet.Publish{
Dup: msg.Dup,
QoS: msg.QoS,
PacketId: msg.PacketId,
Retain: msg.Retained,
TopicName: []byte(msg.Topic),
Payload: msg.Payload,
func (m *Message) ToPublish(version packet.Version) *packet.Publish {
return &packet.Publish{
Dup: m.Dup,
QoS: m.QoS,
PacketId: m.PacketId,
Retain: m.Retained,
TopicName: []byte(m.Topic),
Payload: m.Payload,
Version: version,
}

return pub
}
239 changes: 46 additions & 193 deletions internal/server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import (
"github.com/yunqi/lighthouse/internal/persistence/message"
"github.com/yunqi/lighthouse/internal/persistence/queue"
"github.com/yunqi/lighthouse/internal/persistence/subscription"
"github.com/yunqi/lighthouse/internal/persistence/unack"
"github.com/yunqi/lighthouse/internal/session"
sub "github.com/yunqi/lighthouse/internal/subscription"
"github.com/yunqi/lighthouse/internal/xerror"
"github.com/yunqi/lighthouse/internal/xlog"
"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -62,7 +62,7 @@ type (
ConnectedAt() time.Time
// Connection returns the raw net.Conn
Connection() net.Conn
// Close closes the client connection.
// Close closes the client handleReceiveConnection.
Close() error
// Disconnect sends a disconnect packet to client, it is use to close v5 client.
Disconnect(disconnect *packet.Disconnect)
Expand Down Expand Up @@ -130,18 +130,21 @@ type (
wg sync.WaitGroup
queueStore queue.Queue
subscriptionStore subscription.Store
unackStore unack.Store
limit *packetIdLimiter
log *xlog.Log
remoteAddr net.Addr
deliverMessage func(srcClientID string, msg *message.Message, options subscription.IterationOptions) (matched bool)
}
)

func (c *client) ClientOption() *ClientOption {
panic("implement me")
}

func (c *client) Deliver(message message.Message) error {
panic("implement me")
func (c *client) Deliver(ctx context.Context, message *message.Message) error {
c.write(ctx, message.ToPublish(c.version))
return nil
}

func (c *client) ClientOptions() *ClientOption {
Expand Down Expand Up @@ -204,14 +207,15 @@ func newClient(server *server, conn net.Conn) *client {
log: xlog.LoggerModule("client"),
remoteAddr: conn.RemoteAddr(),
subscriptionStore: server.subscriptionStore,
deliverMessage: server.deliverMessage(),
}
return c
}

func (c *client) listen() {
ctx, span := c.server.tracer.Start(context.Background(), "listen")
logger := c.log.WithContext(ctx)
logger.Debug("create a new client connection", zap.Any("IP", c.remoteAddr.String()))
logger.Debug("create a new client handleReceiveConnection", zap.Any("IP", c.remoteAddr.String()))

c.wg.Add(1)
goroutine.Go(func() {
Expand All @@ -237,9 +241,14 @@ func (c *client) listen() {
// 拉取消息
c.wg.Add(1)
goroutine.Go(func() {
defer c.wg.Done()
c.pollMessageHandler()
c.wg.Done()
})
////c.wg.Add(1)
//goroutine.Go(func() {
// defer c.wg.Done()
// //c.handleReceiveConnection()
//})

c.wg.Add(1)
goroutine.Go(func() {
Expand Down Expand Up @@ -291,113 +300,18 @@ func (c *client) auth(ctx context.Context) bool {

logger.Debug("invalid package", zap.String("package", p.String()))
_ = c.Close()
logger.Debug("close connection", zap.String("IP", c.remoteAddr.String()))
logger.Debug("close handleReceiveConnection", zap.String("IP", c.remoteAddr.String()))
return false
}

func (c *client) readConn() {
defer func() {
// 关闭 in 通道
_ = c.Close()
close(c.in)
}()
go func() {
select {
case <-c.closed:
// 立即关闭
_ = c.clientConn.SetReadDeadline(time.Now())
return
}
}()
for {
var p packet.Packet
if c.IsConnected() {
if keepAlive := c.opt.KeepAlive; keepAlive != 0 { //KeepAlive
_ = c.clientConn.SetReadDeadline(time.Now().Add(time.Duration(keepAlive/2+keepAlive) * time.Second))
}
}
p, err := c.packetReader.Read()
if err != nil {
if err != io.EOF && p != nil {
c.log.Error("read error", zap.String("packet_type", reflect.TypeOf(p).String()))
}
select {
case <-c.closed:
c.log.Debug("客户端退出,关闭连接")
default:
c.log.Debug("连接超时,自动关闭")
}
return
}
//if connect, ok := p.(*packet.Connect); ok {
// c.log.Debug("接收认证信息", zap.String("ClientId", string(connect.ClientId)))
//} else {
// //c.log.Debug("Rec data", zap.String("packet", p.String()))
//}
c.in <- p

// 等待连接认证完成
//c.waitConnection()

}
}

func (c *client) writeConn() {

defer func() {
}()
for p := range c.out {
//c.log.Debug("Ret data", zap.String("packet", p.String()))
err := c.packetWriter.WritePacketAndFlush(p)
if err != nil {
return
}
}
c.log.Debug("写入操作退出")

}
func (c *client) write(ctx context.Context, packet packet.Packet) {
c.log.WithContext(ctx).Debug("write packet", zap.String("packet", packet.String()))
c.out <- packet
}

//func (c *client) waitConnection() {
// <-c.connected
//}

func (c *client) connectionDone() {
close(c.connected)
}

func (c *client) connection() (ok bool) {
defer func() {
c.connectionDone()
}()
timeout := time.NewTimer(5 * time.Second)
defer timeout.Stop()
for {
select {
case p := <-c.in:
//c.log.Debug("从in通道中读出数据", zap.Any("packet", p))
if p == nil {
return
}

switch conn := p.(type) {
case *packet.Connect:
if conn == nil {
//err := xerror.ErrProtocol
break
}
return c.connectAuthentication(context.Background(), conn)
default:
}
case <-timeout.C:
return
}

select {
case <-c.closed:
return
default:
c.out <- packet
}

}

// TODO 验证客户端连接
Expand Down Expand Up @@ -478,15 +392,15 @@ func (c *client) handleConn() {
for p := range c.in {
switch packetData := p.(type) {
case *packet.Publish:
err = c.handlePublish(packetData)
err = c.handleReceivePublish(packetData)
case *packet.Pingreq:
c.handlePingreq(packetData)
c.handleReceivePingreq(packetData)
case *packet.Pubrel:
c.handlePubrel(packetData)
c.handleReceivePubrel(packetData)
case *packet.Subscribe:
c.handleSubscribe(packetData)
c.handleReceiveSubscribe(packetData)
case *packet.Unsubscribe:
c.handleUnsubscribe(packetData)
c.handleReceiveUnsubscribe(packetData)
case *packet.Disconnect:
break
default:
Expand All @@ -496,90 +410,12 @@ func (c *client) handleConn() {
}
}
}

func (c *client) getTraceLog(spanName string) (context.Context, trace.Span, *zap.Logger) {
ctx, span := c.server.tracer.Start(context.Background(), spanName)
logger := c.log.WithContext(ctx)
return ctx, span, logger
}
func (c *client) handlePublish(publish *packet.Publish) *xerror.Error {
ctx, span, logger := c.getTraceLog("publish")
defer span.End()
logger.Debug("received publish packet", zap.String("packet", publish.String()))
var ackPacket packet.Packet
switch publish.QoS {
case packet.QoS1:
ackPacket = publish.CreatePuback()
case packet.QoS2:
ackPacket = publish.CreatePubrec()
}

if ackPacket != nil {
// 返回响应
c.write(ctx, ackPacket)
}

return nil
}

func (c *client) handlePingreq(pingreq *packet.Pingreq) {
ctx, span, logger := c.getTraceLog("ping request")
defer span.End()
logger.Debug("received ping request packet", zap.String("packet", pingreq.String()))
c.write(ctx, pingreq.CreatePingresp())
}

func (c *client) handlePubrel(pubrel *packet.Pubrel) {
ctx, span, logger := c.getTraceLog("publish release")
defer span.End()

logger.Debug("received publish release packet", zap.String("packet", pubrel.String()))
c.write(ctx, pubrel.CreatePubcomp())
}

func (c *client) handleSubscribe(subscribe *packet.Subscribe) {
ctx, span, logger := c.getTraceLog("subscribe")
defer span.End()

logger.Debug("received subscribe packet", zap.String("packet", subscribe.String()))

var subs = make([]*sub.Subscription, 0, len(subscribe.Topics))

for _, topic := range subscribe.Topics {
subs = append(subs, &sub.Subscription{
//ShareName: topic.Name,
TopicFilter: topic.Name,
//ID: subscribe.PacketId,
QoS: topic.QoS,
NoLocal: topic.NoLocal,
RetainAsPublished: topic.RetainAsPublished,
RetainHandling: topic.RetainHandling,
})
}
subscribeResult, err := c.subscriptionStore.Subscribe(ctx, c.clientId, subs...)
if err != nil {
logger.Error("err", zap.Error(err))
return

} else {
logger.Info("", zap.Any("subscribeResult", subscribeResult))
}
c.write(ctx, &packet.Suback{
Version: subscribe.Version,
PacketId: subscribe.PacketId,
Payload: make([]code.Code, len(subscribe.Topics)),
})
}

func (c *client) handleUnsubscribe(unsubscribe *packet.Unsubscribe) {
ctx, span, logger := c.getTraceLog("unsubscribe")
defer span.End()
logger.Debug("received unsubscribe packet", zap.String("packet", unsubscribe.String()))

c.write(ctx, &packet.Unsuback{
Version: unsubscribe.Version,
PacketId: unsubscribe.PacketId,
})
}

func (c *client) pollMessageHandler() {
var err error
Expand Down Expand Up @@ -615,6 +451,7 @@ func (c *client) pollMessageHandler() {
c.limit.batchRelease(ids)
}
}

func (c *client) pollNewMessages(ids []packet.Id) (unused []packet.Id, err error) {
var elems []*queue.Element
elems, err = c.queueStore.Read(context.Background(), ids)
Expand All @@ -627,12 +464,13 @@ func (c *client) pollNewMessages(ids []packet.Id) (unused []packet.Id, err error
if m.QoS != packet.QoS0 {
ids = ids[1:]
}
c.write(context.Background(), message.ToPublish(m.Message, c.version))
c.write(context.Background(), m.Message.ToPublish(c.version))
case *queue.Pubrel:
}
}
return ids, err
}

func (c *client) pollInFlights() (bool, error) {
var elems []*queue.Element
elems, err := c.queueStore.ReadInflight(context.Background(), uint(c.opt.MaxInflight))
Expand All @@ -649,7 +487,7 @@ func (c *client) pollInFlights() (bool, error) {

m.SubscriptionIdentifier = nil
c.limit.markUsedLocked(id)
c.write(context.Background(), message.ToPublish(m.Message, c.version))
c.write(context.Background(), m.Message.ToPublish(c.version))
case *queue.Pubrel:
c.write(context.Background(), &packet.Pubrel{PacketId: id})
}
Expand All @@ -660,3 +498,18 @@ func (c *client) pollInFlights() (bool, error) {
func (c *client) newPacketIdLimiter(limit uint16) {
c.limit = newPacketIDLimiter(limit)
}

func convertError(err error) *xerror.Error {
if err == nil {
return nil
}
if e, ok := err.(*xerror.Error); ok {
return e
}
return &xerror.Error{
Code: code.UnspecifiedError,
ErrorDetails: xerror.ErrorDetails{
ReasonString: []byte(err.Error()),
},
}
}
Loading