diff --git a/cmd/dmsg/dmsg.go b/cmd/dmsg/dmsg.go index d57651ad..42c6704c 100644 --- a/cmd/dmsg/dmsg.go +++ b/cmd/dmsg/dmsg.go @@ -15,6 +15,7 @@ import ( dmsgsocks "github.com/skycoin/dmsg/cmd/dmsg-socks5/commands" dmsgcurl "github.com/skycoin/dmsg/cmd/dmsgcurl/commands" dmsghttp "github.com/skycoin/dmsg/cmd/dmsghttp/commands" + dmsgip "github.com/skycoin/dmsg/cmd/dmsgip/commands" dmsgptycli "github.com/skycoin/dmsg/cmd/dmsgpty-cli/commands" dmsgptyhost "github.com/skycoin/dmsg/cmd/dmsgpty-host/commands" dmsgptyui "github.com/skycoin/dmsg/cmd/dmsgpty-ui/commands" @@ -35,6 +36,7 @@ func init() { dmsgcurl.RootCmd, dmsgweb.RootCmd, dmsgsocks.RootCmd, + dmsgip.RootCmd, ) dmsgdisc.RootCmd.Use = "disc" dmsgserver.RootCmd.Use = "server" @@ -45,6 +47,7 @@ func init() { dmsgptycli.RootCmd.Use = "cli" dmsgptyhost.RootCmd.Use = "host" dmsgptyui.RootCmd.Use = "ui" + dmsgip.RootCmd.Use = "ip" var helpflag bool RootCmd.SetUsageTemplate(help) diff --git a/cmd/dmsgip/README.md b/cmd/dmsgip/README.md new file mode 100644 index 00000000..97a9b498 --- /dev/null +++ b/cmd/dmsgip/README.md @@ -0,0 +1,23 @@ + + + +``` + + + ┌┬┐┌┬┐┌─┐┌─┐ ┬┌─┐ + │││││└─┐│ ┬ │├─┘ + ─┴┘┴ ┴└─┘└─┘ ┴┴ +DMSG ip utility + +Usage: + dmsgip + +Flags: + -c, --dmsg-disc string dmsg discovery url default: + http://dmsgd.skywire.dev + -l, --loglvl string [ debug | warn | error | fatal | panic | trace | info ] (default "fatal") + -s, --sk cipher.SecKey a random key is generated if unspecified + (default 0000000000000000000000000000000000000000000000000000000000000000) + -v, --version version for dmsgip + +``` diff --git a/cmd/dmsgip/commands/dmsgip.go b/cmd/dmsgip/commands/dmsgip.go new file mode 100644 index 00000000..2b6a14ab --- /dev/null +++ b/cmd/dmsgip/commands/dmsgip.go @@ -0,0 +1,133 @@ +// Package commands cmd/dmsgcurl/commands/dmsgcurl.go +package commands + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/skycoin/skywire-utilities/pkg/buildinfo" + "github.com/skycoin/skywire-utilities/pkg/cipher" + "github.com/skycoin/skywire-utilities/pkg/cmdutil" + "github.com/skycoin/skywire-utilities/pkg/logging" + "github.com/skycoin/skywire-utilities/pkg/skyenv" + "github.com/spf13/cobra" + + "github.com/skycoin/dmsg/pkg/disc" + "github.com/skycoin/dmsg/pkg/dmsg" +) + +var ( + dmsgDisc string + sk cipher.SecKey + logLvl string + dmsgServers []string +) + +func init() { + RootCmd.Flags().StringVarP(&dmsgDisc, "dmsg-disc", "c", "", "dmsg discovery url default:\n"+skyenv.DmsgDiscAddr) + RootCmd.Flags().StringVarP(&logLvl, "loglvl", "l", "fatal", "[ debug | warn | error | fatal | panic | trace | info ]\033[0m") + if os.Getenv("DMSGIP_SK") != "" { + sk.Set(os.Getenv("DMSGIP_SK")) //nolint + } + RootCmd.Flags().StringSliceVarP(&dmsgServers, "srv", "d", []string{}, "dmsg server public keys\n\r") + RootCmd.Flags().VarP(&sk, "sk", "s", "a random key is generated if unspecified\n\r") +} + +// RootCmd containsa the root dmsgcurl command +var RootCmd = &cobra.Command{ + Use: func() string { + return strings.Split(filepath.Base(strings.ReplaceAll(strings.ReplaceAll(fmt.Sprintf("%v", os.Args), "[", ""), "]", "")), " ")[0] + }(), + Short: "DMSG ip utility", + Long: ` + ┌┬┐┌┬┐┌─┐┌─┐ ┬┌─┐ + │││││└─┐│ ┬ │├─┘ + ─┴┘┴ ┴└─┘└─┘ ┴┴ +DMSG ip utility`, + SilenceErrors: true, + SilenceUsage: true, + DisableSuggestions: true, + DisableFlagsInUseLine: true, + Version: buildinfo.Version(), + PreRun: func(cmd *cobra.Command, args []string) { + if dmsgDisc == "" { + dmsgDisc = skyenv.DmsgDiscAddr + } + }, + RunE: func(cmd *cobra.Command, args []string) error { + log := logging.MustGetLogger("dmsgip") + + if logLvl != "" { + if lvl, err := logging.LevelFromString(logLvl); err == nil { + logging.SetLevel(lvl) + } + } + + var srvs []cipher.PubKey + for _, srv := range dmsgServers { + var pk cipher.PubKey + if err := pk.Set(srv); err != nil { + return fmt.Errorf("failed to parse server public key: %w", err) + } + srvs = append(srvs, pk) + } + + ctx, cancel := cmdutil.SignalContext(context.Background(), log) + defer cancel() + + pk, err := sk.PubKey() + if err != nil { + pk, sk = cipher.GenerateKeyPair() + } + + dmsgC, closeDmsg, err := startDmsg(ctx, log, pk, sk) + if err != nil { + log.WithError(err).Error("failed to start dmsg") + } + defer closeDmsg() + + ip, err := dmsgC.LookupIP(ctx, srvs) + if err != nil { + log.WithError(err).Error("failed to lookup IP") + } + + fmt.Printf("%v\n", ip) + fmt.Print("\n") + return nil + }, +} + +func startDmsg(ctx context.Context, log *logging.Logger, pk cipher.PubKey, sk cipher.SecKey) (dmsgC *dmsg.Client, stop func(), err error) { + dmsgC = dmsg.NewClient(pk, sk, disc.NewHTTP(dmsgDisc, &http.Client{}, log), &dmsg.Config{MinSessions: dmsg.DefaultMinSessions}) + go dmsgC.Serve(context.Background()) + + stop = func() { + err := dmsgC.Close() + log.WithError(err).Debug("Disconnected from dmsg network.") + fmt.Printf("\n") + } + log.WithField("public_key", pk.String()).WithField("dmsg_disc", dmsgDisc). + Debug("Connecting to dmsg network...") + + select { + case <-ctx.Done(): + stop() + return nil, nil, ctx.Err() + + case <-dmsgC.Ready(): + log.Debug("Dmsg network ready.") + return dmsgC, stop, nil + } +} + +// Execute executes root CLI command. +func Execute() { + if err := RootCmd.Execute(); err != nil { + log.Fatal("Failed to execute command: ", err) + } +} diff --git a/cmd/dmsgip/dmsgip.go b/cmd/dmsgip/dmsgip.go new file mode 100644 index 00000000..12b3ec8e --- /dev/null +++ b/cmd/dmsgip/dmsgip.go @@ -0,0 +1,44 @@ +// package main cmd/dmsgcurl/dmsgcurl.go +package main + +import ( + cc "github.com/ivanpirog/coloredcobra" + "github.com/spf13/cobra" + + "github.com/skycoin/dmsg/cmd/dmsgip/commands" +) + +func init() { + var helpflag bool + commands.RootCmd.SetUsageTemplate(help) + commands.RootCmd.PersistentFlags().BoolVarP(&helpflag, "help", "h", false, "help for dmsgpty-cli") + commands.RootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) + commands.RootCmd.PersistentFlags().MarkHidden("help") //nolint +} + +func main() { + cc.Init(&cc.Config{ + RootCmd: commands.RootCmd, + Headings: cc.HiBlue + cc.Bold, + Commands: cc.HiBlue + cc.Bold, + CmdShortDescr: cc.HiBlue, + Example: cc.HiBlue + cc.Italic, + ExecName: cc.HiBlue + cc.Bold, + Flags: cc.HiBlue + cc.Bold, + FlagsDescr: cc.HiBlue, + NoExtraNewlines: true, + NoBottomNewline: true, + }) + + commands.Execute() +} + +const help = "Usage:\r\n" + + " {{.UseLine}}{{if .HasAvailableSubCommands}}{{end}} {{if gt (len .Aliases) 0}}\r\n\r\n" + + "{{.NameAndAliases}}{{end}}{{if .HasAvailableSubCommands}}\r\n\r\n" + + "Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand)}}\r\n " + + "{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}}\r\n\r\n" + + "Flags:\r\n" + + "{{.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}}\r\n\r\n" + + "Global Flags:\r\n" + + "{{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}\r\n\r\n" diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 0973245a..9b3fd8c9 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -366,6 +366,78 @@ func (ce *Client) DialStream(ctx context.Context, addr Addr) (*Stream, error) { return nil, ErrCannotConnectToDelegated } +// LookupIP dails to dmsg servers for public IP of the client. +func (ce *Client) LookupIP(ctx context.Context, servers []cipher.PubKey) (myIP net.IP, err error) { + + cancellabelCtx, cancel := context.WithCancel(ctx) + defer cancel() + + if servers == nil { + entries, err := ce.discoverServers(cancellabelCtx, true) + if err != nil { + return nil, err + } + for _, entry := range entries { + servers = append(servers, entry.Static) + } + } + + // Range client's delegated servers. + // See if we are already connected to a delegated server. + for _, srvPK := range servers { + if dSes, ok := ce.clientSession(ce.porter, srvPK); ok { + ip, err := dSes.LookupIP(Addr{PK: dSes.RemotePK(), Port: 1}) + if err != nil { + ce.log.WithError(err).WithField("server_pk", srvPK).Warn("Failed to dial server for IP.") + continue + } + + // If the client is test client then ignore Public IP check + if ce.conf.ClientType == "test" { + return ip, nil + } + + // Check if the IP is public + if !netutil.IsPublicIP(ip) { + return nil, errors.New("received non-public IP address from dmsg server") + } + return ip, nil + } + } + + // Range client's delegated servers. + // Attempt to connect to a delegated server. + // And Close it after getting the IP. + for _, srvPK := range servers { + dSes, err := ce.EnsureAndObtainSession(ctx, srvPK) + if err != nil { + continue + } + ip, err := dSes.LookupIP(Addr{PK: dSes.RemotePK(), Port: 1}) + if err != nil { + ce.log.WithError(err).WithField("server_pk", srvPK).Warn("Failed to dial server for IP.") + continue + } + err = dSes.Close() + if err != nil { + ce.log.WithError(err).WithField("server_pk", srvPK).Warn("Failed to close session") + } + + // If the client is test client then ignore Public IP check + if ce.conf.ClientType == "test" { + return ip, nil + } + + // Check if the IP is public + if !netutil.IsPublicIP(ip) { + return nil, errors.New("received non-public IP address from dmsg server") + } + return ip, nil + } + + return nil, ErrCannotConnectToDelegated +} + // Session obtains an established session. func (ce *Client) Session(pk cipher.PubKey) (ClientSession, bool) { return ce.clientSession(ce.porter, pk) @@ -403,7 +475,6 @@ func (ce *Client) EnsureAndObtainSession(ctx context.Context, srvPK cipher.PubKe if err != nil { return ClientSession{}, err } - return ce.dialSession(ctx, srvEntry) } diff --git a/pkg/dmsg/client_session.go b/pkg/dmsg/client_session.go index 7c766975..ba3cfff3 100644 --- a/pkg/dmsg/client_session.go +++ b/pkg/dmsg/client_session.go @@ -69,6 +69,50 @@ func (cs *ClientSession) DialStream(dst Addr) (dStr *Stream, err error) { return dStr, err } +// LookupIP attempts to dial a stream to the server for the IP address of the client. +func (cs *ClientSession) LookupIP(dst Addr) (myIP net.IP, err error) { + log := cs.log. + WithField("func", "ClientSession.LookupIP"). + WithField("dst_addr", cs.rPK) + + dStr, err := newInitiatingStream(cs) + if err != nil { + return nil, err + } + + // Close stream on failure. + defer func() { + if err != nil { + log.WithError(err). + WithField("close_error", dStr.Close()). + Debug("Stream closed on failure.") + } + }() + + // Prepare deadline. + if err = dStr.SetDeadline(time.Now().Add(HandshakeTimeout)); err != nil { + return nil, err + } + + // Do stream handshake. + req, err := dStr.writeIPRequest(dst) + if err != nil { + return nil, err + } + + myIP, err = dStr.readIPResponse(req) + if err != nil { + return nil, err + } + + err = dStr.Close() + if err != nil { + return nil, err + } + + return myIP, err +} + // serve accepts incoming streams from remote clients. func (cs *ClientSession) serve() error { defer func() { diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index 57ce2b30..4b8376ee 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -2,6 +2,7 @@ package dmsg import ( + "fmt" "io" "net" @@ -98,6 +99,29 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream) WithField("dst_addr", req.DstAddr) log.Debug("Read stream request from initiating side.") + if req.IPinfo && req.DstAddr.PK == ss.entity.LocalPK() { + log.Debug("Received IP stream request.") + + ip, err := addrToIP(yStr.RemoteAddr()) + if err != nil { + ss.m.RecordStream(servermetrics.DeltaFailed) // record failed stream + return err + } + + resp := StreamResponse{ + ReqHash: req.raw.Hash(), + Accepted: true, + IP: ip, + } + obj := MakeSignedStreamResponse(&resp, ss.entity.LocalSK()) + + if err := ss.writeObject(yStr, obj); err != nil { + ss.m.RecordStream(servermetrics.DeltaFailed) // record failed stream + return err + } + log.Debug("Wrote IP stream response.") + return nil + } // Obtain next session. ss2, ok := ss.entity.serverSession(req.DstAddr.PK) @@ -129,6 +153,17 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream) return netutil.CopyReadWriteCloser(yStr, yStr2) } +func addrToIP(addr net.Addr) (net.IP, error) { + switch a := addr.(type) { + case *net.TCPAddr: + return a.IP, nil + case *net.UDPAddr: + return a.IP, nil + default: + return nil, fmt.Errorf("unsupported address type %T", addr) + } +} + func (ss *ServerSession) forwardRequest(req StreamRequest) (yStr *yamux.Stream, respObj SignedObject, err error) { defer func() { if err != nil && yStr != nil { diff --git a/pkg/dmsg/stream.go b/pkg/dmsg/stream.go index 72d1bc4e..6bc61f1a 100644 --- a/pkg/dmsg/stream.go +++ b/pkg/dmsg/stream.go @@ -87,6 +87,29 @@ func (s *Stream) writeRequest(rAddr Addr) (req StreamRequest, err error) { return } +func (s *Stream) writeIPRequest(rAddr Addr) (req StreamRequest, err error) { + // Reserve stream in porter. + var lPort uint16 + if lPort, s.close, err = s.ses.porter.ReserveEphemeral(context.Background(), s); err != nil { + return + } + + // Prepare fields. + s.prepareFields(true, Addr{PK: s.ses.LocalPK(), Port: lPort}, rAddr) + + req = StreamRequest{ + Timestamp: time.Now().UnixNano(), + SrcAddr: s.lAddr, + DstAddr: s.rAddr, + IPinfo: true, + } + obj := MakeSignedStreamRequest(&req, s.ses.localSK()) + + // Write request. + err = s.ses.writeObject(s.yStr, obj) + return +} + func (s *Stream) readRequest() (req StreamRequest, err error) { var obj SignedObject if obj, err = s.ses.readObject(s.yStr); err != nil { @@ -158,6 +181,21 @@ func (s *Stream) readResponse(req StreamRequest) error { return s.ns.ProcessHandshakeMessage(resp.NoiseMsg) } +func (s *Stream) readIPResponse(req StreamRequest) (net.IP, error) { + obj, err := s.ses.readObject(s.yStr) + if err != nil { + return nil, err + } + resp, err := obj.ObtainStreamResponse() + if err != nil { + return nil, err + } + if err := resp.Verify(req); err != nil { + return nil, err + } + return resp.IP, nil +} + func (s *Stream) prepareFields(init bool, lAddr, rAddr Addr) { ns, err := noise.New(noise.HandshakeKK, noise.Config{ LocalPK: s.ses.LocalPK(), diff --git a/pkg/dmsg/stream_test.go b/pkg/dmsg/stream_test.go index 2356f4b0..dd2c4b87 100644 --- a/pkg/dmsg/stream_test.go +++ b/pkg/dmsg/stream_test.go @@ -6,12 +6,14 @@ import ( "fmt" "io" "net" + "runtime" "sync" "testing" "time" "github.com/skycoin/skywire-utilities/pkg/cipher" "github.com/skycoin/skywire-utilities/pkg/logging" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/nettest" @@ -209,6 +211,99 @@ func TestStream(t *testing.T) { require.NoError(t, <-chSrv) } +func TestLookupIP(t *testing.T) { + // Prepare mock discovery. + dc := disc.NewMock(0) + const maxSessions = 10 + + // Prepare dmsg server A. + pkSrvA, skSrvB := GenKeyPair(t, "server_A") + srvConf := &ServerConfig{ + MaxSessions: maxSessions, + UpdateInterval: 0, + LimitIP: 200, + } + srv := NewServer(pkSrvA, skSrvB, dc, srvConf, nil) + srv.SetLogger(logging.MustGetLogger("server")) + lisSrv, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Serve dmsg server. + chSrv := make(chan error, 1) + go func() { chSrv <- srv.Serve(lisSrv, "") }() //nolint:errcheck + + // Prepare and serve dmsg client A. + pkA, skA := GenKeyPair(t, "client A") + + clientConfig := &Config{ + MinSessions: DefaultMinSessions, + UpdateInterval: DefaultUpdateInterval * 5, + ClientType: "test", + } + + dmsgC := NewClient(pkA, skA, dc, clientConfig) + go dmsgC.Serve(context.Background()) + t.Cleanup(func() { assert.NoError(t, dmsgC.Close()) }) + <-dmsgC.Ready() + + t.Run("test_connected_server", func(t *testing.T) { + // Ensure all entities are registered in discovery before continuing. + time.Sleep(time.Second * 2) + + // Lookup IP. + srvs := []cipher.PubKey{pkSrvA} + ip, err := dmsgC.LookupIP(context.Background(), srvs) + require.NoError(t, err) + + if runtime.GOOS == "windows" { + require.Equal(t, net.ParseIP("127.0.0.1"), ip) + } else { + require.Equal(t, net.ParseIP("::1"), ip) + } + + // Ensure all entities are deregistered in discovery before continuing. + time.Sleep(time.Second * 2) + }) + + t.Run("test_disconnected_server", func(t *testing.T) { + // Prepare dmsg server B. + pkSrvB, skSrvB := GenKeyPair(t, "server_B") + srvB := NewServer(pkSrvB, skSrvB, dc, srvConf, nil) + srvB.SetLogger(logging.MustGetLogger("server_B")) + lisSrvB, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Serve dmsg server B. + chSrvB := make(chan error, 1) + go func() { chSrvB <- srvB.Serve(lisSrvB, "") }() //nolint:errcheck + + // Ensure all entities are registered in discovery before continuing. + time.Sleep(time.Second * 2) + + srvs := []cipher.PubKey{pkSrvB} + ip, err := dmsgC.LookupIP(context.Background(), srvs) + require.NoError(t, err) + + if runtime.GOOS == "windows" { + require.Equal(t, net.ParseIP("127.0.0.1"), ip) + } else { + require.Equal(t, net.ParseIP("::1"), ip) + } + + // Ensure all entities are deregistered in discovery before continuing. + time.Sleep(time.Second * 2) + + // Ensure the server B entry is deleted and server A entry is still there. + pks := dmsgC.ConnectedServersPK() + require.Equal(t, []string{pkSrvA.String()}, pks) + }) + + // Closing logic. + require.NoError(t, dmsgC.Close()) + require.NoError(t, srv.Close()) + require.NoError(t, <-chSrv) +} + func GenKeyPair(t *testing.T, seed string) (cipher.PubKey, cipher.SecKey) { pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte(seed)) require.NoError(t, err) diff --git a/pkg/dmsg/types.go b/pkg/dmsg/types.go index 1e39ee14..9fbb2ca3 100644 --- a/pkg/dmsg/types.go +++ b/pkg/dmsg/types.go @@ -4,6 +4,7 @@ package dmsg import ( "errors" "fmt" + "net" "strings" "time" @@ -167,6 +168,7 @@ type StreamRequest struct { Timestamp int64 SrcAddr Addr DstAddr Addr + IPinfo bool NoiseMsg []byte raw SignedObject `enc:"-"` // back reference. @@ -203,6 +205,7 @@ func (req StreamRequest) Verify(lastTimestamp int64) error { type StreamResponse struct { ReqHash cipher.SHA256 // Hash of associated dial request. Accepted bool // Whether the request is accepted. + IP net.IP // IP address of the node. ErrCode errorCode // Check if not accepted. NoiseMsg []byte