Skip to content

Commit

Permalink
In repl tests factor out common elements, add conversation tests (#1753)
Browse files Browse the repository at this point in the history
Conversation tests fail (as expected - to be fixed), so skipped for now
  • Loading branch information
sergekh2 authored Dec 8, 2024
1 parent e92b4f3 commit 6f0298d
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 63 deletions.
84 changes: 56 additions & 28 deletions core/node/rpc/repl_multiclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,27 @@ func newServiceTesterForReplication(t *testing.T) *serviceTester {
)
}

func TestReplMulticlientSimple(t *testing.T) {
func TestReplMcSimple(t *testing.T) {
tt := newServiceTesterForReplication(t)

alice := tt.newTestClient(0)

_ = alice.createUserStream()
spaceId, _ := alice.createSpace()
channelId, _ := alice.createChannel(spaceId)

bob := tt.newTestClient(1)
user1LastMb := bob.createUserStream()
bob.joinChannel(spaceId, channelId, user1LastMb)

allClients := testClients{alice, bob}
allClients.requireMembership(channelId)

carol := tt.newTestClient(2)
user2LastMb := carol.createUserStream()
carol.joinChannel(spaceId, channelId, user2LastMb)

allClients = append(allClients, carol)
allClients.requireMembership(channelId)

clients := tt.newTestClients(3)
spaceId, _ := clients[0].createSpace()
channelId := clients.createChannelAndJoin(spaceId)
phrases1 := []string{"hello from Alice", "hello from Bob", "hello from Carol"}
allClients.say(channelId, phrases1...)
clients.say(channelId, phrases1...)

allClients.listen(channelId, [][]string{phrases1})
clients.listen(channelId, [][]string{phrases1})

phrases2 := []string{"hello from Alice 2", "hello from Bob 2", "hello from Carol 2"}
allClients.say(channelId, phrases2...)
allClients.listen(channelId, [][]string{phrases1, phrases2})
clients.say(channelId, phrases2...)
clients.listen(channelId, [][]string{phrases1, phrases2})

phrases3 := []string{"", "hello from Bob 3", ""}
allClients.say(channelId, phrases3...)
allClients.listen(channelId, [][]string{phrases1, phrases2, phrases3})
clients.say(channelId, phrases3...)
clients.listen(channelId, [][]string{phrases1, phrases2, phrases3})
}

func TestReplSpeakUntilMbTrim(t *testing.T) {
func TestReplMcSpeakUntilMbTrim(t *testing.T) {
tt := newServiceTesterForReplication(t)
require := tt.require

Expand All @@ -82,3 +65,48 @@ func TestReplSpeakUntilMbTrim(t *testing.T) {
}
require.Fail("failed to trim miniblocks")
}

func testReplMcConversation(t *testing.T, numClients int, numSteps int, listenInterval int) {
tt := newServiceTesterForReplication(t)
clients := tt.newTestClients(numClients)
spaceId, _ := clients[0].createSpace()
channelId := clients.createChannelAndJoin(spaceId)

messages := make([][]string, numSteps)
for i := range messages {
messages[i] = make([]string, numClients)
for j := range messages[i] {
messages[i][j] = fmt.Sprintf("message %d from client %s", i, clients[j].name)
}
}

var i int
var m []string
defer func() {
if i+1 < len(messages) {
t.Errorf("got through %d steps out of %d", i+1, len(messages))
}
}()
for i, m = range messages {
clients.say(channelId, m...)
if listenInterval > 0 && (i+1)%listenInterval == 0 {
clients.listen(channelId, messages[:i+1])
}
}

if listenInterval <= 0 || numSteps%listenInterval != 0 {
clients.listen(channelId, messages)
}
}

func TestReplMcConversation(t *testing.T) {
t.Skip("SKIPPED: TODO: REPLICATION: fix")

t.Parallel()
t.Run("5x5", func(t *testing.T) {
testReplMcConversation(t, 5, 5, 1)
})
t.Run("debug", func(t *testing.T) {
testReplMcConversation(t, 5, 12, 1)
})
}
194 changes: 159 additions & 35 deletions core/node/rpc/tester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"hash/fnv"
"io"
"log"
"maps"
"math/big"
"net"
"net/http"
"slices"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -507,6 +509,7 @@ type testClient struct {
wallet *crypto.Wallet
userId common.Address
userStreamId StreamId
name string
}

func (st *serviceTester) newTestClient(i int) *testClient {
Expand All @@ -521,9 +524,22 @@ func (st *serviceTester) newTestClient(i int) *testClient {
wallet: wallet,
userId: wallet.Address,
userStreamId: UserStreamIdFromAddr(wallet.Address),
name: fmt.Sprintf("%d-%s", i, wallet.Address.Hex()[2:8]),
}
}

// newTestClients creates a testClients with clients connected to nodes in round-robin fashion.
func (st *serviceTester) newTestClients(numClients int) testClients {
clients := make(testClients, numClients)
for i := range clients {
clients[i] = st.newTestClient(i % st.opts.numNodes)
}
clients.parallelForAll(func(tc *testClient) {
tc.createUserStream()
})
return clients
}

func (tc *testClient) withRequireFor(t require.TestingT) *testClient {
var tcc testClient = *tc
tcc.require = require.New(t)
Expand Down Expand Up @@ -633,10 +649,89 @@ type usersMessage struct {
message string
}

func (tc *testClient) getAllMessages(channelId StreamId) []usersMessage {
func (um usersMessage) String() string {
return fmt.Sprintf("%s: '%s'\n", um.userId.Hex()[2:8], um.message)
}

type userMessages []usersMessage

func flattenUserMessages(userIds []common.Address, messages [][]string) userMessages {
um := userMessages{}
for _, msg := range messages {
for j, m := range msg {
if m != "" {
um = append(um, usersMessage{userId: userIds[j], message: m})
}
}
}
return um
}

func (um userMessages) String() string {
if len(um) == 0 {
return " EMPTY"
}
lines := []string{"\n[[[\n"}
for _, m := range um {
lines = append(lines, m.String())
}
lines = append(lines, "]]]\n")
return strings.Join(lines, "")
}

func diffUserMessages(expected, actual userMessages) (userMessages, userMessages) {
expectedSet := map[string]usersMessage{}
for _, m := range expected {
expectedSet[m.String()] = m
}
actualExtra := userMessages{}
for _, m := range actual {
key := m.String()
_, ok := expectedSet[key]
if ok {
delete(expectedSet, key)
} else {
actualExtra = append(actualExtra, m)
}
}
expectedExtra := slices.Collect(maps.Values(expectedSet))
return expectedExtra, actualExtra
}

func TestDiffUserMessages(t *testing.T) {
assert := assert.New(t)

um1 := usersMessage{common.Address{0x1}, "A"}
um2 := usersMessage{common.Address{0x1}, "B"}
um3 := usersMessage{common.Address{0x2}, "A"}
um4 := usersMessage{common.Address{0x2}, "B"}
umAll := userMessages{um1, um2, um3, um4}

a, b := diffUserMessages(umAll, umAll)
assert.Len(a, 0)
assert.Len(b, 0)

a, b = diffUserMessages(umAll, umAll[:3])
assert.ElementsMatch(a, umAll[3:])
assert.Len(b, 0)

a, b = diffUserMessages(umAll[1:], umAll)
assert.Len(a, 0)
assert.ElementsMatch(b, umAll[:1])

a, b = diffUserMessages(umAll[1:], umAll[:3])
assert.ElementsMatch(a, umAll[3:])
assert.ElementsMatch(b, umAll[:1])

a, b = diffUserMessages(umAll[2:], umAll[:2])
assert.ElementsMatch(a, umAll[2:])
assert.ElementsMatch(b, umAll[:2])
}

func (tc *testClient) getAllMessages(channelId StreamId) userMessages {
_, view := tc.getStreamAndView(channelId, true)

messages := []usersMessage{}
messages := userMessages{}
for e := range view.AllEvents() {
payload := e.GetChannelMessage()
if payload != nil {
Expand All @@ -650,34 +745,44 @@ func (tc *testClient) getAllMessages(channelId StreamId) []usersMessage {
return messages
}

// messages are partially sorted, i.e. messages in the channel that match sub-slices can be in any order; each []string should match userIds saying it.
func (tc *testClient) eventually(f func(*testClient), t ...time.Duration) {
waitFor := 5 * time.Second
if len(t) > 0 {
waitFor = t[0]
}
tick := 100 * time.Millisecond
if len(t) > 1 {
tick = t[1]
}
tc.require.EventuallyWithT(func(t *assert.CollectT) {
f(tc.withRequireFor(t))
}, waitFor, tick)
}

func (tc *testClient) listen(channelId StreamId, userIds []common.Address, messages [][]string) {
msgs := tc.getAllMessages(channelId)
for _, expected := range messages {
notEmptyCount := 0
for _, e := range expected {
if e != "" {
notEmptyCount++
}
expected := flattenUserMessages(userIds, messages)
tc.listenImpl(channelId, expected)
}

var _ = (*testClient)(nil).listen // Suppress unused warning TODO: remove once used

func (tc *testClient) listenImpl(channelId StreamId, expected userMessages) {
actual := tc.getAllMessages(channelId)
tc.eventually(func(tc *testClient) {
expectedExtra, actualExtra := diffUserMessages(expected, actual)
if len(expectedExtra) > 0 {
tc.require.FailNow(
"Didn't receive all messages",
"client %s\nexpectedExtra:%vactualExtra:%v",
tc.name,
expectedExtra,
actualExtra,
)
}
tc.require.NotZero(notEmptyCount, "internal: conversation can't have empty step")
tc.require.GreaterOrEqual(
len(msgs),
notEmptyCount,
"Not enough was said, left %#v, expected %#v",
msgs,
expected,
)
current := msgs[:notEmptyCount]
msgs = msgs[notEmptyCount:]
expectedWithUserIds := []usersMessage{}
for i, e := range expected {
if e != "" {
expectedWithUserIds = append(expectedWithUserIds, usersMessage{userId: userIds[i], message: e})
}
if len(actualExtra) > 0 {
tc.require.FailNow("Received unexpected messages", "actualExtra:%v", actualExtra)
}
tc.require.ElementsMatch(expectedWithUserIds, current)
}
})
}

func (tc *testClient) getStream(streamId StreamId) *protocol.StreamAndCookie {
Expand Down Expand Up @@ -740,17 +845,16 @@ func (tc *testClient) addHistoryToView(
}

func (tc *testClient) requireMembership(streamId StreamId, expectedMemberships []common.Address) {
tc.require.EventuallyWithT(func(t *assert.CollectT) {
tcc := tc.withRequireFor(t)
_, view := tcc.getStreamAndView(streamId)
tc.eventually(func(tc *testClient) {
_, view := tc.getStreamAndView(streamId)
members, err := view.GetChannelMembers()
tcc.require.NoError(err)
tc.require.NoError(err)
actualMembers := []common.Address{}
for _, a := range members.ToSlice() {
actualMembers = append(actualMembers, common.HexToAddress(a))
}
tcc.require.ElementsMatch(expectedMemberships, actualMembers)
}, 5*time.Second, 100*time.Millisecond)
tc.require.ElementsMatch(expectedMemberships, actualMembers)
})
}

type testClients []*testClient
Expand All @@ -776,9 +880,9 @@ func (tcs testClients) userIds() []common.Address {
}

func (tcs testClients) listen(channelId StreamId, messages [][]string) {
userIds := tcs.userIds()
expected := flattenUserMessages(tcs.userIds(), messages)
tcs.parallelForAll(func(tc *testClient) {
tc.listen(channelId, userIds, messages)
tc.listenImpl(channelId, expected)
})
}

Expand All @@ -805,6 +909,7 @@ func parallel[Params any](tcs testClients, f func(*testClient, Params), params .
for range params {
i := <-resultC
if tcs[i].t.Failed() {
tcs[i].t.Fatalf("client %s failed", tcs[i].name)
return
}
}
Expand All @@ -823,7 +928,26 @@ func (tcs testClients) parallelForAll(f func(*testClient)) {
for range tcs {
i := <-resultC
if tcs[i].t.Failed() {
tcs[i].t.Fatalf("client %s failed", tcs[i].name)
return
}
}
}

// setupChannelWithClients creates a channel and returns a testClients with clients connected to it.
// First client is creator of both space and channel.
// Other clients join the channel.
// Clients are connected to nodes in round-robin fashion.
func (tcs testClients) createChannelAndJoin(spaceId StreamId) StreamId {
alice := tcs[0]
channelId, _ := alice.createChannel(spaceId)

tcs[1:].parallelForAll(func(tc *testClient) {
userLastMb := tc.getLastMiniblockHash(tc.userStreamId)
tc.joinChannel(spaceId, channelId, userLastMb)
})

tcs.requireMembership(channelId)

return channelId
}

0 comments on commit 6f0298d

Please sign in to comment.