Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(use any mls ciphersuite, add mls support on established client) #WPB-14877 #36

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>com.wire.bots</groupId>
<artifactId>hold</artifactId>
<version>1.2.1</version>
<version>1.2.2</version>

<name>Legal Hold</name>
<description>Legal Hold Service For Wire</description>
Expand Down Expand Up @@ -67,7 +67,7 @@
<dependency>
<groupId>com.wire</groupId>
<artifactId>helium</artifactId>
<version>1.6.0</version>
<version>1.6.1</version>
</dependency>
<dependency>
<groupId>com.github.smoketurner</groupId>
Expand Down
10 changes: 6 additions & 4 deletions src/main/java/com/wire/bots/hold/DAO/AccessDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import java.util.UUID;

public interface AccessDAO {
@SqlUpdate("INSERT INTO Access (userId, userDomain, clientId, cookie, updated, created, enabled) " +
"VALUES (:userId, :userDomain, :clientId, :cookie, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 1) " +
@SqlUpdate("INSERT INTO Access (userId, userDomain, clientId, cookie, mlsClientCreated, mlsCiphersuite, updated, created, enabled) " +
"VALUES (:userId, :userDomain, :clientId, :cookie, :mlsClientCreated, :mlsCiphersuite, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 1) " +
"ON CONFLICT (userId, userDomain) DO UPDATE SET cookie = EXCLUDED.cookie, clientId = EXCLUDED.clientId, " +
"updated = EXCLUDED.updated, enabled = EXCLUDED.enabled")
"updated = EXCLUDED.updated, enabled = EXCLUDED.enabled, mlsClientCreated = EXCLUDED.mlsClientCreated, mlsCiphersuite = EXCLUDED.mlsCiphersuite")
int insert(@Bind("userId") UUID userId,
@Bind("userDomain") String userDomain,
@Bind("clientId") String clientId,
@Bind("cookie") String cookie);
@Bind("cookie") String cookie,
@Bind("mlsClientCreated") Boolean mlsClientCreated,
@Bind("mlsCiphersuite") Integer mlsCiphersuite);

