Skip to content

Commit

Permalink
feat: application layer heartbeat (rapiz1#136)
Browse files Browse the repository at this point in the history
* feat: application layer heartbeat

* feat: make heartbeat configurable

* fix: update keepalive params

* docs: update about heartbeat
  • Loading branch information
rapiz1 authored Mar 8, 2022
1 parent 1ef7747 commit 2746a0e
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 41 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,16 @@ Here is the full configuration specification:
[client]
remote_addr = "example.com:2333" # Necessary. The address of the server
default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones
heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer heartbeat test. The value must be greater than `server.heartbeat_interval`. Default: 40 secs

[client.transport] # The whole block is optional. Specify which transport to use
type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp"

[client.transport.tcp] # Optional
proxy = "socks5://user:[email protected]:1080" # Optional. Use the proxy to connect to the server
nodelay = false # Optional. Determine whether to enable TCP_NODELAY, if applicable, to improve the latency but decrease the bandwidth. Default: false
keepalive_secs = 10 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 10 seconds
keepalive_interval = 5 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 5 seconds
keepalive_secs = 20 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 20 seconds
keepalive_interval = 8 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 8 seconds

[client.transport.tls] # Necessary if `type` is "tls"
trusted_root = "ca.pem" # Necessary. The certificate of CA that signed the server's certificate
Expand All @@ -136,12 +137,13 @@ local_addr = "127.0.0.1:1082"
[server]
bind_addr = "0.0.0.0:2333" # Necessary. The address that the server listens for clients. Generally only the port needs to be change.
default_token = "default_token_if_not_specify" # Optional
heartbeat_interval = 30 # Optional. The interval between two application-layer heartbeat. Set to 0 to disable sending heartbeat. Default: 30 secs

[server.transport] # Same as `[client.transport]`
type = "tcp"
nodelay = false
keepalive_secs = 10
keepalive_interval = 5
keepalive_secs = 20
keepalive_interval = 8

[server.transport.tls] # Necessary if `type` is "tls"
pkcs12 = "identify.pfx" # Necessary. pkcs12 file of server's certificate and private key
Expand Down
30 changes: 20 additions & 10 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE

// The entrypoint of running a client
pub async fn run_client(
config: &Config,
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = config.client.as_ref().ok_or(anyhow!(
let config = config.client.ok_or(anyhow!(
"Try to run as a client, but the configuration is missing. Please add the `[client]` block"
))?;

Expand Down Expand Up @@ -67,21 +67,21 @@ type ServiceDigest = protocol::Digest;
type Nonce = protocol::Digest;

// Holds the state of a client
struct Client<'a, T: Transport> {
config: &'a ClientConfig,
struct Client<T: Transport> {
config: ClientConfig,
service_handles: HashMap<String, ControlChannelHandle>,
transport: Arc<T>,
}

impl<'a, T: 'static + Transport> Client<'a, T> {
impl<T: 'static + Transport> Client<T> {
// Create a Client from `[client]` config block
async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
async fn from(config: ClientConfig) -> Result<Client<T>> {
let transport =
Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?);
Ok(Client {
config,
service_handles: HashMap::new(),
transport: Arc::new(
T::new(&config.transport).with_context(|| "Failed to create the transport")?,
),
transport,
})
}

Expand All @@ -97,6 +97,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
(*config).clone(),
self.config.remote_addr.clone(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
self.service_handles.insert(name.clone(), handle);
}
Expand All @@ -122,6 +123,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
s,
self.config.remote_addr.clone(),
self.transport.clone(),
self.config.heartbeat_timeout
);
let _ = self.service_handles.insert(name, handle);
},
Expand Down Expand Up @@ -369,6 +371,7 @@ struct ControlChannel<T: Transport> {
shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
remote_addr: String, // `client.remote_addr`
transport: Arc<T>, // Wrapper around the transport layer
heartbeat_timeout: u64, // Application layer heartbeat timeout in secs
}

// Handle of a control channel
Expand Down Expand Up @@ -451,9 +454,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
warn!("{:#}", e);
}
}.instrument(Span::current()));
}
},
ControlChannelCmd::HeartBeat => ()
}
},
_ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => {
warn!("Heartbeat timed out");
break;
}
_ = &mut self.shutdown_rx => {
break;
}
Expand All @@ -471,6 +479,7 @@ impl ControlChannelHandle {
service: ClientServiceConfig,
remote_addr: String,
transport: Arc<T>,
heartbeat_timeout: u64,
) -> ControlChannelHandle {
let digest = protocol::digest(service.name.as_bytes());

Expand All @@ -482,6 +491,7 @@ impl ControlChannelHandle {
shutdown_rx,
remote_addr,
transport,
heartbeat_timeout,
};

tokio::spawn(
Expand Down
16 changes: 16 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use url::Url;

use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY};

/// Application-layer heartbeat interval in secs
const DEFAULT_HEARTBEAT_INTERVAL_SECS: u64 = 30;
const DEFAULT_HEARTBEAT_TIMEOUT_SECS: u64 = 40;

/// String with Debug implementation that emits "MASKED"
/// Used to mask sensitive strings when logging
#[derive(Serialize, Deserialize, Default, PartialEq, Clone)]
Expand Down Expand Up @@ -177,6 +181,10 @@ pub struct TransportConfig {
pub noise: Option<NoiseConfig>,
}

fn default_heartbeat_timeout() -> u64 {
DEFAULT_HEARTBEAT_TIMEOUT_SECS
}

#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
#[serde(deny_unknown_fields)]
pub struct ClientConfig {
Expand All @@ -185,6 +193,12 @@ pub struct ClientConfig {
pub services: HashMap<String, ClientServiceConfig>,
#[serde(default)]
pub transport: TransportConfig,
#[serde(default = "default_heartbeat_timeout")]
pub heartbeat_timeout: u64,
}

fn default_heartbeat_interval() -> u64 {
DEFAULT_HEARTBEAT_INTERVAL_SECS
}

#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
Expand All @@ -195,6 +209,8 @@ pub struct ServerConfig {
pub services: HashMap<String, ServerServiceConfig>,
#[serde(default)]
pub transport: TransportConfig,
#[serde(default = "default_heartbeat_interval")]
pub heartbeat_interval: u64,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()

last_instance = Some((
tokio::spawn(run_instance(
*(config.clone()),
*config,
args.clone(),
shutdown_tx.subscribe(),
service_update_rx,
Expand Down Expand Up @@ -127,13 +127,13 @@ async fn run_instance(
#[cfg(not(feature = "client"))]
crate::helper::feature_not_compile("client");
#[cfg(feature = "client")]
run_client(&config, shutdown_rx, service_update).await
run_client(config, shutdown_rx, service_update).await
}
RunMode::Server => {
#[cfg(not(feature = "server"))]
crate::helper::feature_not_compile("server");
#[cfg(feature = "server")]
run_server(&config, shutdown_rx, service_update).await
run_server(config, shutdown_rx, service_update).await
}
};
ret.unwrap();
Expand Down
6 changes: 4 additions & 2 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::trace;

type ProtocolVersion = u8;
const PROTO_V0: u8 = 0u8;
const _PROTO_V0: u8 = 0u8;
const PROTO_V1: u8 = 1u8;

pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V0;
pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V1;

pub type Digest = [u8; HASH_WIDTH_IN_BYTES];

Expand Down Expand Up @@ -48,6 +49,7 @@ impl std::fmt::Display for Ack {
#[derive(Deserialize, Serialize, Debug)]
pub enum ControlChannelCmd {
CreateDataChannel,
HeartBeat,
}

#[derive(Deserialize, Serialize, Debug)]
Expand Down
73 changes: 54 additions & 19 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake

// The entrypoint of running a server
pub async fn run_server(
config: &Config,
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = match &config.server {
let config = match config.server {
Some(config) => config,
None => {
return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
Expand Down Expand Up @@ -82,9 +82,9 @@ pub async fn run_server(
type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;

// Server holds all states of running a server
struct Server<'a, T: Transport> {
struct Server<T: Transport> {
// `[server]` config
config: &'a ServerConfig,
config: Arc<ServerConfig>,

// `[server.services]` config, indexed by ServiceDigest
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
Expand All @@ -105,14 +105,18 @@ fn generate_service_hashmap(
ret
}

impl<'a, T: 'static + Transport> Server<'a, T> {
impl<T: 'static + Transport> Server<T> {
// Create a server from `[server]`
pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
pub async fn from(config: ServerConfig) -> Result<Server<T>> {
let config = Arc::new(config);
let services = Arc::new(RwLock::new(generate_service_hashmap(&config)));
let control_channels = Arc::new(RwLock::new(ControlChannelMap::new()));
let transport = Arc::new(T::new(&config.transport)?);
Ok(Server {
config,
services: Arc::new(RwLock::new(generate_service_hashmap(config))),
control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
transport: Arc::new(T::new(&config.transport)?),
services,
control_channels,
transport,
})
}

Expand Down Expand Up @@ -171,8 +175,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
Ok(conn) => {
let services = self.services.clone();
let control_channels = self.control_channels.clone();
let server_config = self.config.clone();
tokio::spawn(async move {
if let Err(err) = handle_connection(conn, services, control_channels).await {
if let Err(err) = handle_connection(conn, services, control_channels, server_config).await {
error!("{:#}", err);
}
}.instrument(info_span!("connection", %addr)));
Expand Down Expand Up @@ -233,12 +238,20 @@ async fn handle_connection<T: 'static + Transport>(
mut conn: T::Stream,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
server_config: Arc<ServerConfig>,
) -> Result<()> {
// Read hello
let hello = read_hello(&mut conn).await?;
match hello {
ControlChannelHello(_, service_digest) => {
do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
do_control_channel_handshake(
conn,
services,
control_channels,
service_digest,
server_config,
)
.await?;
}
DataChannelHello(_, nonce) => {
do_data_channel_handshake(conn, control_channels, nonce).await?;
Expand All @@ -252,6 +265,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
service_digest: ServiceDigest,
server_config: Arc<ServerConfig>,
) -> Result<()> {
info!("Try to handshake a control channel");

Expand Down Expand Up @@ -321,7 +335,8 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
conn.flush().await?;

info!(service = %service_config.name, "Control channel established");
let handle = ControlChannelHandle::new(conn, service_config);
let handle =
ControlChannelHandle::new(conn, service_config, server_config.heartbeat_interval);

// Insert the new handle
let _ = h.insert(service_digest, session_key, handle);
Expand Down Expand Up @@ -371,7 +386,11 @@ where
// Create a control channel handle, where the control channel handling task
// and the connection pool task are created.
#[instrument(name = "handle", skip_all, fields(service = %service.name))]
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
fn new(
conn: T::Stream,
service: ServerServiceConfig,
heartbeat_interval: u64,
) -> ControlChannelHandle<T> {
// Create a shutdown channel
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);

Expand Down Expand Up @@ -435,6 +454,7 @@ where
conn,
shutdown_rx,
data_ch_req_rx,
heartbeat_interval,
};

// Run the control channel
Expand All @@ -460,25 +480,34 @@ struct ControlChannel<T: Transport> {
conn: T::Stream, // The connection of control channel
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
heartbeat_interval: u64, // Application-layer heartbeat interval in secs
}

impl<T: Transport> ControlChannel<T> {
async fn write_and_flush(&mut self, data: &[u8]) -> Result<()> {
self.conn
.write_all(data)
.await
.with_context(|| "Failed to write control cmds")?;
self.conn
.flush()
.await
.with_context(|| "Failed to flush control cmds")?;
Ok(())
}
// Run a control channel
#[instrument(skip_all)]
async fn run(mut self) -> Result<()> {
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
let create_ch_cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
let heartbeat = bincode::serialize(&ControlChannelCmd::HeartBeat).unwrap();

// Wait for data channel requests and the shutdown signal
loop {
tokio::select! {
val = self.data_ch_req_rx.recv() => {
match val {
Some(_) => {
if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write control cmds") {
error!("{:#}", e);
break;
}
if let Err(e) = self.conn.flush().await.with_context(|| "Failed to flush control cmds") {
if let Err(e) = self.write_and_flush(&create_ch_cmd).await {
error!("{:#}", e);
break;
}
Expand All @@ -488,6 +517,12 @@ impl<T: Transport> ControlChannel<T> {
}
}
},
_ = time::sleep(Duration::from_secs(self.heartbeat_interval)), if self.heartbeat_interval != 0 => {
if let Err(e) = self.write_and_flush(&heartbeat).await {
error!("{:#}", e);
break;
}
}
// Wait for the shutdown signal
_ = self.shutdown_rx.recv() => {
break;
Expand Down
Loading

0 comments on commit 2746a0e

Please sign in to comment.