From 3a078844db60ca128109d3d27c86c96e2808dfbf Mon Sep 17 00:00:00 2001 From: lihenggui Date: Sat, 15 Jun 2024 18:32:40 -0700 Subject: [PATCH 1/9] Add database definition for traffic data Change-Id: Ia835f6de9abd2bb266f54c27dd40b015b0fc0c34 --- .../traffic/LocalTrafficDataRepository.kt | 44 +++++++++++++++ .../traffic/TrafficDataRepository.kt | 26 +++++++++ .../respository/traffic/TrafficDataSource.kt | 27 +++++++++ .../blocker/core/database/DaosModule.kt | 8 +++ .../blocker/core/database/DatabaseModule.kt | 9 +++ .../core/database/traffic/TrafficDataDao.kt | 38 +++++++++++++ .../database/traffic/TrafficDataDatabase.kt | 26 +++++++++ .../database/traffic/TrafficDataEntity.kt | 55 +++++++++++++++++++ .../blocker/core/model/data/TrafficData.kt | 28 ++++++++++ core/vpn/.gitignore | 1 + core/vpn/build.gradle.kts | 34 ++++++++++++ core/vpn/src/main/AndroidManifest.xml | 30 ++++++++++ .../blocker/core/vpn/BlockerVpnService.kt | 21 +++++++ settings.gradle.kts | 1 + 14 files changed, 348 insertions(+) create mode 100644 core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt create mode 100644 core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataRepository.kt create mode 100644 core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataSource.kt create mode 100644 core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDao.kt create mode 100644 core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt create mode 100644 core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataEntity.kt create mode 100644 core/model/src/main/kotlin/com/merxury/blocker/core/model/data/TrafficData.kt create mode 100644 core/vpn/.gitignore create mode 100644 core/vpn/build.gradle.kts create mode 100644 core/vpn/src/main/AndroidManifest.xml create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt diff --git a/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt new file mode 100644 index 0000000000..5b8fca62e7 --- /dev/null +++ b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.data.respository.traffic + +import com.merxury.blocker.core.database.traffic.TrafficDataDao +import com.merxury.blocker.core.database.traffic.asExternalModel +import com.merxury.blocker.core.database.traffic.fromExternalModel +import com.merxury.blocker.core.model.data.TrafficData +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import javax.inject.Inject + +class LocalTrafficDataRepository @Inject constructor( + private val trafficDataDao: TrafficDataDao, +) : TrafficDataRepository { + override fun insertTrafficData(trafficData: TrafficData) { + trafficDataDao.insert(trafficData.fromExternalModel()) + } + + override fun getTrafficData(packageName: String, keyword: String): Flow> { + return trafficDataDao.getTrafficData(packageName, keyword) + .map { trafficDataList -> + trafficDataList.map { it.asExternalModel() } + } + } + + override fun deleteTrafficData() { + trafficDataDao.deleteAll() + } +} diff --git a/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataRepository.kt b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataRepository.kt new file mode 100644 index 0000000000..d7b59fe00a --- /dev/null +++ b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataRepository.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.data.respository.traffic + +import com.merxury.blocker.core.model.data.TrafficData +import kotlinx.coroutines.flow.Flow + +interface TrafficDataRepository { + fun insertTrafficData(trafficData: TrafficData) + fun getTrafficData(packageName: String, keyword: String): Flow> + fun deleteTrafficData() +} diff --git a/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataSource.kt b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataSource.kt new file mode 100644 index 0000000000..36a053f79e --- /dev/null +++ b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/TrafficDataSource.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.data.respository.traffic + +import com.merxury.blocker.core.model.data.TrafficData +import kotlinx.coroutines.flow.Flow + +interface TrafficDataSource { + fun getTrafficData(packageName: String, keyword: String): Flow> + fun insertTrafficData(trafficData: TrafficData) + fun insertTrafficData(trafficData: List) + fun deleteTrafficData() +} diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt index bb9938f27e..db0980f160 100644 --- a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt @@ -21,6 +21,8 @@ import com.merxury.blocker.core.database.app.InstalledAppDao import com.merxury.blocker.core.database.app.InstalledAppDatabase import com.merxury.blocker.core.database.generalrule.GeneralRuleDao import com.merxury.blocker.core.database.generalrule.GeneralRuleDatabase +import com.merxury.blocker.core.database.traffic.TrafficDataDao +import com.merxury.blocker.core.database.traffic.TrafficDataDatabase import dagger.Module import dagger.Provides import dagger.hilt.InstallIn @@ -45,4 +47,10 @@ internal object DaosModule { fun provideGeneralRuleDao(database: GeneralRuleDatabase): GeneralRuleDao { return database.generalRuleDao() } + + @Provides + @Singleton + fun provideTrafficDataDao(database: TrafficDataDatabase): TrafficDataDao { + return database.trafficDataDao() + } } diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt index 05bfae2497..3fd685ea55 100644 --- a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt @@ -20,6 +20,7 @@ import android.content.Context import androidx.room.Room import com.merxury.blocker.core.database.app.InstalledAppDatabase import com.merxury.blocker.core.database.generalrule.GeneralRuleDatabase +import com.merxury.blocker.core.database.traffic.TrafficDataDatabase import dagger.Module import dagger.Provides import dagger.hilt.InstallIn @@ -51,4 +52,12 @@ internal object DatabaseModule { .fallbackToDestructiveMigration() .build() } + + @Provides + @Singleton + fun provideTrafficDataDatabase(@ApplicationContext context: Context): TrafficDataDatabase { + return Room.databaseBuilder(context, TrafficDataDatabase::class.java, "traffic_data") + .fallbackToDestructiveMigration() + .build() + } } diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDao.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDao.kt new file mode 100644 index 0000000000..70fc38cfd9 --- /dev/null +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDao.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.database.traffic + +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import kotlinx.coroutines.flow.Flow + +@Dao +interface TrafficDataDao { + @Insert(onConflict = OnConflictStrategy.REPLACE) + fun insert(trafficData: TrafficDataEntity) + + @Insert(onConflict = OnConflictStrategy.REPLACE) + fun insertAll(trafficData: List) + + @Query("SELECT * FROM traffic_data WHERE packageName = :packageName AND (domain LIKE '%' || :keyword || '%' OR path LIKE '%' || :keyword || '%') ORDER BY timestamp DESC") + fun getTrafficData(packageName: String, keyword: String): Flow> + + @Query("DELETE FROM traffic_data") + fun deleteAll() +} diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt new file mode 100644 index 0000000000..c6abb8a6e6 --- /dev/null +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.database.traffic + +import androidx.room.Database +import androidx.room.RoomDatabase +import com.merxury.blocker.core.model.data.TrafficData + +@Database(entities = [TrafficData::class], version = 1) +abstract class TrafficDataDatabase : RoomDatabase() { + abstract fun trafficDataDao(): TrafficDataDao +} diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataEntity.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataEntity.kt new file mode 100644 index 0000000000..424dc148a2 --- /dev/null +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataEntity.kt @@ -0,0 +1,55 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.database.traffic + +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.merxury.blocker.core.model.data.TrafficData + +@Entity(tableName = "traffic_data") +data class TrafficDataEntity( + @PrimaryKey(autoGenerate = true) val id: Long = 0, + val timestamp: Long, + val packageName: String, + val ipAddress: String, + val domain: String? = null, + val port: Int, + val path: String? = null, + val blocked: Boolean = false, +) + +fun TrafficDataEntity.asExternalModel() = TrafficData( + id = id, + timestamp = timestamp, + packageName = packageName, + ipAddress = ipAddress, + domain = domain, + port = port, + path = path, + blocked = blocked, +) + +fun TrafficData.fromExternalModel() = TrafficDataEntity( + id = id, + timestamp = timestamp, + packageName = packageName, + ipAddress = ipAddress, + domain = domain, + port = port, + path = path, + blocked = blocked, +) diff --git a/core/model/src/main/kotlin/com/merxury/blocker/core/model/data/TrafficData.kt b/core/model/src/main/kotlin/com/merxury/blocker/core/model/data/TrafficData.kt new file mode 100644 index 0000000000..807499e0d9 --- /dev/null +++ b/core/model/src/main/kotlin/com/merxury/blocker/core/model/data/TrafficData.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.model.data + +data class TrafficData( + val id: Long = 0, + val timestamp: Long, + val packageName: String, + val ipAddress: String, + val domain: String?, + val port: Int, + val path: String?, + val blocked: Boolean = false, +) diff --git a/core/vpn/.gitignore b/core/vpn/.gitignore new file mode 100644 index 0000000000..42afabfd2a --- /dev/null +++ b/core/vpn/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/core/vpn/build.gradle.kts b/core/vpn/build.gradle.kts new file mode 100644 index 0000000000..06e5ceb325 --- /dev/null +++ b/core/vpn/build.gradle.kts @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + alias(libs.plugins.blocker.android.library) + alias(libs.plugins.blocker.android.library.jacoco) + alias(libs.plugins.blocker.android.hilt) + id("kotlin-parcelize") +} + +android { + namespace = "com.merxury.blocker.core.vpn" +} + +dependencies { + api(libs.timber) + implementation(projects.core.common) + implementation(projects.core.data) + + testImplementation(libs.kotlinx.coroutines.test) + testImplementation(libs.turbine) +} diff --git a/core/vpn/src/main/AndroidManifest.xml b/core/vpn/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..0e3da64312 --- /dev/null +++ b/core/vpn/src/main/AndroidManifest.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt new file mode 100644 index 0000000000..5f19b596f5 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn + +import android.net.VpnService + +class BlockerVpnService : VpnService() diff --git a/settings.gradle.kts b/settings.gradle.kts index 86478b2132..9cf7becddc 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -57,6 +57,7 @@ include(":core:rule") include(":core:screenshot-testing") include(":core:testing") include(":core:ui") +include(":core:vpn") include(":feature:applist") include(":feature:appdetail") include(":feature:generalrule") From 17c11cae75ac42c4e6af7db3c5194cc31c963913 Mon Sep 17 00:00:00 2001 From: lihenggui Date: Tue, 18 Jun 2024 15:48:49 -0700 Subject: [PATCH 2/9] Use correct entity class Change-Id: Ib7855b44c1ddcc349aafc7e5f14e870fc95d474e --- .../1.json | 76 +++++++++++++++++++ .../database/traffic/TrafficDataDatabase.kt | 2 +- 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 core/database/schemas/com.merxury.blocker.core.database.traffic.TrafficDataDatabase/1.json diff --git a/core/database/schemas/com.merxury.blocker.core.database.traffic.TrafficDataDatabase/1.json b/core/database/schemas/com.merxury.blocker.core.database.traffic.TrafficDataDatabase/1.json new file mode 100644 index 0000000000..6c9f702241 --- /dev/null +++ b/core/database/schemas/com.merxury.blocker.core.database.traffic.TrafficDataDatabase/1.json @@ -0,0 +1,76 @@ +{ + "formatVersion": 1, + "database": { + "version": 1, + "identityHash": "4e6d693b5a2004c66c250ff874b42300", + "entities": [ + { + "tableName": "traffic_data", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `timestamp` INTEGER NOT NULL, `packageName` TEXT NOT NULL, `ipAddress` TEXT NOT NULL, `domain` TEXT, `port` INTEGER NOT NULL, `path` TEXT, `blocked` INTEGER NOT NULL)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "timestamp", + "columnName": "timestamp", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "packageName", + "columnName": "packageName", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "ipAddress", + "columnName": "ipAddress", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "domain", + "columnName": "domain", + "affinity": "TEXT", + "notNull": false + }, + { + "fieldPath": "port", + "columnName": "port", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "path", + "columnName": "path", + "affinity": "TEXT", + "notNull": false + }, + { + "fieldPath": "blocked", + "columnName": "blocked", + "affinity": "INTEGER", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + } + ], + "views": [], + "setupQueries": [ + "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '4e6d693b5a2004c66c250ff874b42300')" + ] + } +} \ No newline at end of file diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt index c6abb8a6e6..6940524eb0 100644 --- a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt @@ -20,7 +20,7 @@ import androidx.room.Database import androidx.room.RoomDatabase import com.merxury.blocker.core.model.data.TrafficData -@Database(entities = [TrafficData::class], version = 1) +@Database(entities = [TrafficDataEntity::class], version = 1) abstract class TrafficDataDatabase : RoomDatabase() { abstract fun trafficDataDao(): TrafficDataDao } From 7d9d4222333686c28e1a447d83e92d9e4642ea69 Mon Sep 17 00:00:00 2001 From: lihenggui <350699171@qq.com> Date: Sat, 29 Jun 2024 20:57:10 -0700 Subject: [PATCH 3/9] Spotless --- .../respository/traffic/LocalTrafficDataRepository.kt | 10 ++++------ .../com/merxury/blocker/core/database/DaosModule.kt | 4 +--- .../merxury/blocker/core/database/DatabaseModule.kt | 8 +++----- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt index 5b8fca62e7..2d7a2869f4 100644 --- a/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt +++ b/core/data/src/main/kotlin/com/merxury/blocker/core/data/respository/traffic/LocalTrafficDataRepository.kt @@ -31,12 +31,10 @@ class LocalTrafficDataRepository @Inject constructor( trafficDataDao.insert(trafficData.fromExternalModel()) } - override fun getTrafficData(packageName: String, keyword: String): Flow> { - return trafficDataDao.getTrafficData(packageName, keyword) - .map { trafficDataList -> - trafficDataList.map { it.asExternalModel() } - } - } + override fun getTrafficData(packageName: String, keyword: String): Flow> = trafficDataDao.getTrafficData(packageName, keyword) + .map { trafficDataList -> + trafficDataList.map { it.asExternalModel() } + } override fun deleteTrafficData() { trafficDataDao.deleteAll() diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt index 661394a606..4909232c42 100644 --- a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DaosModule.kt @@ -44,7 +44,5 @@ internal object DaosModule { @Provides @Singleton - fun provideTrafficDataDao(database: TrafficDataDatabase): TrafficDataDao { - return database.trafficDataDao() - } + fun provideTrafficDataDao(database: TrafficDataDatabase): TrafficDataDao = database.trafficDataDao() } diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt index b163323eb8..7edb6228b1 100644 --- a/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/DatabaseModule.kt @@ -51,9 +51,7 @@ internal object DatabaseModule { @Provides @Singleton - fun provideTrafficDataDatabase(@ApplicationContext context: Context): TrafficDataDatabase { - return Room.databaseBuilder(context, TrafficDataDatabase::class.java, "traffic_data") - .fallbackToDestructiveMigration() - .build() - } + fun provideTrafficDataDatabase(@ApplicationContext context: Context): TrafficDataDatabase = Room.databaseBuilder(context, TrafficDataDatabase::class.java, "traffic_data") + .fallbackToDestructiveMigration() + .build() } From b85513d5fc2fa5fd9a62b666d62c1c9cc0f33db5 Mon Sep 17 00:00:00 2001 From: lihenggui Date: Mon, 1 Jul 2024 14:19:05 -0700 Subject: [PATCH 4/9] Add basic implementation for the VpnService Change-Id: I20c19e70550066aa5c103c027393a0792f0d212a --- .../database/traffic/TrafficDataDatabase.kt | 1 - .../blocker/core/vpn/BlockerVpnService.kt | 74 +- .../com/merxury/blocker/core/vpn/VpnQueue.kt | 892 ++++++++++++++++++ .../blocker/core/vpn/protocol/IpUtil.kt | 102 ++ .../blocker/core/vpn/protocol/Packet.kt | 499 ++++++++++ .../blocker/core/vpn/protocol/TcbStatus.kt | 26 + 6 files changed, 1592 insertions(+), 2 deletions(-) create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcbStatus.kt diff --git a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt index 6940524eb0..5c4996181d 100644 --- a/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt +++ b/core/database/src/main/kotlin/com/merxury/blocker/core/database/traffic/TrafficDataDatabase.kt @@ -18,7 +18,6 @@ package com.merxury.blocker.core.database.traffic import androidx.room.Database import androidx.room.RoomDatabase -import com.merxury.blocker.core.model.data.TrafficData @Database(entities = [TrafficDataEntity::class], version = 1) abstract class TrafficDataDatabase : RoomDatabase() { diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt index 5f19b596f5..0d3d976908 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt @@ -17,5 +17,77 @@ package com.merxury.blocker.core.vpn import android.net.VpnService +import android.os.Build +import android.os.ParcelFileDescriptor +import com.merxury.blocker.core.di.ApplicationScope +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import dagger.hilt.android.AndroidEntryPoint +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import javax.inject.Inject -class BlockerVpnService : VpnService() +@AndroidEntryPoint +class BlockerVpnService : VpnService() { + + @Inject + @ApplicationScope + lateinit var applicationScope: CoroutineScope + + @Inject + @Dispatcher(IO) + lateinit var ioDispatcher: CoroutineDispatcher + + private var vpnInterface: ParcelFileDescriptor? = null + + override fun onCreate() { + super.onCreate() + UdpSendWorker.start(this) + UdpReceiveWorker.start(this) + UdpSocketCleanWorker.start() + TcpWorker.start(this) + startVpn() + } + + override fun onDestroy() { + super.onDestroy() + disconnect() + UdpSendWorker.stop() + UdpReceiveWorker.stop() + UdpSocketCleanWorker.stop() + TcpWorker.stop() + vpnInterface?.close() + vpnInterface = null + } + + private fun startVpn() { + val builder = Builder() + builder.addAddress("10.0.0.2", 24) + builder.addRoute("0.0.0.0", 0) + builder.addDnsServer("8.8.8.8") + builder.addDnsServer("8.8.4.4") + vpnInterface = builder.establish() + + vpnInterface?.let { + runVpn(it) + } + } + + private fun runVpn(vpnInterface: ParcelFileDescriptor) { + val fileDescriptor = vpnInterface.fileDescriptor + ToNetworkQueueWorker.start(fileDescriptor) + ToDeviceQueueWorker.start(fileDescriptor) + } + + private fun disconnect() { + ToNetworkQueueWorker.stop() + ToDeviceQueueWorker.stop() + vpnInterface?.close() + vpnInterface = null + if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.N) { + stopForeground(true) + } else { + stopForeground(STOP_FOREGROUND_REMOVE) + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt new file mode 100644 index 0000000000..7b4e902c34 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt @@ -0,0 +1,892 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn + +import android.annotation.SuppressLint +import android.net.VpnService +import android.os.Build +import android.util.Base64 +import com.merxury.blocker.core.vpn.protocol.IpUtil +import com.merxury.blocker.core.vpn.protocol.Packet +import com.merxury.blocker.core.vpn.protocol.Packet.TCPHeader +import com.merxury.blocker.core.vpn.protocol.TcbStatus +import timber.log.Timber +import java.io.FileDescriptor +import java.io.FileInputStream +import java.io.FileOutputStream +import java.io.IOException +import java.net.ConnectException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.ClosedByInterruptException +import java.nio.channels.DatagramChannel +import java.nio.channels.FileChannel +import java.nio.channels.SelectionKey +import java.nio.channels.Selector +import java.nio.channels.SocketChannel +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.atomic.AtomicInteger +import kotlin.experimental.and +import kotlin.experimental.or + +/** + * Queue for UDP packets sent from device to network + */ +internal val deviceToNetworkUDPQueue = ArrayBlockingQueue(1024) + +/** + * Queue for TCP packets sent from device to network + */ +internal val deviceToNetworkTCPQueue = ArrayBlockingQueue(1024) + +/** + * Queue for packets sent from network to device + */ +internal val networkToDeviceQueue = ArrayBlockingQueue(1024) + +/** + * TCP forwarding network selector + */ +internal val tcpNioSelector: Selector = Selector.open() + +/** + * Queue for UDP forwarding channels + */ +internal val udpTunnelQueue = ArrayBlockingQueue(1024) + +/** + * UDP forwarding network selector + */ +internal val udpNioSelector: Selector = Selector.open() + +/** + * Existing UDP socket map + * key is target host address:target port:request port + */ +internal val udpSocketMap = HashMap() + +const val UDP_SOCKET_IDLE_TIMEOUT = 60 + +/** + * Worker thread to handle packets sent from device to network + */ +object ToNetworkQueueWorker : Runnable { + private const val TAG = "ToNetworkQueueWorker" + + /** + * Self thread + */ + private lateinit var thread: Thread + + /** + * Channel to read data from the device + */ + private lateinit var vpnInput: FileChannel + + /** + * Total bytes read count + */ + var totalInputCount = 0L + + fun start(vpnFileDescriptor: FileDescriptor) { + if (this::thread.isInitialized && thread.isAlive) throw IllegalStateException("Already running") + vpnInput = FileInputStream(vpnFileDescriptor).channel + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + if (this::thread.isInitialized) { + thread.interrupt() + } + } + + override fun run() { + val readBuffer = ByteBuffer.allocate(16384) + while (!thread.isInterrupted) { + var readCount = 0 + try { + readCount = vpnInput.read(readBuffer) + } catch (e: IOException) { + e.printStackTrace() + continue + } + if (readCount > 0) { + readBuffer.flip() + val byteArray = ByteArray(readCount) + readBuffer.get(byteArray) + + val byteBuffer = ByteBuffer.wrap(byteArray) + totalInputCount += readCount + + val packet = Packet(byteBuffer) + if (packet.isUDP) { + deviceToNetworkUDPQueue.offer(packet) + } else if (packet.isTCP) { + deviceToNetworkTCPQueue.offer(packet) + } else { + Timber.d("Unknown packet protocol type ${packet.ip4Header?.protocolNum}") + } + } else if (readCount < 0) { + break + } + readBuffer.clear() + } + Timber.i("ToNetworkQueueWorker finished running") + } +} + +/** + * Worker thread to handle packets sent from network to device + */ +object ToDeviceQueueWorker : Runnable { + private const val TAG = "ToDeviceQueueWorker" + + /** + * Self thread + */ + private lateinit var thread: Thread + + /** + * Total bytes written count + */ + var totalOutputCount = 0L + + /** + * Channel to write data to the device + */ + private lateinit var vpnOutput: FileChannel + + fun start(vpnFileDescriptor: FileDescriptor) { + if (this::thread.isInitialized && thread.isAlive) throw IllegalStateException("Already running") + vpnOutput = FileOutputStream(vpnFileDescriptor).channel + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + if (this::thread.isInitialized) { + thread.interrupt() + } + } + + override fun run() { + try { + while (!thread.isInterrupted) { + val byteBuffer = networkToDeviceQueue.take() + byteBuffer.flip() + while (byteBuffer.hasRemaining()) { + val count = vpnOutput.write(byteBuffer) + if (count > 0) { + totalOutputCount += count + } + } + } + } catch (_: InterruptedException) { + } catch (_: ClosedByInterruptException) { + } + } +} + +/** + * UDP forwarding channel data + */ +data class UdpTunnel( + val id: String, + val local: InetSocketAddress, + val remote: InetSocketAddress, + val channel: DatagramChannel, +) + +data class ManagedDatagramChannel( + val id: String, + val channel: DatagramChannel, + var lastTime: Long = System.currentTimeMillis(), +) + +/** + * Worker thread to send UDP packets + */ +@SuppressLint("StaticFieldLeak") +object UdpSendWorker : Runnable { + private const val TAG = "UdpSendWorker" + + /** + * Self thread + */ + private lateinit var thread: Thread + + private var vpnService: VpnService? = null + + fun start(vpnService: VpnService) { + this.vpnService = vpnService + udpTunnelQueue.clear() + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + if (this::thread.isInitialized) { + thread.interrupt() + } + vpnService = null + } + + override fun run() { + while (!thread.isInterrupted) { + val packet = deviceToNetworkUDPQueue.take() + + val destinationAddress = packet.ip4Header?.destinationAddress + val udpHeader = packet.udpHeader + + val destinationPort = udpHeader?.destinationPort ?: 0 + val sourcePort = udpHeader?.sourcePort + val ipAndPort = ( + destinationAddress?.hostAddress?.plus(":") + ?: "unknownHostAddress" + ) + destinationPort + ":" + sourcePort + + // Create new socket + val managedChannel = if (!udpSocketMap.containsKey(ipAndPort)) { + val channel = DatagramChannel.open() + var channelConnectSuccess = false + channel.apply { + val socket = socket() + vpnService?.protect(socket) + try { + connect(InetSocketAddress(destinationAddress, destinationPort)) + channelConnectSuccess = true + } catch (_: ConnectException) { + } + configureBlocking(false) + } + if (!channelConnectSuccess) { + continue + } + + val tunnel = UdpTunnel( + ipAndPort, + InetSocketAddress(packet.ip4Header?.sourceAddress, udpHeader?.sourcePort ?: 0), + InetSocketAddress( + packet.ip4Header?.destinationAddress, + udpHeader?.destinationPort ?: 0, + ), + channel, + ) + udpTunnelQueue.offer(tunnel) + udpNioSelector.wakeup() + + val managedDatagramChannel = ManagedDatagramChannel(ipAndPort, channel) + synchronized(udpSocketMap) { + udpSocketMap[ipAndPort] = managedDatagramChannel + } + managedDatagramChannel + } else { + synchronized(udpSocketMap) { + udpSocketMap[ipAndPort] + ?: throw IllegalStateException("udp:udpSocketMap[$ipAndPort] should not be null") + } + } + managedChannel.lastTime = System.currentTimeMillis() + val buffer = packet.backingBuffer + kotlin.runCatching { + while (!thread.isInterrupted && buffer?.hasRemaining() == true) { + managedChannel.channel.write(buffer) + } + }.exceptionOrNull()?.let { + Timber.e("Error sending UDP packet", it) + managedChannel.channel.close() + synchronized(udpSocketMap) { + udpSocketMap.remove(ipAndPort) + } + } + } + } +} + +/** + * Worker thread to receive UDP packets + */ +@SuppressLint("StaticFieldLeak") +object UdpReceiveWorker : Runnable { + + private const val TAG = "UdpReceiveWorker" + + /** + * Self thread + */ + private lateinit var thread: Thread + + private var vpnService: VpnService? = null + + private var ipId = AtomicInteger() + + private const val UDP_HEADER_FULL_SIZE = Packet.IP4_HEADER_SIZE + Packet.UDP_HEADER_SIZE + + fun start(vpnService: VpnService) { + this.vpnService = vpnService + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + thread.interrupt() + } + + private fun sendUdpPacket(tunnel: UdpTunnel, source: InetSocketAddress, data: ByteArray) { + val packet = IpUtil.buildUdpPacket(tunnel.remote, tunnel.local, ipId.addAndGet(1)) + + val byteBuffer = ByteBuffer.allocate(UDP_HEADER_FULL_SIZE + data.size) + byteBuffer.apply { + position(UDP_HEADER_FULL_SIZE) + put(data) + } + packet.updateUDPBuffer(byteBuffer, data.size) + byteBuffer.position(UDP_HEADER_FULL_SIZE + data.size) + networkToDeviceQueue.offer(byteBuffer) + } + + override fun run() { + val receiveBuffer = ByteBuffer.allocate(16384) + while (!thread.isInterrupted) { + val readyChannels = udpNioSelector.select() + while (!thread.isInterrupted) { + val tunnel = udpTunnelQueue.poll() ?: break + kotlin.runCatching { + val key = tunnel.channel.register(udpNioSelector, SelectionKey.OP_READ, tunnel) + key.interestOps(SelectionKey.OP_READ) + }.exceptionOrNull()?.printStackTrace() + } + if (readyChannels == 0) { + udpNioSelector.selectedKeys().clear() + continue + } + val keys = udpNioSelector.selectedKeys() + val iterator = keys.iterator() + while (!thread.isInterrupted && iterator.hasNext()) { + val key = iterator.next() + iterator.remove() + if (key.isValid && key.isReadable) { + val tunnel = key.attachment() as UdpTunnel + kotlin.runCatching { + val inputChannel = key.channel() as DatagramChannel + receiveBuffer.clear() + inputChannel.read(receiveBuffer) + receiveBuffer.flip() + val data = ByteArray(receiveBuffer.remaining()) + receiveBuffer.get(data) + sendUdpPacket( + tunnel, + inputChannel.socket().localSocketAddress as InetSocketAddress, + data, + ) // todo api 21->24 + }.exceptionOrNull()?.let { + it.printStackTrace() + synchronized(udpSocketMap) { + udpSocketMap.remove(tunnel.id) + } + } + } + } + } + } +} + +/** + * Worker thread to clean up expired UDP sockets + */ +object UdpSocketCleanWorker : Runnable { + + private const val TAG = "UdpSocketCleanWorker" + + /** + * Self thread + */ + private lateinit var thread: Thread + + /** + * Check interval in seconds + */ + private const val INTERVAL_TIME = 5L + + fun start() { + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + thread.interrupt() + } + + override fun run() { + while (!thread.isInterrupted) { + synchronized(udpSocketMap) { + val iterator = udpSocketMap.iterator() + var removeCount = 0 + while (!thread.isInterrupted && iterator.hasNext()) { + val managedDatagramChannel = iterator.next() + if (System.currentTimeMillis() - managedDatagramChannel.value.lastTime > UDP_SOCKET_IDLE_TIMEOUT * 1000) { + kotlin.runCatching { + managedDatagramChannel.value.channel.close() + }.exceptionOrNull()?.printStackTrace() + iterator.remove() + removeCount++ + } + } + if (removeCount > 0) { + Timber.d("Removed $removeCount expired inactive UDP sockets, currently active ${udpSocketMap.size}") + } + } + Thread.sleep(INTERVAL_TIME * 1000) + } + } +} + +internal class TcpPipe(val tunnelKey: String, packet: Packet) { + var mySequenceNum: Long = 0 + var theirSequenceNum: Long = 0 + var myAcknowledgementNum: Long = 0 + var theirAcknowledgementNum: Long = 0 + val tunnelId = tunnelIds++ + + val sourceAddress: InetSocketAddress = + InetSocketAddress(packet.ip4Header?.sourceAddress, packet.tcpHeader?.sourcePort ?: 0) + val destinationAddress: InetSocketAddress = InetSocketAddress( + packet.ip4Header?.destinationAddress, + packet.tcpHeader?.destinationPort ?: 0, + ) + val remoteSocketChannel: SocketChannel = + SocketChannel.open().also { it.configureBlocking(false) } + val remoteSocketChannelKey: SelectionKey = + remoteSocketChannel.register(tcpNioSelector, SelectionKey.OP_CONNECT) + .also { it.attach(this) } + + var tcbStatus: TcbStatus = TcbStatus.SYN_SENT + var remoteOutBuffer: ByteBuffer? = null + + var upActive = true + var downActive = true + var packId = 1 + var timestamp = System.currentTimeMillis() + var synCount = 0 + + fun tryConnect(vpnService: VpnService): Result { + val result = kotlin.runCatching { + vpnService.protect(remoteSocketChannel.socket()) + remoteSocketChannel.connect(destinationAddress) + } + return result + } + + companion object { + const val TAG = "TcpPipe" + var tunnelIds = 0 + } +} + +/** + * TCP packet worker thread + * NIO + */ +@SuppressLint("StaticFieldLeak") +object TcpWorker : Runnable { + private const val TAG = "TcpSendWorker" + + private const val TCP_HEADER_SIZE = Packet.IP4_HEADER_SIZE + Packet.TCP_HEADER_SIZE + + private lateinit var thread: Thread + + private val pipeMap = HashMap() + + private var vpnService: VpnService? = null + + fun start(vpnService: VpnService) { + this.vpnService = vpnService + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + thread.interrupt() + vpnService = null + } + + override fun run() { + while (!thread.isInterrupted) { + if (vpnService == null) { + throw IllegalStateException("VpnService should not be null") + } + handleReadFromVpn() + handleSockets() + + Thread.sleep(1) + } + } + + private fun handleReadFromVpn() { + while (!thread.isInterrupted) { + val vpnService = this.vpnService ?: return + val packet = deviceToNetworkTCPQueue.poll() ?: return + val destinationAddress = packet.ip4Header?.destinationAddress + val tcpHeader = packet.tcpHeader + val destinationPort = tcpHeader?.destinationPort + val sourcePort = tcpHeader?.sourcePort + + val ipAndPort = ( + destinationAddress?.hostAddress?.plus(":") + ?: "unknown-host-address" + ) + destinationPort + ":" + sourcePort + + val tcpPipe = if (!pipeMap.containsKey(ipAndPort)) { + val pipe = TcpPipe(ipAndPort, packet) + pipe.tryConnect(vpnService) + pipeMap[ipAndPort] = pipe + pipe + } else { + pipeMap[ipAndPort] + ?: throw IllegalStateException("pipeMap should not contain null key: $ipAndPort") + } + handlePacket(packet, tcpPipe) + } + } + + private fun handleSockets() { + while (!thread.isInterrupted && tcpNioSelector.selectNow() > 0) { + val keys = tcpNioSelector.selectedKeys() + val iterator = keys.iterator() + while (!thread.isInterrupted && iterator.hasNext()) { + val key = iterator.next() + iterator.remove() + val tcpPipe: TcpPipe? = key?.attachment() as? TcpPipe + if (key.isValid) { + kotlin.runCatching { + if (key.isAcceptable) { + throw RuntimeException("key.isAcceptable") + } else if (key.isReadable) { + tcpPipe?.doRead() + } else if (key.isConnectable) { + tcpPipe?.doConnect() + } else if (key.isWritable) { + tcpPipe?.doWrite() + } else { + tcpPipe?.closeRst() + } + null + }.exceptionOrNull()?.let { + Timber.d( + "Error communicating with target: ${ + Base64.encodeToString( + tcpPipe?.destinationAddress.toString().toByteArray(), + Base64.DEFAULT, + ) + }", + ) + it.printStackTrace() + tcpPipe?.closeRst() + } + } + } + } + } + + private fun handlePacket(packet: Packet, tcpPipe: TcpPipe) { + val tcpHeader = packet.tcpHeader ?: return + when { + tcpHeader.isSYN -> { + handleSyn(packet, tcpPipe) + } + + tcpHeader.isRST -> { + handleRst(tcpPipe) + } + + tcpHeader.isFIN -> { + handleFin(packet, tcpPipe) + } + + tcpHeader.isACK -> { + handleAck(packet, tcpPipe) + } + } + } + + private fun handleSyn(packet: Packet, tcpPipe: TcpPipe) { + if (tcpPipe.tcbStatus == TcbStatus.SYN_SENT) { + tcpPipe.tcbStatus = TcbStatus.SYN_RECEIVED + } + val tcpHeader = packet.tcpHeader + tcpPipe.apply { + if (synCount == 0) { + mySequenceNum = 1 + theirSequenceNum = tcpHeader?.sequenceNumber ?: 0 + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 + theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 + sendTcpPack(this, TCPHeader.SYN.toByte() or TCPHeader.ACK.toByte()) + } else { + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 + } + synCount++ + } + } + + private fun handleRst(tcpPipe: TcpPipe) { + tcpPipe.apply { + upActive = false + downActive = false + clean() + tcbStatus = TcbStatus.CLOSE_WAIT + } + } + + private fun handleFin(packet: Packet, tcpPipe: TcpPipe) { + tcpPipe.myAcknowledgementNum = packet.tcpHeader?.sequenceNumber?.plus(1) ?: 0 + tcpPipe.theirAcknowledgementNum = packet.tcpHeader?.acknowledgementNumber?.plus(1) ?: 0 + sendTcpPack(tcpPipe, TCPHeader.ACK.toByte()) + tcpPipe.closeUpStream() + tcpPipe.tcbStatus = TcbStatus.CLOSE_WAIT + } + + private fun handleAck(packet: Packet, tcpPipe: TcpPipe) { + if (tcpPipe.tcbStatus == TcbStatus.SYN_RECEIVED) { + tcpPipe.tcbStatus = TcbStatus.ESTABLISHED + } + + val tcpHeader = packet.tcpHeader + val payloadSize = packet.backingBuffer?.remaining() ?: 0 + + if (payloadSize == 0) { + return + } + + val newAck = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 + if (newAck <= tcpPipe.myAcknowledgementNum) { + return + } + + tcpPipe.apply { + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 + theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 + remoteOutBuffer = packet.backingBuffer + tryFlushWrite(this) + sendTcpPack(this, TCPHeader.ACK.toByte()) + } + } + + /** + * Send TCP packet + */ + private fun sendTcpPack(tcpPipe: TcpPipe, flag: Byte, data: ByteArray? = null) { + val dataSize = data?.size ?: 0 + + val packet = IpUtil.buildTcpPacket( + tcpPipe.destinationAddress, + tcpPipe.sourceAddress, + flag, + tcpPipe.myAcknowledgementNum, + tcpPipe.mySequenceNum, + tcpPipe.packId, + ) + tcpPipe.packId++ + + val byteBuffer = ByteBuffer.allocate(TCP_HEADER_SIZE + dataSize) + byteBuffer.position(TCP_HEADER_SIZE) + + data?.let { + byteBuffer.put(it) + } + + packet.updateTCPBuffer( + byteBuffer, + flag, + tcpPipe.mySequenceNum, + tcpPipe.myAcknowledgementNum, + dataSize, + ) + packet.release() + + byteBuffer.position(TCP_HEADER_SIZE + dataSize) + + networkToDeviceQueue.offer(byteBuffer) + + if ((flag and TCPHeader.SYN.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum++ + } + if ((flag and TCPHeader.FIN.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum++ + } + if ((flag and TCPHeader.ACK.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum += dataSize + } + } + + /** + * Write data to the remote + */ + private fun tryFlushWrite(tcpPipe: TcpPipe): Boolean { + val channel: SocketChannel = tcpPipe.remoteSocketChannel + val buffer = tcpPipe.remoteOutBuffer + + if (tcpPipe.remoteSocketChannel.socket().isOutputShutdown && buffer?.remaining() != 0) { + sendTcpPack(tcpPipe, TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte()) + buffer?.compact() + return false + } + + if (!channel.isConnected) { + val key = tcpPipe.remoteSocketChannelKey + val ops = key.interestOps() or SelectionKey.OP_WRITE + key.interestOps(ops) + buffer?.compact() + return false + } + + while (!thread.isInterrupted && buffer?.hasRemaining() == true) { + val n = kotlin.runCatching { + channel.write(buffer) + } + if (n.isFailure) return false + if (n.getOrThrow() <= 0) { + val key = tcpPipe.remoteSocketChannelKey + val ops = key.interestOps() or SelectionKey.OP_WRITE + key.interestOps(ops) + buffer.compact() + return false + } + } + buffer?.clear() + if (!tcpPipe.upActive) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + tcpPipe.remoteSocketChannel.shutdownOutput() + } else { + // todo The following line will cause the socket to be incorrectly handled, but what if we don't handle it here? + // tcpPipe.remoteSocketChannel.close() + } + } + return true + } + + private fun TcpPipe.closeRst() { + Timber.d("closeRst $tunnelId") + clean() + sendTcpPack(this, TCPHeader.RST.toByte()) + upActive = false + downActive = false + } + + private fun TcpPipe.doRead() { + val buffer = ByteBuffer.allocate(4096) + var isQuitType = false + + while (!thread.isInterrupted) { + buffer.clear() + val length = remoteSocketChannel.read(buffer) + if (length == -1) { + isQuitType = true + break + } else if (length == 0) { + break + } else { + if (tcbStatus != TcbStatus.CLOSE_WAIT) { + buffer.flip() + val dataByteArray = ByteArray(buffer.remaining()) + buffer.get(dataByteArray) + sendTcpPack(this, TCPHeader.ACK.toByte(), dataByteArray) + } + } + } + + if (isQuitType) { + closeDownStream() + } + } + + private fun TcpPipe.doConnect() { + remoteSocketChannel.finishConnect() + timestamp = System.currentTimeMillis() + remoteOutBuffer?.flip() + remoteSocketChannelKey.interestOps(SelectionKey.OP_READ or SelectionKey.OP_WRITE) + } + + private fun TcpPipe.doWrite() { + if (tryFlushWrite(this)) { + remoteSocketChannelKey.interestOps(SelectionKey.OP_READ) + } + } + + private fun TcpPipe.clean() { + kotlin.runCatching { + if (remoteSocketChannel.isOpen) { + remoteSocketChannel.close() + } + remoteOutBuffer = null + pipeMap.remove(tunnelKey) + }.exceptionOrNull()?.printStackTrace() + } + + private fun TcpPipe.closeUpStream() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + kotlin.runCatching { + if (remoteSocketChannel.isOpen && remoteSocketChannel.isConnected) { + remoteSocketChannel.shutdownOutput() + } + }.exceptionOrNull()?.printStackTrace() + upActive = false + + if (!downActive) { + clean() + } + } else { + upActive = false + downActive = false + clean() + } + } + + private fun TcpPipe.closeDownStream() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + kotlin.runCatching { + if (remoteSocketChannel.isConnected) { + remoteSocketChannel.shutdownInput() + val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() + remoteSocketChannelKey.interestOps(ops) + } + sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + downActive = false + if (!upActive) { + clean() + } + } + } else { + sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + upActive = false + downActive = false + clean() + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt new file mode 100644 index 0000000000..8f01845bfd --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt @@ -0,0 +1,102 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.protocol + +import java.net.InetSocketAddress + +internal object IpUtil { + fun buildUdpPacket(source: InetSocketAddress, dest: InetSocketAddress, ipId: Int): Packet { + val packet = Packet().apply { + isTCP = false + isUDP = true + } + + val ip4Header = Packet.IP4Header().apply { + version = 4 + ihl = 5 + destinationAddress = dest.address + headerChecksum = 0 + headerLength = 20 + identificationAndFlagsAndFragmentOffset = ipId shl 16 or (0x40 shl 8) or 0 + optionsAndPadding = 0 + protocol = Packet.IP4Header.TransportProtocol.UDP + protocolNum = 17 + sourceAddress = source.address + totalLength = 60 + typeOfService = 0 + ttl = 64 + } + + val udpHeader = Packet.UDPHeader().apply { + sourcePort = source.port + destinationPort = dest.port + length = 0 + } + + packet.ip4Header = ip4Header + packet.udpHeader = udpHeader + return packet + } + + fun buildTcpPacket( + source: InetSocketAddress, + dest: InetSocketAddress, + flag: Byte, + ack: Long, + seq: Long, + ipId: Int, + ): Packet { + val packet = Packet().apply { + isTCP = true + isUDP = false + } + + val ip4Header = Packet.IP4Header().apply { + version = 4 + ihl = 5 + destinationAddress = dest.address + headerChecksum = 0 + headerLength = 20 + identificationAndFlagsAndFragmentOffset = ipId shl 16 or (0x40 shl 8) or 0 + optionsAndPadding = 0 + protocol = Packet.IP4Header.TransportProtocol.TCP + protocolNum = 6 + sourceAddress = source.address + totalLength = 60 + typeOfService = 0 + ttl = 64 + } + + val tcpHeader = Packet.TCPHeader().apply { + acknowledgementNumber = ack + checksum = 0 + dataOffsetAndReserved = -96 + destinationPort = dest.port + flags = flag + headerLength = 40 + optionsAndPadding = null + sequenceNumber = seq + sourcePort = source.port + urgentPointer = 0 + window = 65535 + } + + packet.ip4Header = ip4Header + packet.tcpHeader = tcpHeader + return packet + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt new file mode 100644 index 0000000000..4b7b982135 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt @@ -0,0 +1,499 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.protocol + +import com.merxury.blocker.core.vpn.protocol.Packet.IP4Header.TransportProtocol.TCP +import com.merxury.blocker.core.vpn.protocol.Packet.IP4Header.TransportProtocol.UDP +import java.net.InetAddress +import java.net.UnknownHostException +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong + +/** + * Representation of an IP Packet + */ +internal class Packet { + companion object { + const val IP4_HEADER_SIZE = 20 + const val TCP_HEADER_SIZE = 20 + const val UDP_HEADER_SIZE = 8 + + val globalPackId = AtomicLong() + } + + var ip4Header: IP4Header? = null + var tcpHeader: TCPHeader? = null + var udpHeader: UDPHeader? = null + var backingBuffer: ByteBuffer? = null + + var isTCP = false + + var isUDP = false + + init { + globalPackId.incrementAndGet() + } + + constructor() + + @Throws(UnknownHostException::class) + constructor(buffer: ByteBuffer) : this() { + ip4Header = IP4Header(buffer) + when (ip4Header?.protocol) { + TCP -> { + tcpHeader = TCPHeader(buffer) + isTCP = true + } + + UDP -> { + udpHeader = UDPHeader(buffer) + isUDP = true + } + + else -> {} + } + backingBuffer = buffer + } + + fun release() { + ip4Header = null + tcpHeader = null + udpHeader = null + backingBuffer = null + } + + override fun toString(): String { + return buildString { + append("Packet{") + append("ip4Header=").append(ip4Header) + if (isTCP) { + append(", tcpHeader=").append(tcpHeader) + } else if (isUDP) { + append(", udpHeader=").append(udpHeader) + } + append(", payloadSize=").append( + backingBuffer?.limit()?.minus(backingBuffer?.position() ?: 0), + ) + append('}') + } + } + + fun updateTCPBuffer( + buffer: ByteBuffer, + flags: Byte, + sequenceNum: Long, + ackNum: Long, + payloadSize: Int, + ) { + buffer.position(0) + fillHeader(buffer) + backingBuffer = buffer + + tcpHeader?.apply { + this.flags = flags + backingBuffer?.put(IP4_HEADER_SIZE + 13, flags) + + this.sequenceNumber = sequenceNum + backingBuffer?.putInt(IP4_HEADER_SIZE + 4, sequenceNum.toInt()) + + this.acknowledgementNumber = ackNum + backingBuffer?.putInt(IP4_HEADER_SIZE + 8, ackNum.toInt()) + + // Reset header size, since we don't need options + val dataOffset = (TCP_HEADER_SIZE shl 2).toByte() + this.dataOffsetAndReserved = dataOffset + backingBuffer?.put(IP4_HEADER_SIZE + 12, dataOffset) + + updateTCPChecksum(payloadSize) + + val ip4TotalLength = IP4_HEADER_SIZE + TCP_HEADER_SIZE + payloadSize + backingBuffer?.putShort(2, ip4TotalLength.toShort()) + ip4Header?.totalLength = ip4TotalLength + + updateIP4Checksum() + } + } + + fun updateUDPBuffer(buffer: ByteBuffer, payloadSize: Int) { + buffer.position(0) + fillHeader(buffer) + backingBuffer = buffer + + udpHeader?.apply { + val udpTotalLength = UDP_HEADER_SIZE + payloadSize + backingBuffer?.putShort(IP4_HEADER_SIZE + 4, udpTotalLength.toShort()) + this.length = udpTotalLength + + // Disable UDP checksum validation + backingBuffer?.putShort(IP4_HEADER_SIZE + 6, 0.toShort()) + this.checksum = 0 + + val ip4TotalLength = IP4_HEADER_SIZE + udpTotalLength + backingBuffer?.putShort(2, ip4TotalLength.toShort()) + ip4Header?.totalLength = ip4TotalLength + + updateIP4Checksum() + } + } + + private fun updateIP4Checksum() { + val buffer = backingBuffer?.duplicate() ?: return + buffer.position(0) + + // Clear previous checksum + buffer.putShort(10, 0.toShort()) + + var ipLength = ip4Header?.headerLength ?: return + var sum = 0 + while (ipLength > 0) { + sum += BitUtils.getUnsignedShort(buffer.short) + ipLength -= 2 + } + while (sum shr 16 > 0) { + sum = (sum and 0xFFFF) + (sum shr 16) + } + + sum = sum.inv() + ip4Header?.headerChecksum = sum + backingBuffer?.putShort(10, sum.toShort()) + } + + private fun updateTCPChecksum(payloadSize: Int) { + var sum = 0 + var tcpLength = TCP_HEADER_SIZE + payloadSize + + // Calculate pseudo-header checksum + ip4Header?.sourceAddress?.address?.let { sourceAddress -> + val buffer = ByteBuffer.wrap(sourceAddress) + sum = BitUtils.getUnsignedShort(buffer.short) + BitUtils.getUnsignedShort(buffer.short) + } + + ip4Header?.destinationAddress?.address?.let { destinationAddress -> + val buffer = ByteBuffer.wrap(destinationAddress) + sum += BitUtils.getUnsignedShort(buffer.short) + BitUtils.getUnsignedShort(buffer.short) + } + + sum += TCP.number + tcpLength + + val buffer = backingBuffer?.duplicate() ?: return + // Clear previous checksum + buffer.putShort(IP4_HEADER_SIZE + 16, 0.toShort()) + + // Calculate TCP segment checksum + buffer.position(IP4_HEADER_SIZE) + while (tcpLength > 1) { + sum += BitUtils.getUnsignedShort(buffer.short) + tcpLength -= 2 + } + if (tcpLength > 0) { + sum += BitUtils.getUnsignedByte(buffer.get()).toInt() shl 8 + } + + while (sum shr 16 > 0) { + sum = (sum and 0xFFFF) + (sum shr 16) + } + + sum = sum.inv() + tcpHeader?.checksum = sum + backingBuffer?.putShort(IP4_HEADER_SIZE + 16, sum.toShort()) + } + + private fun fillHeader(buffer: ByteBuffer) { + ip4Header?.fillHeader(buffer) + if (isUDP) { + udpHeader?.fillHeader(buffer) + } else if (isTCP) { + tcpHeader?.fillHeader(buffer) + } + } + + class IP4Header { + var version: Byte = 0 + var ihl: Byte = 0 + var headerLength: Int = 0 + var typeOfService: Short = 0 + var totalLength: Int = 0 + + var identificationAndFlagsAndFragmentOffset: Int = 0 + + var ttl: Short = 0 + var protocolNum: Short = 0 + var protocol: TransportProtocol? = null + var headerChecksum: Int = 0 + + var sourceAddress: InetAddress? = null + var destinationAddress: InetAddress? = null + + var optionsAndPadding: Int = 0 + + enum class TransportProtocol(val number: Int) { + TCP(6), + UDP(17), + OTHER(0xFF), + ; + + companion object { + fun numberToEnum(protocolNumber: Int): TransportProtocol { + return when (protocolNumber) { + 6 -> TCP + 17 -> UDP + else -> OTHER + } + } + } + } + + constructor() + + @Throws(UnknownHostException::class) + constructor(buffer: ByteBuffer) { + val versionAndIHL = buffer.get() + version = (versionAndIHL.toInt() shr 4).toByte() + ihl = (versionAndIHL.toInt() and 0x0F).toByte() + headerLength = ihl.toInt() shl 2 + + typeOfService = BitUtils.getUnsignedByte(buffer.get()) + totalLength = BitUtils.getUnsignedShort(buffer.short) + + identificationAndFlagsAndFragmentOffset = buffer.int + + ttl = BitUtils.getUnsignedByte(buffer.get()) + protocolNum = BitUtils.getUnsignedByte(buffer.get()) + protocol = + com.merxury.blocker.core.vpn.protocol.Packet.IP4Header.TransportProtocol.numberToEnum( + protocolNum.toInt(), + ) + headerChecksum = BitUtils.getUnsignedShort(buffer.short) + + val addressBytes = ByteArray(4) + buffer.get(addressBytes, 0, 4) + sourceAddress = InetAddress.getByAddress(addressBytes) + + buffer.get(addressBytes, 0, 4) + destinationAddress = InetAddress.getByAddress(addressBytes) + } + + fun fillHeader(buffer: ByteBuffer) { + buffer.put((version.toInt() shl 4 or ihl.toInt()).toByte()) + buffer.put(typeOfService.toByte()) + buffer.putShort(totalLength.toShort()) + + buffer.putInt(identificationAndFlagsAndFragmentOffset) + + buffer.put(ttl.toByte()) + buffer.put(protocol?.number?.toByte() ?: 0) + buffer.putShort(headerChecksum.toShort()) + + sourceAddress?.address?.let { buffer.put(it) } + destinationAddress?.address?.let { buffer.put(it) } + } + + override fun toString(): String { + return buildString { + append("IP4Header{") + append("version=").append(version) + append(", IHL=").append(ihl) + append(", typeOfService=").append(typeOfService) + append(", totalLength=").append(totalLength) + append(", identificationAndFlagsAndFragmentOffset=").append( + identificationAndFlagsAndFragmentOffset, + ) + append(", TTL=").append(ttl) + append(", protocol=").append(protocolNum).append(":").append(protocol) + append(", headerChecksum=").append(headerChecksum) + append(", sourceAddress=").append(sourceAddress?.hostAddress) + append(", destinationAddress=").append(destinationAddress?.hostAddress) + append('}') + } + } + } + + class TCPHeader { + companion object { + const val FIN = 0x01 + const val SYN = 0x02 + const val RST = 0x04 + const val PSH = 0x08 + const val ACK = 0x10 + const val URG = 0x20 + } + + var sourcePort: Int = 0 + var destinationPort: Int = 0 + + var sequenceNumber: Long = 0 + var acknowledgementNumber: Long = 0 + + var dataOffsetAndReserved: Byte = 0 + var headerLength: Int = 0 + var flags: Byte = 0 + var window: Int = 0 + + var checksum: Int = 0 + var urgentPointer: Int = 0 + + var optionsAndPadding: ByteArray? = null + + constructor(buffer: ByteBuffer) { + sourcePort = BitUtils.getUnsignedShort(buffer.short) + destinationPort = BitUtils.getUnsignedShort(buffer.short) + + sequenceNumber = BitUtils.getUnsignedInt(buffer.int) + acknowledgementNumber = BitUtils.getUnsignedInt(buffer.int) + + dataOffsetAndReserved = buffer.get() + headerLength = (dataOffsetAndReserved.toInt() and 0xF0) shr 2 + flags = buffer.get() + window = BitUtils.getUnsignedShort(buffer.short) + + checksum = BitUtils.getUnsignedShort(buffer.short) + urgentPointer = BitUtils.getUnsignedShort(buffer.short) + + val optionsLength = headerLength - TCP_HEADER_SIZE + if (optionsLength > 0) { + optionsAndPadding = ByteArray(optionsLength) + optionsAndPadding?.let { + buffer.get(it, 0, optionsLength) + } + } + } + + constructor() + + val isFIN: Boolean + get() = (flags.toInt() and FIN) == FIN + + val isSYN: Boolean + get() = (flags.toInt() and SYN) == SYN + + val isRST: Boolean + get() = (flags.toInt() and RST) == RST + + val isPSH: Boolean + get() = (flags.toInt() and PSH) == PSH + + val isACK: Boolean + get() = (flags.toInt() and ACK) == ACK + + val isURG: Boolean + get() = (flags.toInt() and URG) == URG + + fun fillHeader(buffer: ByteBuffer) { + buffer.putShort(sourcePort.toShort()) + buffer.putShort(destinationPort.toShort()) + + buffer.putInt(sequenceNumber.toInt()) + buffer.putInt(acknowledgementNumber.toInt()) + + buffer.put(dataOffsetAndReserved) + buffer.put(flags) + buffer.putShort(window.toShort()) + + buffer.putShort(checksum.toShort()) + buffer.putShort(urgentPointer.toShort()) + + optionsAndPadding?.let { + buffer.put(it) + } + } + + fun printSimple(): String { + return buildString { + if (isFIN) append("FIN ") + if (isSYN) append("SYN ") + if (isRST) append("RST ") + if (isPSH) append("PSH ") + if (isACK) append("ACK ") + if (isURG) append("URG ") + append("seq $sequenceNumber ") + append("ack $acknowledgementNumber ") + } + } + + override fun toString(): String { + return buildString { + append("TCPHeader{") + append("sourcePort=").append(sourcePort) + append(", destinationPort=").append(destinationPort) + append(", sequenceNumber=").append(sequenceNumber) + append(", acknowledgementNumber=").append(acknowledgementNumber) + append(", headerLength=").append(headerLength) + append(", window=").append(window) + append(", checksum=").append(checksum) + append(", flags=") + if (isFIN) append(" FIN") + if (isSYN) append(" SYN") + if (isRST) append(" RST") + if (isPSH) append(" PSH") + if (isACK) append(" ACK") + if (isURG) append(" URG") + append('}') + } + } + } + + class UDPHeader { + var sourcePort: Int = 0 + var destinationPort: Int = 0 + + var length: Int = 0 + var checksum: Int = 0 + + constructor() + + constructor(buffer: ByteBuffer) { + sourcePort = BitUtils.getUnsignedShort(buffer.short) + destinationPort = BitUtils.getUnsignedShort(buffer.short) + + length = BitUtils.getUnsignedShort(buffer.short) + checksum = BitUtils.getUnsignedShort(buffer.short) + } + + fun fillHeader(buffer: ByteBuffer) { + buffer.putShort(sourcePort.toShort()) + buffer.putShort(destinationPort.toShort()) + + buffer.putShort(length.toShort()) + buffer.putShort(checksum.toShort()) + } + + override fun toString(): String { + return buildString { + append("UDPHeader{") + append("sourcePort=").append(sourcePort) + append(", destinationPort=").append(destinationPort) + append(", length=").append(length) + append(", checksum=").append(checksum) + append('}') + } + } + } + + private object BitUtils { + fun getUnsignedByte(value: Byte): Short { + return (value.toInt() and 0xFF).toShort() + } + + fun getUnsignedShort(value: Short): Int { + return value.toInt() and 0xFFFF + } + + fun getUnsignedInt(value: Int): Long { + return value.toLong() and 0xFFFFFFFFL + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcbStatus.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcbStatus.kt new file mode 100644 index 0000000000..e9e4123fdc --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcbStatus.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.protocol + +internal enum class TcbStatus { + SYN_SENT, + SYN_RECEIVED, + ESTABLISHED, + CLOSE_WAIT, + LAST_ACK, + CLOSED, +} From 42d3b6f9e4d9f0bde0e75be116fac91881b52aab Mon Sep 17 00:00:00 2001 From: lihenggui Date: Mon, 1 Jul 2024 14:21:02 -0700 Subject: [PATCH 5/9] Suppress deprecation Change-Id: I23247e396200748025b8a936e2db272ee7d9529e --- .../kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt index 0d3d976908..2636acd46b 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt @@ -84,7 +84,9 @@ class BlockerVpnService : VpnService() { ToDeviceQueueWorker.stop() vpnInterface?.close() vpnInterface = null + if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.N) { + @Suppress("DEPRECATION") stopForeground(true) } else { stopForeground(STOP_FOREGROUND_REMOVE) From 7e1ae3ea38352245c9fb426c8980187265b224c7 Mon Sep 17 00:00:00 2001 From: lihenggui <350699171@qq.com> Date: Mon, 1 Jul 2024 18:04:40 -0700 Subject: [PATCH 6/9] Rewrite VpnQueue.kt --- .../blocker/core/vpn/BlockerVpnService.kt | 43 +- .../com/merxury/blocker/core/vpn/VpnQueue.kt | 834 +----------------- .../core/vpn/model/ManagedDatagramChannel.kt | 25 + .../blocker/core/vpn/model/UdpTunnel.kt | 30 + .../blocker/core/vpn/protocol/Packet.kt | 150 ++-- .../blocker/core/vpn/protocol/TcpPipe.kt | 458 ++++++++++ .../blocker/core/vpn/worker/TcpWorker.kt | 416 +++++++++ .../core/vpn/worker/ToDeviceQueueWorker.kt | 74 ++ .../core/vpn/worker/ToNetworkQueueWorker.kt | 93 ++ .../core/vpn/worker/UdpReceiveWorker.kt | 120 +++ .../blocker/core/vpn/worker/UdpSendWorker.kt | 130 +++ .../core/vpn/worker/UdpSocketCleanWorker.kt | 74 ++ 12 files changed, 1519 insertions(+), 928 deletions(-) create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/ManagedDatagramChannel.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpTunnel.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt index 2636acd46b..1dee7bd329 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/BlockerVpnService.kt @@ -22,6 +22,12 @@ import android.os.ParcelFileDescriptor import com.merxury.blocker.core.di.ApplicationScope import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.worker.TcpWorker +import com.merxury.blocker.core.vpn.worker.ToDeviceQueueWorker +import com.merxury.blocker.core.vpn.worker.ToNetworkQueueWorker +import com.merxury.blocker.core.vpn.worker.UdpReceiveWorker +import com.merxury.blocker.core.vpn.worker.UdpSendWorker +import com.merxury.blocker.core.vpn.worker.UdpSocketCleanWorker import dagger.hilt.android.AndroidEntryPoint import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope @@ -39,23 +45,36 @@ class BlockerVpnService : VpnService() { lateinit var ioDispatcher: CoroutineDispatcher private var vpnInterface: ParcelFileDescriptor? = null + private lateinit var toNetworkQueueWorker: ToNetworkQueueWorker + private lateinit var toDeviceQueueWorker: ToDeviceQueueWorker + private lateinit var udpSendWorker: UdpSendWorker + private lateinit var udpReceiveWorker: UdpReceiveWorker + private lateinit var udpSocketCleanWorker: UdpSocketCleanWorker + private lateinit var tcpWorker: TcpWorker override fun onCreate() { super.onCreate() - UdpSendWorker.start(this) - UdpReceiveWorker.start(this) - UdpSocketCleanWorker.start() - TcpWorker.start(this) + toNetworkQueueWorker = ToNetworkQueueWorker(ioDispatcher) + toDeviceQueueWorker = ToDeviceQueueWorker(ioDispatcher) + udpSendWorker = UdpSendWorker(ioDispatcher) + udpReceiveWorker = UdpReceiveWorker(ioDispatcher) + udpSocketCleanWorker = UdpSocketCleanWorker(ioDispatcher) + tcpWorker = TcpWorker(ioDispatcher) + + udpSendWorker.start(this) + udpReceiveWorker.start() + udpSocketCleanWorker.start() + tcpWorker.start(this) startVpn() } override fun onDestroy() { super.onDestroy() disconnect() - UdpSendWorker.stop() - UdpReceiveWorker.stop() - UdpSocketCleanWorker.stop() - TcpWorker.stop() + udpSendWorker.stop() + udpReceiveWorker.stop() + udpSocketCleanWorker.stop() + tcpWorker.stop() vpnInterface?.close() vpnInterface = null } @@ -75,13 +94,13 @@ class BlockerVpnService : VpnService() { private fun runVpn(vpnInterface: ParcelFileDescriptor) { val fileDescriptor = vpnInterface.fileDescriptor - ToNetworkQueueWorker.start(fileDescriptor) - ToDeviceQueueWorker.start(fileDescriptor) + toNetworkQueueWorker.start(fileDescriptor) + toDeviceQueueWorker.start(fileDescriptor) } private fun disconnect() { - ToNetworkQueueWorker.stop() - ToDeviceQueueWorker.stop() + toNetworkQueueWorker.stop() + toDeviceQueueWorker.stop() vpnInterface?.close() vpnInterface = null diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt index 7b4e902c34..6bcfb116d4 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt @@ -16,32 +16,12 @@ package com.merxury.blocker.core.vpn -import android.annotation.SuppressLint -import android.net.VpnService -import android.os.Build -import android.util.Base64 -import com.merxury.blocker.core.vpn.protocol.IpUtil +import com.merxury.blocker.core.vpn.model.ManagedDatagramChannel +import com.merxury.blocker.core.vpn.model.UdpTunnel import com.merxury.blocker.core.vpn.protocol.Packet -import com.merxury.blocker.core.vpn.protocol.Packet.TCPHeader -import com.merxury.blocker.core.vpn.protocol.TcbStatus -import timber.log.Timber -import java.io.FileDescriptor -import java.io.FileInputStream -import java.io.FileOutputStream -import java.io.IOException -import java.net.ConnectException -import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.nio.channels.ClosedByInterruptException -import java.nio.channels.DatagramChannel -import java.nio.channels.FileChannel -import java.nio.channels.SelectionKey import java.nio.channels.Selector -import java.nio.channels.SocketChannel import java.util.concurrent.ArrayBlockingQueue -import java.util.concurrent.atomic.AtomicInteger -import kotlin.experimental.and -import kotlin.experimental.or /** * Queue for UDP packets sent from device to network @@ -80,813 +60,3 @@ internal val udpNioSelector: Selector = Selector.open() internal val udpSocketMap = HashMap() const val UDP_SOCKET_IDLE_TIMEOUT = 60 - -/** - * Worker thread to handle packets sent from device to network - */ -object ToNetworkQueueWorker : Runnable { - private const val TAG = "ToNetworkQueueWorker" - - /** - * Self thread - */ - private lateinit var thread: Thread - - /** - * Channel to read data from the device - */ - private lateinit var vpnInput: FileChannel - - /** - * Total bytes read count - */ - var totalInputCount = 0L - - fun start(vpnFileDescriptor: FileDescriptor) { - if (this::thread.isInitialized && thread.isAlive) throw IllegalStateException("Already running") - vpnInput = FileInputStream(vpnFileDescriptor).channel - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - if (this::thread.isInitialized) { - thread.interrupt() - } - } - - override fun run() { - val readBuffer = ByteBuffer.allocate(16384) - while (!thread.isInterrupted) { - var readCount = 0 - try { - readCount = vpnInput.read(readBuffer) - } catch (e: IOException) { - e.printStackTrace() - continue - } - if (readCount > 0) { - readBuffer.flip() - val byteArray = ByteArray(readCount) - readBuffer.get(byteArray) - - val byteBuffer = ByteBuffer.wrap(byteArray) - totalInputCount += readCount - - val packet = Packet(byteBuffer) - if (packet.isUDP) { - deviceToNetworkUDPQueue.offer(packet) - } else if (packet.isTCP) { - deviceToNetworkTCPQueue.offer(packet) - } else { - Timber.d("Unknown packet protocol type ${packet.ip4Header?.protocolNum}") - } - } else if (readCount < 0) { - break - } - readBuffer.clear() - } - Timber.i("ToNetworkQueueWorker finished running") - } -} - -/** - * Worker thread to handle packets sent from network to device - */ -object ToDeviceQueueWorker : Runnable { - private const val TAG = "ToDeviceQueueWorker" - - /** - * Self thread - */ - private lateinit var thread: Thread - - /** - * Total bytes written count - */ - var totalOutputCount = 0L - - /** - * Channel to write data to the device - */ - private lateinit var vpnOutput: FileChannel - - fun start(vpnFileDescriptor: FileDescriptor) { - if (this::thread.isInitialized && thread.isAlive) throw IllegalStateException("Already running") - vpnOutput = FileOutputStream(vpnFileDescriptor).channel - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - if (this::thread.isInitialized) { - thread.interrupt() - } - } - - override fun run() { - try { - while (!thread.isInterrupted) { - val byteBuffer = networkToDeviceQueue.take() - byteBuffer.flip() - while (byteBuffer.hasRemaining()) { - val count = vpnOutput.write(byteBuffer) - if (count > 0) { - totalOutputCount += count - } - } - } - } catch (_: InterruptedException) { - } catch (_: ClosedByInterruptException) { - } - } -} - -/** - * UDP forwarding channel data - */ -data class UdpTunnel( - val id: String, - val local: InetSocketAddress, - val remote: InetSocketAddress, - val channel: DatagramChannel, -) - -data class ManagedDatagramChannel( - val id: String, - val channel: DatagramChannel, - var lastTime: Long = System.currentTimeMillis(), -) - -/** - * Worker thread to send UDP packets - */ -@SuppressLint("StaticFieldLeak") -object UdpSendWorker : Runnable { - private const val TAG = "UdpSendWorker" - - /** - * Self thread - */ - private lateinit var thread: Thread - - private var vpnService: VpnService? = null - - fun start(vpnService: VpnService) { - this.vpnService = vpnService - udpTunnelQueue.clear() - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - if (this::thread.isInitialized) { - thread.interrupt() - } - vpnService = null - } - - override fun run() { - while (!thread.isInterrupted) { - val packet = deviceToNetworkUDPQueue.take() - - val destinationAddress = packet.ip4Header?.destinationAddress - val udpHeader = packet.udpHeader - - val destinationPort = udpHeader?.destinationPort ?: 0 - val sourcePort = udpHeader?.sourcePort - val ipAndPort = ( - destinationAddress?.hostAddress?.plus(":") - ?: "unknownHostAddress" - ) + destinationPort + ":" + sourcePort - - // Create new socket - val managedChannel = if (!udpSocketMap.containsKey(ipAndPort)) { - val channel = DatagramChannel.open() - var channelConnectSuccess = false - channel.apply { - val socket = socket() - vpnService?.protect(socket) - try { - connect(InetSocketAddress(destinationAddress, destinationPort)) - channelConnectSuccess = true - } catch (_: ConnectException) { - } - configureBlocking(false) - } - if (!channelConnectSuccess) { - continue - } - - val tunnel = UdpTunnel( - ipAndPort, - InetSocketAddress(packet.ip4Header?.sourceAddress, udpHeader?.sourcePort ?: 0), - InetSocketAddress( - packet.ip4Header?.destinationAddress, - udpHeader?.destinationPort ?: 0, - ), - channel, - ) - udpTunnelQueue.offer(tunnel) - udpNioSelector.wakeup() - - val managedDatagramChannel = ManagedDatagramChannel(ipAndPort, channel) - synchronized(udpSocketMap) { - udpSocketMap[ipAndPort] = managedDatagramChannel - } - managedDatagramChannel - } else { - synchronized(udpSocketMap) { - udpSocketMap[ipAndPort] - ?: throw IllegalStateException("udp:udpSocketMap[$ipAndPort] should not be null") - } - } - managedChannel.lastTime = System.currentTimeMillis() - val buffer = packet.backingBuffer - kotlin.runCatching { - while (!thread.isInterrupted && buffer?.hasRemaining() == true) { - managedChannel.channel.write(buffer) - } - }.exceptionOrNull()?.let { - Timber.e("Error sending UDP packet", it) - managedChannel.channel.close() - synchronized(udpSocketMap) { - udpSocketMap.remove(ipAndPort) - } - } - } - } -} - -/** - * Worker thread to receive UDP packets - */ -@SuppressLint("StaticFieldLeak") -object UdpReceiveWorker : Runnable { - - private const val TAG = "UdpReceiveWorker" - - /** - * Self thread - */ - private lateinit var thread: Thread - - private var vpnService: VpnService? = null - - private var ipId = AtomicInteger() - - private const val UDP_HEADER_FULL_SIZE = Packet.IP4_HEADER_SIZE + Packet.UDP_HEADER_SIZE - - fun start(vpnService: VpnService) { - this.vpnService = vpnService - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - thread.interrupt() - } - - private fun sendUdpPacket(tunnel: UdpTunnel, source: InetSocketAddress, data: ByteArray) { - val packet = IpUtil.buildUdpPacket(tunnel.remote, tunnel.local, ipId.addAndGet(1)) - - val byteBuffer = ByteBuffer.allocate(UDP_HEADER_FULL_SIZE + data.size) - byteBuffer.apply { - position(UDP_HEADER_FULL_SIZE) - put(data) - } - packet.updateUDPBuffer(byteBuffer, data.size) - byteBuffer.position(UDP_HEADER_FULL_SIZE + data.size) - networkToDeviceQueue.offer(byteBuffer) - } - - override fun run() { - val receiveBuffer = ByteBuffer.allocate(16384) - while (!thread.isInterrupted) { - val readyChannels = udpNioSelector.select() - while (!thread.isInterrupted) { - val tunnel = udpTunnelQueue.poll() ?: break - kotlin.runCatching { - val key = tunnel.channel.register(udpNioSelector, SelectionKey.OP_READ, tunnel) - key.interestOps(SelectionKey.OP_READ) - }.exceptionOrNull()?.printStackTrace() - } - if (readyChannels == 0) { - udpNioSelector.selectedKeys().clear() - continue - } - val keys = udpNioSelector.selectedKeys() - val iterator = keys.iterator() - while (!thread.isInterrupted && iterator.hasNext()) { - val key = iterator.next() - iterator.remove() - if (key.isValid && key.isReadable) { - val tunnel = key.attachment() as UdpTunnel - kotlin.runCatching { - val inputChannel = key.channel() as DatagramChannel - receiveBuffer.clear() - inputChannel.read(receiveBuffer) - receiveBuffer.flip() - val data = ByteArray(receiveBuffer.remaining()) - receiveBuffer.get(data) - sendUdpPacket( - tunnel, - inputChannel.socket().localSocketAddress as InetSocketAddress, - data, - ) // todo api 21->24 - }.exceptionOrNull()?.let { - it.printStackTrace() - synchronized(udpSocketMap) { - udpSocketMap.remove(tunnel.id) - } - } - } - } - } - } -} - -/** - * Worker thread to clean up expired UDP sockets - */ -object UdpSocketCleanWorker : Runnable { - - private const val TAG = "UdpSocketCleanWorker" - - /** - * Self thread - */ - private lateinit var thread: Thread - - /** - * Check interval in seconds - */ - private const val INTERVAL_TIME = 5L - - fun start() { - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - thread.interrupt() - } - - override fun run() { - while (!thread.isInterrupted) { - synchronized(udpSocketMap) { - val iterator = udpSocketMap.iterator() - var removeCount = 0 - while (!thread.isInterrupted && iterator.hasNext()) { - val managedDatagramChannel = iterator.next() - if (System.currentTimeMillis() - managedDatagramChannel.value.lastTime > UDP_SOCKET_IDLE_TIMEOUT * 1000) { - kotlin.runCatching { - managedDatagramChannel.value.channel.close() - }.exceptionOrNull()?.printStackTrace() - iterator.remove() - removeCount++ - } - } - if (removeCount > 0) { - Timber.d("Removed $removeCount expired inactive UDP sockets, currently active ${udpSocketMap.size}") - } - } - Thread.sleep(INTERVAL_TIME * 1000) - } - } -} - -internal class TcpPipe(val tunnelKey: String, packet: Packet) { - var mySequenceNum: Long = 0 - var theirSequenceNum: Long = 0 - var myAcknowledgementNum: Long = 0 - var theirAcknowledgementNum: Long = 0 - val tunnelId = tunnelIds++ - - val sourceAddress: InetSocketAddress = - InetSocketAddress(packet.ip4Header?.sourceAddress, packet.tcpHeader?.sourcePort ?: 0) - val destinationAddress: InetSocketAddress = InetSocketAddress( - packet.ip4Header?.destinationAddress, - packet.tcpHeader?.destinationPort ?: 0, - ) - val remoteSocketChannel: SocketChannel = - SocketChannel.open().also { it.configureBlocking(false) } - val remoteSocketChannelKey: SelectionKey = - remoteSocketChannel.register(tcpNioSelector, SelectionKey.OP_CONNECT) - .also { it.attach(this) } - - var tcbStatus: TcbStatus = TcbStatus.SYN_SENT - var remoteOutBuffer: ByteBuffer? = null - - var upActive = true - var downActive = true - var packId = 1 - var timestamp = System.currentTimeMillis() - var synCount = 0 - - fun tryConnect(vpnService: VpnService): Result { - val result = kotlin.runCatching { - vpnService.protect(remoteSocketChannel.socket()) - remoteSocketChannel.connect(destinationAddress) - } - return result - } - - companion object { - const val TAG = "TcpPipe" - var tunnelIds = 0 - } -} - -/** - * TCP packet worker thread - * NIO - */ -@SuppressLint("StaticFieldLeak") -object TcpWorker : Runnable { - private const val TAG = "TcpSendWorker" - - private const val TCP_HEADER_SIZE = Packet.IP4_HEADER_SIZE + Packet.TCP_HEADER_SIZE - - private lateinit var thread: Thread - - private val pipeMap = HashMap() - - private var vpnService: VpnService? = null - - fun start(vpnService: VpnService) { - this.vpnService = vpnService - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - thread.interrupt() - vpnService = null - } - - override fun run() { - while (!thread.isInterrupted) { - if (vpnService == null) { - throw IllegalStateException("VpnService should not be null") - } - handleReadFromVpn() - handleSockets() - - Thread.sleep(1) - } - } - - private fun handleReadFromVpn() { - while (!thread.isInterrupted) { - val vpnService = this.vpnService ?: return - val packet = deviceToNetworkTCPQueue.poll() ?: return - val destinationAddress = packet.ip4Header?.destinationAddress - val tcpHeader = packet.tcpHeader - val destinationPort = tcpHeader?.destinationPort - val sourcePort = tcpHeader?.sourcePort - - val ipAndPort = ( - destinationAddress?.hostAddress?.plus(":") - ?: "unknown-host-address" - ) + destinationPort + ":" + sourcePort - - val tcpPipe = if (!pipeMap.containsKey(ipAndPort)) { - val pipe = TcpPipe(ipAndPort, packet) - pipe.tryConnect(vpnService) - pipeMap[ipAndPort] = pipe - pipe - } else { - pipeMap[ipAndPort] - ?: throw IllegalStateException("pipeMap should not contain null key: $ipAndPort") - } - handlePacket(packet, tcpPipe) - } - } - - private fun handleSockets() { - while (!thread.isInterrupted && tcpNioSelector.selectNow() > 0) { - val keys = tcpNioSelector.selectedKeys() - val iterator = keys.iterator() - while (!thread.isInterrupted && iterator.hasNext()) { - val key = iterator.next() - iterator.remove() - val tcpPipe: TcpPipe? = key?.attachment() as? TcpPipe - if (key.isValid) { - kotlin.runCatching { - if (key.isAcceptable) { - throw RuntimeException("key.isAcceptable") - } else if (key.isReadable) { - tcpPipe?.doRead() - } else if (key.isConnectable) { - tcpPipe?.doConnect() - } else if (key.isWritable) { - tcpPipe?.doWrite() - } else { - tcpPipe?.closeRst() - } - null - }.exceptionOrNull()?.let { - Timber.d( - "Error communicating with target: ${ - Base64.encodeToString( - tcpPipe?.destinationAddress.toString().toByteArray(), - Base64.DEFAULT, - ) - }", - ) - it.printStackTrace() - tcpPipe?.closeRst() - } - } - } - } - } - - private fun handlePacket(packet: Packet, tcpPipe: TcpPipe) { - val tcpHeader = packet.tcpHeader ?: return - when { - tcpHeader.isSYN -> { - handleSyn(packet, tcpPipe) - } - - tcpHeader.isRST -> { - handleRst(tcpPipe) - } - - tcpHeader.isFIN -> { - handleFin(packet, tcpPipe) - } - - tcpHeader.isACK -> { - handleAck(packet, tcpPipe) - } - } - } - - private fun handleSyn(packet: Packet, tcpPipe: TcpPipe) { - if (tcpPipe.tcbStatus == TcbStatus.SYN_SENT) { - tcpPipe.tcbStatus = TcbStatus.SYN_RECEIVED - } - val tcpHeader = packet.tcpHeader - tcpPipe.apply { - if (synCount == 0) { - mySequenceNum = 1 - theirSequenceNum = tcpHeader?.sequenceNumber ?: 0 - myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 - theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 - sendTcpPack(this, TCPHeader.SYN.toByte() or TCPHeader.ACK.toByte()) - } else { - myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 - } - synCount++ - } - } - - private fun handleRst(tcpPipe: TcpPipe) { - tcpPipe.apply { - upActive = false - downActive = false - clean() - tcbStatus = TcbStatus.CLOSE_WAIT - } - } - - private fun handleFin(packet: Packet, tcpPipe: TcpPipe) { - tcpPipe.myAcknowledgementNum = packet.tcpHeader?.sequenceNumber?.plus(1) ?: 0 - tcpPipe.theirAcknowledgementNum = packet.tcpHeader?.acknowledgementNumber?.plus(1) ?: 0 - sendTcpPack(tcpPipe, TCPHeader.ACK.toByte()) - tcpPipe.closeUpStream() - tcpPipe.tcbStatus = TcbStatus.CLOSE_WAIT - } - - private fun handleAck(packet: Packet, tcpPipe: TcpPipe) { - if (tcpPipe.tcbStatus == TcbStatus.SYN_RECEIVED) { - tcpPipe.tcbStatus = TcbStatus.ESTABLISHED - } - - val tcpHeader = packet.tcpHeader - val payloadSize = packet.backingBuffer?.remaining() ?: 0 - - if (payloadSize == 0) { - return - } - - val newAck = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 - if (newAck <= tcpPipe.myAcknowledgementNum) { - return - } - - tcpPipe.apply { - myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 - theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 - remoteOutBuffer = packet.backingBuffer - tryFlushWrite(this) - sendTcpPack(this, TCPHeader.ACK.toByte()) - } - } - - /** - * Send TCP packet - */ - private fun sendTcpPack(tcpPipe: TcpPipe, flag: Byte, data: ByteArray? = null) { - val dataSize = data?.size ?: 0 - - val packet = IpUtil.buildTcpPacket( - tcpPipe.destinationAddress, - tcpPipe.sourceAddress, - flag, - tcpPipe.myAcknowledgementNum, - tcpPipe.mySequenceNum, - tcpPipe.packId, - ) - tcpPipe.packId++ - - val byteBuffer = ByteBuffer.allocate(TCP_HEADER_SIZE + dataSize) - byteBuffer.position(TCP_HEADER_SIZE) - - data?.let { - byteBuffer.put(it) - } - - packet.updateTCPBuffer( - byteBuffer, - flag, - tcpPipe.mySequenceNum, - tcpPipe.myAcknowledgementNum, - dataSize, - ) - packet.release() - - byteBuffer.position(TCP_HEADER_SIZE + dataSize) - - networkToDeviceQueue.offer(byteBuffer) - - if ((flag and TCPHeader.SYN.toByte()) != 0.toByte()) { - tcpPipe.mySequenceNum++ - } - if ((flag and TCPHeader.FIN.toByte()) != 0.toByte()) { - tcpPipe.mySequenceNum++ - } - if ((flag and TCPHeader.ACK.toByte()) != 0.toByte()) { - tcpPipe.mySequenceNum += dataSize - } - } - - /** - * Write data to the remote - */ - private fun tryFlushWrite(tcpPipe: TcpPipe): Boolean { - val channel: SocketChannel = tcpPipe.remoteSocketChannel - val buffer = tcpPipe.remoteOutBuffer - - if (tcpPipe.remoteSocketChannel.socket().isOutputShutdown && buffer?.remaining() != 0) { - sendTcpPack(tcpPipe, TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte()) - buffer?.compact() - return false - } - - if (!channel.isConnected) { - val key = tcpPipe.remoteSocketChannelKey - val ops = key.interestOps() or SelectionKey.OP_WRITE - key.interestOps(ops) - buffer?.compact() - return false - } - - while (!thread.isInterrupted && buffer?.hasRemaining() == true) { - val n = kotlin.runCatching { - channel.write(buffer) - } - if (n.isFailure) return false - if (n.getOrThrow() <= 0) { - val key = tcpPipe.remoteSocketChannelKey - val ops = key.interestOps() or SelectionKey.OP_WRITE - key.interestOps(ops) - buffer.compact() - return false - } - } - buffer?.clear() - if (!tcpPipe.upActive) { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - tcpPipe.remoteSocketChannel.shutdownOutput() - } else { - // todo The following line will cause the socket to be incorrectly handled, but what if we don't handle it here? - // tcpPipe.remoteSocketChannel.close() - } - } - return true - } - - private fun TcpPipe.closeRst() { - Timber.d("closeRst $tunnelId") - clean() - sendTcpPack(this, TCPHeader.RST.toByte()) - upActive = false - downActive = false - } - - private fun TcpPipe.doRead() { - val buffer = ByteBuffer.allocate(4096) - var isQuitType = false - - while (!thread.isInterrupted) { - buffer.clear() - val length = remoteSocketChannel.read(buffer) - if (length == -1) { - isQuitType = true - break - } else if (length == 0) { - break - } else { - if (tcbStatus != TcbStatus.CLOSE_WAIT) { - buffer.flip() - val dataByteArray = ByteArray(buffer.remaining()) - buffer.get(dataByteArray) - sendTcpPack(this, TCPHeader.ACK.toByte(), dataByteArray) - } - } - } - - if (isQuitType) { - closeDownStream() - } - } - - private fun TcpPipe.doConnect() { - remoteSocketChannel.finishConnect() - timestamp = System.currentTimeMillis() - remoteOutBuffer?.flip() - remoteSocketChannelKey.interestOps(SelectionKey.OP_READ or SelectionKey.OP_WRITE) - } - - private fun TcpPipe.doWrite() { - if (tryFlushWrite(this)) { - remoteSocketChannelKey.interestOps(SelectionKey.OP_READ) - } - } - - private fun TcpPipe.clean() { - kotlin.runCatching { - if (remoteSocketChannel.isOpen) { - remoteSocketChannel.close() - } - remoteOutBuffer = null - pipeMap.remove(tunnelKey) - }.exceptionOrNull()?.printStackTrace() - } - - private fun TcpPipe.closeUpStream() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - kotlin.runCatching { - if (remoteSocketChannel.isOpen && remoteSocketChannel.isConnected) { - remoteSocketChannel.shutdownOutput() - } - }.exceptionOrNull()?.printStackTrace() - upActive = false - - if (!downActive) { - clean() - } - } else { - upActive = false - downActive = false - clean() - } - } - - private fun TcpPipe.closeDownStream() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - kotlin.runCatching { - if (remoteSocketChannel.isConnected) { - remoteSocketChannel.shutdownInput() - val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() - remoteSocketChannelKey.interestOps(ops) - } - sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) - downActive = false - if (!upActive) { - clean() - } - } - } else { - sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) - upActive = false - downActive = false - clean() - } - } -} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/ManagedDatagramChannel.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/ManagedDatagramChannel.kt new file mode 100644 index 0000000000..63f60fdb21 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/ManagedDatagramChannel.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.model + +import java.nio.channels.DatagramChannel + +data class ManagedDatagramChannel( + val id: String, + val channel: DatagramChannel, + var lastTime: Long = System.currentTimeMillis(), +) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpTunnel.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpTunnel.kt new file mode 100644 index 0000000000..fd636abb26 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpTunnel.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.model + +import java.net.InetSocketAddress +import java.nio.channels.DatagramChannel + +/** + * UDP forwarding channel data + */ +data class UdpTunnel( + val id: String, + val local: InetSocketAddress, + val remote: InetSocketAddress, + val channel: DatagramChannel, +) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt index 4b7b982135..bbc7224b74 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt @@ -76,20 +76,18 @@ internal class Packet { backingBuffer = null } - override fun toString(): String { - return buildString { - append("Packet{") - append("ip4Header=").append(ip4Header) - if (isTCP) { - append(", tcpHeader=").append(tcpHeader) - } else if (isUDP) { - append(", udpHeader=").append(udpHeader) - } - append(", payloadSize=").append( - backingBuffer?.limit()?.minus(backingBuffer?.position() ?: 0), - ) - append('}') + override fun toString(): String = buildString { + append("Packet{") + append("ip4Header=").append(ip4Header) + if (isTCP) { + append(", tcpHeader=").append(tcpHeader) + } else if (isUDP) { + append(", udpHeader=").append(udpHeader) } + append(", payloadSize=").append( + backingBuffer?.limit()?.minus(backingBuffer?.position() ?: 0), + ) + append('}') } fun updateTCPBuffer( @@ -247,12 +245,10 @@ internal class Packet { ; companion object { - fun numberToEnum(protocolNumber: Int): TransportProtocol { - return when (protocolNumber) { - 6 -> TCP - 17 -> UDP - else -> OTHER - } + fun numberToEnum(protocolNumber: Int): TransportProtocol = when (protocolNumber) { + 6 -> TCP + 17 -> UDP + else -> OTHER } } } @@ -302,23 +298,21 @@ internal class Packet { destinationAddress?.address?.let { buffer.put(it) } } - override fun toString(): String { - return buildString { - append("IP4Header{") - append("version=").append(version) - append(", IHL=").append(ihl) - append(", typeOfService=").append(typeOfService) - append(", totalLength=").append(totalLength) - append(", identificationAndFlagsAndFragmentOffset=").append( - identificationAndFlagsAndFragmentOffset, - ) - append(", TTL=").append(ttl) - append(", protocol=").append(protocolNum).append(":").append(protocol) - append(", headerChecksum=").append(headerChecksum) - append(", sourceAddress=").append(sourceAddress?.hostAddress) - append(", destinationAddress=").append(destinationAddress?.hostAddress) - append('}') - } + override fun toString(): String = buildString { + append("IP4Header{") + append("version=").append(version) + append(", IHL=").append(ihl) + append(", typeOfService=").append(typeOfService) + append(", totalLength=").append(totalLength) + append(", identificationAndFlagsAndFragmentOffset=").append( + identificationAndFlagsAndFragmentOffset, + ) + append(", TTL=").append(ttl) + append(", protocol=").append(protocolNum).append(":").append(protocol) + append(", headerChecksum=").append(headerChecksum) + append(", sourceAddress=").append(sourceAddress?.hostAddress) + append(", destinationAddress=").append(destinationAddress?.hostAddress) + append('}') } } @@ -411,38 +405,34 @@ internal class Packet { } } - fun printSimple(): String { - return buildString { - if (isFIN) append("FIN ") - if (isSYN) append("SYN ") - if (isRST) append("RST ") - if (isPSH) append("PSH ") - if (isACK) append("ACK ") - if (isURG) append("URG ") - append("seq $sequenceNumber ") - append("ack $acknowledgementNumber ") - } + fun printSimple(): String = buildString { + if (isFIN) append("FIN ") + if (isSYN) append("SYN ") + if (isRST) append("RST ") + if (isPSH) append("PSH ") + if (isACK) append("ACK ") + if (isURG) append("URG ") + append("seq $sequenceNumber ") + append("ack $acknowledgementNumber ") } - override fun toString(): String { - return buildString { - append("TCPHeader{") - append("sourcePort=").append(sourcePort) - append(", destinationPort=").append(destinationPort) - append(", sequenceNumber=").append(sequenceNumber) - append(", acknowledgementNumber=").append(acknowledgementNumber) - append(", headerLength=").append(headerLength) - append(", window=").append(window) - append(", checksum=").append(checksum) - append(", flags=") - if (isFIN) append(" FIN") - if (isSYN) append(" SYN") - if (isRST) append(" RST") - if (isPSH) append(" PSH") - if (isACK) append(" ACK") - if (isURG) append(" URG") - append('}') - } + override fun toString(): String = buildString { + append("TCPHeader{") + append("sourcePort=").append(sourcePort) + append(", destinationPort=").append(destinationPort) + append(", sequenceNumber=").append(sequenceNumber) + append(", acknowledgementNumber=").append(acknowledgementNumber) + append(", headerLength=").append(headerLength) + append(", window=").append(window) + append(", checksum=").append(checksum) + append(", flags=") + if (isFIN) append(" FIN") + if (isSYN) append(" SYN") + if (isRST) append(" RST") + if (isPSH) append(" PSH") + if (isACK) append(" ACK") + if (isURG) append(" URG") + append('}') } } @@ -471,29 +461,21 @@ internal class Packet { buffer.putShort(checksum.toShort()) } - override fun toString(): String { - return buildString { - append("UDPHeader{") - append("sourcePort=").append(sourcePort) - append(", destinationPort=").append(destinationPort) - append(", length=").append(length) - append(", checksum=").append(checksum) - append('}') - } + override fun toString(): String = buildString { + append("UDPHeader{") + append("sourcePort=").append(sourcePort) + append(", destinationPort=").append(destinationPort) + append(", length=").append(length) + append(", checksum=").append(checksum) + append('}') } } private object BitUtils { - fun getUnsignedByte(value: Byte): Short { - return (value.toInt() and 0xFF).toShort() - } + fun getUnsignedByte(value: Byte): Short = (value.toInt() and 0xFF).toShort() - fun getUnsignedShort(value: Short): Int { - return value.toInt() and 0xFFFF - } + fun getUnsignedShort(value: Short): Int = value.toInt() and 0xFFFF - fun getUnsignedInt(value: Int): Long { - return value.toLong() and 0xFFFFFFFFL - } + fun getUnsignedInt(value: Int): Long = value.toLong() and 0xFFFFFFFFL } } diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt new file mode 100644 index 0000000000..05071289fa --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt @@ -0,0 +1,458 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.protocol + +import android.annotation.SuppressLint +import android.net.VpnService +import android.os.Build +import android.util.Base64 +import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue +import com.merxury.blocker.core.vpn.networkToDeviceQueue +import com.merxury.blocker.core.vpn.protocol.Packet.TCPHeader +import com.merxury.blocker.core.vpn.tcpNioSelector +import timber.log.Timber +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey +import java.nio.channels.SocketChannel +import kotlin.experimental.and +import kotlin.experimental.or + +internal class TcpPipe(val tunnelKey: String, packet: Packet) { + var mySequenceNum: Long = 0 + var theirSequenceNum: Long = 0 + var myAcknowledgementNum: Long = 0 + var theirAcknowledgementNum: Long = 0 + val tunnelId = tunnelIds++ + + val sourceAddress: InetSocketAddress = + InetSocketAddress(packet.ip4Header?.sourceAddress, packet.tcpHeader?.sourcePort ?: 0) + val destinationAddress: InetSocketAddress = InetSocketAddress( + packet.ip4Header?.destinationAddress, + packet.tcpHeader?.destinationPort ?: 0, + ) + val remoteSocketChannel: SocketChannel = + SocketChannel.open().also { it.configureBlocking(false) } + val remoteSocketChannelKey: SelectionKey = + remoteSocketChannel.register(tcpNioSelector, SelectionKey.OP_CONNECT) + .also { it.attach(this) } + + var tcbStatus: TcbStatus = TcbStatus.SYN_SENT + var remoteOutBuffer: ByteBuffer? = null + + var upActive = true + var downActive = true + var packId = 1 + var timestamp = System.currentTimeMillis() + var synCount = 0 + + fun tryConnect(vpnService: VpnService): Result { + val result = kotlin.runCatching { + vpnService.protect(remoteSocketChannel.socket()) + remoteSocketChannel.connect(destinationAddress) + } + return result + } + + companion object { + const val TAG = "TcpPipe" + var tunnelIds = 0 + } +} + +/** + * TCP packet worker thread + * NIO + */ +@SuppressLint("StaticFieldLeak") +object TcpWorker : Runnable { + private const val TAG = "TcpSendWorker" + + private const val TCP_HEADER_SIZE = Packet.IP4_HEADER_SIZE + Packet.TCP_HEADER_SIZE + + private lateinit var thread: Thread + + private val pipeMap = HashMap() + + private var vpnService: VpnService? = null + + fun start(vpnService: VpnService) { + this.vpnService = vpnService + thread = Thread(this).apply { + name = TAG + start() + } + } + + fun stop() { + thread.interrupt() + vpnService = null + } + + override fun run() { + while (!thread.isInterrupted) { + if (vpnService == null) { + throw IllegalStateException("VpnService should not be null") + } + handleReadFromVpn() + handleSockets() + + Thread.sleep(1) + } + } + + private fun handleReadFromVpn() { + while (!thread.isInterrupted) { + val vpnService = this.vpnService ?: return + val packet = deviceToNetworkTCPQueue.poll() ?: return + val destinationAddress = packet.ip4Header?.destinationAddress + val tcpHeader = packet.tcpHeader + val destinationPort = tcpHeader?.destinationPort + val sourcePort = tcpHeader?.sourcePort + + val ipAndPort = ( + destinationAddress?.hostAddress?.plus(":") + ?: "unknown-host-address" + ) + destinationPort + ":" + sourcePort + + val tcpPipe = if (!pipeMap.containsKey(ipAndPort)) { + val pipe = TcpPipe(ipAndPort, packet) + pipe.tryConnect(vpnService) + pipeMap[ipAndPort] = pipe + pipe + } else { + pipeMap[ipAndPort] + ?: throw IllegalStateException("pipeMap should not contain null key: $ipAndPort") + } + handlePacket(packet, tcpPipe) + } + } + + private fun handleSockets() { + while (!thread.isInterrupted && tcpNioSelector.selectNow() > 0) { + val keys = tcpNioSelector.selectedKeys() + val iterator = keys.iterator() + while (!thread.isInterrupted && iterator.hasNext()) { + val key = iterator.next() + iterator.remove() + val tcpPipe: TcpPipe? = key?.attachment() as? TcpPipe + if (key.isValid) { + kotlin.runCatching { + if (key.isAcceptable) { + throw RuntimeException("key.isAcceptable") + } else if (key.isReadable) { + tcpPipe?.doRead() + } else if (key.isConnectable) { + tcpPipe?.doConnect() + } else if (key.isWritable) { + tcpPipe?.doWrite() + } else { + tcpPipe?.closeRst() + } + null + }.exceptionOrNull()?.let { + Timber.d( + "Error communicating with target: ${ + Base64.encodeToString( + tcpPipe?.destinationAddress.toString().toByteArray(), + Base64.DEFAULT, + ) + }", + ) + it.printStackTrace() + tcpPipe?.closeRst() + } + } + } + } + } + + private fun handlePacket(packet: Packet, tcpPipe: TcpPipe) { + val tcpHeader = packet.tcpHeader ?: return + when { + tcpHeader.isSYN -> { + handleSyn(packet, tcpPipe) + } + + tcpHeader.isRST -> { + handleRst(tcpPipe) + } + + tcpHeader.isFIN -> { + handleFin(packet, tcpPipe) + } + + tcpHeader.isACK -> { + handleAck(packet, tcpPipe) + } + } + } + + private fun handleSyn(packet: Packet, tcpPipe: TcpPipe) { + if (tcpPipe.tcbStatus == TcbStatus.SYN_SENT) { + tcpPipe.tcbStatus = TcbStatus.SYN_RECEIVED + } + val tcpHeader = packet.tcpHeader + tcpPipe.apply { + if (synCount == 0) { + mySequenceNum = 1 + theirSequenceNum = tcpHeader?.sequenceNumber ?: 0 + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 + theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 + sendTcpPack(this, TCPHeader.SYN.toByte() or TCPHeader.ACK.toByte()) + } else { + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 + } + synCount++ + } + } + + private fun handleRst(tcpPipe: TcpPipe) { + tcpPipe.apply { + upActive = false + downActive = false + clean() + tcbStatus = TcbStatus.CLOSE_WAIT + } + } + + private fun handleFin(packet: Packet, tcpPipe: TcpPipe) { + tcpPipe.myAcknowledgementNum = packet.tcpHeader?.sequenceNumber?.plus(1) ?: 0 + tcpPipe.theirAcknowledgementNum = packet.tcpHeader?.acknowledgementNumber?.plus(1) ?: 0 + sendTcpPack(tcpPipe, TCPHeader.ACK.toByte()) + tcpPipe.closeUpStream() + tcpPipe.tcbStatus = TcbStatus.CLOSE_WAIT + } + + private fun handleAck(packet: Packet, tcpPipe: TcpPipe) { + if (tcpPipe.tcbStatus == TcbStatus.SYN_RECEIVED) { + tcpPipe.tcbStatus = TcbStatus.ESTABLISHED + } + + val tcpHeader = packet.tcpHeader + val payloadSize = packet.backingBuffer?.remaining() ?: 0 + + if (payloadSize == 0) { + return + } + + val newAck = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 + if (newAck <= tcpPipe.myAcknowledgementNum) { + return + } + + tcpPipe.apply { + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 + theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 + remoteOutBuffer = packet.backingBuffer + tryFlushWrite(this) + sendTcpPack(this, TCPHeader.ACK.toByte()) + } + } + + /** + * Send TCP packet + */ + private fun sendTcpPack(tcpPipe: TcpPipe, flag: Byte, data: ByteArray? = null) { + val dataSize = data?.size ?: 0 + + val packet = IpUtil.buildTcpPacket( + tcpPipe.destinationAddress, + tcpPipe.sourceAddress, + flag, + tcpPipe.myAcknowledgementNum, + tcpPipe.mySequenceNum, + tcpPipe.packId, + ) + tcpPipe.packId++ + + val byteBuffer = ByteBuffer.allocate(TCP_HEADER_SIZE + dataSize) + byteBuffer.position(TCP_HEADER_SIZE) + + data?.let { + byteBuffer.put(it) + } + + packet.updateTCPBuffer( + byteBuffer, + flag, + tcpPipe.mySequenceNum, + tcpPipe.myAcknowledgementNum, + dataSize, + ) + packet.release() + + byteBuffer.position(TCP_HEADER_SIZE + dataSize) + + networkToDeviceQueue.offer(byteBuffer) + + if ((flag and TCPHeader.SYN.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum++ + } + if ((flag and TCPHeader.FIN.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum++ + } + if ((flag and TCPHeader.ACK.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum += dataSize + } + } + + /** + * Write data to the remote + */ + private fun tryFlushWrite(tcpPipe: TcpPipe): Boolean { + val channel: SocketChannel = tcpPipe.remoteSocketChannel + val buffer = tcpPipe.remoteOutBuffer + + if (tcpPipe.remoteSocketChannel.socket().isOutputShutdown && buffer?.remaining() != 0) { + sendTcpPack(tcpPipe, TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte()) + buffer?.compact() + return false + } + + if (!channel.isConnected) { + val key = tcpPipe.remoteSocketChannelKey + val ops = key.interestOps() or SelectionKey.OP_WRITE + key.interestOps(ops) + buffer?.compact() + return false + } + + while (!thread.isInterrupted && buffer?.hasRemaining() == true) { + val n = kotlin.runCatching { + channel.write(buffer) + } + if (n.isFailure) return false + if (n.getOrThrow() <= 0) { + val key = tcpPipe.remoteSocketChannelKey + val ops = key.interestOps() or SelectionKey.OP_WRITE + key.interestOps(ops) + buffer.compact() + return false + } + } + buffer?.clear() + if (!tcpPipe.upActive) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + tcpPipe.remoteSocketChannel.shutdownOutput() + } else { + // todo The following line will cause the socket to be incorrectly handled, but what if we don't handle it here? + // tcpPipe.remoteSocketChannel.close() + } + } + return true + } + + private fun TcpPipe.closeRst() { + Timber.d("closeRst $tunnelId") + clean() + sendTcpPack(this, TCPHeader.RST.toByte()) + upActive = false + downActive = false + } + + private fun TcpPipe.doRead() { + val buffer = ByteBuffer.allocate(4096) + var isQuitType = false + + while (!thread.isInterrupted) { + buffer.clear() + val length = remoteSocketChannel.read(buffer) + if (length == -1) { + isQuitType = true + break + } else if (length == 0) { + break + } else { + if (tcbStatus != TcbStatus.CLOSE_WAIT) { + buffer.flip() + val dataByteArray = ByteArray(buffer.remaining()) + buffer.get(dataByteArray) + sendTcpPack(this, TCPHeader.ACK.toByte(), dataByteArray) + } + } + } + + if (isQuitType) { + closeDownStream() + } + } + + private fun TcpPipe.doConnect() { + remoteSocketChannel.finishConnect() + timestamp = System.currentTimeMillis() + remoteOutBuffer?.flip() + remoteSocketChannelKey.interestOps(SelectionKey.OP_READ or SelectionKey.OP_WRITE) + } + + private fun TcpPipe.doWrite() { + if (tryFlushWrite(this)) { + remoteSocketChannelKey.interestOps(SelectionKey.OP_READ) + } + } + + private fun TcpPipe.clean() { + kotlin.runCatching { + if (remoteSocketChannel.isOpen) { + remoteSocketChannel.close() + } + remoteOutBuffer = null + pipeMap.remove(tunnelKey) + }.exceptionOrNull()?.printStackTrace() + } + + private fun TcpPipe.closeUpStream() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + kotlin.runCatching { + if (remoteSocketChannel.isOpen && remoteSocketChannel.isConnected) { + remoteSocketChannel.shutdownOutput() + } + }.exceptionOrNull()?.printStackTrace() + upActive = false + + if (!downActive) { + clean() + } + } else { + upActive = false + downActive = false + clean() + } + } + + private fun TcpPipe.closeDownStream() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + kotlin.runCatching { + if (remoteSocketChannel.isConnected) { + remoteSocketChannel.shutdownInput() + val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() + remoteSocketChannelKey.interestOps(ops) + } + sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + downActive = false + if (!upActive) { + clean() + } + } + } else { + sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + upActive = false + downActive = false + clean() + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt new file mode 100644 index 0000000000..a5ffd5bd7a --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt @@ -0,0 +1,416 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.worker + +import android.net.VpnService +import android.os.Build +import android.util.Base64 +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue +import com.merxury.blocker.core.vpn.networkToDeviceQueue +import com.merxury.blocker.core.vpn.protocol.IpUtil +import com.merxury.blocker.core.vpn.protocol.Packet +import com.merxury.blocker.core.vpn.protocol.Packet.TCPHeader +import com.merxury.blocker.core.vpn.protocol.TcbStatus +import com.merxury.blocker.core.vpn.protocol.TcpPipe +import com.merxury.blocker.core.vpn.tcpNioSelector +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import timber.log.Timber +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey +import java.nio.channels.SocketChannel +import javax.inject.Inject +import kotlin.experimental.and +import kotlin.experimental.or + +private const val TCP_HEADER_SIZE = Packet.IP4_HEADER_SIZE + Packet.TCP_HEADER_SIZE +class TcpWorker @Inject constructor( + @Dispatcher(IO) private val dispatcher: CoroutineDispatcher, +) { + + private val pipeMap = HashMap() + private var vpnService: VpnService? = null + + private val scope = CoroutineScope(dispatcher + SupervisorJob()) + + fun start(vpnService: VpnService) { + this.vpnService = vpnService + scope.launch { + runWorker() + } + } + + fun stop() { + scope.cancel() + vpnService = null + } + + private suspend fun runWorker() = withContext(dispatcher) { + while (scope.isActive) { + if (vpnService == null) { + throw IllegalStateException("VpnService should not be null") + } + handleReadFromVpn() + handleSockets() + + delay(1) + } + } + + private suspend fun handleReadFromVpn() = withContext(dispatcher) { + while (isActive) { + val vpnService = this@TcpWorker.vpnService ?: return@withContext + val packet = deviceToNetworkTCPQueue.poll() ?: return@withContext + val destinationAddress = packet.ip4Header?.destinationAddress + val tcpHeader = packet.tcpHeader + val destinationPort = tcpHeader?.destinationPort + val sourcePort = tcpHeader?.sourcePort + + val ipAndPort = ( + destinationAddress?.hostAddress?.plus(":") + ?: "unknown-host-address" + ) + destinationPort + ":" + sourcePort + + val tcpPipe = if (!pipeMap.containsKey(ipAndPort)) { + val pipe = TcpPipe(ipAndPort, packet) + pipe.tryConnect(vpnService) + pipeMap[ipAndPort] = pipe + pipe + } else { + pipeMap[ipAndPort] + ?: throw IllegalStateException("pipeMap should not contain null key: $ipAndPort") + } + handlePacket(packet, tcpPipe) + } + } + + private suspend fun handleSockets() = withContext(dispatcher) { + while (isActive && tcpNioSelector.selectNow() > 0) { + val keys = tcpNioSelector.selectedKeys() + val iterator = keys.iterator() + while (isActive && iterator.hasNext()) { + val key = iterator.next() + iterator.remove() + val tcpPipe: TcpPipe? = key?.attachment() as? TcpPipe + if (key.isValid) { + kotlin.runCatching { + if (key.isAcceptable) { + throw RuntimeException("key.isAcceptable") + } else if (key.isReadable) { + tcpPipe?.doRead() + } else if (key.isConnectable) { + tcpPipe?.doConnect() + } else if (key.isWritable) { + tcpPipe?.doWrite() + } else { + tcpPipe?.closeRst() + } + null + }.exceptionOrNull()?.let { + Timber.d( + "Error communicating with target: ${ + Base64.encodeToString( + tcpPipe?.destinationAddress.toString().toByteArray(), + Base64.DEFAULT, + ) + }", + ) + it.printStackTrace() + tcpPipe?.closeRst() + } + } + } + } + } + + private fun handlePacket(packet: Packet, tcpPipe: TcpPipe) { + val tcpHeader = packet.tcpHeader ?: return + when { + tcpHeader.isSYN -> { + handleSyn(packet, tcpPipe) + } + + tcpHeader.isRST -> { + handleRst(tcpPipe) + } + + tcpHeader.isFIN -> { + handleFin(packet, tcpPipe) + } + + tcpHeader.isACK -> { + handleAck(packet, tcpPipe) + } + } + } + + private fun handleSyn(packet: Packet, tcpPipe: TcpPipe) { + if (tcpPipe.tcbStatus == TcbStatus.SYN_SENT) { + tcpPipe.tcbStatus = TcbStatus.SYN_RECEIVED + } + val tcpHeader = packet.tcpHeader + tcpPipe.apply { + if (synCount == 0) { + mySequenceNum = 1 + theirSequenceNum = tcpHeader?.sequenceNumber ?: 0 + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 + theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 + sendTcpPack(this, TCPHeader.SYN.toByte() or TCPHeader.ACK.toByte()) + } else { + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 + } + synCount++ + } + } + + private fun handleRst(tcpPipe: TcpPipe) { + tcpPipe.apply { + upActive = false + downActive = false + clean() + tcbStatus = TcbStatus.CLOSE_WAIT + } + } + + private fun handleFin(packet: Packet, tcpPipe: TcpPipe) { + tcpPipe.myAcknowledgementNum = packet.tcpHeader?.sequenceNumber?.plus(1) ?: 0 + tcpPipe.theirAcknowledgementNum = packet.tcpHeader?.acknowledgementNumber?.plus(1) ?: 0 + sendTcpPack(tcpPipe, TCPHeader.ACK.toByte()) + tcpPipe.closeUpStream() + tcpPipe.tcbStatus = TcbStatus.CLOSE_WAIT + } + + private fun handleAck(packet: Packet, tcpPipe: TcpPipe) { + if (tcpPipe.tcbStatus == TcbStatus.SYN_RECEIVED) { + tcpPipe.tcbStatus = TcbStatus.ESTABLISHED + } + + val tcpHeader = packet.tcpHeader + val payloadSize = packet.backingBuffer?.remaining() ?: 0 + + if (payloadSize == 0) { + return + } + + val newAck = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 + if (newAck <= tcpPipe.myAcknowledgementNum) { + return + } + + tcpPipe.apply { + myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 + theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 + remoteOutBuffer = packet.backingBuffer + tryFlushWrite(this) + sendTcpPack(this, TCPHeader.ACK.toByte()) + } + } + + private fun sendTcpPack(tcpPipe: TcpPipe, flag: Byte, data: ByteArray? = null) { + val dataSize = data?.size ?: 0 + + val packet = IpUtil.buildTcpPacket( + tcpPipe.destinationAddress, + tcpPipe.sourceAddress, + flag, + tcpPipe.myAcknowledgementNum, + tcpPipe.mySequenceNum, + tcpPipe.packId, + ) + tcpPipe.packId++ + + val byteBuffer = ByteBuffer.allocate(TCP_HEADER_SIZE + dataSize) + byteBuffer.position(TCP_HEADER_SIZE) + + data?.let { + byteBuffer.put(it) + } + + packet.updateTCPBuffer( + byteBuffer, + flag, + tcpPipe.mySequenceNum, + tcpPipe.myAcknowledgementNum, + dataSize, + ) + packet.release() + + byteBuffer.position(TCP_HEADER_SIZE + dataSize) + + networkToDeviceQueue.offer(byteBuffer) + + if ((flag and TCPHeader.SYN.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum++ + } + if ((flag and TCPHeader.FIN.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum++ + } + if ((flag and TCPHeader.ACK.toByte()) != 0.toByte()) { + tcpPipe.mySequenceNum += dataSize + } + } + + private fun tryFlushWrite(tcpPipe: TcpPipe): Boolean { + val channel: SocketChannel = tcpPipe.remoteSocketChannel + val buffer = tcpPipe.remoteOutBuffer + + if (tcpPipe.remoteSocketChannel.socket().isOutputShutdown && buffer?.remaining() != 0) { + sendTcpPack(tcpPipe, TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte()) + buffer?.compact() + return false + } + + if (!channel.isConnected) { + val key = tcpPipe.remoteSocketChannelKey + val ops = key.interestOps() or SelectionKey.OP_WRITE + key.interestOps(ops) + buffer?.compact() + return false + } + + while (scope.isActive && buffer?.hasRemaining() == true) { + val n = kotlin.runCatching { + channel.write(buffer) + } + if (n.isFailure) return false + if (n.getOrThrow() <= 0) { + val key = tcpPipe.remoteSocketChannelKey + val ops = key.interestOps() or SelectionKey.OP_WRITE + key.interestOps(ops) + buffer.compact() + return false + } + } + buffer?.clear() + if (!tcpPipe.upActive) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + tcpPipe.remoteSocketChannel.shutdownOutput() + } else { + // todo The following line will cause the socket to be incorrectly handled, but what if we don't handle it here? + // tcpPipe.remoteSocketChannel.close() + } + } + return true + } + + private fun TcpPipe.closeRst() { + Timber.d("closeRst $tunnelId") + clean() + sendTcpPack(this, TCPHeader.RST.toByte()) + upActive = false + downActive = false + } + + private fun TcpPipe.doRead() { + val buffer = ByteBuffer.allocate(4096) + var isQuitType = false + + while (scope.isActive) { + buffer.clear() + val length = remoteSocketChannel.read(buffer) + if (length == -1) { + isQuitType = true + break + } else if (length == 0) { + break + } else { + if (tcbStatus != TcbStatus.CLOSE_WAIT) { + buffer.flip() + val dataByteArray = ByteArray(buffer.remaining()) + buffer.get(dataByteArray) + sendTcpPack(this, TCPHeader.ACK.toByte(), dataByteArray) + } + } + } + + if (isQuitType) { + closeDownStream() + } + } + + private fun TcpPipe.doConnect() { + remoteSocketChannel.finishConnect() + timestamp = System.currentTimeMillis() + remoteOutBuffer?.flip() + remoteSocketChannelKey.interestOps(SelectionKey.OP_READ or SelectionKey.OP_WRITE) + } + + private fun TcpPipe.doWrite() { + if (tryFlushWrite(this)) { + remoteSocketChannelKey.interestOps(SelectionKey.OP_READ) + } + } + + private fun TcpPipe.clean() { + kotlin.runCatching { + if (remoteSocketChannel.isOpen) { + remoteSocketChannel.close() + } + remoteOutBuffer = null + pipeMap.remove(tunnelKey) + }.exceptionOrNull()?.printStackTrace() + } + + private fun TcpPipe.closeUpStream() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + kotlin.runCatching { + if (remoteSocketChannel.isOpen && remoteSocketChannel.isConnected) { + remoteSocketChannel.shutdownOutput() + } + }.exceptionOrNull()?.printStackTrace() + upActive = false + + if (!downActive) { + clean() + } + } else { + upActive = false + downActive = false + clean() + } + } + + private fun TcpPipe.closeDownStream() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + kotlin.runCatching { + if (remoteSocketChannel.isConnected) { + remoteSocketChannel.shutdownInput() + val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() + remoteSocketChannelKey.interestOps(ops) + } + sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + downActive = false + if (!upActive) { + clean() + } + } + } else { + sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + upActive = false + downActive = false + clean() + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt new file mode 100644 index 0000000000..82b91e2b8c --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt @@ -0,0 +1,74 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.worker + +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.networkToDeviceQueue +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import timber.log.Timber +import java.io.FileDescriptor +import java.io.FileOutputStream +import java.nio.channels.ClosedByInterruptException +import java.nio.channels.FileChannel +import javax.inject.Inject + +class ToDeviceQueueWorker @Inject constructor( + @Dispatcher(IO) private val dispatcher: CoroutineDispatcher, +) { + + private lateinit var vpnOutput: FileChannel + var totalOutputCount = 0L + + private val scope = CoroutineScope(dispatcher + SupervisorJob()) + + fun start(vpnFileDescriptor: FileDescriptor) { + vpnOutput = FileOutputStream(vpnFileDescriptor).channel + scope.launch { + runWorker() + } + } + + fun stop() { + scope.cancel() + } + + private suspend fun runWorker() = withContext(dispatcher) { + try { + while (scope.isActive) { + val byteBuffer = networkToDeviceQueue.take() + byteBuffer.flip() + while (byteBuffer.hasRemaining()) { + val count = vpnOutput.write(byteBuffer) + if (count > 0) { + totalOutputCount += count + } + } + } + } catch (e: InterruptedException) { + Timber.e(e.message) + } catch (e: ClosedByInterruptException) { + Timber.e(e.message) + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt new file mode 100644 index 0000000000..a0fbbfe095 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt @@ -0,0 +1,93 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.worker + +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue +import com.merxury.blocker.core.vpn.deviceToNetworkUDPQueue +import com.merxury.blocker.core.vpn.protocol.Packet +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import timber.log.Timber +import java.io.FileDescriptor +import java.io.FileInputStream +import java.io.IOException +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import javax.inject.Inject + +class ToNetworkQueueWorker @Inject constructor( + @Dispatcher(IO) private val dispatcher: CoroutineDispatcher, +) { + + private lateinit var vpnInput: FileChannel + var totalInputCount = 0L + + private val scope = CoroutineScope(dispatcher + SupervisorJob()) + + fun start(vpnFileDescriptor: FileDescriptor) { + vpnInput = FileInputStream(vpnFileDescriptor).channel + scope.launch { + runWorker() + } + } + + fun stop() { + scope.cancel() + } + + private suspend fun runWorker() = withContext(dispatcher) { + val readBuffer = ByteBuffer.allocate(16384) + while (scope.isActive) { + var readCount = 0 + try { + readCount = vpnInput.read(readBuffer) + } catch (e: IOException) { + e.printStackTrace() + continue + } + if (readCount > 0) { + readBuffer.flip() + val byteArray = ByteArray(readCount) + readBuffer.get(byteArray) + + val byteBuffer = ByteBuffer.wrap(byteArray) + totalInputCount += readCount + + val packet = Packet(byteBuffer) + if (packet.isUDP) { + deviceToNetworkUDPQueue.offer(packet) + } else if (packet.isTCP) { + deviceToNetworkTCPQueue.offer(packet) + } else { + Timber.d("Unknown packet protocol type ${packet.ip4Header?.protocolNum}") + } + } else if (readCount < 0) { + break + } + readBuffer.clear() + } + + Timber.i("ToNetworkQueueWorker finished running") + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt new file mode 100644 index 0000000000..f2cd018438 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt @@ -0,0 +1,120 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.worker + +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.model.UdpTunnel +import com.merxury.blocker.core.vpn.networkToDeviceQueue +import com.merxury.blocker.core.vpn.protocol.IpUtil +import com.merxury.blocker.core.vpn.protocol.Packet +import com.merxury.blocker.core.vpn.udpNioSelector +import com.merxury.blocker.core.vpn.udpSocketMap +import com.merxury.blocker.core.vpn.udpTunnelQueue +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import timber.log.Timber +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.DatagramChannel +import java.nio.channels.SelectionKey +import java.util.concurrent.atomic.AtomicInteger +import javax.inject.Inject + +private const val UDP_HEADER_FULL_SIZE = Packet.IP4_HEADER_SIZE + Packet.UDP_HEADER_SIZE + +class UdpReceiveWorker @Inject constructor( + @Dispatcher(IO) private val dispatcher: CoroutineDispatcher, +) { + + private var ipId = AtomicInteger() + + private val scope = CoroutineScope(dispatcher + SupervisorJob()) + + fun start() { + scope.launch { + runWorker() + } + } + + fun stop() { + scope.cancel() + } + + private fun sendUdpPacket(tunnel: UdpTunnel, source: InetSocketAddress, data: ByteArray) { + val packet = IpUtil.buildUdpPacket(tunnel.remote, tunnel.local, ipId.addAndGet(1)) + + val byteBuffer = ByteBuffer.allocate(UDP_HEADER_FULL_SIZE + data.size) + byteBuffer.apply { + position(UDP_HEADER_FULL_SIZE) + put(data) + } + packet.updateUDPBuffer(byteBuffer, data.size) + byteBuffer.position(UDP_HEADER_FULL_SIZE + data.size) + networkToDeviceQueue.offer(byteBuffer) + } + + private suspend fun runWorker() = withContext(dispatcher) { + val receiveBuffer = ByteBuffer.allocate(16384) + while (scope.isActive) { + val readyChannels = udpNioSelector.select() + while (scope.isActive) { + val tunnel = udpTunnelQueue.poll() ?: break + kotlin.runCatching { + val key = tunnel.channel.register(udpNioSelector, SelectionKey.OP_READ, tunnel) + key.interestOps(SelectionKey.OP_READ) + }.exceptionOrNull()?.printStackTrace() + } + if (readyChannels == 0) { + udpNioSelector.selectedKeys().clear() + continue + } + val keys = udpNioSelector.selectedKeys() + val iterator = keys.iterator() + while (isActive && iterator.hasNext()) { + val key = iterator.next() + iterator.remove() + if (key.isValid && key.isReadable) { + val tunnel = key.attachment() as UdpTunnel + kotlin.runCatching { + val inputChannel = key.channel() as DatagramChannel + receiveBuffer.clear() + inputChannel.read(receiveBuffer) + receiveBuffer.flip() + val data = ByteArray(receiveBuffer.remaining()) + receiveBuffer.get(data) + sendUdpPacket( + tunnel, + inputChannel.socket().localSocketAddress as InetSocketAddress, + data, + ) + }.exceptionOrNull()?.let { + Timber.e(it) + synchronized(udpSocketMap) { + udpSocketMap.remove(tunnel.id) + } + } + } + } + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt new file mode 100644 index 0000000000..81dc16ae69 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt @@ -0,0 +1,130 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.worker + +import android.net.VpnService +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.deviceToNetworkUDPQueue +import com.merxury.blocker.core.vpn.model.ManagedDatagramChannel +import com.merxury.blocker.core.vpn.model.UdpTunnel +import com.merxury.blocker.core.vpn.udpNioSelector +import com.merxury.blocker.core.vpn.udpSocketMap +import com.merxury.blocker.core.vpn.udpTunnelQueue +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import timber.log.Timber +import java.net.ConnectException +import java.net.InetSocketAddress +import java.nio.channels.DatagramChannel +import javax.inject.Inject + +class UdpSendWorker @Inject constructor( + @Dispatcher(IO) private val dispatcher: CoroutineDispatcher, +) { + private var vpnService: VpnService? = null + + private val scope = CoroutineScope(dispatcher + SupervisorJob()) + + fun start(vpnService: VpnService) { + this.vpnService = vpnService + udpTunnelQueue.clear() + scope.launch { + runWorker() + } + } + + fun stop() { + scope.cancel() + vpnService = null + } + + private suspend fun runWorker() = withContext(dispatcher) { + while (scope.isActive) { + val packet = deviceToNetworkUDPQueue.take() + + val destinationAddress = packet.ip4Header?.destinationAddress + val udpHeader = packet.udpHeader + + val destinationPort = udpHeader?.destinationPort ?: 0 + val sourcePort = udpHeader?.sourcePort + val ipAndPort = ( + destinationAddress?.hostAddress?.plus(":") + ?: "unknownHostAddress" + ) + destinationPort + ":" + sourcePort + + val managedChannel = if (!udpSocketMap.containsKey(ipAndPort)) { + val channel = DatagramChannel.open() + var channelConnectSuccess = false + channel.apply { + val socket = socket() + vpnService?.protect(socket) + try { + connect(InetSocketAddress(destinationAddress, destinationPort)) + channelConnectSuccess = true + } catch (_: ConnectException) { + } + configureBlocking(false) + } + if (!channelConnectSuccess) { + continue + } + + val tunnel = UdpTunnel( + ipAndPort, + InetSocketAddress(packet.ip4Header?.sourceAddress, udpHeader?.sourcePort ?: 0), + InetSocketAddress( + packet.ip4Header?.destinationAddress, + udpHeader?.destinationPort ?: 0, + ), + channel, + ) + udpTunnelQueue.offer(tunnel) + udpNioSelector.wakeup() + + val managedDatagramChannel = ManagedDatagramChannel(ipAndPort, channel) + synchronized(udpSocketMap) { + udpSocketMap[ipAndPort] = managedDatagramChannel + } + managedDatagramChannel + } else { + synchronized(udpSocketMap) { + udpSocketMap[ipAndPort] + ?: throw IllegalStateException("udp:udpSocketMap[$ipAndPort] should not be null") + } + } + managedChannel.lastTime = System.currentTimeMillis() + val buffer = packet.backingBuffer + kotlin.runCatching { + while (isActive && buffer?.hasRemaining() == true) { + managedChannel.channel.write(buffer) + } + }.exceptionOrNull()?.let { + Timber.e("Error sending UDP packet", it) + managedChannel.channel.close() + synchronized(udpSocketMap) { + udpSocketMap.remove(ipAndPort) + } + } + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt new file mode 100644 index 0000000000..4433e6e36d --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt @@ -0,0 +1,74 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.worker + +import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO +import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.UDP_SOCKET_IDLE_TIMEOUT +import com.merxury.blocker.core.vpn.udpSocketMap +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import timber.log.Timber +import javax.inject.Inject + +private const val INTERVAL_TIME = 5L + +class UdpSocketCleanWorker @Inject constructor( + @Dispatcher(IO)private val dispatcher: CoroutineDispatcher, +) { + + private val scope = CoroutineScope(dispatcher + SupervisorJob()) + + fun start() { + scope.launch { + runWorker() + } + } + + fun stop() { + scope.cancel() + } + + private suspend fun runWorker() = withContext(dispatcher) { + while (scope.isActive) { + synchronized(udpSocketMap) { + val iterator = udpSocketMap.iterator() + var removeCount = 0 + while (isActive && iterator.hasNext()) { + val managedDatagramChannel = iterator.next() + if (System.currentTimeMillis() - managedDatagramChannel.value.lastTime > UDP_SOCKET_IDLE_TIMEOUT * 1000) { + kotlin.runCatching { + managedDatagramChannel.value.channel.close() + }.exceptionOrNull()?.printStackTrace() + iterator.remove() + removeCount++ + } + } + if (removeCount > 0) { + Timber.d("Removed $removeCount expired inactive UDP sockets, currently active ${udpSocketMap.size}") + } + } + delay(INTERVAL_TIME * 1000) + } + } +} From a22474f21e39268cda7ee832481b130dd05862a7 Mon Sep 17 00:00:00 2001 From: lihenggui <350699171@qq.com> Date: Mon, 1 Jul 2024 19:11:46 -0700 Subject: [PATCH 7/9] Remove unused code --- .../blocker/core/vpn/protocol/Packet.kt | 2 +- .../blocker/core/vpn/protocol/TcpPipe.kt | 395 +----------------- .../blocker/core/vpn/worker/TcpWorker.kt | 10 +- .../core/vpn/worker/ToDeviceQueueWorker.kt | 2 +- .../core/vpn/worker/ToNetworkQueueWorker.kt | 2 +- .../core/vpn/worker/UdpReceiveWorker.kt | 4 +- .../blocker/core/vpn/worker/UdpSendWorker.kt | 2 +- .../core/vpn/worker/UdpSocketCleanWorker.kt | 2 +- 8 files changed, 13 insertions(+), 406 deletions(-) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt index bbc7224b74..004806dee4 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt @@ -270,7 +270,7 @@ internal class Packet { ttl = BitUtils.getUnsignedByte(buffer.get()) protocolNum = BitUtils.getUnsignedByte(buffer.get()) protocol = - com.merxury.blocker.core.vpn.protocol.Packet.IP4Header.TransportProtocol.numberToEnum( + TransportProtocol.numberToEnum( protocolNum.toInt(), ) headerChecksum = BitUtils.getUnsignedShort(buffer.short) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt index 05071289fa..32e601d669 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/TcpPipe.kt @@ -16,21 +16,12 @@ package com.merxury.blocker.core.vpn.protocol -import android.annotation.SuppressLint import android.net.VpnService -import android.os.Build -import android.util.Base64 -import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue -import com.merxury.blocker.core.vpn.networkToDeviceQueue -import com.merxury.blocker.core.vpn.protocol.Packet.TCPHeader import com.merxury.blocker.core.vpn.tcpNioSelector -import timber.log.Timber import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.SelectionKey import java.nio.channels.SocketChannel -import kotlin.experimental.and -import kotlin.experimental.or internal class TcpPipe(val tunnelKey: String, packet: Packet) { var mySequenceNum: Long = 0 @@ -61,7 +52,7 @@ internal class TcpPipe(val tunnelKey: String, packet: Packet) { var synCount = 0 fun tryConnect(vpnService: VpnService): Result { - val result = kotlin.runCatching { + val result = runCatching { vpnService.protect(remoteSocketChannel.socket()) remoteSocketChannel.connect(destinationAddress) } @@ -69,390 +60,6 @@ internal class TcpPipe(val tunnelKey: String, packet: Packet) { } companion object { - const val TAG = "TcpPipe" var tunnelIds = 0 } } - -/** - * TCP packet worker thread - * NIO - */ -@SuppressLint("StaticFieldLeak") -object TcpWorker : Runnable { - private const val TAG = "TcpSendWorker" - - private const val TCP_HEADER_SIZE = Packet.IP4_HEADER_SIZE + Packet.TCP_HEADER_SIZE - - private lateinit var thread: Thread - - private val pipeMap = HashMap() - - private var vpnService: VpnService? = null - - fun start(vpnService: VpnService) { - this.vpnService = vpnService - thread = Thread(this).apply { - name = TAG - start() - } - } - - fun stop() { - thread.interrupt() - vpnService = null - } - - override fun run() { - while (!thread.isInterrupted) { - if (vpnService == null) { - throw IllegalStateException("VpnService should not be null") - } - handleReadFromVpn() - handleSockets() - - Thread.sleep(1) - } - } - - private fun handleReadFromVpn() { - while (!thread.isInterrupted) { - val vpnService = this.vpnService ?: return - val packet = deviceToNetworkTCPQueue.poll() ?: return - val destinationAddress = packet.ip4Header?.destinationAddress - val tcpHeader = packet.tcpHeader - val destinationPort = tcpHeader?.destinationPort - val sourcePort = tcpHeader?.sourcePort - - val ipAndPort = ( - destinationAddress?.hostAddress?.plus(":") - ?: "unknown-host-address" - ) + destinationPort + ":" + sourcePort - - val tcpPipe = if (!pipeMap.containsKey(ipAndPort)) { - val pipe = TcpPipe(ipAndPort, packet) - pipe.tryConnect(vpnService) - pipeMap[ipAndPort] = pipe - pipe - } else { - pipeMap[ipAndPort] - ?: throw IllegalStateException("pipeMap should not contain null key: $ipAndPort") - } - handlePacket(packet, tcpPipe) - } - } - - private fun handleSockets() { - while (!thread.isInterrupted && tcpNioSelector.selectNow() > 0) { - val keys = tcpNioSelector.selectedKeys() - val iterator = keys.iterator() - while (!thread.isInterrupted && iterator.hasNext()) { - val key = iterator.next() - iterator.remove() - val tcpPipe: TcpPipe? = key?.attachment() as? TcpPipe - if (key.isValid) { - kotlin.runCatching { - if (key.isAcceptable) { - throw RuntimeException("key.isAcceptable") - } else if (key.isReadable) { - tcpPipe?.doRead() - } else if (key.isConnectable) { - tcpPipe?.doConnect() - } else if (key.isWritable) { - tcpPipe?.doWrite() - } else { - tcpPipe?.closeRst() - } - null - }.exceptionOrNull()?.let { - Timber.d( - "Error communicating with target: ${ - Base64.encodeToString( - tcpPipe?.destinationAddress.toString().toByteArray(), - Base64.DEFAULT, - ) - }", - ) - it.printStackTrace() - tcpPipe?.closeRst() - } - } - } - } - } - - private fun handlePacket(packet: Packet, tcpPipe: TcpPipe) { - val tcpHeader = packet.tcpHeader ?: return - when { - tcpHeader.isSYN -> { - handleSyn(packet, tcpPipe) - } - - tcpHeader.isRST -> { - handleRst(tcpPipe) - } - - tcpHeader.isFIN -> { - handleFin(packet, tcpPipe) - } - - tcpHeader.isACK -> { - handleAck(packet, tcpPipe) - } - } - } - - private fun handleSyn(packet: Packet, tcpPipe: TcpPipe) { - if (tcpPipe.tcbStatus == TcbStatus.SYN_SENT) { - tcpPipe.tcbStatus = TcbStatus.SYN_RECEIVED - } - val tcpHeader = packet.tcpHeader - tcpPipe.apply { - if (synCount == 0) { - mySequenceNum = 1 - theirSequenceNum = tcpHeader?.sequenceNumber ?: 0 - myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 - theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 - sendTcpPack(this, TCPHeader.SYN.toByte() or TCPHeader.ACK.toByte()) - } else { - myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 - } - synCount++ - } - } - - private fun handleRst(tcpPipe: TcpPipe) { - tcpPipe.apply { - upActive = false - downActive = false - clean() - tcbStatus = TcbStatus.CLOSE_WAIT - } - } - - private fun handleFin(packet: Packet, tcpPipe: TcpPipe) { - tcpPipe.myAcknowledgementNum = packet.tcpHeader?.sequenceNumber?.plus(1) ?: 0 - tcpPipe.theirAcknowledgementNum = packet.tcpHeader?.acknowledgementNumber?.plus(1) ?: 0 - sendTcpPack(tcpPipe, TCPHeader.ACK.toByte()) - tcpPipe.closeUpStream() - tcpPipe.tcbStatus = TcbStatus.CLOSE_WAIT - } - - private fun handleAck(packet: Packet, tcpPipe: TcpPipe) { - if (tcpPipe.tcbStatus == TcbStatus.SYN_RECEIVED) { - tcpPipe.tcbStatus = TcbStatus.ESTABLISHED - } - - val tcpHeader = packet.tcpHeader - val payloadSize = packet.backingBuffer?.remaining() ?: 0 - - if (payloadSize == 0) { - return - } - - val newAck = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 - if (newAck <= tcpPipe.myAcknowledgementNum) { - return - } - - tcpPipe.apply { - myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(payloadSize) ?: 0 - theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 - remoteOutBuffer = packet.backingBuffer - tryFlushWrite(this) - sendTcpPack(this, TCPHeader.ACK.toByte()) - } - } - - /** - * Send TCP packet - */ - private fun sendTcpPack(tcpPipe: TcpPipe, flag: Byte, data: ByteArray? = null) { - val dataSize = data?.size ?: 0 - - val packet = IpUtil.buildTcpPacket( - tcpPipe.destinationAddress, - tcpPipe.sourceAddress, - flag, - tcpPipe.myAcknowledgementNum, - tcpPipe.mySequenceNum, - tcpPipe.packId, - ) - tcpPipe.packId++ - - val byteBuffer = ByteBuffer.allocate(TCP_HEADER_SIZE + dataSize) - byteBuffer.position(TCP_HEADER_SIZE) - - data?.let { - byteBuffer.put(it) - } - - packet.updateTCPBuffer( - byteBuffer, - flag, - tcpPipe.mySequenceNum, - tcpPipe.myAcknowledgementNum, - dataSize, - ) - packet.release() - - byteBuffer.position(TCP_HEADER_SIZE + dataSize) - - networkToDeviceQueue.offer(byteBuffer) - - if ((flag and TCPHeader.SYN.toByte()) != 0.toByte()) { - tcpPipe.mySequenceNum++ - } - if ((flag and TCPHeader.FIN.toByte()) != 0.toByte()) { - tcpPipe.mySequenceNum++ - } - if ((flag and TCPHeader.ACK.toByte()) != 0.toByte()) { - tcpPipe.mySequenceNum += dataSize - } - } - - /** - * Write data to the remote - */ - private fun tryFlushWrite(tcpPipe: TcpPipe): Boolean { - val channel: SocketChannel = tcpPipe.remoteSocketChannel - val buffer = tcpPipe.remoteOutBuffer - - if (tcpPipe.remoteSocketChannel.socket().isOutputShutdown && buffer?.remaining() != 0) { - sendTcpPack(tcpPipe, TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte()) - buffer?.compact() - return false - } - - if (!channel.isConnected) { - val key = tcpPipe.remoteSocketChannelKey - val ops = key.interestOps() or SelectionKey.OP_WRITE - key.interestOps(ops) - buffer?.compact() - return false - } - - while (!thread.isInterrupted && buffer?.hasRemaining() == true) { - val n = kotlin.runCatching { - channel.write(buffer) - } - if (n.isFailure) return false - if (n.getOrThrow() <= 0) { - val key = tcpPipe.remoteSocketChannelKey - val ops = key.interestOps() or SelectionKey.OP_WRITE - key.interestOps(ops) - buffer.compact() - return false - } - } - buffer?.clear() - if (!tcpPipe.upActive) { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - tcpPipe.remoteSocketChannel.shutdownOutput() - } else { - // todo The following line will cause the socket to be incorrectly handled, but what if we don't handle it here? - // tcpPipe.remoteSocketChannel.close() - } - } - return true - } - - private fun TcpPipe.closeRst() { - Timber.d("closeRst $tunnelId") - clean() - sendTcpPack(this, TCPHeader.RST.toByte()) - upActive = false - downActive = false - } - - private fun TcpPipe.doRead() { - val buffer = ByteBuffer.allocate(4096) - var isQuitType = false - - while (!thread.isInterrupted) { - buffer.clear() - val length = remoteSocketChannel.read(buffer) - if (length == -1) { - isQuitType = true - break - } else if (length == 0) { - break - } else { - if (tcbStatus != TcbStatus.CLOSE_WAIT) { - buffer.flip() - val dataByteArray = ByteArray(buffer.remaining()) - buffer.get(dataByteArray) - sendTcpPack(this, TCPHeader.ACK.toByte(), dataByteArray) - } - } - } - - if (isQuitType) { - closeDownStream() - } - } - - private fun TcpPipe.doConnect() { - remoteSocketChannel.finishConnect() - timestamp = System.currentTimeMillis() - remoteOutBuffer?.flip() - remoteSocketChannelKey.interestOps(SelectionKey.OP_READ or SelectionKey.OP_WRITE) - } - - private fun TcpPipe.doWrite() { - if (tryFlushWrite(this)) { - remoteSocketChannelKey.interestOps(SelectionKey.OP_READ) - } - } - - private fun TcpPipe.clean() { - kotlin.runCatching { - if (remoteSocketChannel.isOpen) { - remoteSocketChannel.close() - } - remoteOutBuffer = null - pipeMap.remove(tunnelKey) - }.exceptionOrNull()?.printStackTrace() - } - - private fun TcpPipe.closeUpStream() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - kotlin.runCatching { - if (remoteSocketChannel.isOpen && remoteSocketChannel.isConnected) { - remoteSocketChannel.shutdownOutput() - } - }.exceptionOrNull()?.printStackTrace() - upActive = false - - if (!downActive) { - clean() - } - } else { - upActive = false - downActive = false - clean() - } - } - - private fun TcpPipe.closeDownStream() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - kotlin.runCatching { - if (remoteSocketChannel.isConnected) { - remoteSocketChannel.shutdownInput() - val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() - remoteSocketChannelKey.interestOps(ops) - } - sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) - downActive = false - if (!upActive) { - clean() - } - } - } else { - sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) - upActive = false - downActive = false - clean() - } - } -} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt index a5ffd5bd7a..f25536efd6 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt @@ -115,7 +115,7 @@ class TcpWorker @Inject constructor( iterator.remove() val tcpPipe: TcpPipe? = key?.attachment() as? TcpPipe if (key.isValid) { - kotlin.runCatching { + runCatching { if (key.isAcceptable) { throw RuntimeException("key.isAcceptable") } else if (key.isReadable) { @@ -291,7 +291,7 @@ class TcpWorker @Inject constructor( } while (scope.isActive && buffer?.hasRemaining() == true) { - val n = kotlin.runCatching { + val n = runCatching { channel.write(buffer) } if (n.isFailure) return false @@ -364,7 +364,7 @@ class TcpWorker @Inject constructor( } private fun TcpPipe.clean() { - kotlin.runCatching { + runCatching { if (remoteSocketChannel.isOpen) { remoteSocketChannel.close() } @@ -375,7 +375,7 @@ class TcpWorker @Inject constructor( private fun TcpPipe.closeUpStream() { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - kotlin.runCatching { + runCatching { if (remoteSocketChannel.isOpen && remoteSocketChannel.isConnected) { remoteSocketChannel.shutdownOutput() } @@ -394,7 +394,7 @@ class TcpWorker @Inject constructor( private fun TcpPipe.closeDownStream() { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { - kotlin.runCatching { + runCatching { if (remoteSocketChannel.isConnected) { remoteSocketChannel.shutdownInput() val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt index 82b91e2b8c..3a4570c531 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToDeviceQueueWorker.kt @@ -38,7 +38,7 @@ class ToDeviceQueueWorker @Inject constructor( ) { private lateinit var vpnOutput: FileChannel - var totalOutputCount = 0L + private var totalOutputCount = 0L private val scope = CoroutineScope(dispatcher + SupervisorJob()) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt index a0fbbfe095..36e8b4a28d 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt @@ -59,7 +59,7 @@ class ToNetworkQueueWorker @Inject constructor( private suspend fun runWorker() = withContext(dispatcher) { val readBuffer = ByteBuffer.allocate(16384) while (scope.isActive) { - var readCount = 0 + var readCount: Int try { readCount = vpnInput.read(readBuffer) } catch (e: IOException) { diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt index f2cd018438..22d71aa7a8 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt @@ -79,7 +79,7 @@ class UdpReceiveWorker @Inject constructor( val readyChannels = udpNioSelector.select() while (scope.isActive) { val tunnel = udpTunnelQueue.poll() ?: break - kotlin.runCatching { + runCatching { val key = tunnel.channel.register(udpNioSelector, SelectionKey.OP_READ, tunnel) key.interestOps(SelectionKey.OP_READ) }.exceptionOrNull()?.printStackTrace() @@ -95,7 +95,7 @@ class UdpReceiveWorker @Inject constructor( iterator.remove() if (key.isValid && key.isReadable) { val tunnel = key.attachment() as UdpTunnel - kotlin.runCatching { + runCatching { val inputChannel = key.channel() as DatagramChannel receiveBuffer.clear() inputChannel.read(receiveBuffer) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt index 81dc16ae69..43c0a6b77e 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt @@ -114,7 +114,7 @@ class UdpSendWorker @Inject constructor( } managedChannel.lastTime = System.currentTimeMillis() val buffer = packet.backingBuffer - kotlin.runCatching { + runCatching { while (isActive && buffer?.hasRemaining() == true) { managedChannel.channel.write(buffer) } diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt index 4433e6e36d..eff3790200 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSocketCleanWorker.kt @@ -57,7 +57,7 @@ class UdpSocketCleanWorker @Inject constructor( while (isActive && iterator.hasNext()) { val managedDatagramChannel = iterator.next() if (System.currentTimeMillis() - managedDatagramChannel.value.lastTime > UDP_SOCKET_IDLE_TIMEOUT * 1000) { - kotlin.runCatching { + runCatching { managedDatagramChannel.value.channel.close() }.exceptionOrNull()?.printStackTrace() iterator.remove() From 6d34ce504248456f9b8fe6603129e4a84953a9b7 Mon Sep 17 00:00:00 2001 From: lihenggui Date: Tue, 2 Jul 2024 14:53:38 -0700 Subject: [PATCH 8/9] Kotlinfy code Change-Id: I546223f8e899cdde6cd2643defc4293444691da8 --- .../core/vpn/extension/UnsignedExtensions.kt | 23 ++ .../blocker/core/vpn/model/Ip4Header.kt | 78 ++++ .../blocker/core/vpn/model/TcpHeader.kt | 175 +++++++++ .../core/vpn/model/TransportProtocol.kt | 32 ++ .../blocker/core/vpn/model/UdpHeader.kt | 41 +++ .../blocker/core/vpn/protocol/IpUtil.kt | 102 ------ .../blocker/core/vpn/protocol/Packet.kt | 337 ++---------------- .../blocker/core/vpn/worker/TcpWorker.kt | 78 +++- .../core/vpn/worker/ToNetworkQueueWorker.kt | 4 +- .../core/vpn/worker/UdpReceiveWorker.kt | 38 +- 10 files changed, 483 insertions(+), 425 deletions(-) create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/extension/UnsignedExtensions.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TransportProtocol.kt create mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpHeader.kt delete mode 100644 core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/extension/UnsignedExtensions.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/extension/UnsignedExtensions.kt new file mode 100644 index 0000000000..a02f7fe5e1 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/extension/UnsignedExtensions.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.extension + +internal fun Byte.toUnsignedByte(): Short = (this.toInt() and 0xFF).toShort() + +internal fun Short.toUnsignedShort(): Int = this.toInt() and 0xFFFF + +internal fun Int.toUnsignedInt(): Long = this.toLong() and 0xFFFFFFFFL diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt new file mode 100644 index 0000000000..fc8cc97f11 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt @@ -0,0 +1,78 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.model +import com.merxury.blocker.core.vpn.extension.toUnsignedByte +import com.merxury.blocker.core.vpn.extension.toUnsignedShort +import java.net.InetAddress +import java.net.UnknownHostException +import java.nio.ByteBuffer + +internal data class Ip4Header( + var version: Byte = 0, + var ihl: Byte = 0, + var headerLength: Int = 0, + var typeOfService: Short = 0, + var totalLength: Int = 0, + var identificationAndFlagsAndFragmentOffset: Int = 0, + var ttl: Short = 0, + var protocolNum: Short = 0, + var protocol: TransportProtocol? = null, + var headerChecksum: Int = 0, + var sourceAddress: InetAddress? = null, + var destinationAddress: InetAddress? = null, + var optionsAndPadding: Int = 0, +) { + @Throws(UnknownHostException::class) + constructor(buffer: ByteBuffer) : this() { + val versionAndIHL = buffer.get() + version = (versionAndIHL.toInt() shr 4).toByte() + ihl = (versionAndIHL.toInt() and 0x0F).toByte() + headerLength = ihl.toInt() shl 2 + + typeOfService = buffer.get().toUnsignedByte() + totalLength = buffer.short.toUnsignedShort() + + identificationAndFlagsAndFragmentOffset = buffer.int + + ttl = buffer.get().toUnsignedByte() + protocolNum = buffer.get().toUnsignedByte() + protocol = TransportProtocol.numberToEnum(protocolNum.toInt()) + headerChecksum = buffer.short.toUnsignedShort() + + val addressBytes = ByteArray(4) + buffer.get(addressBytes, 0, 4) + sourceAddress = InetAddress.getByAddress(addressBytes) + + buffer.get(addressBytes, 0, 4) + destinationAddress = InetAddress.getByAddress(addressBytes) + } + + fun fillHeader(buffer: ByteBuffer) { + buffer.put((version.toInt() shl 4 or ihl.toInt()).toByte()) + buffer.put(typeOfService.toByte()) + buffer.putShort(totalLength.toShort()) + + buffer.putInt(identificationAndFlagsAndFragmentOffset) + + buffer.put(ttl.toByte()) + buffer.put(protocol?.number?.toByte() ?: 0) + buffer.putShort(headerChecksum.toShort()) + + sourceAddress?.address?.let { buffer.put(it) } + destinationAddress?.address?.let { buffer.put(it) } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt new file mode 100644 index 0000000000..f64348d381 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt @@ -0,0 +1,175 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.model + +import com.merxury.blocker.core.vpn.extension.toUnsignedInt +import com.merxury.blocker.core.vpn.extension.toUnsignedShort +import com.merxury.blocker.core.vpn.protocol.Packet.Companion.TCP_HEADER_SIZE +import java.nio.ByteBuffer + +data class TcpHeader( + var sourcePort: Int = 0, + var destinationPort: Int = 0, + var sequenceNumber: Long = 0, + var acknowledgementNumber: Long = 0, + var dataOffsetAndReserved: Byte = 0, + var headerLength: Int = 0, + var flags: Byte = 0, + var window: Int = 0, + var checksum: Int = 0, + var urgentPointer: Int = 0, + var optionsAndPadding: ByteArray? = null, +) { + companion object { + const val FIN = 0x01 + const val SYN = 0x02 + const val RST = 0x04 + const val PSH = 0x08 + const val ACK = 0x10 + const val URG = 0x20 + } + + constructor(buffer: ByteBuffer) : this() { + sourcePort = buffer.short.toUnsignedShort() + destinationPort = buffer.short.toUnsignedShort() + + sequenceNumber = buffer.int.toUnsignedInt() + acknowledgementNumber = buffer.int.toUnsignedInt() + + dataOffsetAndReserved = buffer.get() + headerLength = (dataOffsetAndReserved.toInt() and 0xF0) shr 2 + flags = buffer.get() + window = buffer.short.toUnsignedShort() + + checksum = buffer.short.toUnsignedShort() + urgentPointer = buffer.short.toUnsignedShort() + + val optionsLength = headerLength - TCP_HEADER_SIZE + if (optionsLength > 0) { + optionsAndPadding = ByteArray(optionsLength) + optionsAndPadding?.let { + buffer.get(it, 0, optionsLength) + } + } + } + + val isFIN: Boolean + get() = (flags.toInt() and FIN) == FIN + + val isSYN: Boolean + get() = (flags.toInt() and SYN) == SYN + + val isRST: Boolean + get() = (flags.toInt() and RST) == RST + + val isPSH: Boolean + get() = (flags.toInt() and PSH) == PSH + + val isACK: Boolean + get() = (flags.toInt() and ACK) == ACK + + val isURG: Boolean + get() = (flags.toInt() and URG) == URG + + fun fillHeader(buffer: ByteBuffer) { + buffer.putShort(sourcePort.toShort()) + buffer.putShort(destinationPort.toShort()) + + buffer.putInt(sequenceNumber.toInt()) + buffer.putInt(acknowledgementNumber.toInt()) + + buffer.put(dataOffsetAndReserved) + buffer.put(flags) + buffer.putShort(window.toShort()) + + buffer.putShort(checksum.toShort()) + buffer.putShort(urgentPointer.toShort()) + + optionsAndPadding?.let { + buffer.put(it) + } + } + + fun printSimple(): String = buildString { + if (isFIN) append("FIN ") + if (isSYN) append("SYN ") + if (isRST) append("RST ") + if (isPSH) append("PSH ") + if (isACK) append("ACK ") + if (isURG) append("URG ") + append("seq $sequenceNumber ") + append("ack $acknowledgementNumber ") + } + + override fun toString(): String = buildString { + append("TcpHeader{") + append("sourcePort=").append(sourcePort) + append(", destinationPort=").append(destinationPort) + append(", sequenceNumber=").append(sequenceNumber) + append(", acknowledgementNumber=").append(acknowledgementNumber) + append(", headerLength=").append(headerLength) + append(", window=").append(window) + append(", checksum=").append(checksum) + append(", flags=") + if (isFIN) append(" FIN") + if (isSYN) append(" SYN") + if (isRST) append(" RST") + if (isPSH) append(" PSH") + if (isACK) append(" ACK") + if (isURG) append(" URG") + append('}') + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is TcpHeader) return false + + if (sourcePort != other.sourcePort) return false + if (destinationPort != other.destinationPort) return false + if (sequenceNumber != other.sequenceNumber) return false + if (acknowledgementNumber != other.acknowledgementNumber) return false + if (dataOffsetAndReserved != other.dataOffsetAndReserved) return false + if (headerLength != other.headerLength) return false + if (flags != other.flags) return false + if (window != other.window) return false + if (checksum != other.checksum) return false + if (urgentPointer != other.urgentPointer) return false + if (optionsAndPadding != null) { + if (other.optionsAndPadding == null) return false + if (!optionsAndPadding.contentEquals(other.optionsAndPadding)) return false + } else if (other.optionsAndPadding != null) { + return false + } + + return true + } + + override fun hashCode(): Int { + var result = sourcePort + result = 31 * result + destinationPort + result = 31 * result + sequenceNumber.hashCode() + result = 31 * result + acknowledgementNumber.hashCode() + result = 31 * result + dataOffsetAndReserved + result = 31 * result + headerLength + result = 31 * result + flags + result = 31 * result + window + result = 31 * result + checksum + result = 31 * result + urgentPointer + result = 31 * result + (optionsAndPadding?.contentHashCode() ?: 0) + return result + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TransportProtocol.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TransportProtocol.kt new file mode 100644 index 0000000000..032f750c5f --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TransportProtocol.kt @@ -0,0 +1,32 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.model + +enum class TransportProtocol(val number: Int) { + TCP(6), + UDP(17), + OTHER(0xFF), + ; + + companion object { + fun numberToEnum(protocolNumber: Int): TransportProtocol = when (protocolNumber) { + 6 -> TCP + 17 -> UDP + else -> OTHER + } + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpHeader.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpHeader.kt new file mode 100644 index 0000000000..aeee5bc111 --- /dev/null +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/UdpHeader.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2024 Blocker + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.merxury.blocker.core.vpn.model + +import com.merxury.blocker.core.vpn.extension.toUnsignedShort +import java.nio.ByteBuffer + +data class UdpHeader( + var sourcePort: Int = 0, + var destinationPort: Int = 0, + var length: Int = 0, + var checksum: Int = 0, +) { + constructor(buffer: ByteBuffer) : this( + sourcePort = buffer.short.toUnsignedShort(), + destinationPort = buffer.short.toUnsignedShort(), + length = buffer.short.toUnsignedShort(), + checksum = buffer.short.toUnsignedShort(), + ) + + fun fillHeader(buffer: ByteBuffer) { + buffer.putShort(sourcePort.toShort()) + buffer.putShort(destinationPort.toShort()) + buffer.putShort(length.toShort()) + buffer.putShort(checksum.toShort()) + } +} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt deleted file mode 100644 index 8f01845bfd..0000000000 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/IpUtil.kt +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2024 Blocker - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.merxury.blocker.core.vpn.protocol - -import java.net.InetSocketAddress - -internal object IpUtil { - fun buildUdpPacket(source: InetSocketAddress, dest: InetSocketAddress, ipId: Int): Packet { - val packet = Packet().apply { - isTCP = false - isUDP = true - } - - val ip4Header = Packet.IP4Header().apply { - version = 4 - ihl = 5 - destinationAddress = dest.address - headerChecksum = 0 - headerLength = 20 - identificationAndFlagsAndFragmentOffset = ipId shl 16 or (0x40 shl 8) or 0 - optionsAndPadding = 0 - protocol = Packet.IP4Header.TransportProtocol.UDP - protocolNum = 17 - sourceAddress = source.address - totalLength = 60 - typeOfService = 0 - ttl = 64 - } - - val udpHeader = Packet.UDPHeader().apply { - sourcePort = source.port - destinationPort = dest.port - length = 0 - } - - packet.ip4Header = ip4Header - packet.udpHeader = udpHeader - return packet - } - - fun buildTcpPacket( - source: InetSocketAddress, - dest: InetSocketAddress, - flag: Byte, - ack: Long, - seq: Long, - ipId: Int, - ): Packet { - val packet = Packet().apply { - isTCP = true - isUDP = false - } - - val ip4Header = Packet.IP4Header().apply { - version = 4 - ihl = 5 - destinationAddress = dest.address - headerChecksum = 0 - headerLength = 20 - identificationAndFlagsAndFragmentOffset = ipId shl 16 or (0x40 shl 8) or 0 - optionsAndPadding = 0 - protocol = Packet.IP4Header.TransportProtocol.TCP - protocolNum = 6 - sourceAddress = source.address - totalLength = 60 - typeOfService = 0 - ttl = 64 - } - - val tcpHeader = Packet.TCPHeader().apply { - acknowledgementNumber = ack - checksum = 0 - dataOffsetAndReserved = -96 - destinationPort = dest.port - flags = flag - headerLength = 40 - optionsAndPadding = null - sequenceNumber = seq - sourcePort = source.port - urgentPointer = 0 - window = 65535 - } - - packet.ip4Header = ip4Header - packet.tcpHeader = tcpHeader - return packet - } -} diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt index 004806dee4..4fe61c4a64 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt @@ -16,54 +16,45 @@ package com.merxury.blocker.core.vpn.protocol -import com.merxury.blocker.core.vpn.protocol.Packet.IP4Header.TransportProtocol.TCP -import com.merxury.blocker.core.vpn.protocol.Packet.IP4Header.TransportProtocol.UDP -import java.net.InetAddress +import com.merxury.blocker.core.vpn.extension.toUnsignedByte +import com.merxury.blocker.core.vpn.extension.toUnsignedShort +import com.merxury.blocker.core.vpn.model.Ip4Header +import com.merxury.blocker.core.vpn.model.TcpHeader +import com.merxury.blocker.core.vpn.model.TransportProtocol.TCP +import com.merxury.blocker.core.vpn.model.TransportProtocol.UDP +import com.merxury.blocker.core.vpn.model.UdpHeader import java.net.UnknownHostException import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong /** * Representation of an IP Packet */ -internal class Packet { +internal data class Packet( + var ip4Header: Ip4Header? = null, + var tcpHeader: TcpHeader? = null, + var udpHeader: UdpHeader? = null, + var backingBuffer: ByteBuffer? = null, + var isTcp: Boolean = false, + var isUdp: Boolean = false, +) { companion object { const val IP4_HEADER_SIZE = 20 const val TCP_HEADER_SIZE = 20 const val UDP_HEADER_SIZE = 8 - - val globalPackId = AtomicLong() - } - - var ip4Header: IP4Header? = null - var tcpHeader: TCPHeader? = null - var udpHeader: UDPHeader? = null - var backingBuffer: ByteBuffer? = null - - var isTCP = false - - var isUDP = false - - init { - globalPackId.incrementAndGet() } - constructor() - @Throws(UnknownHostException::class) constructor(buffer: ByteBuffer) : this() { - ip4Header = IP4Header(buffer) + ip4Header = Ip4Header(buffer) when (ip4Header?.protocol) { TCP -> { - tcpHeader = TCPHeader(buffer) - isTCP = true + tcpHeader = TcpHeader(buffer) + isTcp = true } - UDP -> { - udpHeader = UDPHeader(buffer) - isUDP = true + udpHeader = UdpHeader(buffer) + isUdp = true } - else -> {} } backingBuffer = buffer @@ -79,9 +70,9 @@ internal class Packet { override fun toString(): String = buildString { append("Packet{") append("ip4Header=").append(ip4Header) - if (isTCP) { + if (isTcp) { append(", tcpHeader=").append(tcpHeader) - } else if (isUDP) { + } else if (isUdp) { append(", udpHeader=").append(udpHeader) } append(", payloadSize=").append( @@ -90,7 +81,7 @@ internal class Packet { append('}') } - fun updateTCPBuffer( + fun updateTcpBuffer( buffer: ByteBuffer, flags: Byte, sequenceNum: Long, @@ -122,11 +113,11 @@ internal class Packet { backingBuffer?.putShort(2, ip4TotalLength.toShort()) ip4Header?.totalLength = ip4TotalLength - updateIP4Checksum() + updateIp4Checksum() } } - fun updateUDPBuffer(buffer: ByteBuffer, payloadSize: Int) { + fun updateUdpBuffer(buffer: ByteBuffer, payloadSize: Int) { buffer.position(0) fillHeader(buffer) backingBuffer = buffer @@ -144,11 +135,11 @@ internal class Packet { backingBuffer?.putShort(2, ip4TotalLength.toShort()) ip4Header?.totalLength = ip4TotalLength - updateIP4Checksum() + updateIp4Checksum() } } - private fun updateIP4Checksum() { + private fun updateIp4Checksum() { val buffer = backingBuffer?.duplicate() ?: return buffer.position(0) @@ -158,7 +149,7 @@ internal class Packet { var ipLength = ip4Header?.headerLength ?: return var sum = 0 while (ipLength > 0) { - sum += BitUtils.getUnsignedShort(buffer.short) + sum += buffer.short.toUnsignedShort() ipLength -= 2 } while (sum shr 16 > 0) { @@ -177,12 +168,12 @@ internal class Packet { // Calculate pseudo-header checksum ip4Header?.sourceAddress?.address?.let { sourceAddress -> val buffer = ByteBuffer.wrap(sourceAddress) - sum = BitUtils.getUnsignedShort(buffer.short) + BitUtils.getUnsignedShort(buffer.short) + sum = buffer.short.toUnsignedShort() + buffer.short.toUnsignedShort() } ip4Header?.destinationAddress?.address?.let { destinationAddress -> val buffer = ByteBuffer.wrap(destinationAddress) - sum += BitUtils.getUnsignedShort(buffer.short) + BitUtils.getUnsignedShort(buffer.short) + sum += buffer.short.toUnsignedShort() + buffer.short.toUnsignedShort() } sum += TCP.number + tcpLength @@ -194,11 +185,11 @@ internal class Packet { // Calculate TCP segment checksum buffer.position(IP4_HEADER_SIZE) while (tcpLength > 1) { - sum += BitUtils.getUnsignedShort(buffer.short) + sum += buffer.short.toUnsignedShort() tcpLength -= 2 } if (tcpLength > 0) { - sum += BitUtils.getUnsignedByte(buffer.get()).toInt() shl 8 + sum += buffer.get().toUnsignedByte().toInt() shl 8 } while (sum shr 16 > 0) { @@ -212,270 +203,10 @@ internal class Packet { private fun fillHeader(buffer: ByteBuffer) { ip4Header?.fillHeader(buffer) - if (isUDP) { + if (isUdp) { udpHeader?.fillHeader(buffer) - } else if (isTCP) { + } else if (isTcp) { tcpHeader?.fillHeader(buffer) } } - - class IP4Header { - var version: Byte = 0 - var ihl: Byte = 0 - var headerLength: Int = 0 - var typeOfService: Short = 0 - var totalLength: Int = 0 - - var identificationAndFlagsAndFragmentOffset: Int = 0 - - var ttl: Short = 0 - var protocolNum: Short = 0 - var protocol: TransportProtocol? = null - var headerChecksum: Int = 0 - - var sourceAddress: InetAddress? = null - var destinationAddress: InetAddress? = null - - var optionsAndPadding: Int = 0 - - enum class TransportProtocol(val number: Int) { - TCP(6), - UDP(17), - OTHER(0xFF), - ; - - companion object { - fun numberToEnum(protocolNumber: Int): TransportProtocol = when (protocolNumber) { - 6 -> TCP - 17 -> UDP - else -> OTHER - } - } - } - - constructor() - - @Throws(UnknownHostException::class) - constructor(buffer: ByteBuffer) { - val versionAndIHL = buffer.get() - version = (versionAndIHL.toInt() shr 4).toByte() - ihl = (versionAndIHL.toInt() and 0x0F).toByte() - headerLength = ihl.toInt() shl 2 - - typeOfService = BitUtils.getUnsignedByte(buffer.get()) - totalLength = BitUtils.getUnsignedShort(buffer.short) - - identificationAndFlagsAndFragmentOffset = buffer.int - - ttl = BitUtils.getUnsignedByte(buffer.get()) - protocolNum = BitUtils.getUnsignedByte(buffer.get()) - protocol = - TransportProtocol.numberToEnum( - protocolNum.toInt(), - ) - headerChecksum = BitUtils.getUnsignedShort(buffer.short) - - val addressBytes = ByteArray(4) - buffer.get(addressBytes, 0, 4) - sourceAddress = InetAddress.getByAddress(addressBytes) - - buffer.get(addressBytes, 0, 4) - destinationAddress = InetAddress.getByAddress(addressBytes) - } - - fun fillHeader(buffer: ByteBuffer) { - buffer.put((version.toInt() shl 4 or ihl.toInt()).toByte()) - buffer.put(typeOfService.toByte()) - buffer.putShort(totalLength.toShort()) - - buffer.putInt(identificationAndFlagsAndFragmentOffset) - - buffer.put(ttl.toByte()) - buffer.put(protocol?.number?.toByte() ?: 0) - buffer.putShort(headerChecksum.toShort()) - - sourceAddress?.address?.let { buffer.put(it) } - destinationAddress?.address?.let { buffer.put(it) } - } - - override fun toString(): String = buildString { - append("IP4Header{") - append("version=").append(version) - append(", IHL=").append(ihl) - append(", typeOfService=").append(typeOfService) - append(", totalLength=").append(totalLength) - append(", identificationAndFlagsAndFragmentOffset=").append( - identificationAndFlagsAndFragmentOffset, - ) - append(", TTL=").append(ttl) - append(", protocol=").append(protocolNum).append(":").append(protocol) - append(", headerChecksum=").append(headerChecksum) - append(", sourceAddress=").append(sourceAddress?.hostAddress) - append(", destinationAddress=").append(destinationAddress?.hostAddress) - append('}') - } - } - - class TCPHeader { - companion object { - const val FIN = 0x01 - const val SYN = 0x02 - const val RST = 0x04 - const val PSH = 0x08 - const val ACK = 0x10 - const val URG = 0x20 - } - - var sourcePort: Int = 0 - var destinationPort: Int = 0 - - var sequenceNumber: Long = 0 - var acknowledgementNumber: Long = 0 - - var dataOffsetAndReserved: Byte = 0 - var headerLength: Int = 0 - var flags: Byte = 0 - var window: Int = 0 - - var checksum: Int = 0 - var urgentPointer: Int = 0 - - var optionsAndPadding: ByteArray? = null - - constructor(buffer: ByteBuffer) { - sourcePort = BitUtils.getUnsignedShort(buffer.short) - destinationPort = BitUtils.getUnsignedShort(buffer.short) - - sequenceNumber = BitUtils.getUnsignedInt(buffer.int) - acknowledgementNumber = BitUtils.getUnsignedInt(buffer.int) - - dataOffsetAndReserved = buffer.get() - headerLength = (dataOffsetAndReserved.toInt() and 0xF0) shr 2 - flags = buffer.get() - window = BitUtils.getUnsignedShort(buffer.short) - - checksum = BitUtils.getUnsignedShort(buffer.short) - urgentPointer = BitUtils.getUnsignedShort(buffer.short) - - val optionsLength = headerLength - TCP_HEADER_SIZE - if (optionsLength > 0) { - optionsAndPadding = ByteArray(optionsLength) - optionsAndPadding?.let { - buffer.get(it, 0, optionsLength) - } - } - } - - constructor() - - val isFIN: Boolean - get() = (flags.toInt() and FIN) == FIN - - val isSYN: Boolean - get() = (flags.toInt() and SYN) == SYN - - val isRST: Boolean - get() = (flags.toInt() and RST) == RST - - val isPSH: Boolean - get() = (flags.toInt() and PSH) == PSH - - val isACK: Boolean - get() = (flags.toInt() and ACK) == ACK - - val isURG: Boolean - get() = (flags.toInt() and URG) == URG - - fun fillHeader(buffer: ByteBuffer) { - buffer.putShort(sourcePort.toShort()) - buffer.putShort(destinationPort.toShort()) - - buffer.putInt(sequenceNumber.toInt()) - buffer.putInt(acknowledgementNumber.toInt()) - - buffer.put(dataOffsetAndReserved) - buffer.put(flags) - buffer.putShort(window.toShort()) - - buffer.putShort(checksum.toShort()) - buffer.putShort(urgentPointer.toShort()) - - optionsAndPadding?.let { - buffer.put(it) - } - } - - fun printSimple(): String = buildString { - if (isFIN) append("FIN ") - if (isSYN) append("SYN ") - if (isRST) append("RST ") - if (isPSH) append("PSH ") - if (isACK) append("ACK ") - if (isURG) append("URG ") - append("seq $sequenceNumber ") - append("ack $acknowledgementNumber ") - } - - override fun toString(): String = buildString { - append("TCPHeader{") - append("sourcePort=").append(sourcePort) - append(", destinationPort=").append(destinationPort) - append(", sequenceNumber=").append(sequenceNumber) - append(", acknowledgementNumber=").append(acknowledgementNumber) - append(", headerLength=").append(headerLength) - append(", window=").append(window) - append(", checksum=").append(checksum) - append(", flags=") - if (isFIN) append(" FIN") - if (isSYN) append(" SYN") - if (isRST) append(" RST") - if (isPSH) append(" PSH") - if (isACK) append(" ACK") - if (isURG) append(" URG") - append('}') - } - } - - class UDPHeader { - var sourcePort: Int = 0 - var destinationPort: Int = 0 - - var length: Int = 0 - var checksum: Int = 0 - - constructor() - - constructor(buffer: ByteBuffer) { - sourcePort = BitUtils.getUnsignedShort(buffer.short) - destinationPort = BitUtils.getUnsignedShort(buffer.short) - - length = BitUtils.getUnsignedShort(buffer.short) - checksum = BitUtils.getUnsignedShort(buffer.short) - } - - fun fillHeader(buffer: ByteBuffer) { - buffer.putShort(sourcePort.toShort()) - buffer.putShort(destinationPort.toShort()) - - buffer.putShort(length.toShort()) - buffer.putShort(checksum.toShort()) - } - - override fun toString(): String = buildString { - append("UDPHeader{") - append("sourcePort=").append(sourcePort) - append(", destinationPort=").append(destinationPort) - append(", length=").append(length) - append(", checksum=").append(checksum) - append('}') - } - } - - private object BitUtils { - fun getUnsignedByte(value: Byte): Short = (value.toInt() and 0xFF).toShort() - - fun getUnsignedShort(value: Short): Int = value.toInt() and 0xFFFF - - fun getUnsignedInt(value: Int): Long = value.toLong() and 0xFFFFFFFFL - } } diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt index f25536efd6..73570799dd 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt @@ -22,10 +22,11 @@ import android.util.Base64 import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO import com.merxury.blocker.core.dispatchers.Dispatcher import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue +import com.merxury.blocker.core.vpn.model.Ip4Header +import com.merxury.blocker.core.vpn.model.TcpHeader +import com.merxury.blocker.core.vpn.model.TransportProtocol import com.merxury.blocker.core.vpn.networkToDeviceQueue -import com.merxury.blocker.core.vpn.protocol.IpUtil import com.merxury.blocker.core.vpn.protocol.Packet -import com.merxury.blocker.core.vpn.protocol.Packet.TCPHeader import com.merxury.blocker.core.vpn.protocol.TcbStatus import com.merxury.blocker.core.vpn.protocol.TcpPipe import com.merxury.blocker.core.vpn.tcpNioSelector @@ -38,6 +39,7 @@ import kotlinx.coroutines.isActive import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import timber.log.Timber +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.SelectionKey import java.nio.channels.SocketChannel @@ -177,7 +179,7 @@ class TcpWorker @Inject constructor( theirSequenceNum = tcpHeader?.sequenceNumber ?: 0 myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 - sendTcpPack(this, TCPHeader.SYN.toByte() or TCPHeader.ACK.toByte()) + sendTcpPack(this, TcpHeader.SYN.toByte() or TcpHeader.ACK.toByte()) } else { myAcknowledgementNum = tcpHeader?.sequenceNumber?.plus(1) ?: 0 } @@ -197,7 +199,7 @@ class TcpWorker @Inject constructor( private fun handleFin(packet: Packet, tcpPipe: TcpPipe) { tcpPipe.myAcknowledgementNum = packet.tcpHeader?.sequenceNumber?.plus(1) ?: 0 tcpPipe.theirAcknowledgementNum = packet.tcpHeader?.acknowledgementNumber?.plus(1) ?: 0 - sendTcpPack(tcpPipe, TCPHeader.ACK.toByte()) + sendTcpPack(tcpPipe, TcpHeader.ACK.toByte()) tcpPipe.closeUpStream() tcpPipe.tcbStatus = TcbStatus.CLOSE_WAIT } @@ -224,14 +226,14 @@ class TcpWorker @Inject constructor( theirAcknowledgementNum = tcpHeader?.acknowledgementNumber ?: 0 remoteOutBuffer = packet.backingBuffer tryFlushWrite(this) - sendTcpPack(this, TCPHeader.ACK.toByte()) + sendTcpPack(this, TcpHeader.ACK.toByte()) } } private fun sendTcpPack(tcpPipe: TcpPipe, flag: Byte, data: ByteArray? = null) { val dataSize = data?.size ?: 0 - val packet = IpUtil.buildTcpPacket( + val packet = buildTcpPacket( tcpPipe.destinationAddress, tcpPipe.sourceAddress, flag, @@ -248,7 +250,7 @@ class TcpWorker @Inject constructor( byteBuffer.put(it) } - packet.updateTCPBuffer( + packet.updateTcpBuffer( byteBuffer, flag, tcpPipe.mySequenceNum, @@ -261,23 +263,69 @@ class TcpWorker @Inject constructor( networkToDeviceQueue.offer(byteBuffer) - if ((flag and TCPHeader.SYN.toByte()) != 0.toByte()) { + if ((flag and TcpHeader.SYN.toByte()) != 0.toByte()) { tcpPipe.mySequenceNum++ } - if ((flag and TCPHeader.FIN.toByte()) != 0.toByte()) { + if ((flag and TcpHeader.FIN.toByte()) != 0.toByte()) { tcpPipe.mySequenceNum++ } - if ((flag and TCPHeader.ACK.toByte()) != 0.toByte()) { + if ((flag and TcpHeader.ACK.toByte()) != 0.toByte()) { tcpPipe.mySequenceNum += dataSize } } + private fun buildTcpPacket( + source: InetSocketAddress, + dest: InetSocketAddress, + flag: Byte, + ack: Long, + seq: Long, + ipId: Int, + ): Packet { + val ip4Header = Ip4Header( + version = 4, + ihl = 5, + destinationAddress = dest.address, + headerChecksum = 0, + headerLength = 20, + identificationAndFlagsAndFragmentOffset = ipId shl 16 or (0x40 shl 8) or 0, + optionsAndPadding = 0, + protocol = TransportProtocol.TCP, + protocolNum = 6, + sourceAddress = source.address, + totalLength = 60, + typeOfService = 0, + ttl = 64, + ) + + val tcpHeader = TcpHeader( + acknowledgementNumber = ack, + checksum = 0, + dataOffsetAndReserved = -96, + destinationPort = dest.port, + flags = flag, + headerLength = 40, + optionsAndPadding = null, + sequenceNumber = seq, + sourcePort = source.port, + urgentPointer = 0, + window = 65535, + ) + + return Packet( + isTcp = true, + isUdp = false, + ip4Header = ip4Header, + tcpHeader = tcpHeader, + ) + } + private fun tryFlushWrite(tcpPipe: TcpPipe): Boolean { val channel: SocketChannel = tcpPipe.remoteSocketChannel val buffer = tcpPipe.remoteOutBuffer if (tcpPipe.remoteSocketChannel.socket().isOutputShutdown && buffer?.remaining() != 0) { - sendTcpPack(tcpPipe, TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte()) + sendTcpPack(tcpPipe, TcpHeader.FIN.toByte() or TcpHeader.ACK.toByte()) buffer?.compact() return false } @@ -318,7 +366,7 @@ class TcpWorker @Inject constructor( private fun TcpPipe.closeRst() { Timber.d("closeRst $tunnelId") clean() - sendTcpPack(this, TCPHeader.RST.toByte()) + sendTcpPack(this, TcpHeader.RST.toByte()) upActive = false downActive = false } @@ -340,7 +388,7 @@ class TcpWorker @Inject constructor( buffer.flip() val dataByteArray = ByteArray(buffer.remaining()) buffer.get(dataByteArray) - sendTcpPack(this, TCPHeader.ACK.toByte(), dataByteArray) + sendTcpPack(this, TcpHeader.ACK.toByte(), dataByteArray) } } } @@ -400,14 +448,14 @@ class TcpWorker @Inject constructor( val ops = remoteSocketChannelKey.interestOps() and SelectionKey.OP_READ.inv() remoteSocketChannelKey.interestOps(ops) } - sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + sendTcpPack(this, (TcpHeader.FIN.toByte() or TcpHeader.ACK.toByte())) downActive = false if (!upActive) { clean() } } } else { - sendTcpPack(this, (TCPHeader.FIN.toByte() or TCPHeader.ACK.toByte())) + sendTcpPack(this, (TcpHeader.FIN.toByte() or TcpHeader.ACK.toByte())) upActive = false downActive = false clean() diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt index 36e8b4a28d..21b0acb34e 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt @@ -75,9 +75,9 @@ class ToNetworkQueueWorker @Inject constructor( totalInputCount += readCount val packet = Packet(byteBuffer) - if (packet.isUDP) { + if (packet.isUdp) { deviceToNetworkUDPQueue.offer(packet) - } else if (packet.isTCP) { + } else if (packet.isTcp) { deviceToNetworkTCPQueue.offer(packet) } else { Timber.d("Unknown packet protocol type ${packet.ip4Header?.protocolNum}") diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt index 22d71aa7a8..2d153bcbfb 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpReceiveWorker.kt @@ -18,9 +18,11 @@ package com.merxury.blocker.core.vpn.worker import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO import com.merxury.blocker.core.dispatchers.Dispatcher +import com.merxury.blocker.core.vpn.model.Ip4Header +import com.merxury.blocker.core.vpn.model.TransportProtocol +import com.merxury.blocker.core.vpn.model.UdpHeader import com.merxury.blocker.core.vpn.model.UdpTunnel import com.merxury.blocker.core.vpn.networkToDeviceQueue -import com.merxury.blocker.core.vpn.protocol.IpUtil import com.merxury.blocker.core.vpn.protocol.Packet import com.merxury.blocker.core.vpn.udpNioSelector import com.merxury.blocker.core.vpn.udpSocketMap @@ -61,18 +63,48 @@ class UdpReceiveWorker @Inject constructor( } private fun sendUdpPacket(tunnel: UdpTunnel, source: InetSocketAddress, data: ByteArray) { - val packet = IpUtil.buildUdpPacket(tunnel.remote, tunnel.local, ipId.addAndGet(1)) + val packet = buildUdpPacket(tunnel.remote, tunnel.local, ipId.addAndGet(1)) val byteBuffer = ByteBuffer.allocate(UDP_HEADER_FULL_SIZE + data.size) byteBuffer.apply { position(UDP_HEADER_FULL_SIZE) put(data) } - packet.updateUDPBuffer(byteBuffer, data.size) + packet.updateUdpBuffer(byteBuffer, data.size) byteBuffer.position(UDP_HEADER_FULL_SIZE + data.size) networkToDeviceQueue.offer(byteBuffer) } + private fun buildUdpPacket(source: InetSocketAddress, dest: InetSocketAddress, ipId: Int): Packet { + val ip4Header = Ip4Header( + version = 4, + ihl = 5, + destinationAddress = dest.address, + headerChecksum = 0, + headerLength = 20, + identificationAndFlagsAndFragmentOffset = ipId shl 16 or (0x40 shl 8) or 0, + optionsAndPadding = 0, + protocol = TransportProtocol.UDP, + protocolNum = 17, + sourceAddress = source.address, + totalLength = 60, + typeOfService = 0, + ttl = 64, + ) + + val udpHeader = UdpHeader( + sourcePort = source.port, + destinationPort = dest.port, + length = 0, + ) + return Packet( + isTcp = false, + isUdp = true, + ip4Header = ip4Header, + udpHeader = udpHeader, + ) + } + private suspend fun runWorker() = withContext(dispatcher) { val receiveBuffer = ByteBuffer.allocate(16384) while (scope.isActive) { From a836ae7fc727613cd61d8352f700414d9cf5d3a5 Mon Sep 17 00:00:00 2001 From: lihenggui <350699171@qq.com> Date: Tue, 2 Jul 2024 18:44:08 -0700 Subject: [PATCH 9/9] Kotlinfy code --- .../com/merxury/blocker/core/vpn/VpnQueue.kt | 4 +-- .../blocker/core/vpn/model/Ip4Header.kt | 7 ++-- .../blocker/core/vpn/model/TcpHeader.kt | 36 +++++++++---------- .../blocker/core/vpn/protocol/Packet.kt | 4 +-- .../blocker/core/vpn/worker/TcpWorker.kt | 13 +++---- .../core/vpn/worker/ToNetworkQueueWorker.kt | 8 ++--- .../blocker/core/vpn/worker/UdpSendWorker.kt | 4 +-- 7 files changed, 39 insertions(+), 37 deletions(-) diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt index 6bcfb116d4..e1bdc74e32 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/VpnQueue.kt @@ -26,12 +26,12 @@ import java.util.concurrent.ArrayBlockingQueue /** * Queue for UDP packets sent from device to network */ -internal val deviceToNetworkUDPQueue = ArrayBlockingQueue(1024) +internal val deviceToNetworkUdpQueue = ArrayBlockingQueue(1024) /** * Queue for TCP packets sent from device to network */ -internal val deviceToNetworkTCPQueue = ArrayBlockingQueue(1024) +internal val deviceToNetworkTcpQueue = ArrayBlockingQueue(1024) /** * Queue for packets sent from network to device diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt index fc8cc97f11..c2ab82cff2 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/Ip4Header.kt @@ -23,6 +23,7 @@ import java.nio.ByteBuffer internal data class Ip4Header( var version: Byte = 0, + // Internet Header Length var ihl: Byte = 0, var headerLength: Int = 0, var typeOfService: Short = 0, @@ -38,9 +39,9 @@ internal data class Ip4Header( ) { @Throws(UnknownHostException::class) constructor(buffer: ByteBuffer) : this() { - val versionAndIHL = buffer.get() - version = (versionAndIHL.toInt() shr 4).toByte() - ihl = (versionAndIHL.toInt() and 0x0F).toByte() + val versionAndIhl = buffer.get() + version = (versionAndIhl.toInt() shr 4).toByte() + ihl = (versionAndIhl.toInt() and 0x0F).toByte() headerLength = ihl.toInt() shl 2 typeOfService = buffer.get().toUnsignedByte() diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt index f64348d381..d11de21da3 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/model/TcpHeader.kt @@ -67,22 +67,22 @@ data class TcpHeader( } } - val isFIN: Boolean + val isFin: Boolean get() = (flags.toInt() and FIN) == FIN - val isSYN: Boolean + val isSyn: Boolean get() = (flags.toInt() and SYN) == SYN - val isRST: Boolean + val isRst: Boolean get() = (flags.toInt() and RST) == RST - val isPSH: Boolean + val isPsh: Boolean get() = (flags.toInt() and PSH) == PSH - val isACK: Boolean + val isAck: Boolean get() = (flags.toInt() and ACK) == ACK - val isURG: Boolean + val isUrg: Boolean get() = (flags.toInt() and URG) == URG fun fillHeader(buffer: ByteBuffer) { @@ -105,12 +105,12 @@ data class TcpHeader( } fun printSimple(): String = buildString { - if (isFIN) append("FIN ") - if (isSYN) append("SYN ") - if (isRST) append("RST ") - if (isPSH) append("PSH ") - if (isACK) append("ACK ") - if (isURG) append("URG ") + if (isFin) append("FIN ") + if (isSyn) append("SYN ") + if (isRst) append("RST ") + if (isPsh) append("PSH ") + if (isAck) append("ACK ") + if (isUrg) append("URG ") append("seq $sequenceNumber ") append("ack $acknowledgementNumber ") } @@ -125,12 +125,12 @@ data class TcpHeader( append(", window=").append(window) append(", checksum=").append(checksum) append(", flags=") - if (isFIN) append(" FIN") - if (isSYN) append(" SYN") - if (isRST) append(" RST") - if (isPSH) append(" PSH") - if (isACK) append(" ACK") - if (isURG) append(" URG") + if (isFin) append(" FIN") + if (isSyn) append(" SYN") + if (isRst) append(" RST") + if (isPsh) append(" PSH") + if (isAck) append(" ACK") + if (isUrg) append(" URG") append('}') } diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt index 4fe61c4a64..188c229d62 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/protocol/Packet.kt @@ -107,7 +107,7 @@ internal data class Packet( this.dataOffsetAndReserved = dataOffset backingBuffer?.put(IP4_HEADER_SIZE + 12, dataOffset) - updateTCPChecksum(payloadSize) + updateTcpChecksum(payloadSize) val ip4TotalLength = IP4_HEADER_SIZE + TCP_HEADER_SIZE + payloadSize backingBuffer?.putShort(2, ip4TotalLength.toShort()) @@ -161,7 +161,7 @@ internal data class Packet( backingBuffer?.putShort(10, sum.toShort()) } - private fun updateTCPChecksum(payloadSize: Int) { + private fun updateTcpChecksum(payloadSize: Int) { var sum = 0 var tcpLength = TCP_HEADER_SIZE + payloadSize diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt index 73570799dd..2c317de5b5 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/TcpWorker.kt @@ -21,7 +21,7 @@ import android.os.Build import android.util.Base64 import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO import com.merxury.blocker.core.dispatchers.Dispatcher -import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue +import com.merxury.blocker.core.vpn.deviceToNetworkTcpQueue import com.merxury.blocker.core.vpn.model.Ip4Header import com.merxury.blocker.core.vpn.model.TcpHeader import com.merxury.blocker.core.vpn.model.TransportProtocol @@ -48,6 +48,7 @@ import kotlin.experimental.and import kotlin.experimental.or private const val TCP_HEADER_SIZE = Packet.IP4_HEADER_SIZE + Packet.TCP_HEADER_SIZE + class TcpWorker @Inject constructor( @Dispatcher(IO) private val dispatcher: CoroutineDispatcher, ) { @@ -84,7 +85,7 @@ class TcpWorker @Inject constructor( private suspend fun handleReadFromVpn() = withContext(dispatcher) { while (isActive) { val vpnService = this@TcpWorker.vpnService ?: return@withContext - val packet = deviceToNetworkTCPQueue.poll() ?: return@withContext + val packet = deviceToNetworkTcpQueue.poll() ?: return@withContext val destinationAddress = packet.ip4Header?.destinationAddress val tcpHeader = packet.tcpHeader val destinationPort = tcpHeader?.destinationPort @@ -150,19 +151,19 @@ class TcpWorker @Inject constructor( private fun handlePacket(packet: Packet, tcpPipe: TcpPipe) { val tcpHeader = packet.tcpHeader ?: return when { - tcpHeader.isSYN -> { + tcpHeader.isSyn -> { handleSyn(packet, tcpPipe) } - tcpHeader.isRST -> { + tcpHeader.isRst -> { handleRst(tcpPipe) } - tcpHeader.isFIN -> { + tcpHeader.isFin -> { handleFin(packet, tcpPipe) } - tcpHeader.isACK -> { + tcpHeader.isAck -> { handleAck(packet, tcpPipe) } } diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt index 21b0acb34e..6dc5d95dd5 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/ToNetworkQueueWorker.kt @@ -18,8 +18,8 @@ package com.merxury.blocker.core.vpn.worker import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO import com.merxury.blocker.core.dispatchers.Dispatcher -import com.merxury.blocker.core.vpn.deviceToNetworkTCPQueue -import com.merxury.blocker.core.vpn.deviceToNetworkUDPQueue +import com.merxury.blocker.core.vpn.deviceToNetworkTcpQueue +import com.merxury.blocker.core.vpn.deviceToNetworkUdpQueue import com.merxury.blocker.core.vpn.protocol.Packet import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope @@ -76,9 +76,9 @@ class ToNetworkQueueWorker @Inject constructor( val packet = Packet(byteBuffer) if (packet.isUdp) { - deviceToNetworkUDPQueue.offer(packet) + deviceToNetworkUdpQueue.offer(packet) } else if (packet.isTcp) { - deviceToNetworkTCPQueue.offer(packet) + deviceToNetworkTcpQueue.offer(packet) } else { Timber.d("Unknown packet protocol type ${packet.ip4Header?.protocolNum}") } diff --git a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt index 43c0a6b77e..cdff46549a 100644 --- a/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt +++ b/core/vpn/src/main/kotlin/com/merxury/blocker/core/vpn/worker/UdpSendWorker.kt @@ -19,7 +19,7 @@ package com.merxury.blocker.core.vpn.worker import android.net.VpnService import com.merxury.blocker.core.dispatchers.BlockerDispatchers.IO import com.merxury.blocker.core.dispatchers.Dispatcher -import com.merxury.blocker.core.vpn.deviceToNetworkUDPQueue +import com.merxury.blocker.core.vpn.deviceToNetworkUdpQueue import com.merxury.blocker.core.vpn.model.ManagedDatagramChannel import com.merxury.blocker.core.vpn.model.UdpTunnel import com.merxury.blocker.core.vpn.udpNioSelector @@ -60,7 +60,7 @@ class UdpSendWorker @Inject constructor( private suspend fun runWorker() = withContext(dispatcher) { while (scope.isActive) { - val packet = deviceToNetworkUDPQueue.take() + val packet = deviceToNetworkUdpQueue.take() val destinationAddress = packet.ip4Header?.destinationAddress val udpHeader = packet.udpHeader