Skip to content

Commit

Permalink
fix hanging get traversal (#176)
Browse files Browse the repository at this point in the history
* Tests

* cleanup resources properly

* separate context for republish

* remove race condition
  • Loading branch information
decentralgabe authored Apr 11, 2024
1 parent f5a1022 commit 9eb197a
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 75 deletions.
2 changes: 1 addition & 1 deletion impl/cmd/cli/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ var identityGetCmd = &cobra.Command{
}

// get the identity from the dht
gotResp, err := d.Get(context.Background(), id)
gotResp, err := d.GetFull(context.Background(), id)
if err != nil {
logrus.WithError(err).Error("failed to get identity from dht")
return err
Expand Down
1 change: 0 additions & 1 deletion impl/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/stretchr/testify v1.9.0
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/toorop/gin-logrus v0.0.0-20210225092905-2c785434f26f
github.com/tv42/zbase32 v0.0.0-20220222190657-f76a9fc892fa
go.etcd.io/bbolt v1.3.9
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.50.0
Expand Down
2 changes: 0 additions & 2 deletions impl/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,6 @@ github.com/swaggo/swag v1.8.12/go.mod h1:lNfm6Gg+oAq3zRJQNEMBE66LIJKM44mxFqhEEgy
github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/toorop/gin-logrus v0.0.0-20210225092905-2c785434f26f h1:oqdnd6OGlOUu1InG37hWcCB3a+Jy3fwjylyVboaNMwY=
github.com/toorop/gin-logrus v0.0.0-20210225092905-2c785434f26f/go.mod h1:X3Dd1SB8Gt1V968NTzpKFjMM6O8ccta2NPC6MprOxZQ=
github.com/tursodatabase/libsql-client-go v0.0.0-20240220085343-4ae0eb9d0898 h1:1MvEhzI5pvP27e9Dzz861mxk9WzXZLSJwzOU67cKTbU=
github.com/tursodatabase/libsql-client-go v0.0.0-20240220085343-4ae0eb9d0898/go.mod h1:9bKuHS7eZh/0mJndbUOrCx8Ej3PlsRDszj4L7oVYMPQ=
github.com/tv42/zbase32 v0.0.0-20220222190657-f76a9fc892fa h1:2EwhXkNkeMjX9iFYGWLPQLPhw9O58BhnYgtYKeqybcY=
Expand Down
92 changes: 79 additions & 13 deletions impl/internal/dht/get.go → impl/internal/dht/getput.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"crypto/sha1"
"errors"
"math"
"sync"

"github.com/anacrolix/log"
k_nearest_nodes "github.com/anacrolix/dht/v2/k-nearest-nodes"
"github.com/anacrolix/torrent/bencode"
"github.com/sirupsen/logrus"

"github.com/anacrolix/dht/v2"
"github.com/anacrolix/dht/v2/bep44"
Expand All @@ -16,7 +18,7 @@ import (
)

// Copied from https://github.com/anacrolix/dht/blob/master/exts/getput/getput.go and modified
// to return signature data
// to return signature data and allow for context cancellations

type FullGetResult struct {
Seq int64
Expand All @@ -26,7 +28,7 @@ type FullGetResult struct {
}

func startGetTraversal(
target bep44.Target, s *dht.Server, seq *int64, salt []byte,
ctx context.Context, target bep44.Target, s *dht.Server, seq *int64, salt []byte,
) (
vChan chan FullGetResult, op *traversal.Operation, err error,
) {
Expand All @@ -35,34 +37,35 @@ func startGetTraversal(
Alpha: 15,
Target: target,
DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult {
logger := log.ContextLogger(ctx)
res := s.Get(ctx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{})
err := res.ToError()
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, dht.TransactionTimeout) {
logger.Levelf(log.Debug, "error querying %v: %v", addr, err)
logrus.WithContext(ctx).WithError(err).Debugf("error querying %v", addr)
}
if r := res.Reply.R; r != nil {
rv := r.V
bv := rv
if sha1.Sum(bv) == target {
select {
case vChan <- FullGetResult{
V: rv,
Sig: r.Sig,
V: rv,
Sig: r.Sig,
Mutable: false,
}:
case <-ctx.Done():
}
} else if sha1.Sum(append(r.K[:], salt...)) == target && bep44.Verify(r.K[:], salt, *r.Seq, bv, r.Sig[:]) {
select {
case vChan <- FullGetResult{
Seq: *r.Seq,
V: rv,
Sig: r.Sig,
Seq: *r.Seq,
V: rv,
Sig: r.Sig,
Mutable: true,
}:
case <-ctx.Done():
}
} else if rv != nil {
logger.Levelf(log.Debug, "get response item hash didn't match target: %q", rv)
logrus.WithContext(ctx).Debugf("get response item hash didn't match target: %q", rv)
}
}
tqr := res.TraversalQueryResult(addr)
Expand All @@ -76,6 +79,16 @@ func startGetTraversal(
},
NodeFilter: s.TraversalNodeFilter,
})

// list for context cancellation or stalled traversal
go func() {
select {
case <-ctx.Done():
op.Stop()
case <-op.Stalled():
}
}()

nodes, err := s.TraversalStartingNodes()
op.AddNodes(nodes)
return
Expand All @@ -86,7 +99,7 @@ func Get(
) (
ret FullGetResult, stats *traversal.Stats, err error,
) {
vChan, op, err := startGetTraversal(target, s, seq, salt)
vChan, op, err := startGetTraversal(ctx, target, s, seq, salt)
if err != nil {
return
}
Expand All @@ -99,7 +112,7 @@ receiveResults:
err = errors.New("value not found")
}
case v := <-vChan:
log.ContextLogger(ctx).Levelf(log.Debug, "received %#v", v)
logrus.WithContext(ctx).Debugf("received %#v", v)
gotValue = true
if !v.Mutable {
ret = v
Expand All @@ -116,3 +129,56 @@ receiveResults:
stats = op.Stats()
return
}

type SeqToPut func(seq int64) bep44.Put

func Put(
ctx context.Context, target krpc.ID, s *dht.Server, salt []byte, seqToPut SeqToPut,
) (
stats *traversal.Stats, err error,
) {
vChan, op, err := startGetTraversal(ctx, target, s,
// When we do a get traversal for a put, we don't care what seq the peers have?
nil,
// This is duplicated with the put, but we need it to filter responses for autoSeq.
salt)
if err != nil {
return
}
var autoSeq int64
notDone:
select {
case v := <-vChan:
if v.Mutable && v.Seq > autoSeq {
autoSeq = v.Seq
}
// There are more optimizations that can be done here. We can set CAS automatically, and we
// can skip updating the sequence number if the existing content already matches (and
// presumably republish the existing seq).
goto notDone
case <-op.Stalled():
case <-ctx.Done():
err = ctx.Err()
}
op.Stop()
var wg sync.WaitGroup
put := seqToPut(autoSeq)
op.Closest().Range(func(elem k_nearest_nodes.Elem) {
wg.Add(1)
go func() {
defer wg.Done()
// This is enforced by startGetTraversal.
token := elem.Data.(string)
res := s.Put(ctx, dht.NewAddr(elem.Addr.UDP()), put, token, dht.QueryRateLimiting{})
err = res.ToError()
if err != nil {
logrus.WithContext(ctx).WithError(err).Warnf("error putting to %v [token=%q]", elem.Addr, token)
} else {
logrus.WithContext(ctx).WithError(err).Debugf("put to %v [token=%q]", elem.Addr, token)
}
}()
})
wg.Wait()
stats = op.Stats()
return
}
5 changes: 4 additions & 1 deletion impl/internal/did/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/url"
"time"

"github.com/TBD54566975/ssi-sdk/did"
"github.com/anacrolix/dht/v2/bep44"
Expand All @@ -24,9 +25,11 @@ func NewGatewayClient(gatewayURL string) (*GatewayClient, error) {
if _, err := url.Parse(gatewayURL); err != nil {
return nil, err
}
client := http.DefaultClient
client.Timeout = time.Second * 10
return &GatewayClient{
gatewayURL: gatewayURL,
client: http.DefaultClient,
client: client,
}, nil
}

Expand Down
23 changes: 2 additions & 21 deletions impl/pkg/dht/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
errutil "github.com/TBD54566975/ssi-sdk/util"
"github.com/anacrolix/dht/v2"
"github.com/anacrolix/dht/v2/bep44"
"github.com/anacrolix/dht/v2/exts/getput"
"github.com/anacrolix/log"
"github.com/anacrolix/torrent/types/infohash"
"github.com/pkg/errors"
Expand All @@ -33,7 +32,6 @@ func NewDHT(bootstrapPeers []string) (*DHT, error) {
logrus.WithField("bootstrap_peers", len(bootstrapPeers)).Info("initializing DHT")

c := dht.NewDefaultServerConfig()
// change default expire to 24 hours
c.Exp = time.Hour * 24
c.NoSecurity = false
conn, err := net.ListenPacket("udp", "0.0.0.0:6881")
Expand Down Expand Up @@ -94,7 +92,7 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) {
}

key := util.Z32Encode(request.K[:])
t, err := getput.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put {
t, err := dhtint.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put {
return request
})
if err != nil {
Expand All @@ -108,23 +106,6 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) {
return util.Z32Encode(request.K[:]), nil
}

// Get returns the BEP-44 result for the given key from the DHT.
// The key is a z32-encoded string, such as "yj47pezutnpw9pyudeeai8cx8z8d6wg35genrkoqf9k3rmfzy58o".
func (d *DHT) Get(ctx context.Context, key string) (*getput.GetResult, error) {
ctx, span := telemetry.GetTracer().Start(ctx, "DHT.Get")
defer span.End()

z32Decoded, err := util.Z32Decode(key)
if err != nil {
return nil, errors.Wrapf(err, "failed to decode key [%s]", key)
}
res, t, err := getput.Get(ctx, infohash.HashBytes(z32Decoded), d.Server, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get key[%s] from dht; tried %d nodes, got %d responses", key, t.NumAddrsTried, t.NumResponses)
}
return &res, nil
}

// GetFull returns the full BEP-44 result for the given key from the DHT, using our modified
// implementation of getput.Get. It should ONLY be used when it's needed to get the signature
// data for a record.
Expand All @@ -134,7 +115,7 @@ func (d *DHT) GetFull(ctx context.Context, key string) (*dhtint.FullGetResult, e

z32Decoded, err := util.Z32Decode(key)
if err != nil {
return nil, errutil.LoggingCtxErrorMsgf(ctx, err, "failed to decode key [%s]", key)
return nil, errors.Wrapf(err, "failed to decode key [%s]", key)
}
res, t, err := dhtint.Get(ctx, infohash.HashBytes(z32Decoded), d.Server, nil, nil)
if err != nil {
Expand Down
12 changes: 3 additions & 9 deletions impl/pkg/dht/dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (

func TestGetPutDHT(t *testing.T) {
ctx := context.Background()

d := dhtclient.NewTestDHT(t)
defer d.Close()

pubKey, privKey, err := util.GenerateKeypair()
require.NoError(t, err)
Expand All @@ -34,18 +34,12 @@ func TestGetPutDHT(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, id)

got, err := d.Get(ctx, id)
got, err := d.GetFull(ctx, id)
require.NoError(t, err)
require.NotEmpty(t, got)
require.Equal(t, bencode.Bytes(put.V.([]byte)), got.V[2:])
require.Equal(t, put.Seq, got.Seq)

full, err := d.GetFull(ctx, id)
require.NoError(t, err)
require.NotEmpty(t, full)
require.Equal(t, bencode.Bytes(put.V.([]byte)), full.V[2:])
require.Equal(t, put.Seq, full.Seq)
require.False(t, full.Mutable)
require.True(t, got.Mutable)

var payload string
err = bencode.Unmarshal(got.V, &payload)
Expand Down
9 changes: 5 additions & 4 deletions impl/pkg/dht/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
)

func init() {
logrus.SetFormatter(&logrus.JSONFormatter{})
log.Default.Handlers = []log.Handler{logrusHandler{}}
}

Expand All @@ -19,12 +20,12 @@ func (logrusHandler) Handle(record log.Record) {

switch record.Level {
case log.Debug:
entry.Debug(msg)
entry.Debugf("%s\n", msg)
case log.Info:
entry.Info(msg)
entry.Infof("%s\n", msg)
case log.Warning, log.Error:
entry.Warn(msg)
entry.Warnf("%s\n", msg)
default:
entry.Debug(msg)
entry.Debugf("%s\n", msg)
}
}
5 changes: 3 additions & 2 deletions impl/pkg/dht/pkarr.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (

"github.com/TBD54566975/ssi-sdk/util"
"github.com/anacrolix/dht/v2/bep44"
"github.com/anacrolix/dht/v2/exts/getput"
"github.com/anacrolix/torrent/bencode"
"github.com/miekg/dns"

"github.com/TBD54566975/did-dht-method/internal/dht"
)

// CreatePkarrPublishRequest creates a put request for the given records. Requires a public/private keypair and the records to put.
Expand Down Expand Up @@ -50,7 +51,7 @@ func CreatePkarrPublishRequest(privateKey ed25519.PrivateKey, msg dns.Msg) (*bep

// ParsePkarrGetResponse parses the response from a get request.
// The response is expected to be a slice of DNS resource records.
func ParsePkarrGetResponse(response getput.GetResult) (*dns.Msg, error) {
func ParsePkarrGetResponse(response dht.FullGetResult) (*dns.Msg, error) {
var payload string
if err := bencode.Unmarshal(response.V, &payload); err != nil {
return nil, util.LoggingErrorMsg(err, "failed to unmarshal payload value")
Expand Down
12 changes: 7 additions & 5 deletions impl/pkg/dht/pkarr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import (
"github.com/TBD54566975/did-dht-method/internal/util"
)

func TestGetPutPKARRDHT(t *testing.T) {
d := NewTestDHT(t)
func TestGetPutPkarrDHT(t *testing.T) {
dht := NewTestDHT(t)
defer dht.Close()

_, privKey, err := util.GenerateKeypair()
require.NoError(t, err)
Expand All @@ -44,11 +45,11 @@ func TestGetPutPKARRDHT(t *testing.T) {
put, err := CreatePkarrPublishRequest(privKey, msg)
require.NoError(t, err)

id, err := d.Put(context.Background(), *put)
id, err := dht.Put(context.Background(), *put)
require.NoError(t, err)
require.NotEmpty(t, id)

got, err := d.Get(context.Background(), id)
got, err := dht.GetFull(context.Background(), id)
require.NoError(t, err)
require.NotEmpty(t, got)

Expand All @@ -61,6 +62,7 @@ func TestGetPutPKARRDHT(t *testing.T) {

func TestGetPutDIDDHT(t *testing.T) {
dht := NewTestDHT(t)
defer dht.Close()

pubKey, _, err := crypto.GenerateSECP256k1Key()
require.NoError(t, err)
Expand Down Expand Up @@ -108,7 +110,7 @@ func TestGetPutDIDDHT(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, gotID)

got, err := dht.Get(context.Background(), gotID)
got, err := dht.GetFull(context.Background(), gotID)
require.NoError(t, err)
require.NotEmpty(t, got)

Expand Down
Loading

0 comments on commit 9eb197a

Please sign in to comment.