Skip to content

Commit

Permalink
fix: auto tunnel bugs
Browse files Browse the repository at this point in the history
Fixes auto tunnel bug that can happen on startup
Fixes auto tunnel bug that didn't allow manual toggle override of tunnel while auto tunnel is active
Add basic foreground persistent notifications (optionally turned off via android settings)
  • Loading branch information
zaneschepke committed Nov 8, 2024
1 parent 7d810c7 commit cab2945
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ annotation class Kernel
@Qualifier
@Retention(AnnotationRetention.BINARY)
annotation class Userspace

@Qualifier
@Retention(AnnotationRetention.BINARY)
annotation class TunnelShell

@Qualifier
@Retention(AnnotationRetention.BINARY)
annotation class AppShell
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,18 @@ import javax.inject.Singleton
@Module
@InstallIn(SingletonComponent::class)
class TunnelModule {

@Provides
@Singleton
@TunnelShell
fun provideTunnelRootShell(@ApplicationContext context: Context): RootShell {
return RootShell(context)
}

@Provides
@Singleton
fun provideRootShell(@ApplicationContext context: Context): RootShell {
@AppShell
fun provideAppRootShell(@ApplicationContext context: Context): RootShell {
return RootShell(context)
}

Expand All @@ -40,14 +49,14 @@ class TunnelModule {
@Provides
@Singleton
@Userspace
fun provideUserspaceBackend(@ApplicationContext context: Context, rootShell: RootShell): Backend {
fun provideUserspaceBackend(@ApplicationContext context: Context, @TunnelShell rootShell: RootShell): Backend {
return GoBackend(context, RootTunnelActionHandler(rootShell))
}

@Provides
@Singleton
@Kernel
fun provideKernelBackend(@ApplicationContext context: Context, rootShell: RootShell): Backend {
fun provideKernelBackend(@ApplicationContext context: Context, @TunnelShell rootShell: RootShell): Backend {
return WgQuickBackend(context, rootShell, ToolsInstaller(context, rootShell), RootTunnelActionHandler(rootShell))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.zaneschepke.wireguardautotunnel.R
import com.zaneschepke.wireguardautotunnel.data.domain.Settings
import com.zaneschepke.wireguardautotunnel.data.domain.TunnelConfig
import com.zaneschepke.wireguardautotunnel.data.repository.AppDataRepository
import com.zaneschepke.wireguardautotunnel.module.AppShell
import com.zaneschepke.wireguardautotunnel.module.IoDispatcher
import com.zaneschepke.wireguardautotunnel.module.MainImmediateDispatcher
import com.zaneschepke.wireguardautotunnel.service.network.EthernetService
Expand All @@ -25,7 +26,6 @@ import com.zaneschepke.wireguardautotunnel.service.tunnel.TunnelState
import com.zaneschepke.wireguardautotunnel.util.Constants
import com.zaneschepke.wireguardautotunnel.util.extensions.cancelWithMessage
import com.zaneschepke.wireguardautotunnel.util.extensions.getCurrentWifiName
import com.zaneschepke.wireguardautotunnel.util.extensions.isDown
import com.zaneschepke.wireguardautotunnel.util.extensions.isReachable
import com.zaneschepke.wireguardautotunnel.util.extensions.onNotRunning
import dagger.hilt.android.AndroidEntryPoint
Expand All @@ -34,7 +34,6 @@ import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.collectLatest
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.update
Expand All @@ -50,6 +49,7 @@ class AutoTunnelService : LifecycleService() {
private val foregroundId = 122

@Inject
@AppShell
lateinit var rootShell: Provider<RootShell>

@Inject
Expand Down Expand Up @@ -143,7 +143,7 @@ class AutoTunnelService : LifecycleService() {
super.onDestroy()
}

private fun launchWatcherNotification(description: String = getString(R.string.watcher_notification_text_active)) {
private fun launchWatcherNotification(description: String = getString(R.string.monitoring_state_changes)) {
val notification =
notificationService.createNotification(
channelId = getString(R.string.watcher_channel_id),
Expand Down Expand Up @@ -209,28 +209,16 @@ class AutoTunnelService : LifecycleService() {
when (status) {
is NetworkStatus.Available -> {
Timber.i("Gained Mobile data connection")
autoTunnelStateFlow.update {
it.copy(
isMobileDataConnected = true,
)
}
emitMobileDataConnected(true)
}

is NetworkStatus.CapabilitiesChanged -> {
autoTunnelStateFlow.update {
it.copy(
isMobileDataConnected = true,
)
}
emitMobileDataConnected(true)
Timber.i("Mobile data capabilities changed")
}

is NetworkStatus.Unavailable -> {
autoTunnelStateFlow.update {
it.copy(
isMobileDataConnected = false,
)
}
emitMobileDataConnected(false)
Timber.i("Lost mobile data connection")
}
}
Expand Down Expand Up @@ -283,6 +271,7 @@ class AutoTunnelService : LifecycleService() {
old.map { it.isActive } != new.map { it.isActive }
},
) { settings, tunnels ->
Timber.d("Tunnels or settings changed!")
autoTunnelStateFlow.value.copy(
settings = settings,
tunnels = tunnels,
Expand All @@ -299,7 +288,7 @@ class AutoTunnelService : LifecycleService() {
Timber.i("Starting vpn state watcher")
withContext(ioDispatcher) {
tunnelService.get().vpnState.distinctUntilChanged { old, new ->
old.tunnelConfig == new.tunnelConfig && old.status == new.status
old.tunnelConfig?.id == new.tunnelConfig?.id
}.collect { state ->
autoTunnelStateFlow.update {
it.copy(vpnState = state)
Expand Down Expand Up @@ -367,39 +356,55 @@ class AutoTunnelService : LifecycleService() {
mobileDataJob = null
}

private fun updateEthernet(connected: Boolean) {
private fun emitEthernetConnected(connected: Boolean) {
autoTunnelStateFlow.update {
it.copy(
isEthernetConnected = connected,
)
}
}

private fun updateWifi(connected: Boolean) {
private fun emitWifiConnected(connected: Boolean) {
autoTunnelStateFlow.update {
it.copy(
isWifiConnected = connected,
)
}
}

private fun emitWifiSSID(ssid: String) {
autoTunnelStateFlow.update {
it.copy(
currentNetworkSSID = ssid,
)
}
}

private fun emitMobileDataConnected(connected: Boolean) {
autoTunnelStateFlow.update {
it.copy(
isMobileDataConnected = connected,
)
}
}

private suspend fun watchForEthernetConnectivityChanges() {
withContext(ioDispatcher) {
Timber.i("Starting ethernet data watcher")
ethernetService.networkStatus.collect { status ->
when (status) {
is NetworkStatus.Available -> {
Timber.i("Gained Ethernet connection")
updateEthernet(true)
emitEthernetConnected(true)
}

is NetworkStatus.CapabilitiesChanged -> {
Timber.i("Ethernet capabilities changed")
updateEthernet(true)
emitEthernetConnected(true)
}

is NetworkStatus.Unavailable -> {
updateEthernet(false)
emitEthernetConnected(false)
Timber.i("Lost Ethernet connection")
}
}
Expand All @@ -414,12 +419,12 @@ class AutoTunnelService : LifecycleService() {
when (status) {
is NetworkStatus.Available -> {
Timber.i("Gained Wi-Fi connection")
updateWifi(true)
emitWifiConnected(true)
}

is NetworkStatus.CapabilitiesChanged -> {
Timber.i("Wifi capabilities changed")
updateWifi(true)
emitWifiConnected(true)
val ssid = getWifiSSID(status.networkCapabilities)
ssid?.let { name ->
if (name.contains(Constants.UNREADABLE_SSID)) {
Expand All @@ -428,16 +433,12 @@ class AutoTunnelService : LifecycleService() {
Timber.i("Detected valid SSID")
}
appDataRepository.appState.setCurrentSsid(name)
autoTunnelStateFlow.update {
it.copy(
currentNetworkSSID = name,
)
}
emitWifiSSID(name)
} ?: Timber.w("Failed to read ssid")
}

is NetworkStatus.Unavailable -> {
updateWifi(false)
emitWifiConnected(false)
Timber.i("Lost Wi-Fi connection")
}
}
Expand All @@ -461,17 +462,17 @@ class AutoTunnelService : LifecycleService() {
private suspend fun handleNetworkEventChanges() {
withContext(ioDispatcher) {
Timber.i("Starting network event watcher")
autoTunnelStateFlow.collectLatest { watcherState ->
autoTunnelStateFlow.collect { watcherState ->
val autoTunnel = "Auto-tunnel watcher"
Timber.d("New watcher state!")
// delay for rapid network state changes and then collect latest
delay(Constants.WATCHER_COLLECTION_DELAY)
val activeTunnel = watcherState.vpnState.tunnelConfig
val defaultTunnel = appDataRepository.getPrimaryOrFirstTunnel()
val isTunnelDown = tunnelService.get().getState() == TunnelState.DOWN
when {
watcherState.isEthernetConditionMet() -> {
Timber.i("$autoTunnel - tunnel on on ethernet condition met")
if (watcherState.vpnState.isDown()) {
if (isTunnelDown) {
defaultTunnel?.let {
tunnelService.get().startTunnel(it)
}
Expand All @@ -483,7 +484,7 @@ class AutoTunnelService : LifecycleService() {
val mobileDataTunnel = getMobileDataTunnel()
val tunnel =
mobileDataTunnel ?: defaultTunnel
if (watcherState.vpnState.isDown() || activeTunnel?.isMobileDataTunnel == false) {
if (isTunnelDown || activeTunnel?.isMobileDataTunnel == false) {
tunnel?.let {
tunnelService.get().startTunnel(it)
}
Expand All @@ -492,7 +493,7 @@ class AutoTunnelService : LifecycleService() {

watcherState.isTunnelOffOnMobileDataConditionMet() -> {
Timber.i("$autoTunnel - tunnel off on mobile data met, turning vpn off")
if (!watcherState.vpnState.isDown()) {
if (!isTunnelDown) {
activeTunnel?.let {
tunnelService.get().stopTunnel(it)
}
Expand All @@ -502,20 +503,20 @@ class AutoTunnelService : LifecycleService() {
watcherState.isUntrustedWifiConditionMet() -> {
Timber.i("Untrusted wifi condition met")
if (activeTunnel == null || watcherState.isCurrentSSIDActiveTunnelNetwork() == false ||
watcherState.vpnState.isDown()
isTunnelDown
) {
Timber.i(
"$autoTunnel - tunnel on ssid not associated with current tunnel condition met",
)
watcherState.getTunnelWithMatchingTunnelNetwork()?.let {
Timber.i("Found tunnel associated with this SSID, bringing tunnel up: ${it.name}")
if (watcherState.vpnState.isDown() || activeTunnel?.id != it.id) {
if (isTunnelDown || activeTunnel?.id != it.id) {
tunnelService.get().startTunnel(it)
}
} ?: suspend {
Timber.i("No tunnel associated with this SSID, using defaults")
val default = appDataRepository.getPrimaryOrFirstTunnel()
if (default?.name != tunnelService.get().name || watcherState.vpnState.isDown()) {
if (default?.name != tunnelService.get().name || isTunnelDown) {
default?.let {
tunnelService.get().startTunnel(it)
}
Expand All @@ -528,22 +529,22 @@ class AutoTunnelService : LifecycleService() {
Timber.i(
"$autoTunnel - tunnel off on trusted wifi condition met, turning vpn off",
)
if (!watcherState.vpnState.isDown()) activeTunnel?.let { tunnelService.get().stopTunnel(it) }
if (!isTunnelDown) activeTunnel?.let { tunnelService.get().stopTunnel(it) }
}

watcherState.isTunnelOffOnWifiConditionMet() -> {
Timber.i(
"$autoTunnel - tunnel off on wifi condition met, turning vpn off",
)
if (!watcherState.vpnState.isDown()) activeTunnel?.let { tunnelService.get().stopTunnel(it) }
}

watcherState.isTunnelOffOnNoConnectivityMet() -> {
Timber.i(
"$autoTunnel - tunnel off on no connectivity met, turning vpn off",
)
if (!watcherState.vpnState.isDown()) activeTunnel?.let { tunnelService.get().stopTunnel(it) }
if (!isTunnelDown) activeTunnel?.let { tunnelService.get().stopTunnel(it) }
}
// TODO disable for this now
// watcherState.isTunnelOffOnNoConnectivityMet() -> {
// Timber.i(
// "$autoTunnel - tunnel off on no connectivity met, turning vpn off",
// )
// if (!isTunnelDown) activeTunnel?.let { tunnelService.get().stopTunnel(it) }
// }

else -> {
Timber.i("$autoTunnel - no condition met")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ class ServiceManager
companion object : SingletonHolder<ServiceManager, Context>(::ServiceManager)

private fun <T : Service> startService(cls: Class<T>, background: Boolean) {
val intent = Intent(context, cls)
if (background) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
runCatching {
val intent = Intent(context, cls)
if (background) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
}.onFailure { Timber.e(it) }
}

suspend fun startAutoTunnel(background: Boolean) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TunnelBackgroundService : LifecycleService() {
return notificationService.createNotification(
getString(R.string.vpn_channel_id),
getString(R.string.vpn_channel_name),
getString(R.string.tunnel_start_text),
getString(R.string.tunnel_running),
description = "",
)
}
Expand Down
Loading

0 comments on commit cab2945

Please sign in to comment.