Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refreshOnchainPaymentState arg #1012

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (m *Meterer) Start(ctx context.Context) {
for {
select {
case <-ticker.C:
if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx, nil); err != nil {
if err := m.ChainPaymentState.RefreshOnchainPaymentState(ctx); err != nil {
m.logger.Error("Failed to refresh on-chain state", "error", err)
}
case <-ctx.Done():
Expand Down
2 changes: 1 addition & 1 deletion core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func setup(_ *testing.M) {
}

paymentChainState.On("RefreshOnchainPaymentState", testifymock.Anything).Return(nil).Maybe()
if err := paymentChainState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := paymentChainState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}

Expand Down
45 changes: 22 additions & 23 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// OnchainPaymentState is an interface for getting information about the current chain state for payments.
type OnchainPayment interface {
RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error
RefreshOnchainPaymentState(ctx context.Context) error
GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error)
GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error)
GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error)
Expand Down Expand Up @@ -49,49 +49,45 @@ type PaymentVaultParams struct {
}

func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymentState, error) {
paymentVaultParams, err := GetPaymentVaultParams(ctx, tx)
if err != nil {
return nil, err
}

state := OnchainPaymentState{
tx: tx,
ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment),
OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment),
PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{},
}
state.PaymentVaultParams.Store(paymentVaultParams)

return &state, nil
}

func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultParams, error) {
blockNumber, err := tx.GetCurrentBlockNumber(ctx)
paymentVaultParams, err := state.GetPaymentVaultParams(ctx)
if err != nil {
return nil, err
}

quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber)
state.PaymentVaultParams.Store(paymentVaultParams)

return &state, nil
}

func (pcs *OnchainPaymentState) GetPaymentVaultParams(ctx context.Context) (*PaymentVaultParams, error) {
quorumNumbers, err := pcs.GetOnDemandQuorumNumbers(ctx)
if err != nil {
return nil, err
}

globalSymbolsPerSecond, err := tx.GetGlobalSymbolsPerSecond(ctx)
globalSymbolsPerSecond, err := pcs.tx.GetGlobalSymbolsPerSecond(ctx)
if err != nil {
return nil, err
}

minNumSymbols, err := tx.GetMinNumSymbols(ctx)
minNumSymbols, err := pcs.tx.GetMinNumSymbols(ctx)
if err != nil {
return nil, err
}

pricePerSymbol, err := tx.GetPricePerSymbol(ctx)
pricePerSymbol, err := pcs.tx.GetPricePerSymbol(ctx)
if err != nil {
return nil, err
}

reservationWindow, err := tx.GetReservationWindow(ctx)
reservationWindow, err := pcs.tx.GetReservationWindow(ctx)
if err != nil {
return nil, err
}
Expand All @@ -106,8 +102,8 @@ func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultPa
}

// RefreshOnchainPaymentState returns the current onchain payment state
func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error {
paymentVaultParams, err := GetPaymentVaultParams(ctx, tx)
func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error {
paymentVaultParams, err := pcs.GetPaymentVaultParams(ctx)
if err != nil {
return err
}
Expand All @@ -120,7 +116,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
accountIDs = append(accountIDs, accountID)
}

reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs)
reservedPayments, err := pcs.tx.GetReservedPayments(ctx, accountIDs)
if err != nil {
return err
}
Expand All @@ -133,7 +129,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
accountIDs = append(accountIDs, accountID)
}

onDemandPayments, err := tx.GetOnDemandPayments(ctx, accountIDs)
onDemandPayments, err := pcs.tx.GetOnDemandPayments(ctx, accountIDs)
if err != nil {
return err
}
Expand All @@ -146,10 +142,11 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
// GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) {
pcs.ReservationsLock.RLock()
defer pcs.ReservationsLock.RUnlock()
if reservation, ok := (pcs.ReservedPayments)[accountID]; ok {
pcs.ReservationsLock.RUnlock()
return reservation, nil
}
pcs.ReservationsLock.RUnlock()

