diff --git a/state/chain.go b/state/chain.go index dc4b4b5..d5d19de 100644 --- a/state/chain.go +++ b/state/chain.go @@ -40,22 +40,24 @@ func newChainIterator( startPoint ocommon.Point, inclusive bool, ) (*ChainIterator, error) { - // Lookup start block in metadata DB - tmpBlock, err := models.BlockByPoint(ls.db, startPoint) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrBlockNotFound - } - return nil, err - } ci := &ChainIterator{ - ls: ls, - startPoint: startPoint, - blockNumber: tmpBlock.Number, + ls: ls, + startPoint: startPoint, } - // Increment next block number is non-inclusive - if !inclusive { - ci.blockNumber++ + // Lookup start block in metadata DB if not origin + if startPoint.Slot > 0 || len(startPoint.Hash) > 0 { + tmpBlock, err := models.BlockByPoint(ls.db, startPoint) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrBlockNotFound + } + return nil, err + } + ci.blockNumber = tmpBlock.Number + // Increment next block number if non-inclusive + if !inclusive { + ci.blockNumber++ + } } return ci, nil } diff --git a/state/state.go b/state/state.go index c36bea8..61ae94d 100644 --- a/state/state.go +++ b/state/state.go @@ -618,6 +618,7 @@ func (ls *LedgerState) GetIntersectPoint( ) (*ocommon.Point, error) { tip := ls.Tip() var ret ocommon.Point + foundOrigin := false txn := ls.db.Transaction(false) err := txn.Do(func(txn *database.Txn) error { for _, point := range points { @@ -629,6 +630,11 @@ func (ls *LedgerState) GetIntersectPoint( if point.Slot < ret.Slot { continue } + // Check for special origin point + if point.Slot == 0 && len(point.Hash) == 0 { + foundOrigin = true + continue + } // Lookup block in metadata DB tmpBlock, err := models.BlockByPoint(ls.db, point) if err != nil { @@ -646,7 +652,7 @@ func (ls *LedgerState) GetIntersectPoint( if err != nil { return nil, err } - if ret.Slot > 0 { + if ret.Slot > 0 || foundOrigin { return &ret, nil } return nil, nil