diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index bc5922e88..60f0ee7cb 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -296,9 +296,12 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) { lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) for { last := lastTSPointer.Load() - if current.tso <= last.tso || !current.arrival.After(last.arrival) { + if current.tso <= last.tso { return } + if last.arrival.After(current.arrival) { + current.arrival = last.arrival + } if lastTSPointer.CompareAndSwap(last, current) { return } diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 5110cc266..01c67c5e4 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -551,3 +551,38 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) { assert.NoError(t, err) mustNoNotify() } + +func TestSetLastTSAlwaysPushTS(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + var wg sync.WaitGroup + cancel := make(chan struct{}) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for { + select { + case <-cancel: + return + default: + } + ts, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + lastTS, found := o.getLastTS(oracle.GlobalTxnScope) + assert.True(t, found) + assert.GreaterOrEqual(t, lastTS, ts) + } + }() + } + time.Sleep(time.Second) + close(cancel) + wg.Wait() +}