Skip to content

Commit

Permalink
refactor: refreshOnchainPaymentState arg
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Dec 17, 2024
1 parent b679a14 commit f5db445
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 35 deletions.
2 changes: 1 addition & 1 deletion core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (m *Meterer) Start(ctx context.Context) {
for {
select {
case <-ticker.C:
if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx, nil); err != nil {
if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx); err != nil {
m.logger.Error("Failed to refresh on-chain state", "error", err)
}
case <-ctx.Done():
Expand Down
2 changes: 1 addition & 1 deletion core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func setup(_ *testing.M) {
}

paymentChainState.On("RefreshOnchainPaymentState", testifymock.Anything).Return(nil).Maybe()
if err := paymentChainState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := paymentChainState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}

Expand Down
45 changes: 22 additions & 23 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// OnchainPaymentState is an interface for getting information about the current chain state for payments.
type OnchainPayment interface {
RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error
RefreshOnchainPaymentState(ctx context.Context) error
GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error)
GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error)
GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error)
Expand Down Expand Up @@ -49,49 +49,45 @@ type PaymentVaultParams struct {
}

func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymentState, error) {
paymentVaultParams, err := GetPaymentVaultParams(ctx, tx)
if err != nil {
return nil, err
}

state := OnchainPaymentState{
tx: tx,
ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment),
OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment),
PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{},
}
state.PaymentVaultParams.Store(paymentVaultParams)

return &state, nil
}

func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultParams, error) {
blockNumber, err := tx.GetCurrentBlockNumber(ctx)
paymentVaultParams, err := state.GetPaymentVaultParams(ctx)
if err != nil {
return nil, err
}

quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber)
state.PaymentVaultParams.Store(paymentVaultParams)

return &state, nil
}

func (pcs *OnchainPaymentState) GetPaymentVaultParams(ctx context.Context) (*PaymentVaultParams, error) {
quorumNumbers, err := pcs.GetOnDemandQuorumNumbers(ctx)
if err != nil {
return nil, err
}

globalSymbolsPerSecond, err := tx.GetGlobalSymbolsPerSecond(ctx)
globalSymbolsPerSecond, err := pcs.tx.GetGlobalSymbolsPerSecond(ctx)
if err != nil {
return nil, err
}

minNumSymbols, err := tx.GetMinNumSymbols(ctx)
minNumSymbols, err := pcs.tx.GetMinNumSymbols(ctx)
if err != nil {
return nil, err
}

pricePerSymbol, err := tx.GetPricePerSymbol(ctx)
pricePerSymbol, err := pcs.tx.GetPricePerSymbol(ctx)
if err != nil {
return nil, err
}

reservationWindow, err := tx.GetReservationWindow(ctx)
reservationWindow, err := pcs.tx.GetReservationWindow(ctx)
if err != nil {
return nil, err
}
Expand All @@ -106,8 +102,8 @@ func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultPa
}

// RefreshOnchainPaymentState returns the current onchain payment state
func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error {
paymentVaultParams, err := GetPaymentVaultParams(ctx, tx)
func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error {
paymentVaultParams, err := pcs.GetPaymentVaultParams(ctx)
if err != nil {
return err
}
Expand All @@ -120,7 +116,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
accountIDs = append(accountIDs, accountID)
}

reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs)
reservedPayments, err := pcs.tx.GetReservedPayments(ctx, accountIDs)
if err != nil {
return err
}
Expand All @@ -133,7 +129,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
accountIDs = append(accountIDs, accountID)
}

onDemandPayments, err := tx.GetOnDemandPayments(ctx, accountIDs)
onDemandPayments, err := pcs.tx.GetOnDemandPayments(ctx, accountIDs)
if err != nil {
return err
}
Expand All @@ -146,10 +142,11 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
// GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) {
pcs.ReservationsLock.RLock()
defer pcs.ReservationsLock.RUnlock()
if reservation, ok := (pcs.ReservedPayments)[accountID]; ok {
pcs.ReservationsLock.RUnlock()
return reservation, nil
}
pcs.ReservationsLock.RUnlock()

// pulls the chain state
res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID)
Expand All @@ -166,10 +163,12 @@ func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context,
// GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment
func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
pcs.OnDemandLocks.RLock()
defer pcs.OnDemandLocks.RUnlock()
if payment, ok := (pcs.OnDemandPayments)[accountID]; ok {
pcs.OnDemandLocks.RUnlock()
return payment, nil
}
pcs.OnDemandLocks.RUnlock()

// pulls the chain state
res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions core/meterer/onchain_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"

"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
"github.com/Layr-Labs/eigenda/core/mock"
gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
Expand All @@ -30,7 +29,7 @@ func TestRefreshOnchainPaymentState(t *testing.T) {
ctx := context.Background()
mockState.On("RefreshOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(nil)

err := mockState.RefreshOnchainPaymentState(ctx, &eth.Reader{})
err := mockState.RefreshOnchainPaymentState(ctx)
assert.NoError(t, err)
}

Expand Down
3 changes: 1 addition & 2 deletions core/mock/payment_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"

"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
"github.com/Layr-Labs/eigenda/core/meterer"
gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/mock"
Expand All @@ -25,7 +24,7 @@ func (m *MockOnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (ui
return value, args.Error(1)
}

func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error {
func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error {
args := m.Called()
return args.Error(0)
}
Expand Down
4 changes: 2 additions & 2 deletions disperser/apiserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal

mockState := &mock.MockOnchainPaymentState{}
mockState.On("RefreshOnchainPaymentState", tmock.Anything).Return(nil).Maybe()
if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}

Expand Down Expand Up @@ -798,7 +798,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal
panic("failed to create offchain store")
}
mt := meterer.NewMeterer(meterer.Config{}, mockState, store, logger)
err = mt.ChainPaymentState.RefreshOnchainPaymentState(context.Background(), nil)
err = mt.ChainPaymentState.RefreshOnchainPaymentState(context.Background())
if err != nil {
panic("failed to make initial query to the on-chain state")
}
Expand Down
2 changes: 1 addition & 1 deletion disperser/apiserver/server_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func newTestServerV2(t *testing.T) *testComponents {
mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{CumulativePayment: big.NewInt(3864)}, nil)
mockState.On("GetOnDemandQuorumNumbers", tmock.Anything).Return([]uint8{0, 1}, nil)

if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}
table_names := []string{"reservations_server_" + t.Name(), "ondemand_server_" + t.Name(), "global_server_" + t.Name()}
Expand Down
2 changes: 1 addition & 1 deletion disperser/cmd/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func RunDisperserServer(ctx *cli.Context) error {
if err != nil {
return fmt.Errorf("failed to create onchain payment state: %w", err)
}
if err := paymentChainState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := paymentChainState.RefreshOnchainPaymentState(context.Background()); err != nil {
return fmt.Errorf("failed to make initial query to the on-chain state: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/stretchr/testify/require"
"log"
"math"
"math/big"
Expand All @@ -18,6 +17,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/Layr-Labs/eigenda/common/pubip"
"github.com/Layr-Labs/eigenda/encoding/kzg"
"github.com/Layr-Labs/eigenda/encoding/kzg/prover"
Expand Down Expand Up @@ -289,7 +290,7 @@ func mustMakeDisperser(t *testing.T, cst core.IndexedChainState, store disperser
}

mockState.On("RefreshOnchainPaymentState", mock.Anything).Return(nil).Maybe()
if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}

Expand Down

0 comments on commit f5db445

Please sign in to comment.