Skip to content

Commit

Permalink
chore(communities)_: request missing channels' encryption keys in a loop
Browse files Browse the repository at this point in the history
  • Loading branch information
osmaczko committed Jul 16, 2024
1 parent 5059c19 commit 9403475
Show file tree
Hide file tree
Showing 11 changed files with 437 additions and 41 deletions.
17 changes: 15 additions & 2 deletions protocol/communities/community.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ func (o *Community) MarshalJSON() ([]byte, error) {
CategoryID: c.CategoryId,
HideIfPermissionsNotMet: c.HideIfPermissionsNotMet,
Position: int(c.Position),
MissingEncryptionKey: !o.IsMemberInChat(o.MemberIdentity(), id) && o.IsMemberLikelyInChat(id),
MissingEncryptionKey: o.HasMissingEncryptionKey(id),
}

if chat.TokenGated {
Expand Down Expand Up @@ -771,11 +771,15 @@ func (o *Community) HasMember(pk *ecdsa.PublicKey) bool {
return o.hasMember(pk)
}

func (o *Community) isMemberInChat(pk *ecdsa.PublicKey, chatID string) bool {
return o.getChatMember(pk, chatID) != nil
}

func (o *Community) IsMemberInChat(pk *ecdsa.PublicKey, chatID string) bool {
o.mutex.Lock()
defer o.mutex.Unlock()

return o.getChatMember(pk, chatID) != nil
return o.isMemberInChat(pk, chatID)
}

// Uses bloom filter members list to estimate presence in the channel.
Expand Down Expand Up @@ -1915,6 +1919,15 @@ func (o *Community) ChannelEncrypted(channelID string) bool {
return o.channelEncrypted(channelID)
}

func (o *Community) HasMissingEncryptionKey(channelID string) bool {
o.mutex.Lock()
defer o.mutex.Unlock()

return o.channelEncrypted(channelID) &&
!o.isMemberInChat(o.MemberIdentity(), channelID) &&
o.IsMemberLikelyInChat(channelID)
}

func TokenPermissionsByType(permissions map[string]*CommunityTokenPermission, permissionType protobuf.CommunityTokenPermission_Type) []*CommunityTokenPermission {
result := make([]*CommunityTokenPermission, 0)
for _, tokenPermission := range permissions {
Expand Down
89 changes: 89 additions & 0 deletions protocol/communities/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5159,3 +5159,92 @@ func (c *Community) ToStatusLinkPreview() (*common.StatusCommunityLinkPreview, e

return communityLinkPreview, nil
}

func (m *Manager) determineChannelsForHRKeysRequest(c *Community, now int64) ([]string, error) {
result := []string{}

channelsWithMissingKeys := func() map[string]struct{} {
r := map[string]struct{}{}
for id := range c.Chats() {
if c.HasMissingEncryptionKey(id) {
r[id] = struct{}{}
}
}
return r
}()

if len(channelsWithMissingKeys) == 0 {
return result, nil
}

requests, err := m.persistence.GetEncryptionKeyRequests(c.ID(), channelsWithMissingKeys)
if err != nil {
return nil, err
}

for channelID := range channelsWithMissingKeys {
request, ok := requests[channelID]
if !ok {
// If there's no prior request, ask for encryption key now
result = append(result, channelID)
continue
}

// Exponential backoff formula: initial delay * 2^(requestCount - 1)
initialDelay := int64(10 * 60 * 1000) // 10 minutes in milliseconds
backoffDuration := initialDelay * (1 << (request.requestedCount - 1))
nextRequestTime := request.requestedAt + backoffDuration

if now >= nextRequestTime {
result = append(result, channelID)
}
}

return result, nil
}

type CommunityWithChannelIDs struct {
Community *Community
ChannelIDs []string
}

// DetermineChannelsForHRKeysRequest identifies channels in a community that
// should ask for encryption keys based on their current state and past request records,
// as determined by exponential backoff.
func (m *Manager) DetermineChannelsForHRKeysRequest() ([]*CommunityWithChannelIDs, error) {
communities, err := m.Joined()
if err != nil {
return nil, err
}

result := []*CommunityWithChannelIDs{}
now := time.Now().UnixMilli()

for _, c := range communities {
if c.IsControlNode() {
continue
}

channelsToRequest, err := m.determineChannelsForHRKeysRequest(c, now)
if err != nil {
return nil, err
}

if len(channelsToRequest) > 0 {
result = append(result, &CommunityWithChannelIDs{
Community: c,
ChannelIDs: channelsToRequest,
})
}
}

return result, nil
}

func (m *Manager) updateEncryptionKeysRequests(communityID types.HexBytes, channelIDs []string, now int64) error {
return m.persistence.UpdateAndPruneEncryptionKeyRequests(communityID, channelIDs, now)
}

func (m *Manager) UpdateEncryptionKeysRequests(communityID types.HexBytes, channelIDs []string) error {
return m.updateEncryptionKeysRequests(communityID, channelIDs, time.Now().UnixMilli())
}
90 changes: 90 additions & 0 deletions protocol/communities/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2055,3 +2055,93 @@ func (s *ManagerSuite) TestFillMissingCommunityTokens() {
s.Require().NoError(err)
s.Require().Len(community.CommunityTokensMetadata(), 1)
}

func (s *ManagerSuite) TestDetermineChannelsForHRKeysRequest() {
request := &requests.CreateCommunity{
Name: "status",
Description: "token membership description",
Membership: protobuf.CommunityPermissions_AUTO_ACCEPT,
}

community, err := s.manager.CreateCommunity(request, true)
s.Require().NoError(err)
s.Require().NotNil(community)

channel := &protobuf.CommunityChat{
Members: map[string]*protobuf.CommunityMember{
common.PubkeyToHex(&s.manager.identity.PublicKey): {},
},
}

description := community.config.CommunityDescription
description.Chats = map[string]*protobuf.CommunityChat{}
description.Chats["channel-id"] = channel

// Simulate channel encrypted
_, err = community.UpsertTokenPermission(&protobuf.CommunityTokenPermission{
ChatIds: []string{ChatID(community.IDString(), "channel-id")},
})
s.Require().NoError(err)

err = generateBloomFiltersForChannels(description, s.manager.identity)
s.Require().NoError(err)

now := int64(1)
tenMinutes := int64(10 * 60 * 1000)

// Member does not have missing encryption keys
channels, err := s.manager.determineChannelsForHRKeysRequest(community, now)
s.Require().NoError(err)
s.Require().Empty(channels)

// Simulate missing encryption key
channel.Members = map[string]*protobuf.CommunityMember{}

// Channel without prior request should be returned
channels, err = s.manager.determineChannelsForHRKeysRequest(community, now)
s.Require().NoError(err)
s.Require().Len(channels, 1)
s.Require().Equal("channel-id", channels[0])

// Simulate encryption keys request
err = s.manager.updateEncryptionKeysRequests(community.ID(), []string{"channel-id"}, now)
s.Require().NoError(err)

// Channel with prior request should not be returned before backoff interval
channels, err = s.manager.determineChannelsForHRKeysRequest(community, now)
s.Require().NoError(err)
s.Require().Len(channels, 0)

// Channel with prior request should be returned only after backoff interval
channels, err = s.manager.determineChannelsForHRKeysRequest(community, now+tenMinutes)
s.Require().NoError(err)
s.Require().Len(channels, 1)
s.Require().Equal("channel-id", channels[0])

// Simulate multiple encryption keys request
err = s.manager.updateEncryptionKeysRequests(community.ID(), []string{"channel-id"}, now+tenMinutes)
s.Require().NoError(err)
err = s.manager.updateEncryptionKeysRequests(community.ID(), []string{"channel-id"}, now+2*tenMinutes)
s.Require().NoError(err)

// Channel with prior request should not be returned before backoff interval
channels, err = s.manager.determineChannelsForHRKeysRequest(community, now+2*tenMinutes)
s.Require().NoError(err)
s.Require().Len(channels, 0)

// Channel with prior request should be returned only after backoff interval
channels, err = s.manager.determineChannelsForHRKeysRequest(community, now+6*tenMinutes)
s.Require().NoError(err)
s.Require().Len(channels, 1)
s.Require().Equal("channel-id", channels[0])

// Simulate encryption key being received (it will remove request for given channel)
err = s.manager.updateEncryptionKeysRequests(community.ID(), []string{}, now)
s.Require().NoError(err)

// Channel without prior request should be returned
channels, err = s.manager.determineChannelsForHRKeysRequest(community, now)
s.Require().NoError(err)
s.Require().Len(channels, 1)
s.Require().Equal("channel-id", channels[0])
}
105 changes: 105 additions & 0 deletions protocol/communities/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ type CommunityRecordBundle struct {
installationID *string
}

type EncryptionKeysRequestRecord struct {
communityID []byte
channelID string
requestedAt int64
requestedCount uint
}

const OR = " OR "
const communitiesBaseQuery = `
SELECT
Expand Down Expand Up @@ -2092,3 +2099,101 @@ func (p *Persistence) GetCommunityRequestsToJoinRevealedAddresses(communityID []

return accounts, nil
}

func (p *Persistence) GetEncryptionKeyRequests(communityID []byte, channelIDs map[string]struct{}) (map[string]*EncryptionKeysRequestRecord, error) {
result := map[string]*EncryptionKeysRequestRecord{}

query := "SELECT channel_id, requested_at, requested_count FROM community_encryption_keys_requests WHERE community_id = ? AND channel_id IN (?" + strings.Repeat(",?", len(channelIDs)-1) + ")"

args := make([]interface{}, 0, len(channelIDs)+1)
args = append(args, communityID)
for channelID := range channelIDs {
args = append(args, channelID)
}

rows, err := p.db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()

for rows.Next() {
var channelID string
var requestedAt int64
var requestedCount uint
err := rows.Scan(&channelID, &requestedAt, &requestedCount)
if err != nil {
return nil, err
}
result[channelID] = &EncryptionKeysRequestRecord{
communityID: communityID,
channelID: channelID,
requestedAt: requestedAt,
requestedCount: requestedCount,
}
}

err = rows.Err()
if err != nil {
return nil, err
}

return result, nil
}

func (p *Persistence) UpdateAndPruneEncryptionKeyRequests(communityID types.HexBytes, channelIDs []string, requestedAt int64) error {
tx, err := p.db.Begin()
if err != nil {
return err
}

defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()

if len(channelIDs) == 0 {
deleteQuery := "DELETE FROM community_encryption_keys_requests WHERE community_id = ?"
_, err = tx.Exec(deleteQuery, communityID)
return err
}

// Delete entries that do not match the channelIDs list
deleteQuery := "DELETE FROM community_encryption_keys_requests WHERE community_id = ? AND channel_id NOT IN (?" + strings.Repeat(",?", len(channelIDs)-1) + ")"
args := make([]interface{}, 0, len(channelIDs)+1)
args = append(args, communityID)
for _, channelID := range channelIDs {
args = append(args, channelID)
}
_, err = tx.Exec(deleteQuery, args...)
if err != nil {
return err
}

stmt, err := tx.Prepare(`
INSERT INTO community_encryption_keys_requests (community_id, channel_id, requested_at, requested_count)
VALUES (?, ?, ?, 1)
ON CONFLICT(community_id, channel_id)
DO UPDATE SET
requested_at = excluded.requested_at,
requested_count = community_encryption_keys_requests.requested_count + 1
WHERE excluded.requested_at > community_encryption_keys_requests.requested_at
`)
if err != nil {
return err
}
defer stmt.Close()

for _, channelID := range channelIDs {
_, err := stmt.Exec(communityID, channelID, requestedAt)
if err != nil {
return err
}
}

return nil
}
1 change: 1 addition & 0 deletions protocol/messenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ func (m *Messenger) Start() (*MessengerResponse, error) {
}
m.startMessageSegmentsCleanupLoop()
m.startHashRatchetEncryptedMessagesCleanupLoop()
m.startRequestMissingCommunityChannelsHRKeysLoop()

if err := m.cleanTopics(); err != nil {
return nil, err
Expand Down
20 changes: 1 addition & 19 deletions protocol/messenger_backup_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,25 +393,7 @@ func (m *Messenger) requestCommunityKeysAndSharedAddresses(state *ReceivedMessag
}

if isEncrypted {
request := &protobuf.CommunityEncryptionKeysRequest{
CommunityId: syncCommunity.Id,
}

payload, err := proto.Marshal(request)
if err != nil {
return err
}

rawMessage := &common.RawMessage{
Payload: payload,
Sender: m.identity,
CommunityID: community.ID(),
SkipEncryptionLayer: true,
MessageType: protobuf.ApplicationMetadataMessage_COMMUNITY_ENCRYPTION_KEYS_REQUEST,
}

_, err = m.SendMessageToControlNode(community, rawMessage)

err = m.requestCommunityEncryptionKeys(community, nil)
if err != nil {
m.logger.Error("failed to request community encryption keys", zap.String("communityId", community.IDString()), zap.Error(err))
return err
Expand Down
Loading

0 comments on commit 9403475

Please sign in to comment.