Skip to content

Commit

Permalink
fix: support origin point for chainsync intersect (#296)
Browse files Browse the repository at this point in the history
Fixes #292
  • Loading branch information
agaffney authored Dec 22, 2024
1 parent 2634e8a commit b2dc611
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
30 changes: 16 additions & 14 deletions state/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ func newChainIterator(
startPoint ocommon.Point,
inclusive bool,
) (*ChainIterator, error) {
// Lookup start block in metadata DB
tmpBlock, err := models.BlockByPoint(ls.db, startPoint)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrBlockNotFound
}
return nil, err
}
ci := &ChainIterator{
ls: ls,
startPoint: startPoint,
blockNumber: tmpBlock.Number,
ls: ls,
startPoint: startPoint,
}
// Increment next block number is non-inclusive
if !inclusive {
ci.blockNumber++
// Lookup start block in metadata DB if not origin
if startPoint.Slot > 0 || len(startPoint.Hash) > 0 {
tmpBlock, err := models.BlockByPoint(ls.db, startPoint)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrBlockNotFound
}
return nil, err
}
ci.blockNumber = tmpBlock.Number
// Increment next block number if non-inclusive
if !inclusive {
ci.blockNumber++
}
}
return ci, nil
}
Expand Down
8 changes: 7 additions & 1 deletion state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ func (ls *LedgerState) GetIntersectPoint(
) (*ocommon.Point, error) {
tip := ls.Tip()
var ret ocommon.Point
foundOrigin := false
txn := ls.db.Transaction(false)
err := txn.Do(func(txn *database.Txn) error {
for _, point := range points {
Expand All @@ -629,6 +630,11 @@ func (ls *LedgerState) GetIntersectPoint(
if point.Slot < ret.Slot {
continue
}
// Check for special origin point
if point.Slot == 0 && len(point.Hash) == 0 {
foundOrigin = true
continue
}
// Lookup block in metadata DB
tmpBlock, err := models.BlockByPoint(ls.db, point)
if err != nil {
Expand All @@ -646,7 +652,7 @@ func (ls *LedgerState) GetIntersectPoint(
if err != nil {
return nil, err
}
if ret.Slot > 0 {
if ret.Slot > 0 || foundOrigin {
return &ret, nil
}
return nil, nil
Expand Down

0 comments on commit b2dc611

Please sign in to comment.