Skip to content

Commit

Permalink
removing dht code from communication constructor add custom peer disc…
Browse files Browse the repository at this point in the history
…overy
  • Loading branch information
brewmaster012 committed Oct 25, 2024
1 parent f8b548c commit 7d2b4f3
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 39 deletions.
12 changes: 12 additions & 0 deletions conversion/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,15 @@ func GetEDDSAPrivateKeyRawBytes(privateKey crypto2.PrivKey) ([]byte, error) {
copy(keyBytesArray[:], pk[:])
return keyBytesArray[:], nil
}

func Bech32PubkeyToPeerID(pubKey string) (peer.ID, error) {
bech32PubKey, err := sdk.UnmarshalPubKey(sdk.AccPK, pubKey)
if err != nil {
return "", err
}
secp256k1PubKey, err := crypto2.UnmarshalSecp256k1PublicKey(bech32PubKey.Bytes())
if err != nil {
return "", err
}
return peer.IDFromPublicKey(secp256k1PubKey)
}
50 changes: 19 additions & 31 deletions p2p/communication.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@ import (
"time"

libp2p "github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
discovery_routing "github.com/libp2p/go-libp2p/p2p/discovery/routing"
discovery_util "github.com/libp2p/go-libp2p/p2p/discovery/util"
rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
Expand Down Expand Up @@ -62,10 +59,11 @@ type Communication struct {
BroadcastMsgChan chan *messages.BroadcastMsgChan
externalAddr maddr.Multiaddr
streamMgr *StreamMgr
whitelistedPeers []peer.ID
}

// NewCommunication create a new instance of Communication
func NewCommunication(rendezvous string, bootstrapPeers []maddr.Multiaddr, port int, externalIP string) (*Communication, error) {
func NewCommunication(rendezvous string, bootstrapPeers []maddr.Multiaddr, port int, externalIP string, whitelistedPeers []peer.ID) (*Communication, error) {
addr, err := maddr.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", port))
if err != nil {
return nil, fmt.Errorf("fail to create listen addr: %w", err)
Expand All @@ -90,6 +88,7 @@ func NewCommunication(rendezvous string, bootstrapPeers []maddr.Multiaddr, port
BroadcastMsgChan: make(chan *messages.BroadcastMsgChan, 1024),
externalAddr: externalAddr,
streamMgr: NewStreamMgr(),
whitelistedPeers: whitelistedPeers,
}, nil
}

Expand Down Expand Up @@ -244,7 +243,7 @@ func (c *Communication) bootStrapConnectivityCheck() error {
}

func (c *Communication) startChannel(privKeyBytes []byte) error {
ctx := context.Background()
c.logger.Warn().Msgf("No DHT enabled")
p2pPriKey, err := crypto.UnmarshalSecp256k1PrivateKey(privKeyBytes)
if err != nil {
c.logger.Error().Msgf("error is %f", err)
Expand Down Expand Up @@ -312,14 +311,6 @@ func (c *Communication) startChannel(privKeyBytes []byte) error {
// client because we want each peer to maintain its own local copy of the
// DHT, so that the bootstrapping node of the DHT can go down without
// inhibiting future peer discovery.
kademliaDHT, err := dht.New(ctx, h, dht.Mode(dht.ModeServer))
if err != nil {
return fmt.Errorf("fail to create DHT: %w", err)
}
c.logger.Debug().Msg("Bootstrapping the DHT")
if err = kademliaDHT.Bootstrap(ctx); err != nil {
return fmt.Errorf("fail to bootstrap DHT: %w", err)
}

var connectionErr error
for i := 0; i < 5; i++ {
Expand All @@ -334,30 +325,27 @@ func (c *Communication) startChannel(privKeyBytes []byte) error {
return fmt.Errorf("fail to connect to bootstrap peer: %w", connectionErr)
}

// We use a rendezvous point "meet me here" to announce our location.
// This is like telling your friends to meet you at the Eiffel Tower.
routingDiscovery := discovery_routing.NewRoutingDiscovery(kademliaDHT)
discovery_util.Advertise(ctx, routingDiscovery, c.rendezvous)

// Create a goroutine to shut down the DHT after 5 minutes
go func() {
select {
case <-time.After(5 * time.Minute):
c.logger.Info().Msg("Closing Kademlia DHT after 5 minutes")
if err := kademliaDHT.Close(); err != nil {
c.logger.Error().Err(err).Msg("Failed to close Kademlia DHT")
}
case <-ctx.Done():
c.logger.Info().Msg("Context done, not waiting for 5 minutes to close DHT")
}
}()

err = c.bootStrapConnectivityCheck()
if err != nil {
return err
}

c.logger.Info().Msg("Successfully announced!")

c.logger.Info().Msg("Start peer discovery/gossip...")
//c.bootstrapPeers
bootstrapPeerAddrInfos := make([]peer.AddrInfo, 0, len(c.bootstrapPeers))
for _, addr := range c.bootstrapPeers {
peerInfo, err := peer.AddrInfoFromP2pAddr(addr)
if err != nil {
c.logger.Error().Err(err).Msgf("fail to convert multiaddr to peer info: %s", addr)
continue
}
bootstrapPeerAddrInfos = append(bootstrapPeerAddrInfos, *peerInfo)
}
discovery := NewPeerDiscovery(c.host, bootstrapPeerAddrInfos)
discovery.Start(context.Background())

return nil
}

Expand Down
175 changes: 175 additions & 0 deletions p2p/discovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package p2p

import (
"context"
"encoding/json"
"io"
"sync"
"time"

"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

const DiscoveryProtocol = "/tss/discovery/1.0.0"
const GossipInterval = 10 * time.Second

type PeerDiscovery struct {
host host.Host
knownPeers map[peer.ID]peer.AddrInfo
bootstrapPeers []peer.AddrInfo
mu sync.RWMutex
logger zerolog.Logger
}

func NewPeerDiscovery(h host.Host, bootstrapPeers []peer.AddrInfo) *PeerDiscovery {
pd := &PeerDiscovery{
host: h,
knownPeers: make(map[peer.ID]peer.AddrInfo),
bootstrapPeers: bootstrapPeers,
logger: log.With().Str("module", "peer-discovery").Logger(),
}

// Set up discovery protocol handler
h.SetStreamHandler(DiscoveryProtocol, pd.handleDiscovery)

return pd
}

// Start begins the discovery process
func (pd *PeerDiscovery) Start(ctx context.Context) {
pd.logger.Info().Msgf("Starting peer discovery with bootstrap peers: %v", pd.bootstrapPeers)
// Connect to bootstrap peers first
for _, pinfo := range pd.bootstrapPeers {
if err := pd.host.Connect(ctx, pinfo); err != nil {
pd.logger.Error().Err(err).Msgf("Failed to connect to bootstrap peer %s", pinfo.ID)
continue
}
pd.addPeer(pinfo)
}

// Start periodic gossip
go pd.startGossip(ctx)
}

// addPeer adds a peer to known peers
func (pd *PeerDiscovery) addPeer(pinfo peer.AddrInfo) {
pd.mu.Lock()
defer pd.mu.Unlock()

if pinfo.ID == pd.host.ID() {
return // Don't add ourselves
}
pd.knownPeers[pinfo.ID] = pinfo
}

// GetPeers returns all known peers
func (pd *PeerDiscovery) GetPeers() []peer.AddrInfo {
pd.mu.RLock()
defer pd.mu.RUnlock()

peers := make([]peer.AddrInfo, 0, len(pd.knownPeers))
for _, p := range pd.knownPeers {
peers = append(peers, p)
}
return peers
}

// handleDiscovery handles incoming discovery streams
func (pd *PeerDiscovery) handleDiscovery(s network.Stream) {
pd.logger.Debug().Msgf("Received discovery stream from %s", s.Conn().RemotePeer())
defer s.Close()

ma := s.Conn().RemoteMultiaddr()

ai := peer.AddrInfo{
ID: s.Conn().RemotePeer(),
Addrs: []multiaddr.Multiaddr{ma},
}
pd.addPeer(ai)

// Share our known peers
peers := pd.GetPeers()
data, err := json.Marshal(peers)
if err != nil {
pd.logger.Error().Err(err).Msgf("Failed to marshal peers")
return
}
_, err = s.Write(data)
if err != nil {
pd.logger.Error().Err(err).Msgf("Failed to write to stream")
}
}

// startGossip periodically shares peer information
func (pd *PeerDiscovery) startGossip(ctx context.Context) {
ticker := time.NewTicker(GossipInterval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:

pd.gossipPeers(ctx)
}
}
}

func (pd *PeerDiscovery) gossipPeers(ctx context.Context) {
pd.logger.Debug().Msgf("Gossiping known peers")
peers := pd.GetPeers()
pd.logger.Debug().Msgf("current peers: %v", peers)

ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
for _, p := range peers {
if p.ID == pd.host.ID() {
continue
}

err := pd.host.Connect(ctx, p)
if err != nil {
pd.logger.Error().Err(err).Msgf("Failed to connect to peer %s", p)
}
pd.logger.Debug().Msgf("Connected to peer %s", p)

// Open discovery stream
s, err := pd.host.NewStream(ctx, p.ID, DiscoveryProtocol)
if err != nil {
pd.logger.Error().Err(err).Msgf("Failed to open discovery stream to %s", p)
continue
}
pd.logger.Debug().Msgf("Opened discovery stream to %s", p)

// Read peer info from stream
// This is a simplified example - implement proper serialization
buf, err := io.ReadAll(s)
if err != nil {
s.Close()
pd.logger.Error().Err(err).Msgf("Failed to read from stream")
continue
}
pd.logger.Info().Msgf("Received peer data: %s", string(buf))

// Parse received peer info and add to known peers
var recvPeers []peer.AddrInfo
err = json.Unmarshal(buf, &recvPeers)
if err != nil {
s.Close()
pd.logger.Error().Err(err).Msgf("Failed to unmarshal peer data")
continue
}
for _, p := range recvPeers {
pd.logger.Debug().Msgf("Adding peer %s", p)
pd.addPeer(p)
}

s.Close()
}
}
16 changes: 8 additions & 8 deletions tss/tss.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func NewTss(
preParams *bkeygen.LocalPreParams,
externalIP string,
tssPassword string,
whitelistedPeers []peer.ID,
) (*TssServer, error) {
pk := coskey.PubKey{
Key: priKey.PubKey().Bytes()[:],
Expand All @@ -82,7 +83,7 @@ func NewTss(
bootstrapPeers = savedPeers
bootstrapPeers = append(bootstrapPeers, cmdBootstrapPeers...)
}
comm, err := p2p.NewCommunication(rendezvous, bootstrapPeers, p2pPort, externalIP)
comm, err := p2p.NewCommunication(rendezvous, bootstrapPeers, p2pPort, externalIP, whitelistedPeers)
if err != nil {
return nil, fmt.Errorf("fail to create communication layer: %w", err)
}
Expand Down Expand Up @@ -231,17 +232,16 @@ func (t *TssServer) GetLocalPeerID() string {
}

// GetKnownPeers return the the ID and IP address of all peers.
func (t *TssServer) GetKnownPeers() []PeerInfo {
infos := []PeerInfo{}
func (t *TssServer) GetKnownPeers() []peer.AddrInfo {
var infos []peer.AddrInfo
host := t.p2pCommunication.GetHost()

for _, conn := range host.Network().Conns() {
peer := conn.RemotePeer()
p := conn.RemotePeer()
addrs := conn.RemoteMultiaddr()
ip, _ := addrs.ValueForProtocol(maddr.P_IP4)
pi := PeerInfo{
ID: peer.String(),
Address: ip,
pi := peer.AddrInfo{
p,
[]maddr.Multiaddr{addrs},
}
infos = append(infos, pi)
}
Expand Down

0 comments on commit 7d2b4f3

Please sign in to comment.