// pulls the chain state
res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID)
Expand All @@ -166,10 +163,12 @@ func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context,
// GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment
func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
pcs.OnDemandLocks.RLock()
defer pcs.OnDemandLocks.RUnlock()
if payment, ok := (pcs.OnDemandPayments)[accountID]; ok {
pcs.OnDemandLocks.RUnlock()
return payment, nil
}
pcs.OnDemandLocks.RUnlock()

// pulls the chain state
res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions core/meterer/onchain_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"

"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
"github.com/Layr-Labs/eigenda/core/mock"
gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
Expand All @@ -30,7 +29,7 @@ func TestRefreshOnchainPaymentState(t *testing.T) {
ctx := context.Background()
mockState.On("RefreshOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(nil)

err := mockState.RefreshOnchainPaymentState(ctx, &eth.Reader{})
err := mockState.RefreshOnchainPaymentState(ctx)
assert.NoError(t, err)
}

Expand Down
3 changes: 1 addition & 2 deletions core/mock/payment_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"

"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
"github.com/Layr-Labs/eigenda/core/meterer"
gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/mock"
Expand All @@ -25,7 +24,7 @@ func (m *MockOnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (ui
return value, args.Error(1)
}

func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error {
func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error {
args := m.Called()
return args.Error(0)
}
Expand Down
4 changes: 2 additions & 2 deletions disperser/apiserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal

mockState := &mock.MockOnchainPaymentState{}
mockState.On("RefreshOnchainPaymentState", tmock.Anything).Return(nil).Maybe()
if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}

Expand Down Expand Up @@ -798,7 +798,7 @@ func newTestServer(transactor core.Writer, testName string) *apiserver.Dispersal
panic("failed to create offchain store")
}
mt := meterer.NewMeterer(meterer.Config{}, mockState, store, logger)
err = mt.ChainPaymentState.RefreshOnchainPaymentState(context.Background(), nil)
err = mt.ChainPaymentState.RefreshOnchainPaymentState(context.Background())
if err != nil {
panic("failed to make initial query to the on-chain state")
}
Expand Down
2 changes: 1 addition & 1 deletion disperser/apiserver/server_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func newTestServerV2(t *testing.T) *testComponents {
mockState.On("GetOnDemandPaymentByAccount", tmock.Anything, tmock.Anything).Return(&core.OnDemandPayment{CumulativePayment: big.NewInt(3864)}, nil)
mockState.On("GetOnDemandQuorumNumbers", tmock.Anything).Return([]uint8{0, 1}, nil)

if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}
table_names := []string{"reservations_server_" + t.Name(), "ondemand_server_" + t.Name(), "global_server_" + t.Name()}
Expand Down
2 changes: 1 addition & 1 deletion disperser/cmd/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func RunDisperserServer(ctx *cli.Context) error {
if err != nil {
return fmt.Errorf("failed to create onchain payment state: %w", err)
}
if err := paymentChainState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := paymentChainState.RefreshOnchainPaymentState(context.Background()); err != nil {
return fmt.Errorf("failed to make initial query to the on-chain state: %w", err)
}

Expand Down
5 changes: 3 additions & 2 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/stretchr/testify/require"
"log"
"math"
"math/big"
Expand All @@ -18,6 +17,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/Layr-Labs/eigenda/common/pubip"
"github.com/Layr-Labs/eigenda/encoding/kzg"
"github.com/Layr-Labs/eigenda/encoding/kzg/prover"
Expand Down Expand Up @@ -289,7 +290,7 @@ func mustMakeDisperser(t *testing.T, cst core.IndexedChainState, store disperser
}

mockState.On("RefreshOnchainPaymentState", mock.Anything).Return(nil).Maybe()
if err := mockState.RefreshOnchainPaymentState(context.Background(), nil); err != nil {
if err := mockState.RefreshOnchainPaymentState(context.Background()); err != nil {
panic("failed to make initial query to the on-chain state")
}

Expand Down
Loading