From c1a367e17db90d3660a760e110b34bbdcbcd20c5 Mon Sep 17 00:00:00 2001 From: Erson Pereira <38801159+ersonp@users.noreply.github.com> Date: Wed, 10 Jul 2024 18:49:37 +0530 Subject: [PATCH] Feat/ip from server (#272) * feat: Add DialServerForIP method to dmsg client This commit adds a new method `DialServerForIP` to the `Client` struct in the `pkg/dmsg/client.go` file. This method dials to dmsg servers to retrieve the public IP address of the client. It iterates through a list of server entries, attempts to dial each server, and returns the first public IP address it receives. The purpose of this change is to provide a way for the client to obtain its public IP address from dmsg servers. * feat: Add new commandline tool dmsgip * chore: fix linting * feat: Add support for dmsg server public keys in DialServerForIP method This commit modifies the `DialServerForIP` method in the `pkg/dmsg/client.go` file to accept a slice of dmsg server public keys as an argument. If the `servers` argument is nil, the method retrieves the server entries using the `discoverServers` function and populates the `servers` slice with the static public keys from the entries. Then, it iterates through the `servers` slice and attempts to dial each server to retrieve the public IP address of the client. The purpose of this change is to allow the `DialServerForIP` method to support custom dmsg server public keys, providing more flexibility in obtaining the client's public IP address from dmsg servers. * feat: Add error handling for non-public IP address in DialServerForIP method * refactor: Improve error handling and connection logic in DialServerForIP method This commit refactors the `DialServerForIP` method in the `pkg/dmsg/client.go` file to improve error handling and connection logic. It introduces two separate loops to handle delegated servers and attempts to connect to each server individually. Additionally, it properly closes if the session is created just for the IP after dialing the server. The purpose of this change is to enhance the reliability and stability of the `DialServerForIP` method when retrieving the public IP address from dmsg servers. * refactor: Improve error handling and add MinSessions to startDmsg * feat: Add dmsgip commandline tool to dmsg asa a subcommand * refactor: Rename DialServerForIP to LookupIP in dmsg client This commit renames the `DialServerForIP` method to `LookupIP` in the `pkg/dmsg/client.go` file. The functionality remains the same, but the new name better reflects the purpose of the method, which is to lookup the public IP address of the client from dmsg servers. The purpose of this change is to improve the clarity and consistency of the method name, making it more intuitive for developers working with the dmsg client. * chore: Move code around * refactor: Improve error handling in LookupIP method * test: add unit tests for LookupIP * refactor: improve error handling in dmsgip cmd * test: update IP lookup logic in stream_test.go This commit updates the IP lookup logic in the `stream_test.go` file. It introduces conditional checks based on the operating system to handle different IP address formats. On Windows, the IP address is expected to be "127.0.0.1", while on other operating systems, it is expected to be "::1". The purpose of this change is to ensure that the IP lookup tests pass correctly on different operating systems, improving the reliability and consistency of the test suite. --- cmd/dmsg/dmsg.go | 3 + cmd/dmsgip/README.md | 23 ++++++ cmd/dmsgip/commands/dmsgip.go | 133 ++++++++++++++++++++++++++++++++++ cmd/dmsgip/dmsgip.go | 44 +++++++++++ pkg/dmsg/client.go | 73 ++++++++++++++++++- pkg/dmsg/client_session.go | 44 +++++++++++ pkg/dmsg/server_session.go | 35 +++++++++ pkg/dmsg/stream.go | 38 ++++++++++ pkg/dmsg/stream_test.go | 95 ++++++++++++++++++++++++ pkg/dmsg/types.go | 3 + 10 files changed, 490 insertions(+), 1 deletion(-) create mode 100644 cmd/dmsgip/README.md create mode 100644 cmd/dmsgip/commands/dmsgip.go create mode 100644 cmd/dmsgip/dmsgip.go diff --git a/cmd/dmsg/dmsg.go b/cmd/dmsg/dmsg.go index d57651adf..42c6704cc 100644 --- a/cmd/dmsg/dmsg.go +++ b/cmd/dmsg/dmsg.go @@ -15,6 +15,7 @@ import ( dmsgsocks "github.com/skycoin/dmsg/cmd/dmsg-socks5/commands" dmsgcurl "github.com/skycoin/dmsg/cmd/dmsgcurl/commands" dmsghttp "github.com/skycoin/dmsg/cmd/dmsghttp/commands" + dmsgip "github.com/skycoin/dmsg/cmd/dmsgip/commands" dmsgptycli "github.com/skycoin/dmsg/cmd/dmsgpty-cli/commands" dmsgptyhost "github.com/skycoin/dmsg/cmd/dmsgpty-host/commands" dmsgptyui "github.com/skycoin/dmsg/cmd/dmsgpty-ui/commands" @@ -35,6 +36,7 @@ func init() { dmsgcurl.RootCmd, dmsgweb.RootCmd, dmsgsocks.RootCmd, + dmsgip.RootCmd, ) dmsgdisc.RootCmd.Use = "disc" dmsgserver.RootCmd.Use = "server" @@ -45,6 +47,7 @@ func init() { dmsgptycli.RootCmd.Use = "cli" dmsgptyhost.RootCmd.Use = "host" dmsgptyui.RootCmd.Use = "ui" + dmsgip.RootCmd.Use = "ip" var helpflag bool RootCmd.SetUsageTemplate(help) diff --git a/cmd/dmsgip/README.md b/cmd/dmsgip/README.md new file mode 100644 index 000000000..97a9b4987 --- /dev/null +++ b/cmd/dmsgip/README.md @@ -0,0 +1,23 @@ + + + +``` + + + ┌┬┐┌┬┐┌─┐┌─┐ ┬┌─┐ + │││││└─┐│ ┬ │├─┘ + ─┴┘┴ ┴└─┘└─┘ ┴┴ +DMSG ip utility + +Usage: + dmsgip + +Flags: + -c, --dmsg-disc string dmsg discovery url default: + http://dmsgd.skywire.dev + -l, --loglvl string [ debug | warn | error | fatal | panic | trace | info ] (default "fatal") + -s, --sk cipher.SecKey a random key is generated if unspecified + (default 0000000000000000000000000000000000000000000000000000000000000000) + -v, --version version for dmsgip + +``` diff --git a/cmd/dmsgip/commands/dmsgip.go b/cmd/dmsgip/commands/dmsgip.go new file mode 100644 index 000000000..2b6a14ab2 --- /dev/null +++ b/cmd/dmsgip/commands/dmsgip.go @@ -0,0 +1,133 @@ +// Package commands cmd/dmsgcurl/commands/dmsgcurl.go +package commands + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/skycoin/skywire-utilities/pkg/buildinfo" + "github.com/skycoin/skywire-utilities/pkg/cipher" + "github.com/skycoin/skywire-utilities/pkg/cmdutil" + "github.com/skycoin/skywire-utilities/pkg/logging" + "github.com/skycoin/skywire-utilities/pkg/skyenv" + "github.com/spf13/cobra" + + "github.com/skycoin/dmsg/pkg/disc" + "github.com/skycoin/dmsg/pkg/dmsg" +) + +var ( + dmsgDisc string + sk cipher.SecKey + logLvl string + dmsgServers []string +) + +func init() { + RootCmd.Flags().StringVarP(&dmsgDisc, "dmsg-disc", "c", "", "dmsg discovery url default:\n"+skyenv.DmsgDiscAddr) + RootCmd.Flags().StringVarP(&logLvl, "loglvl", "l", "fatal", "[ debug | warn | error | fatal | panic | trace | info ]\033[0m") + if os.Getenv("DMSGIP_SK") != "" { + sk.Set(os.Getenv("DMSGIP_SK")) //nolint + } + RootCmd.Flags().StringSliceVarP(&dmsgServers, "srv", "d", []string{}, "dmsg server public keys\n\r") + RootCmd.Flags().VarP(&sk, "sk", "s", "a random key is generated if unspecified\n\r") +} + +// RootCmd containsa the root dmsgcurl command +var RootCmd = &cobra.Command{ + Use: func() string { + return strings.Split(filepath.Base(strings.ReplaceAll(strings.ReplaceAll(fmt.Sprintf("%v", os.Args), "[", ""), "]", "")), " ")[0] + }(), + Short: "DMSG ip utility", + Long: ` + ┌┬┐┌┬┐┌─┐┌─┐ ┬┌─┐ + │││││└─┐│ ┬ │├─┘ + ─┴┘┴ ┴└─┘└─┘ ┴┴ +DMSG ip utility`, + SilenceErrors: true, + SilenceUsage: true, + DisableSuggestions: true, + DisableFlagsInUseLine: true, + Version: buildinfo.Version(), + PreRun: func(cmd *cobra.Command, args []string) { + if dmsgDisc == "" { + dmsgDisc = skyenv.DmsgDiscAddr + } + }, + RunE: func(cmd *cobra.Command, args []string) error { + log := logging.MustGetLogger("dmsgip") + + if logLvl != "" { + if lvl, err := logging.LevelFromString(logLvl); err == nil { + logging.SetLevel(lvl) + } + } + + var srvs []cipher.PubKey + for _, srv := range dmsgServers { + var pk cipher.PubKey + if err := pk.Set(srv); err != nil { + return fmt.Errorf("failed to parse server public key: %w", err) + } + srvs = append(srvs, pk) + } + + ctx, cancel := cmdutil.SignalContext(context.Background(), log) + defer cancel() + + pk, err := sk.PubKey() + if err != nil { + pk, sk = cipher.GenerateKeyPair() + } + + dmsgC, closeDmsg, err := startDmsg(ctx, log, pk, sk) + if err != nil { + log.WithError(err).Error("failed to start dmsg") + } + defer closeDmsg() + + ip, err := dmsgC.LookupIP(ctx, srvs) + if err != nil { + log.WithError(err).Error("failed to lookup IP") + } + + fmt.Printf("%v\n", ip) + fmt.Print("\n") + return nil + }, +} + +func startDmsg(ctx context.Context, log *logging.Logger, pk cipher.PubKey, sk cipher.SecKey) (dmsgC *dmsg.Client, stop func(), err error) { + dmsgC = dmsg.NewClient(pk, sk, disc.NewHTTP(dmsgDisc, &http.Client{}, log), &dmsg.Config{MinSessions: dmsg.DefaultMinSessions}) + go dmsgC.Serve(context.Background()) + + stop = func() { + err := dmsgC.Close() + log.WithError(err).Debug("Disconnected from dmsg network.") + fmt.Printf("\n") + } + log.WithField("public_key", pk.String()).WithField("dmsg_disc", dmsgDisc). + Debug("Connecting to dmsg network...") + + select { + case <-ctx.Done(): + stop() + return nil, nil, ctx.Err() + + case <-dmsgC.Ready(): + log.Debug("Dmsg network ready.") + return dmsgC, stop, nil + } +} + +// Execute executes root CLI command. +func Execute() { + if err := RootCmd.Execute(); err != nil { + log.Fatal("Failed to execute command: ", err) + } +} diff --git a/cmd/dmsgip/dmsgip.go b/cmd/dmsgip/dmsgip.go new file mode 100644 index 000000000..12b3ec8ee --- /dev/null +++ b/cmd/dmsgip/dmsgip.go @@ -0,0 +1,44 @@ +// package main cmd/dmsgcurl/dmsgcurl.go +package main + +import ( + cc "github.com/ivanpirog/coloredcobra" + "github.com/spf13/cobra" + + "github.com/skycoin/dmsg/cmd/dmsgip/commands" +) + +func init() { + var helpflag bool + commands.RootCmd.SetUsageTemplate(help) + commands.RootCmd.PersistentFlags().BoolVarP(&helpflag, "help", "h", false, "help for dmsgpty-cli") + commands.RootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) + commands.RootCmd.PersistentFlags().MarkHidden("help") //nolint +} + +func main() { + cc.Init(&cc.Config{ + RootCmd: commands.RootCmd, + Headings: cc.HiBlue + cc.Bold, + Commands: cc.HiBlue + cc.Bold, + CmdShortDescr: cc.HiBlue, + Example: cc.HiBlue + cc.Italic, + ExecName: cc.HiBlue + cc.Bold, + Flags: cc.HiBlue + cc.Bold, + FlagsDescr: cc.HiBlue, + NoExtraNewlines: true, + NoBottomNewline: true, + }) + + commands.Execute() +} + +const help = "Usage:\r\n" + + " {{.UseLine}}{{if .HasAvailableSubCommands}}{{end}} {{if gt (len .Aliases) 0}}\r\n\r\n" + + "{{.NameAndAliases}}{{end}}{{if .HasAvailableSubCommands}}\r\n\r\n" + + "Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand)}}\r\n " + + "{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}}\r\n\r\n" + + "Flags:\r\n" + + "{{.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}}\r\n\r\n" + + "Global Flags:\r\n" + + "{{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}\r\n\r\n" diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 0973245a7..9b3fd8c99 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -366,6 +366,78 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { return nil, ErrCannotConnectToDelegated } +// LookupIP dails to dmsg servers for public IP of the client. +func (ce *Client) LookupIP(ctx context.Context, servers []cipher.PubKey) (myIP net.IP, err error) { + + cancellabelCtx, cancel := context.WithCancel(ctx) + defer cancel() + + if servers == nil { + entries, err := ce.discoverServers(cancellabelCtx, true) + if err != nil { + return nil, err + } + for _, entry := range entries { + servers = append(servers, entry.Static) + } + } + + // Range client's delegated servers. + // See if we are already connected to a delegated server. + for _, srvPK := range servers { + if dSes, ok := ce.clientSession(ce.porter, srvPK); ok { + ip, err := dSes.LookupIP(Addr{PK: dSes.RemotePK(), Port: 1}) + if err != nil { + ce.log.WithError(err).WithField("server_pk", srvPK).Warn("Failed to dial server for IP.") + continue + } + + // If the client is test client then ignore Public IP check + if ce.conf.ClientType == "test" { + return ip, nil + } + + // Check if the IP is public + if !netutil.IsPublicIP(ip) { + return nil, errors.New("received non-public IP address from dmsg server") + } + return ip, nil + } + } + + // Range client's delegated servers. + // Attempt to connect to a delegated server. + // And Close it after getting the IP. + for _, srvPK := range servers { + dSes, err := ce.EnsureAndObtainSession(ctx, srvPK) + if err != nil { + continue + } + ip, err := dSes.LookupIP(Addr{PK: dSes.RemotePK(), Port: 1}) + if err != nil { + ce.log.WithError(err).WithField("server_pk", srvPK).Warn("Failed to dial server for IP.") + continue + } + err = dSes.Close() + if err != nil { + ce.log.WithError(err).WithField("server_pk", srvPK).Warn("Failed to close session") + } + + // If the client is test client then ignore Public IP check + if ce.conf.ClientType == "test" { + return ip, nil + } + + // Check if the IP is public + if !netutil.IsPublicIP(ip) { + return nil, errors.New("received non-public IP address from dmsg server") + } + return ip, nil + } + + return nil, ErrCannotConnectToDelegated +} + // Session obtains an established session. func (ce *Client) Session(pk cipher.PubKey) (ClientSession, bool) { return ce.clientSession(ce.porter, pk) @@ -403,7 +475,6 @@ func (ce *Client) EnsureAndObtainSession(ctx context.Context, srvPK cipher.PubKe if err != nil { return ClientSession{}, err } - return ce.dialSession(ctx, srvEntry) } diff --git a/pkg/dmsg/client_session.go b/pkg/dmsg/client_session.go index 7c766975a..ba3cfff3a 100644 --- a/pkg/dmsg/client_session.go +++ b/pkg/dmsg/client_session.go @@ -69,6 +69,50 @@ func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) { return dStr, err } +// LookupIP attempts to dial a stream to the server for the IP address of the client. +func (cs *ClientSession) LookupIP(dst Addr) (myIP net.IP, err error) { + log := cs.log. + WithField("func", "ClientSession.LookupIP"). + WithField("dst_addr", cs.rPK) + + dStr, err := newInitiatingStream(cs) + if err != nil { + return nil, err + } + + // Close stream on failure. + defer func() { + if err != nil { + log.WithError(err). + WithField("close_error", dStr.Close()). + Debug("Stream closed on failure.") + } + }() + + // Prepare deadline. + if err = dStr.SetDeadline(time.Now().Add(HandshakeTimeout)); err != nil { + return nil, err + } + + // Do stream handshake. + req, err := dStr.writeIPRequest(dst) + if err != nil { + return nil, err + } + + myIP, err = dStr.readIPResponse(req) + if err != nil { + return nil, err + } + + err = dStr.Close() + if err != nil { + return nil, err + } + + return myIP, err +} + // serve accepts incoming streams from remote clients. func (cs *ClientSession) serve() error { defer func() { diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index 57ce2b304..4b8376eec 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -2,6 +2,7 @@ package dmsg import ( + "fmt" "io" "net" @@ -98,6 +99,29 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream) WithField("dst_addr", req.DstAddr) log.Debug("Read stream request from initiating side.") + if req.IPinfo && req.DstAddr.PK == ss.entity.LocalPK() { + log.Debug("Received IP stream request.") + + ip, err := addrToIP(yStr.RemoteAddr()) + if err != nil { + ss.m.RecordStream(servermetrics.DeltaFailed) // record failed stream + return err + } + + resp := StreamResponse{ + ReqHash: req.raw.Hash(), + Accepted: true, + IP: ip, + } + obj := MakeSignedStreamResponse(&resp, ss.entity.LocalSK()) + + if err := ss.writeObject(yStr, obj); err != nil { + ss.m.RecordStream(servermetrics.DeltaFailed) // record failed stream + return err + } + log.Debug("Wrote IP stream response.") + return nil + } // Obtain next session. ss2, ok := ss.entity.serverSession(req.DstAddr.PK) @@ -129,6 +153,17 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream) return netutil.CopyReadWriteCloser(yStr, yStr2) } +func addrToIP(addr net.Addr) (net.IP, error) { + switch a := addr.(type) { + case *net.TCPAddr: + return a.IP, nil + case *net.UDPAddr: + return a.IP, nil + default: + return nil, fmt.Errorf("unsupported address type %T", addr) + } +} + func (ss *ServerSession) forwardRequest(req StreamRequest) (yStr *yamux.Stream, respObj SignedObject, err error) { defer func() { if err != nil && yStr != nil { diff --git a/pkg/dmsg/stream.go b/pkg/dmsg/stream.go index 72d1bc4e5..6bc61f1a2 100644 --- a/pkg/dmsg/stream.go +++ b/pkg/dmsg/stream.go @@ -87,6 +87,29 @@ func (s *Stream) writeRequest(rAddr Addr) (req StreamRequest, err error) { return } +func (s *Stream) writeIPRequest(rAddr Addr) (req StreamRequest, err error) { + // Reserve stream in porter. + var lPort uint16 + if lPort, s.close, err = s.ses.porter.ReserveEphemeral(context.Background(), s); err != nil { + return + } + + // Prepare fields. + s.prepareFields(true, Addr{PK: s.ses.LocalPK(), Port: lPort}, rAddr) + + req = StreamRequest{ + Timestamp: time.Now().UnixNano(), + SrcAddr: s.lAddr, + DstAddr: s.rAddr, + IPinfo: true, + } + obj := MakeSignedStreamRequest(&req, s.ses.localSK()) + + // Write request. + err = s.ses.writeObject(s.yStr, obj) + return +} + func (s *Stream) readRequest() (req StreamRequest, err error) { var obj SignedObject if obj, err = s.ses.readObject(s.yStr); err != nil { @@ -158,6 +181,21 @@ func (s *Stream) readResponse(req StreamRequest) error { return s.ns.ProcessHandshakeMessage(resp.NoiseMsg) } +func (s *Stream) readIPResponse(req StreamRequest) (net.IP, error) { + obj, err := s.ses.readObject(s.yStr) + if err != nil { + return nil, err + } + resp, err := obj.ObtainStreamResponse() + if err != nil { + return nil, err + } + if err := resp.Verify(req); err != nil { + return nil, err + } + return resp.IP, nil +} + func (s *Stream) prepareFields(init bool, lAddr, rAddr Addr) { ns, err := noise.New(noise.HandshakeKK, noise.Config{ LocalPK: s.ses.LocalPK(), diff --git a/pkg/dmsg/stream_test.go b/pkg/dmsg/stream_test.go index 2356f4b08..dd2c4b872 100644 --- a/pkg/dmsg/stream_test.go +++ b/pkg/dmsg/stream_test.go @@ -6,12 +6,14 @@ import ( "fmt" "io" "net" + "runtime" "sync" "testing" "time" "github.com/skycoin/skywire-utilities/pkg/cipher" "github.com/skycoin/skywire-utilities/pkg/logging" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/nettest" @@ -209,6 +211,99 @@ func TestStream(t *testing.T) { require.NoError(t, <-chSrv) } +func TestLookupIP(t *testing.T) { + // Prepare mock discovery. + dc := disc.NewMock(0) + const maxSessions = 10 + + // Prepare dmsg server A. + pkSrvA, skSrvB := GenKeyPair(t, "server_A") + srvConf := &ServerConfig{ + MaxSessions: maxSessions, + UpdateInterval: 0, + LimitIP: 200, + } + srv := NewServer(pkSrvA, skSrvB, dc, srvConf, nil) + srv.SetLogger(logging.MustGetLogger("server")) + lisSrv, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Serve dmsg server. + chSrv := make(chan error, 1) + go func() { chSrv <- srv.Serve(lisSrv, "") }() //nolint:errcheck + + // Prepare and serve dmsg client A. + pkA, skA := GenKeyPair(t, "client A") + + clientConfig := &Config{ + MinSessions: DefaultMinSessions, + UpdateInterval: DefaultUpdateInterval * 5, + ClientType: "test", + } + + dmsgC := NewClient(pkA, skA, dc, clientConfig) + go dmsgC.Serve(context.Background()) + t.Cleanup(func() { assert.NoError(t, dmsgC.Close()) }) + <-dmsgC.Ready() + + t.Run("test_connected_server", func(t *testing.T) { + // Ensure all entities are registered in discovery before continuing. + time.Sleep(time.Second * 2) + + // Lookup IP. + srvs := []cipher.PubKey{pkSrvA} + ip, err := dmsgC.LookupIP(context.Background(), srvs) + require.NoError(t, err) + + if runtime.GOOS == "windows" { + require.Equal(t, net.ParseIP("127.0.0.1"), ip) + } else { + require.Equal(t, net.ParseIP("::1"), ip) + } + + // Ensure all entities are deregistered in discovery before continuing. + time.Sleep(time.Second * 2) + }) + + t.Run("test_disconnected_server", func(t *testing.T) { + // Prepare dmsg server B. + pkSrvB, skSrvB := GenKeyPair(t, "server_B") + srvB := NewServer(pkSrvB, skSrvB, dc, srvConf, nil) + srvB.SetLogger(logging.MustGetLogger("server_B")) + lisSrvB, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Serve dmsg server B. + chSrvB := make(chan error, 1) + go func() { chSrvB <- srvB.Serve(lisSrvB, "") }() //nolint:errcheck + + // Ensure all entities are registered in discovery before continuing. + time.Sleep(time.Second * 2) + + srvs := []cipher.PubKey{pkSrvB} + ip, err := dmsgC.LookupIP(context.Background(), srvs) + require.NoError(t, err) + + if runtime.GOOS == "windows" { + require.Equal(t, net.ParseIP("127.0.0.1"), ip) + } else { + require.Equal(t, net.ParseIP("::1"), ip) + } + + // Ensure all entities are deregistered in discovery before continuing. + time.Sleep(time.Second * 2) + + // Ensure the server B entry is deleted and server A entry is still there. + pks := dmsgC.ConnectedServersPK() + require.Equal(t, []string{pkSrvA.String()}, pks) + }) + + // Closing logic. + require.NoError(t, dmsgC.Close()) + require.NoError(t, srv.Close()) + require.NoError(t, <-chSrv) +} + func GenKeyPair(t *testing.T, seed string) (cipher.PubKey, cipher.SecKey) { pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte(seed)) require.NoError(t, err) diff --git a/pkg/dmsg/types.go b/pkg/dmsg/types.go index 1e39ee14f..9fbb2ca3c 100644 --- a/pkg/dmsg/types.go +++ b/pkg/dmsg/types.go @@ -4,6 +4,7 @@ package dmsg import ( "errors" "fmt" + "net" "strings" "time" @@ -167,6 +168,7 @@ type StreamRequest struct { Timestamp int64 SrcAddr Addr DstAddr Addr + IPinfo bool NoiseMsg []byte raw SignedObject `enc:"-"` // back reference. @@ -203,6 +205,7 @@ func (req StreamRequest) Verify(lastTimestamp int64) error { type StreamResponse struct { ReqHash cipher.SHA256 // Hash of associated dial request. Accepted bool // Whether the request is accepted. + IP net.IP // IP address of the node. ErrCode errorCode // Check if not accepted. NoiseMsg []byte