@SqlUpdate("UPDATE Access SET enabled = 0, updated = CURRENT_TIMESTAMP WHERE userId = :userId " +
"AND (( :userDomain IS NULL AND userDomain IS null ) or ( :userDomain IS NOT NULL AND userDomain = :userDomain ))")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public LHAccess map(ResultSet rs, int columnNumber, StatementContext ctx) throws
LHAccess.clientId = rs.getString("clientId");
LHAccess.token = rs.getString("token");
LHAccess.cookie = rs.getString("cookie");
LHAccess.mlsClientCreated = rs.getBoolean("mlsClientCreated");
LHAccess.mlsCiphersuite = rs.getInt("mlsCiphersuite");
LHAccess.updated = rs.getString("updated");
LHAccess.created = rs.getString("created");
LHAccess.enabled = rs.getInt("enabled") == 1;
Expand Down
18 changes: 15 additions & 3 deletions src/main/java/com/wire/bots/hold/NotificationProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.wire.bots.hold.DAO.AccessDAO;
import com.wire.bots.hold.model.database.LHAccess;
import com.wire.bots.hold.service.DeviceManagementService;
import com.wire.bots.hold.utils.LoginClientExtension;
import com.wire.helium.API;
import com.wire.helium.models.Access;
Expand All @@ -26,11 +27,18 @@ public class NotificationProcessor implements Runnable {
private final Client client;
private final AccessDAO accessDAO;
private final HoldMessageResource messageResource;

NotificationProcessor(Client client, AccessDAO accessDAO, HoldMessageResource messageResource) {
private final DeviceManagementService deviceManagementService;

NotificationProcessor(
Client client,
AccessDAO accessDAO,
HoldMessageResource messageResource,
DeviceManagementService deviceManagementService)
{
this.client = client;
this.accessDAO = accessDAO;
this.messageResource = messageResource;
this.deviceManagementService = deviceManagementService;
}

@Override
Expand All @@ -53,8 +61,12 @@ private void process(LHAccess device) {

final Access access = LoginClientExtension.refreshToken(client, device.clientId, device.cookie);
accessDAO.update(device.userId.id, device.userId.domain, access.getAccessToken(), access.getCookie().value);

final API api = new API(client, null, access.getAccessToken());

if (!device.mlsClientCreated) {
deviceManagementService.configureMlsClient(device.userId, device.clientId, access.getCookie().value, api);
}

NotificationList notificationList = api.retrieveNotifications(
device.clientId,
device.last,
Expand Down
30 changes: 15 additions & 15 deletions src/main/java/com/wire/bots/hold/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static void main(String[] args) throws Exception {
@Override
public void initialize(Bootstrap<Config> bootstrap) {
bootstrap.setConfigurationSourceProvider(new SubstitutingSourceProvider(
bootstrap.getConfigurationSourceProvider(), new EnvironmentVariableSubstitutor(false)));
bootstrap.getConfigurationSourceProvider(), new EnvironmentVariableSubstitutor(false)));

bootstrap.addBundle(new SwaggerBundle<>() {
@Override
Expand Down Expand Up @@ -173,12 +173,12 @@ public void run(Config config, Environment environment) throws ExecutionExceptio
final HoldClientRepo repo = new HoldClientRepo(jdbi, cf, httpClient, config.coreCryptoPassword);

final HoldMessageResource holdMessageResource = new HoldMessageResource(new MessageHandler(jdbi), repo);
final NotificationProcessor notificationProcessor = new NotificationProcessor(httpClient, accessDAO, holdMessageResource);
final NotificationProcessor notificationProcessor = new NotificationProcessor(httpClient, accessDAO, holdMessageResource, deviceManagementService);

environment.lifecycle()
.scheduledExecutorService("notifications")
.build()
.scheduleWithFixedDelay(notificationProcessor, 10, config.sleep.toSeconds(), TimeUnit.SECONDS);
.scheduledExecutorService("notifications")
.build()
.scheduleWithFixedDelay(notificationProcessor, 10, config.sleep.toSeconds(), TimeUnit.SECONDS);

CollectorRegistry.defaultRegistry.register(new DropwizardExports(metrics));

Expand All @@ -199,24 +199,24 @@ protected void addResource(Object component) {

private Client createHttpClient(Config config, Environment env) {
return new JerseyClientBuilder(env)
.using(config.getJerseyClient())
.withProvider(MultiPartFeature.class)
.withProvider(JacksonJsonProvider.class)
.build(getName());
.using(config.getJerseyClient())
.withProvider(MultiPartFeature.class)
.withProvider(JacksonJsonProvider.class)
.build(getName());
}

protected Jdbi buildJdbi(Config.Database database, Environment env) {
return Jdbi
.create(database.build(env.metrics(), getName()))
.installPlugin(new SqlObjectPlugin());
.create(database.build(env.metrics(), getName()))
.installPlugin(new SqlObjectPlugin());
}

protected void setupDatabase(Config.Database database) {
Flyway flyway = Flyway
.configure()
.dataSource(database.getUrl(), database.getUser(), database.getPassword())
.baselineOnMigrate(database.baseline)
.load();
.configure()
.dataSource(database.getUrl(), database.getUser(), database.getPassword())
.baselineOnMigrate(database.baseline)
.load();
flyway.migrate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ public class LHAccess {
public String clientId;
public String token;
public String cookie;
public boolean mlsClientCreated;
public Integer mlsCiphersuite;
public String updated;
public String created;
public boolean enabled;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV0 payload) {
try {
deviceManagementService.confirmDevice(
new QualifiedId(payload.userId, null),
payload.teamId,
payload.clientId,
payload.refreshToken
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ public RemoveResourceV0(DeviceManagementService deviceManagementService) {
public Response remove(@ApiParam @Valid InitPayloadV0 payload) {
try {
deviceManagementService.removeDevice(
new QualifiedId(payload.userId, null),
payload.teamId
new QualifiedId(payload.userId, null)
);

return Response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV1 payload) {
try {
deviceManagementService.confirmDevice(
payload.userId,
payload.teamId,
payload.clientId,
payload.refreshToken
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public RemoveResourceV1(DeviceManagementService deviceManagementService) {
@ApiResponse(code = 200, message = "Legal Hold Device was removed")})
public Response remove(@ApiParam @Valid InitPayloadV1 payload) {
try {
deviceManagementService.removeDevice(payload.userId, payload.teamId);
deviceManagementService.removeDevice(payload.userId);

return Response
.ok()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.wire.helium.models.Access;
import com.wire.xenon.WireClientBase;
import com.wire.xenon.backend.models.Conversation;
import com.wire.xenon.backend.models.FeatureConfig;
import com.wire.xenon.backend.models.QualifiedId;
import com.wire.xenon.crypto.Crypto;
import com.wire.xenon.crypto.mls.CryptoMlsClient;
Expand Down Expand Up @@ -74,7 +75,24 @@ public InitializedDeviceDTO initiateLegalHoldDevice(QualifiedId userId, UUID tea
}

/**
* <p>Confirm a user's device under legal hold.</p>
* Confirm a user's device under legal hold. Client authentication performed using refreshToken.
* @param userId user setup to be put under legal hold
* @param clientId user's device
* @param refreshToken cookie token from the database, used to get new access token and cookie
* @throws RuntimeException when any of the (parallel or not) MLS tasks fails. Or if inserting refreshToken to Database fails.
*/
public void confirmDevice(QualifiedId userId, String clientId, String refreshToken) throws RuntimeException {
final Access access = LoginClientExtension.refreshToken(client, clientId, refreshToken);
API api = new API(client, null, access.accessToken);
boolean mlsClientCreated = configureMlsClient(userId, clientId, access.getCookie().value, api);

if (!mlsClientCreated) {
// If MLS client was added storing the client with MLS data is done already, else store only Proteus data
storeProteusOnlyDevice(userId, clientId, access.getCookie().value);
}
}
/**
* Confirm a user's device under legal hold. Client authentication is done beforehand
spoonman01 marked this conversation as resolved.
Show resolved Hide resolved
* <p>
* If MLS is enabled, then we initialize CryptoMlsClient and WireClientBase, then fetch and upload in parallel:
* - PublicKeys
Expand All @@ -87,17 +105,18 @@ public InitializedDeviceDTO initiateLegalHoldDevice(QualifiedId userId, UUID tea
* Stores the refreshToken in order to fetch user notifications while under legal hold
* </p>
* @param userId user setup to be put under legal hold
* @param teamId user's own team
* @param clientId user's device
* @param refreshToken token used to get expiring api tokens
* @param cookie new cookie token used to get access tokens
* @param api object to interact with Wire API, setup for the client with a valid access token
* @throws RuntimeException when any of the (parallel or not) MLS tasks fails. Or if inserting refreshToken to Database fails.
* @return true if MLS client was added, false otherwise
*/
public void confirmDevice(QualifiedId userId, UUID teamId, String clientId, String refreshToken) throws RuntimeException {
final Access access = LoginClientExtension.refreshToken(client, clientId, refreshToken);
API api = new API(client, null, access.accessToken);
public boolean configureMlsClient(QualifiedId userId, String clientId, String cookie, API api) throws RuntimeException {
final FeatureConfig mlsConfig = api.getFeatureConfig();

if (api.isMlsEnabled()) {
try (CryptoMlsClient cryptoMlsClient = new CryptoMlsClient(clientId, userId, coreCryptoPassword)) {
if (mlsConfig.mls.isMlsStatusEnabled()) {
Logger.info("MLS is enabled for user %s, configuring client and joining conversations", userId);
try (CryptoMlsClient cryptoMlsClient = new CryptoMlsClient(clientId, userId, mlsConfig.mls.config.defaultCipherSuite, coreCryptoPassword)) {
// CryptoMlsClient will be closed from `try` with resource so there is no issue passing
// Crypto as null, as we will not be calling wireClientBase.close()
WireClientBase wireClientBase = new WireClientBase(api, null, cryptoMlsClient, null);
Expand Down Expand Up @@ -133,47 +152,54 @@ public void confirmDevice(QualifiedId userId, UUID teamId, String clientId, Stri
Logger.info("Conversation ID: %s, Name: %s, GroupId: %s", conversation.id, conversation.name, conversation.mlsGroupId);
wireClientBase.joinMlsConversation(conversation.id, conversation.mlsGroupId);
}
storeProteusAndMlsDevice(userId, clientId, cookie, mlsConfig.mls.config.defaultCipherSuite);
return true;
} catch (ExecutionException exception) {
throw new RuntimeException("ExecutionException: " + exception.getCause().getMessage());
} catch (InterruptedException exception) {
throw new RuntimeException("InterruptedException: " + exception.getMessage());
}
}
return false;
}

// Proteus
private void storeProteusAndMlsDevice(QualifiedId userId, String clientId, String cookie, Integer mlsCiphersuite) {
storeDevice(userId, clientId, cookie, true, mlsCiphersuite);
}

private void storeProteusOnlyDevice(QualifiedId userId, String clientId, String cookie) {
storeDevice(userId, clientId, cookie, false, null);
}

private void storeDevice(QualifiedId userId, String clientId, String cookie, Boolean mlsClientAdded, Integer mlsCiphersuite) {
int insert = accessDAO.insert(userId.id,
userId.domain,
clientId,
access.getCookie().value);
cookie,
mlsClientAdded,
mlsCiphersuite);

if (0 == insert) {
Logger.error("ConfirmResource: Failed to insert Access %s:%s",
userId,
clientId);

Logger.error("ConfirmResource: Failed to insert Access %s:%s", userId, clientId);
throw new RuntimeException("Cannot insert new device");
}

Logger.info("ConfirmResource: team: %s, user:%s, client: %s",
teamId,
userId,
clientId);
Logger.info("ConfirmResource: user:%s, client: %s", userId, clientId);
}

/**
* Remove a user from legal hold.
* <p>Before soft-deleting on the database, cleans up Proteus data and MLS data if it was ever created</p>
* @param userId user setup to be put under legal hold
* @param teamId user's own team
* @throws IOException
* @throws CryptoException
*/
public void removeDevice(QualifiedId userId, UUID teamId) throws IOException, CryptoException {
public void removeDevice(QualifiedId userId) throws IOException, CryptoException {
// MLS
LHAccess userAccess = accessDAO.get(userId.id, userId.domain);
if (userAccess != null) {
API api = new API(client, null, userAccess.token);
if (api.isMlsEnabled()) {
try (CryptoMlsClient cryptoMlsClient = new CryptoMlsClient(userAccess.clientId, userId, coreCryptoPassword)) {
if (userAccess.mlsClientCreated) {
try (CryptoMlsClient cryptoMlsClient = new CryptoMlsClient(userAccess.clientId, userId, userAccess.mlsCiphersuite, coreCryptoPassword)) {
cryptoMlsClient.wipe();
}
}
Expand All @@ -182,18 +208,12 @@ public void removeDevice(QualifiedId userId, UUID teamId) throws IOException, Cr
// Proteus
try (Crypto crypto = cf.create(userId)) {
crypto.purge();

int removeAccess = accessDAO.disable(userId.id, userId.domain);

Logger.info(
"RemoveResource: team: %s, user: %s, removed: %s",
teamId,
userId,
removeAccess
);
} catch (Exception e) {
Logger.exception(e, "RemoveLegalHoldDevice: %s", e.getMessage());
throw e;
}

int removeAccess = accessDAO.disable(userId.id, userId.domain);
Logger.info("RemoveResource: user: %s, removed: %s", userId, removeAccess);
}
}
4 changes: 2 additions & 2 deletions src/main/java/com/wire/bots/hold/utils/HoldClientRepo.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public WireClient getClient(QualifiedId userId, String deviceId, QualifiedId con
final LHAccess single = jdbi.onDemand(AccessDAO.class).get(userId.id, userId.domain);
final API api = new API(httpClient, conversationId, single.token);

// for receiving notifications
final CryptoMlsClient cryptoMlsClient = new CryptoMlsClient(deviceId, userId, coreCryptoPassword);
// Create MLS client only if device was already initialized for MLS
final CryptoMlsClient cryptoMlsClient = single.mlsClientCreated ? new CryptoMlsClient(deviceId, userId, single.mlsCiphersuite, coreCryptoPassword) : null;
return new HoldWireClient(userId, deviceId, conversationId, cryptoMlsClient, crypto, api);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ALTER TABLE Access
ADD COLUMN mlsClientCreated BOOLEAN NOT NULL DEFAULT false;

ALTER TABLE Access
ADD COLUMN mlsCiphersuite SMALLINT DEFAULT null;
8 changes: 6 additions & 2 deletions src/test/java/com/wire/bots/hold/DatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.util.*;

import static com.wire.bots.hold.utils.Constant.DEFAULT_CIPHERSUITE_IDENTIFIER;

public class DatabaseTest {
private static final DropwizardTestSupport<Config> SUPPORT = new DropwizardTestSupport<>(
Service.class, "hold.yaml",
Expand Down Expand Up @@ -104,7 +106,7 @@ public void accessTests() {
final String cookie2 = "cookie2";
final String token = "token";

final int insert = accessDAO.insert(userId.id, userId.domain, clientId, cookie);
final int insert = accessDAO.insert(userId.id, userId.domain, clientId, cookie, false, null);
accessDAO.updateLast(userId.id, userId.domain, last);
accessDAO.update(userId.id, userId.domain, token, cookie2);

Expand All @@ -114,10 +116,12 @@ public void accessTests() {

accessDAO.disable(userId.id, userId.domain);

final int insert2 = accessDAO.insert(userId.id, userId.domain, clientId, cookie);
final int insert2 = accessDAO.insert(userId.id, userId.domain, clientId, cookie, true, DEFAULT_CIPHERSUITE_IDENTIFIER);
final LHAccess lhAccess2 = accessDAO.get(userId.id, userId.domain);
assert lhAccess2 != null;
assert lhAccess2.created.equals(lhAccess.created);
assert lhAccess2.mlsClientCreated;
assert lhAccess2.mlsCiphersuite.equals(DEFAULT_CIPHERSUITE_IDENTIFIER);
}

@Test
Expand Down
Loading
Loading