From b5dffff0d5830037435ee62412c7659040a8f60c Mon Sep 17 00:00:00 2001 From: "jarryd.took" <7352293+troxil@users.noreply.github.com> Date: Thu, 30 May 2024 14:34:06 +1000 Subject: [PATCH 1/5] feat: add port forwarding - AWS-StartPortForwardingSessionToRemoteHost --- aws_gate/cli.py | 30 +++++++ aws_gate/port_forward.py | 108 +++++++++++++++++++++++++ tests/unit/test_port_forward.py | 137 ++++++++++++++++++++++++++++++++ 3 files changed, 275 insertions(+) create mode 100644 aws_gate/port_forward.py create mode 100644 tests/unit/test_port_forward.py diff --git a/aws_gate/cli.py b/aws_gate/cli.py index ac37c2fd..880902e0 100644 --- a/aws_gate/cli.py +++ b/aws_gate/cli.py @@ -28,6 +28,7 @@ from aws_gate.ssh import ssh from aws_gate.ssh_config import ssh_config from aws_gate.ssh_proxy import ssh_proxy +from aws_gate.port_forward import port_forward from aws_gate.utils import get_default_region logger = logging.getLogger(__name__) @@ -88,6 +89,25 @@ def get_argument_parser(*args, **kwargs): "instance_name", help="Instance we wish to open session to" ) + # 'port-forward' subcommand + port_forward_parser = subparsers.add_parser( + "port-forward", help="Open new session on instance and forward to target host" + ) + port_forward_parser.add_argument("-p", "--profile", help="AWS profile to use") + port_forward_parser.add_argument("-r", "--region", help="AWS region to use") + port_forward_parser.add_argument( + "instance_name", help="Instance we wish to open session to" + ) + port_forward_parser.add_argument( + "target_host", help="Host to forward into" + ) + port_forward_parser.add_argument( + "target_port", help="Port to forward to", type=int, + ) + port_forward_parser.add_argument( + "--local_port", help="Local port to forward to", type=int, default=7000 + ) + # 'ssh' subcommand ssh_parser = subparsers.add_parser( "ssh", help="Open SSH session on instance and connect to it" @@ -284,6 +304,16 @@ def main(args=None, argument_parser=None): region_name=region, profile_name=profile, ) + elif args.subcommand == "port-forward": + port_forward( + config=config, + instance_name=args.instance_name, + target_host=args.target_host, + region_name=region, + profile_name=profile, + target_port=args.target_port, + local_port=args.local_port, + ) elif args.subcommand == "ssh": ssh( config=config, diff --git a/aws_gate/port_forward.py b/aws_gate/port_forward.py new file mode 100644 index 00000000..25d3e041 --- /dev/null +++ b/aws_gate/port_forward.py @@ -0,0 +1,108 @@ +import logging + +from aws_gate.constants import AWS_DEFAULT_PROFILE, AWS_DEFAULT_REGION +from aws_gate.decorators import ( + plugin_version, + plugin_required, + valid_aws_profile, + valid_aws_region, +) +from aws_gate.query import query_instance +from aws_gate.session_common import BaseSession +from aws_gate.utils import ( + get_aws_client, + get_aws_resource, + fetch_instance_details_from_config, +) + +logger = logging.getLogger(__name__) + + +class SSMPortForwardSession(BaseSession): + """ + SSM Port Forward Session to Remote via instance. + + Refer to SSM Document: AWS-StartPortForwardingSessionToRemoteHost + + :param instance_id: The instance ID to connect to + :param target_host: The target host to forward to + :param region_name: The region name + :param profile_name: The profile name + :param target_port: The target port + :param local_port: The local port + :param ssm: The SSM client + """ + + def __init__( + self, + instance_id, + target_host: str, + target_port: int, + region_name=AWS_DEFAULT_REGION, + profile_name=AWS_DEFAULT_PROFILE, + local_port: int = 7000, + ssm=None, + ): + self._instance_id = instance_id + self._region_name = region_name + self._profile_name = profile_name if profile_name is not None else "" + self._ssm = ssm + self._target_host = target_host + self._target_port = target_port + self._local_port = local_port + + start_session_kwargs = dict( + Target=self._instance_id, + DocumentName="AWS-StartPortForwardingSessionToRemoteHost", + Parameters={ + "portNumber": [str(self._target_port)], + "localPortNumber": [str(self._local_port)], + "host": [self._target_host], + }, + ) + + self._session_parameters = start_session_kwargs + + +@plugin_required +@plugin_version("1.1.23.0") +@valid_aws_profile +@valid_aws_region +def port_forward( + config, + instance_name, + target_host, + target_port, + local_port=7000, + profile_name=AWS_DEFAULT_PROFILE, + region_name=AWS_DEFAULT_REGION, +): + instance, profile, region = fetch_instance_details_from_config( + config, instance_name, profile_name, region_name + ) + + ssm = get_aws_client("ssm", region_name=region, profile_name=profile) + ec2 = get_aws_resource("ec2", region_name=region, profile_name=profile) + + instance_id = query_instance(name=instance, ec2=ec2) + if instance_id is None: + raise ValueError(f"No instance could be found for name: {instance}") + + logger.info( + "Opening SSM Port Forwarding Session to %s:%s via instance %s (%s) via profile %s to %s:%s", + target_host, + target_port, + instance_id, + region, + profile, + ) + with SSMPortForwardSession( + instance_id, + region_name=region, + profile_name=profile, + ssm=ssm, + target_host=target_host, + target_port=target_port, + local_port=local_port, + ) as sess: + sess.open() diff --git a/tests/unit/test_port_forward.py b/tests/unit/test_port_forward.py new file mode 100644 index 00000000..50dcad48 --- /dev/null +++ b/tests/unit/test_port_forward.py @@ -0,0 +1,137 @@ +import pytest + +from aws_gate.port_forward import port_forward, SSMPortForwardSession + + +def test_create_ssm_forward_session(ssm_mock, instance_id): + sess = SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234) + sess.create() + + assert ssm_mock.start_session.called + + +def test_terminate_ssm_forward_session(ssm_mock, instance_id): + sess = SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234) + + sess.create() + sess.terminate() + + assert ssm_mock.terminate_session.called + + +def test_open_ssm_forward_session(mocker, instance_id, ssm_mock): + m = mocker.patch("aws_gate.session_common.execute_plugin", return_value="output") + + sess = SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234) + sess.open() + + assert m.called + + +def test_ssm_forward_session_context_manager(ssm_mock, instance_id): + with SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234): + pass + + assert ssm_mock.start_session.called + assert ssm_mock.terminate_session.called + + +def test_port_forward(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id) + port_forward_mock = mocker.patch( + "aws_gate.port_forward.SSMPortForwardSession", return_value=mocker.MagicMock() + ) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + + port_forward( + config=config, + instance_name="instance_name", + target_host="target_host", + target_port=22, + profile_name="default", + region_name="eu-west-1", + ) + + assert port_forward_mock.called + + +def test_port_forward_exception_invalid_profile(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + + with pytest.raises(ValueError): + port_forward( + config=config, + instance_name="instance_name", + target_host="target_host", + target_port=22, + profile_name="invalid-default", + region_name="eu-west-1", + ) + + +def test_port_forward_exception_invalid_region(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id) + mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + mocker.patch( + "aws_gate.port_forward.SSMPortForwardSession", return_value=mocker.MagicMock() + ) + with pytest.raises(ValueError): + port_forward( + config=config, + region_name="not-a-region", + instance_name="instance_name", + target_port=22, + profile_name="default", + target_host="target_host", + ) + + +def test_port_forward_exception_unknown_instance_id(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=None) + mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + with pytest.raises(ValueError): + port_forward( + config=config, + region_name="ap-southeast-2", + instance_name=instance_id, + target_port=22, + profile_name="default", + target_host="target_host", + ) + + +def test_port_forward_exception_without_config(mocker, instance_id, empty_config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=None) + mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + with pytest.raises(ValueError): + port_forward( + config=empty_config, + region_name="ap-southeast-2", + instance_name=instance_id, + target_port=22, + profile_name="default", + target_host="target_host", + ) From 5b9699100cf2bddd96e911c41d08536c8d1d1a9a Mon Sep 17 00:00:00 2001 From: "jarryd.took" <7352293+troxil@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:30:06 +1000 Subject: [PATCH 2/5] feat: add local port forwarding: AWS-StartPortForwardingSession --- aws_gate/cli.py | 6 ++--- aws_gate/port_forward.py | 56 +++++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/aws_gate/cli.py b/aws_gate/cli.py index 880902e0..745edfd3 100644 --- a/aws_gate/cli.py +++ b/aws_gate/cli.py @@ -91,7 +91,7 @@ def get_argument_parser(*args, **kwargs): # 'port-forward' subcommand port_forward_parser = subparsers.add_parser( - "port-forward", help="Open new session on instance and forward to target host" + "port-forward", help="Open new session on instance and forward to a port locally or remotely" ) port_forward_parser.add_argument("-p", "--profile", help="AWS profile to use") port_forward_parser.add_argument("-r", "--region", help="AWS region to use") @@ -99,10 +99,10 @@ def get_argument_parser(*args, **kwargs): "instance_name", help="Instance we wish to open session to" ) port_forward_parser.add_argument( - "target_host", help="Host to forward into" + "target_port", help="Port to forward to", type=int ) port_forward_parser.add_argument( - "target_port", help="Port to forward to", type=int, + "--target_host", help="Host to forward into", default=None ) port_forward_parser.add_argument( "--local_port", help="Local port to forward to", type=int, default=7000 diff --git a/aws_gate/port_forward.py b/aws_gate/port_forward.py index 25d3e041..ba6a4cbd 100644 --- a/aws_gate/port_forward.py +++ b/aws_gate/port_forward.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from aws_gate.constants import AWS_DEFAULT_PROFILE, AWS_DEFAULT_REGION from aws_gate.decorators import ( @@ -20,15 +21,17 @@ class SSMPortForwardSession(BaseSession): """ - SSM Port Forward Session to Remote via instance. + SSM Port Forward Session to local or remote via instance - Refer to SSM Document: AWS-StartPortForwardingSessionToRemoteHost + Refer to SSM Documents: + * AWS-StartPortForwardingSession + * AWS-StartPortForwardingSessionToRemoteHost :param instance_id: The instance ID to connect to + :param target_port: The target port to forward to :param target_host: The target host to forward to :param region_name: The region name :param profile_name: The profile name - :param target_port: The target port :param local_port: The local port :param ssm: The SSM client """ @@ -36,8 +39,8 @@ class SSMPortForwardSession(BaseSession): def __init__( self, instance_id, - target_host: str, target_port: int, + target_host: Optional[str] = None, region_name=AWS_DEFAULT_REGION, profile_name=AWS_DEFAULT_PROFILE, local_port: int = 7000, @@ -51,14 +54,22 @@ def __init__( self._target_port = target_port self._local_port = local_port + forward_parameters = { + "portNumber": [str(self._target_port)], + "localPortNumber": [str(self._local_port)], + } + + # remote forward or local forward + if self._target_host is None: + document_name = "AWS-StartPortForwardingSession" + else: + document_name = "AWS-StartPortForwardingSessionToRemoteHost" + forward_parameters.update({"host": [self._target_host]}) + start_session_kwargs = dict( Target=self._instance_id, - DocumentName="AWS-StartPortForwardingSessionToRemoteHost", - Parameters={ - "portNumber": [str(self._target_port)], - "localPortNumber": [str(self._local_port)], - "host": [self._target_host], - }, + DocumentName=document_name, + Parameters=forward_parameters, ) self._session_parameters = start_session_kwargs @@ -88,14 +99,23 @@ def port_forward( if instance_id is None: raise ValueError(f"No instance could be found for name: {instance}") - logger.info( - "Opening SSM Port Forwarding Session to %s:%s via instance %s (%s) via profile %s to %s:%s", - target_host, - target_port, - instance_id, - region, - profile, - ) + if target_host is None: + logger.info( + "Opening SSM Port Forwarding Session listening on %s in instance %s (%s) via profile %s to %s:%s", + target_port, + instance_id, + region, + profile, + ) + else: + logger.info( + "Opening SSM Port Forwarding Session to %s:%s via instance %s (%s) via profile %s to %s:%s", + target_host, + target_port, + instance_id, + region, + profile, + ) with SSMPortForwardSession( instance_id, region_name=region, From a2ddb65e1dcdab49a1d4e0db64237f02bc9b8cdb Mon Sep 17 00:00:00 2001 From: "jarryd.took" <7352293+troxil@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:39:15 +1000 Subject: [PATCH 3/5] lint: make same style for optional in port_forward --- aws_gate/port_forward.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aws_gate/port_forward.py b/aws_gate/port_forward.py index ba6a4cbd..357cf478 100644 --- a/aws_gate/port_forward.py +++ b/aws_gate/port_forward.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from aws_gate.constants import AWS_DEFAULT_PROFILE, AWS_DEFAULT_REGION from aws_gate.decorators import ( @@ -40,7 +39,7 @@ def __init__( self, instance_id, target_port: int, - target_host: Optional[str] = None, + target_host=None, region_name=AWS_DEFAULT_REGION, profile_name=AWS_DEFAULT_PROFILE, local_port: int = 7000, @@ -59,7 +58,7 @@ def __init__( "localPortNumber": [str(self._local_port)], } - # remote forward or local forward + # local forward or remote forward if self._target_host is None: document_name = "AWS-StartPortForwardingSession" else: From e1900217efb6e884c1c30c0ba1f6e49142cbe1d4 Mon Sep 17 00:00:00 2001 From: "jarryd.took" <7352293+troxil@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:31:01 +1000 Subject: [PATCH 4/5] fix: add local vs remote forwarding --- tests/unit/test_port_forward.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_port_forward.py b/tests/unit/test_port_forward.py index 50dcad48..1738dd47 100644 --- a/tests/unit/test_port_forward.py +++ b/tests/unit/test_port_forward.py @@ -4,14 +4,18 @@ def test_create_ssm_forward_session(ssm_mock, instance_id): - sess = SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234) + sess = SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234 + ) sess.create() assert ssm_mock.start_session.called def test_terminate_ssm_forward_session(ssm_mock, instance_id): - sess = SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234) + sess = SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234 + ) sess.create() sess.terminate() @@ -19,17 +23,35 @@ def test_terminate_ssm_forward_session(ssm_mock, instance_id): assert ssm_mock.terminate_session.called -def test_open_ssm_forward_session(mocker, instance_id, ssm_mock): +@pytest.mark.parametrize( + "target_host", + [ + None, + "my-fun-host", + ], + ids=["Target is None, Local Forward", "Target is not None, Remote Forward"], +) +def test_open_ssm_forward_session(mocker, instance_id, ssm_mock, target_host): m = mocker.patch("aws_gate.session_common.execute_plugin", return_value="output") - sess = SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234) + sess = SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host=target_host, target_port=1234 + ) sess.open() + if target_host: + expected_doc_name = "AWS-StartPortForwardingSessionToRemoteHost" + else: + expected_doc_name = "AWS-StartPortForwardingSession" + + assert sess._session_parameters.get("DocumentName") == expected_doc_name assert m.called def test_ssm_forward_session_context_manager(ssm_mock, instance_id): - with SSMPortForwardSession(instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234): + with SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234 + ): pass assert ssm_mock.start_session.called From b747653cf308e39ff58c341b326a155897c7b3ac Mon Sep 17 00:00:00 2001 From: "jarryd.took" <7352293+troxil@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:34:54 +1000 Subject: [PATCH 5/5] lint: remove doc - consistent with rest of repo --- aws_gate/port_forward.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/aws_gate/port_forward.py b/aws_gate/port_forward.py index 357cf478..1135dd29 100644 --- a/aws_gate/port_forward.py +++ b/aws_gate/port_forward.py @@ -19,22 +19,6 @@ class SSMPortForwardSession(BaseSession): - """ - SSM Port Forward Session to local or remote via instance - - Refer to SSM Documents: - * AWS-StartPortForwardingSession - * AWS-StartPortForwardingSessionToRemoteHost - - :param instance_id: The instance ID to connect to - :param target_port: The target port to forward to - :param target_host: The target host to forward to - :param region_name: The region name - :param profile_name: The profile name - :param local_port: The local port - :param ssm: The SSM client - """ - def __init__( self, instance_id,