From 473af79f38929b97f1b2c027f1922e8f1440bed2 Mon Sep 17 00:00:00 2001 From: Yaroslav Maslennikov Date: Tue, 3 Sep 2024 15:59:31 +0200 Subject: [PATCH] Allow channel handler to control adjust_window message sending The channel handler callback module can implement the get_adjust/0 function returning either 'immediate' or 'delayed' values. In the latter case the channel handler module is responsible for invoking ssh_connection:adjust_window/3 to send ssh_msg_adjust_window to the peer. --- lib/ssh/src/ssh_client_channel.erl | 36 +++++++++++++++++++++----- lib/ssh/src/ssh_connection_handler.erl | 15 +++++++++++ 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/lib/ssh/src/ssh_client_channel.erl b/lib/ssh/src/ssh_client_channel.erl index 01f265c1103a..5fd48e6f9ff5 100644 --- a/lib/ssh/src/ssh_client_channel.erl +++ b/lib/ssh/src/ssh_client_channel.erl @@ -199,6 +199,7 @@ The following message is taken care of by the `ssh_client_channel` behavior. channel_cb, channel_state, channel_id, + channel_adjust_fun, % :: fun/2 close_sent = false }). @@ -351,6 +352,24 @@ The user is responsible for any initialization of the process and must call enter_loop(State) -> gen_server:enter_loop(?MODULE, [], State). +check_adjust_fun(Cb, ChState) -> + case catch Cb:get_adjust(ChState) of + Val when Val == immediate orelse Val == delayed -> + %% The existence of the get_adjust should not change in runtime + %% So it should be safe to use it here + fun(Msg, ChannelState) -> + Adjust = Cb:get_adjust(ChannelState), + if Adjust == immediate -> + adjust_window(Msg); + true -> % delayed + ok + end + end; + _ -> + %% If the channel handler is not aware that it can manage adjustments + %% then OTP SSH function is used + fun(Msg, _) -> adjust_window(Msg) end + end. %%==================================================================== %% gen_server callbacks %%==================================================================== @@ -400,17 +419,21 @@ init([Options]) -> process_flag(trap_exit, true), try Cb:init(channel_cb_init_args(Options)) of {ok, ChannelState} -> + ChannelAdjustFun = check_adjust_fun(Cb, ChannelState), State = #state{cm = ConnectionManager, channel_cb = Cb, channel_id = ChannelId, - channel_state = ChannelState}, + channel_state = ChannelState, + channel_adjust_fun = ChannelAdjustFun}, self() ! {ssh_channel_up, ChannelId, ConnectionManager}, {ok, State}; {ok, ChannelState, Timeout} -> + ChannelAdjustFun = check_adjust_fun(Cb, ChannelState), State = #state{cm = ConnectionManager, channel_cb = Cb, channel_id = ChannelId, - channel_state = ChannelState}, + channel_state = ChannelState, + channel_adjust_fun = ChannelAdjustFun}, self() ! {ssh_channel_up, ChannelId, ConnectionManager}, {ok, State, Timeout}; {stop, Why} -> @@ -498,14 +521,15 @@ handle_info({ssh_cm, ConnectionManager, {closed, ChannelId}}, (catch ssh_connection:close(ConnectionManager, ChannelId)), {stop, normal, State#state{close_sent = true}}; -handle_info({ssh_cm, _, _} = Msg, #state{channel_cb = Module, - channel_state = ChannelState0} = State) -> +handle_info({ssh_cm, _, _} = Msg, #state{channel_cb = Module, + channel_adjust_fun = AdjustFun, + channel_state = ChannelState0} = State) -> try Module:handle_ssh_msg(Msg, ChannelState0) of {ok, ChannelState} -> - adjust_window(Msg), + AdjustFun(Msg, ChannelState), {noreply, State#state{channel_state = ChannelState}}; {ok, ChannelState, Timeout} -> - adjust_window(Msg), + AdjustFun(Msg, ChannelState), {noreply, State#state{channel_state = ChannelState}, Timeout}; {stop, ChannelId, ChannelState} -> do_the_close(Msg, ChannelId, State#state{channel_state = ChannelState}) diff --git a/lib/ssh/src/ssh_connection_handler.erl b/lib/ssh/src/ssh_connection_handler.erl index 12db85d015fb..7cd9df270cd0 100644 --- a/lib/ssh/src/ssh_connection_handler.erl +++ b/lib/ssh/src/ssh_connection_handler.erl @@ -95,6 +95,10 @@ -define(call_disconnectfun_and_log_cond(LogMsg, DetailedText, StateName, D), call_disconnectfun_and_log_cond(LogMsg, DetailedText, ?MODULE, ?LINE, StateName, D)). +%% Minimum number of bytes reported by the "upper layer" that cause +%% #ssh_msg_channel_adjust_window to be sent to the SSH peer +-define(MIN_ADJUST, 64). + %%==================================================================== %% Start / stop %%==================================================================== @@ -834,6 +838,17 @@ handle_event(cast, {adjust_window,ChannelId,Bytes}, StateName, D) when ?CONNECTE Channel#channel{recv_window_pending = Pending + Bytes}), keep_state_and_data; + #channel{recv_window_size = WinSize, + recv_window_pending = Pending, + recv_packet_size = _PktSize} = Channel + when ((Bytes + Pending) < ?MIN_ADJUST andalso (WinSize > 0)) -> + %% It does not make sense to send updates of e.g. 1 byte + %% if we are still able to receive something + ssh_client_channel:cache_update(cache(D), + Channel#channel{recv_window_pending = + Pending + Bytes}), + keep_state_and_data; + #channel{recv_window_size = WinSize, recv_window_pending = Pending, remote_id = Id} = Channel ->