From 46b89fff2ee72554c64ead687f02f2228d5e923a Mon Sep 17 00:00:00 2001 From: Joe Testa Date: Sun, 21 Apr 2024 17:05:57 -0400 Subject: [PATCH] Sockets now time out after 30 seconds during connection rate testing. --- src/ssh_audit/dheat.py | 53 ++++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/src/ssh_audit/dheat.py b/src/ssh_audit/dheat.py index 75727df6..97be799b 100644 --- a/src/ssh_audit/dheat.py +++ b/src/ssh_audit/dheat.py @@ -309,14 +309,14 @@ def dh_rate_test(out: 'OutputBuffer', aconf: 'AuditConf', kex: 'SSH2_Kex', max_t def _dh_rate_test(out: 'OutputBuffer', aconf: 'AuditConf', kex: 'SSH2_Kex', max_time: float, max_connections: int, concurrent_sockets: int) -> str: '''Attempts to quickly create many sockets to the target server. This simulates the DHEat attack without causing an actual DoS condition. If a rate greater than MAX_SAFE_RATE is allowed, then a warning string is returned.''' - def _close_socket(socket_list: List[socket.socket], s: socket.socket) -> None: + def _close_socket(socket_dict: Dict[socket.socket, float], s: socket.socket) -> None: try: s.shutdown(socket.SHUT_RDWR) s.close() except OSError: pass - socket_list.remove(s) + del socket_dict[s] if sys.platform == "win32": DHEat.YELLOWB = "\033[1;93m" @@ -361,20 +361,21 @@ def _close_socket(socket_list: List[socket.socket], s: socket.socket) -> None: num_attempted_connections = 0 num_opened_connections = 0 - socket_list: List[socket.socket] = [] + socket_dict: Dict[socket.socket, float] = {} start_timer = time.time() + now = start_timer last_update = start_timer while True: + now = time.time() # During non-interactive tests, limit based on time and number of connections. Otherwise, we loop indefinitely until the user presses CTRL-C. - if (interactive is False) and ((time.time() - start_timer) >= max_time) or (num_opened_connections >= max_connections): + if (interactive is False) and ((now - start_timer) >= max_time) or (num_opened_connections >= max_connections): break # out.d("interactive: %r; time.time() - start_timer: %f; max_time: %f; num_opened_connections: %u; max_connections: %u" % (interactive, time.time() - start_timer, max_time, num_opened_connections, max_connections), write_now=True) # Give the user some interactive feedback. if interactive: - now = time.time() if (now - last_update) >= 1.0: seconds_running = now - start_timer print("%s%s%s Run time: %s%.1f%s; TCP SYNs: %s%u%s; Compl. conns: %s%u%s; TCP SYNs/sec: %s%.1f%s; Compl. conns/sec: %s%.1f%s \r" % (DHEat.WHITEB, spinner[spinner_index], DHEat.CLEAR, DHEat.WHITEB, seconds_running, DHEat.CLEAR, DHEat.WHITEB, num_attempted_connections, DHEat.CLEAR, DHEat.WHITEB, num_opened_connections, DHEat.CLEAR, DHEat.BLUEB, num_attempted_connections / seconds_running, DHEat.CLEAR, DHEat.BLUEB, num_opened_connections / seconds_running, DHEat.CLEAR), end="") @@ -390,17 +391,36 @@ def _close_socket(socket_list: List[socket.socket], s: socket.socket) -> None: if sleep_time > 0.0: time.sleep(sleep_time) - while (len(socket_list) < concurrent_sockets) and (len(socket_list) + num_opened_connections < max_connections): + # Check our sockets to see if they've existed for more than 30 seconds. If so, close them so new ones can be re-opened in their place. + timedout_sockets = [] + for s, create_time in socket_dict.items(): + if (now - create_time) > 30: + timedout_sockets.append(s) # We shouldn't modify the dictionary while iterating over it, so add it to a separate list. + + # Now we can safely close the timed-out sockets. + while True: + if len(timedout_sockets) == 0: # Ensure that len() is called in every iteration by putting it here instead of the while clause. + break + + out.d("Closing timed-out socket.", write_now=True) + _close_socket(socket_dict, timedout_sockets[0]) + del timedout_sockets[0] + + # Open new sockets until we've hit the number of concurrent sockets, or if we exceeded the number of maximum connections. + while (len(socket_dict) < concurrent_sockets) and (len(socket_dict) + num_opened_connections < max_connections): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setblocking(False) - # out.d("Creating socket (%u of %u already exist)..." % (len(socket_list), concurrent_sockets), write_now=True) + # out.d("Creating socket (%u of %u already exist)..." % (len(socket_dict), concurrent_sockets), write_now=True) ret = s.connect_ex((aconf.host, aconf.port)) num_attempted_connections += 1 if ret in [0, 115]: # Check if connection is successful or EINPROGRESS. - socket_list.append(s) + socket_dict[s] = now + else: + out.d("connect_ex() returned: %d" % ret, write_now=True) - # out.d("Calling select() on %u sockets..." % len(socket_list), write_now=True) + # out.d("Calling select() on %u sockets..." % len(socket_dict), write_now=True) + socket_list: List[socket.socket] = [*socket_dict] # Get a list of sockets from the dictionary. rlist, _, elist = select.select(socket_list, [], socket_list, 0.1) # For each socket that has something for us to read... @@ -410,9 +430,9 @@ def _close_socket(socket_list: List[socket.socket], s: socket.socket) -> None: try: buf = s.recv(8) # out.d("Banner: %r" % buf, write_now=True) - except (ConnectionResetError, BrokenPipeError): + except (ConnectionRefusedError, ConnectionResetError, BrokenPipeError, TimeoutError): out.d("Socket error.", write_now=True) - _close_socket(socket_list, s) + _close_socket(socket_dict, s) continue # If we received the SSH header, we'll count this as an opened connection. @@ -420,7 +440,7 @@ def _close_socket(socket_list: List[socket.socket], s: socket.socket) -> None: num_opened_connections += 1 out.d("Number of opened connections: %u (max: %u)." % (num_opened_connections, max_connections)) - _close_socket(socket_list, s) + _close_socket(socket_dict, s) # Since we just closed the socket, ensure its not in the exception list. if s in elist: @@ -429,11 +449,14 @@ def _close_socket(socket_list: List[socket.socket], s: socket.socket) -> None: # Close all sockets that are in the exception state. for s in elist: # out.d("Socket in exception list.", write_now=True) - _close_socket(socket_list, s) + _close_socket(socket_dict, s) # Close any remaining sockets. - while len(socket_list) > 0: - _close_socket(socket_list, socket_list[0]) + while True: + if len(socket_dict) == 0: # Ensure that len() is called in every iteration by putting it here instead of the while clause. + break + + _close_socket(socket_dict, [*socket_dict][0]) # Close & remove the first socket we find. time_elapsed = time.time() - start_timer out.d("DHEat.dh_rate_test() results: time elapsed: %f; connections created: %u" % (time_elapsed, num_opened_connections), write_now=True)