Skip to content

Commit

Permalink
add a bunch of code to help with threading
Browse files Browse the repository at this point in the history
  • Loading branch information
nplasterer committed Sep 20, 2023
1 parent e4fcf71 commit c20814a
Showing 1 changed file with 88 additions and 46 deletions.
134 changes: 88 additions & 46 deletions ios/XMTPModule.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,52 @@ extension Conversation {
}

public class XMTPModule: Module {
@MainActor var clients: [String: XMTP.Client] = [:]
var signer: ReactNativeSigner?
@MainActor var conversations: [String: Conversation] = [:]
@MainActor var subscriptions: [String: Task<Void, Never>] = [:]
let clientsManager = ClientsManager()
let conversationsManager = ConversationsManager()
let subscriptionsManager = SubscriptionsManager()

actor ClientsManager {
private var clients: [String: XMTP.Client] = [:]

// A method to update the conversations
func updateClient(key: String, client: XMTP.Client?) {
clients[key] = client
}

// A method to retrieve a conversation
func getClient(key: String) -> XMTP.Client? {
return clients[key]
}
}

actor ConversationsManager {
private var conversations: [String: Conversation] = [:]

// A method to update the conversations
func updateConversation(key: String, conversation: Conversation?) {
conversations[key] = conversation
}

// A method to retrieve a conversation
func getConversation(key: String) -> Conversation? {
return conversations[key]
}
}

actor SubscriptionsManager {
private var subscriptions: [String: Task<Void, Never>] = [:]

// A method to update the subscriptions
func updateSubscription(key: String, task: Task<Void, Never>?) {
subscriptions[key] = task
}

// A method to retrieve a subscription
func getSubscription(key: String) -> Task<Void, Never>? {
return subscriptions[key]
}
}

enum Error: Swift.Error {
case noClient, conversationNotFound(String), noMessage, invalidKeyBundle, invalidDigest, badPreparation(String)
Expand All @@ -86,7 +128,7 @@ public class XMTPModule: Module {
Events("sign", "authed", "conversation", "message")

AsyncFunction("address") { (clientAddress: String) -> String in
if let client = await clients[clientAddress] {
if let client = await clientsManager.getClient(key: clientAddress) {
return client.address
} else {
return "No Client."
Expand All @@ -100,7 +142,7 @@ public class XMTPModule: Module {
let signer = ReactNativeSigner(module: self, address: address)
self.signer = signer
let options = createClientConfig(env: environment, appVersion: appVersion)
await self.clients[address] = try await XMTP.Client.create(account: signer, options: options)
await clientsManager.updateClient(key: address, client: try await XMTP.Client.create(account: signer, options: options))
self.signer = nil
sendEvent("authed")
}
Expand All @@ -115,7 +157,7 @@ public class XMTPModule: Module {
let options = createClientConfig(env: environment, appVersion: appVersion)
let client = try await Client.create(account: privateKey, options: options)

await self.clients[client.address] = client
await clientsManager.updateClient(key: client.address, client: client)
return client.address
}

Expand All @@ -129,7 +171,7 @@ public class XMTPModule: Module {

let options = createClientConfig(env: environment, appVersion: appVersion)
let client = try await Client.from(bundle: bundle, options: options)
await self.clients[client.address] = client
await clientsManager.updateClient(key: client.address, client: client)
return client.address
} catch {
print("ERRO! Failed to create client: \(error)")
Expand All @@ -139,7 +181,7 @@ public class XMTPModule: Module {

// Export the client's serialized key bundle.
AsyncFunction("exportKeyBundle") { (clientAddress: String) -> String in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}
let bundle = try client.privateKeyBundle.serializedData().base64EncodedString()
Expand All @@ -156,29 +198,29 @@ public class XMTPModule: Module {

// Import a conversation from its serialized topic data.
AsyncFunction("importConversationTopicData") { (clientAddress: String, topicData: String) -> String in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}
let data = try Xmtp_KeystoreApi_V1_TopicMap.TopicData(
serializedData: Data(base64Encoded: Data(topicData.utf8))!
)
let conversation = client.conversations.importTopicData(data: data)
await conversations[conversation.cacheKey(clientAddress)] = conversation
let conversation = try await client.conversations.importTopicData(data: data)
await conversationsManager.updateConversation(key: conversation.cacheKey(clientAddress), conversation: conversation)
return try ConversationWrapper.encode(conversation, client: client)
}

//
// Client API
AsyncFunction("canMessage") { (clientAddress: String, peerAddress: String) -> Bool in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}

return try await client.canMessage(peerAddress)
}

AsyncFunction("encryptAttachment") { (clientAddress: String, fileJson: String) -> String in
if await clients[clientAddress] == nil {
if await clientsManager.getClient(key: clientAddress) == nil {
throw Error.noClient
}
let file = try DecryptedLocalAttachment.fromJson(fileJson)
Expand All @@ -204,7 +246,7 @@ public class XMTPModule: Module {
}

AsyncFunction("decryptAttachment") { (clientAddress: String, encryptedFileJson: String) -> String in
if await clients[clientAddress] == nil {
if await clientsManager.getClient(key: clientAddress) == nil {
throw Error.noClient
}
let encryptedFile = try EncryptedLocalAttachment.fromJson(encryptedFileJson)
Expand All @@ -229,14 +271,14 @@ public class XMTPModule: Module {
}

AsyncFunction("listConversations") { (clientAddress: String) -> [String] in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}

let conversations = try await client.conversations.list()

return try conversations.map { conversation in
await self.conversations[conversation.cacheKey(clientAddress)] = conversation
await conversationsManager.updateConversation(key: conversation.cacheKey(clientAddress), conversation: conversation)

return try ConversationWrapper.encode(conversation, client: client)
}
Expand Down Expand Up @@ -270,7 +312,7 @@ public class XMTPModule: Module {
}

AsyncFunction("loadBatchMessages") { (clientAddress: String, topics: [String]) -> [String] in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}

Expand Down Expand Up @@ -363,7 +405,7 @@ public class XMTPModule: Module {
}

AsyncFunction("sendPreparedMessage") { (clientAddress: String, preparedLocalMessageJson: String) -> String in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}
guard let local = try? PreparedLocalMessage.fromJson(preparedLocalMessageJson) else {
Expand All @@ -386,7 +428,7 @@ public class XMTPModule: Module {
}

AsyncFunction("createConversation") { (clientAddress: String, peerAddress: String, contextJson: String) -> String in
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}

Expand All @@ -405,24 +447,24 @@ public class XMTPModule: Module {
}
}

Function("subscribeToConversations") { (clientAddress: String) in
subscribeToConversations(clientAddress: clientAddress)
AsyncFunction("subscribeToConversations") { (clientAddress: String) in
try await subscribeToConversations(clientAddress: clientAddress)
}

Function("subscribeToAllMessages") { (clientAddress: String) in
subscribeToAllMessages(clientAddress: clientAddress)
AsyncFunction("subscribeToAllMessages") { (clientAddress: String) in
try await subscribeToAllMessages(clientAddress: clientAddress)
}

AsyncFunction("subscribeToMessages") { (clientAddress: String, topic: String) in
try await subscribeToMessages(clientAddress: clientAddress, topic: topic)
}

AsyncFunction("unsubscribeFromConversations") {
await subscriptions["conversations"]?.cancel()
await subscriptionsManager.getSubscription(key: "conversations")?.cancel()
}

AsyncFunction("unsubscribeFromAllMessages") {
await subscriptions["messages"]?.cancel()
await subscriptionsManager.getSubscription(key: "messages")?.cancel()
}

AsyncFunction("unsubscribeFromMessages") { (clientAddress: String, topic: String) in
Expand Down Expand Up @@ -494,49 +536,49 @@ public class XMTPModule: Module {
}

func findConversation(clientAddress: String, topic: String) async throws -> Conversation? {
guard let client = await clients[clientAddress] else {
guard let client = await clientsManager.getClient(key: clientAddress) else {
throw Error.noClient
}

let cacheKey = Conversation.cacheKeyForTopic(clientAddress: clientAddress, topic: topic)
if let conversation = await conversations[cacheKey] {
if let conversation = await conversationsManager.getConversation(key: cacheKey) {
return conversation
} else if let conversation = try await client.conversations.list().first(where: { $0.topic == topic }) {
await conversations[cacheKey] = conversation
await conversationsManager.updateConversation(key: cacheKey, conversation: conversation)
return conversation
}

return nil
}

func subscribeToConversations(clientAddress: String) {
guard let client = await clients[clientAddress] else {
func subscribeToConversations(clientAddress: String) async throws {
guard let client = await clientsManager.getClient(key: clientAddress) else {
return
}

await subscriptions["conversations"]?.cancel()
await subscriptions["conversations"] = Task {
await subscriptionsManager.getSubscription(key: "conversations")?.cancel()
await subscriptionsManager.updateSubscription(key: "conversations", task: Task {
do {
for try await conversation in client.conversations.stream() {
for try await conversation in try await client.conversations.stream() {
sendEvent("conversation", [
"clientAddress": clientAddress,
"conversation": try ConversationWrapper.encodeToObj(conversation, client: client)
])
}
} catch {
print("Error in conversations subscription: \(error)")
await subscriptions["conversations"]?.cancel()
await subscriptionsManager.getSubscription(key: "conversations")?.cancel()
}
}
})
}

func subscribeToAllMessages(clientAddress: String) {
guard let client = await clients[clientAddress] else {
func subscribeToAllMessages(clientAddress: String) async throws {
guard let client = await clientsManager.getClient(key: clientAddress) else {
return
}

await subscriptions["messages"]?.cancel()
await subscriptions["messages"] = Task {
await subscriptionsManager.getSubscription(key: "messages")?.cancel()
await subscriptionsManager.updateSubscription(key: "messages", task: Task {
do {
for try await message in try await client.conversations.streamAllMessages() {
do {
Expand All @@ -550,18 +592,18 @@ public class XMTPModule: Module {
}
} catch {
print("Error in all messages subscription: \(error)")
await subscriptions["messages"]?.cancel()
await subscriptionsManager.getSubscription(key: "messages")?.cancel()
}
}
})
}

func subscribeToMessages(clientAddress: String, topic: String) async throws {
guard let conversation = try await findConversation(clientAddress: clientAddress, topic: topic) else {
return
}

await subscriptions[conversation.cacheKey(clientAddress)]?.cancel()
await subscriptions[conversation.cacheKey(clientAddress)] = Task {
await subscriptionsManager.getSubscription(key: conversation.cacheKey(clientAddress))?.cancel()
await subscriptionsManager.updateSubscription(key: conversation.cacheKey(clientAddress), task: Task {
do {
for try await message in conversation.streamMessages() {
do {
Expand All @@ -575,16 +617,16 @@ public class XMTPModule: Module {
}
} catch {
print("Error in messages subscription: \(error)")
await subscriptions[conversation.cacheKey(clientAddress)]?.cancel()
await subscriptionsManager.getSubscription(key: conversation.cacheKey(clientAddress))?.cancel()
}
}
})
}

func unsubscribeFromMessages(clientAddress: String, topic: String) async throws {
guard let conversation = try await findConversation(clientAddress: clientAddress, topic: topic) else {
return
}

await subscriptions[conversation.cacheKey(clientAddress)]?.cancel()
await subscriptionsManager.getSubscription(key: conversation.cacheKey(clientAddress))?.cancel()
}
}

0 comments on commit c20814a

Please sign in to comment.