Skip to content

Commit

Permalink
refactor: refreshOnchainPaymentState arg
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Dec 17, 2024
1 parent b679a14 commit f4c4bb0
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// OnchainPaymentState is an interface for getting information about the current chain state for payments.
type OnchainPayment interface {
RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error
RefreshOnchainPaymentState(ctx context.Context) error
GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error)
GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error)
GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error)
Expand Down Expand Up @@ -49,49 +49,45 @@ type PaymentVaultParams struct {
}

func NewOnchainPaymentState(ctx context.Context, tx *eth.Reader) (*OnchainPaymentState, error) {
paymentVaultParams, err := GetPaymentVaultParams(ctx, tx)
if err != nil {
return nil, err
}

state := OnchainPaymentState{
tx: tx,
ReservedPayments: make(map[gethcommon.Address]*core.ReservedPayment),
OnDemandPayments: make(map[gethcommon.Address]*core.OnDemandPayment),
PaymentVaultParams: atomic.Pointer[PaymentVaultParams]{},
}
state.PaymentVaultParams.Store(paymentVaultParams)

return &state, nil
}

func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultParams, error) {
blockNumber, err := tx.GetCurrentBlockNumber(ctx)
paymentVaultParams, err := state.GetPaymentVaultParams(ctx)
if err != nil {
return nil, err
}

quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber)
state.PaymentVaultParams.Store(paymentVaultParams)

return &state, nil
}

func (pcs *OnchainPaymentState) GetPaymentVaultParams(ctx context.Context) (*PaymentVaultParams, error) {
quorumNumbers, err := pcs.GetOnDemandQuorumNumbers(ctx)
if err != nil {
return nil, err
}

globalSymbolsPerSecond, err := tx.GetGlobalSymbolsPerSecond(ctx)
globalSymbolsPerSecond, err := pcs.tx.GetGlobalSymbolsPerSecond(ctx)
if err != nil {
return nil, err
}

minNumSymbols, err := tx.GetMinNumSymbols(ctx)
minNumSymbols, err := pcs.tx.GetMinNumSymbols(ctx)
if err != nil {
return nil, err
}

pricePerSymbol, err := tx.GetPricePerSymbol(ctx)
pricePerSymbol, err := pcs.tx.GetPricePerSymbol(ctx)
if err != nil {
return nil, err
}

reservationWindow, err := tx.GetReservationWindow(ctx)
reservationWindow, err := pcs.tx.GetReservationWindow(ctx)
if err != nil {
return nil, err
}
Expand All @@ -106,8 +102,8 @@ func GetPaymentVaultParams(ctx context.Context, tx *eth.Reader) (*PaymentVaultPa
}

// RefreshOnchainPaymentState returns the current onchain payment state
func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Reader) error {
paymentVaultParams, err := GetPaymentVaultParams(ctx, tx)
func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context) error {
paymentVaultParams, err := pcs.GetPaymentVaultParams(ctx)
if err != nil {
return err
}
Expand All @@ -120,7 +116,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
accountIDs = append(accountIDs, accountID)
}

reservedPayments, err := tx.GetReservedPayments(ctx, accountIDs)
reservedPayments, err := pcs.tx.GetReservedPayments(ctx, accountIDs)
if err != nil {
return err
}
Expand All @@ -133,7 +129,7 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
accountIDs = append(accountIDs, accountID)
}

onDemandPayments, err := tx.GetOnDemandPayments(ctx, accountIDs)
onDemandPayments, err := pcs.tx.GetOnDemandPayments(ctx, accountIDs)
if err != nil {
return err
}
Expand All @@ -146,10 +142,11 @@ func (pcs *OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context,
// GetReservedPaymentByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.ReservedPayment, error) {
pcs.ReservationsLock.RLock()
defer pcs.ReservationsLock.RUnlock()
if reservation, ok := (pcs.ReservedPayments)[accountID]; ok {
pcs.ReservationsLock.RUnlock()
return reservation, nil
}
pcs.ReservationsLock.RUnlock()

// pulls the chain state
res, err := pcs.tx.GetReservedPaymentByAccount(ctx, accountID)
Expand All @@ -166,10 +163,12 @@ func (pcs *OnchainPaymentState) GetReservedPaymentByAccount(ctx context.Context,
// GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment
func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID gethcommon.Address) (*core.OnDemandPayment, error) {
pcs.OnDemandLocks.RLock()
defer pcs.OnDemandLocks.RUnlock()
if payment, ok := (pcs.OnDemandPayments)[accountID]; ok {
pcs.OnDemandLocks.RUnlock()
return payment, nil
}
pcs.OnDemandLocks.RUnlock()

// pulls the chain state
res, err := pcs.tx.GetOnDemandPaymentByAccount(ctx, accountID)
if err != nil {
Expand Down

0 comments on commit f4c4bb0

Please sign in to comment.