Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Port Forward Support - without SSH #1769

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions aws_gate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from aws_gate.ssh import ssh
from aws_gate.ssh_config import ssh_config
from aws_gate.ssh_proxy import ssh_proxy
from aws_gate.port_forward import port_forward
from aws_gate.utils import get_default_region

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,6 +89,25 @@ def get_argument_parser(*args, **kwargs):
"instance_name", help="Instance we wish to open session to"
)

# 'port-forward' subcommand
port_forward_parser = subparsers.add_parser(
"port-forward", help="Open new session on instance and forward to a port locally or remotely"
)
port_forward_parser.add_argument("-p", "--profile", help="AWS profile to use")
port_forward_parser.add_argument("-r", "--region", help="AWS region to use")
port_forward_parser.add_argument(
"instance_name", help="Instance we wish to open session to"
)
port_forward_parser.add_argument(
"target_port", help="Port to forward to", type=int
)
port_forward_parser.add_argument(
"--target_host", help="Host to forward into", default=None
)
port_forward_parser.add_argument(
"--local_port", help="Local port to forward to", type=int, default=7000
)

# 'ssh' subcommand
ssh_parser = subparsers.add_parser(
"ssh", help="Open SSH session on instance and connect to it"
Expand Down Expand Up @@ -284,6 +304,16 @@ def main(args=None, argument_parser=None):
region_name=region,
profile_name=profile,
)
elif args.subcommand == "port-forward":
port_forward(
config=config,
instance_name=args.instance_name,
target_host=args.target_host,
region_name=region,
profile_name=profile,
target_port=args.target_port,
local_port=args.local_port,
)
elif args.subcommand == "ssh":
ssh(
config=config,
Expand Down
111 changes: 111 additions & 0 deletions aws_gate/port_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import logging

from aws_gate.constants import AWS_DEFAULT_PROFILE, AWS_DEFAULT_REGION
from aws_gate.decorators import (
plugin_version,
plugin_required,
valid_aws_profile,
valid_aws_region,
)
from aws_gate.query import query_instance
from aws_gate.session_common import BaseSession
from aws_gate.utils import (
get_aws_client,
get_aws_resource,
fetch_instance_details_from_config,
)

logger = logging.getLogger(__name__)


class SSMPortForwardSession(BaseSession):
def __init__(
self,
instance_id,
target_port: int,
target_host=None,
region_name=AWS_DEFAULT_REGION,
profile_name=AWS_DEFAULT_PROFILE,
local_port: int = 7000,
ssm=None,
):
self._instance_id = instance_id
self._region_name = region_name
self._profile_name = profile_name if profile_name is not None else ""
self._ssm = ssm
self._target_host = target_host
self._target_port = target_port
self._local_port = local_port

forward_parameters = {
"portNumber": [str(self._target_port)],
"localPortNumber": [str(self._local_port)],
}

# local forward or remote forward
if self._target_host is None:
document_name = "AWS-StartPortForwardingSession"
else:
document_name = "AWS-StartPortForwardingSessionToRemoteHost"
forward_parameters.update({"host": [self._target_host]})

start_session_kwargs = dict(
Target=self._instance_id,
DocumentName=document_name,
Parameters=forward_parameters,
)

self._session_parameters = start_session_kwargs


@plugin_required
@plugin_version("1.1.23.0")
@valid_aws_profile
@valid_aws_region
def port_forward(
config,
instance_name,
target_host,
target_port,
local_port=7000,
profile_name=AWS_DEFAULT_PROFILE,
region_name=AWS_DEFAULT_REGION,
):
instance, profile, region = fetch_instance_details_from_config(
config, instance_name, profile_name, region_name
)

ssm = get_aws_client("ssm", region_name=region, profile_name=profile)
ec2 = get_aws_resource("ec2", region_name=region, profile_name=profile)

instance_id = query_instance(name=instance, ec2=ec2)
if instance_id is None:
raise ValueError(f"No instance could be found for name: {instance}")

if target_host is None:
logger.info(
"Opening SSM Port Forwarding Session listening on %s in instance %s (%s) via profile %s to %s:%s",
target_port,
instance_id,
region,
profile,
)
else:
logger.info(
"Opening SSM Port Forwarding Session to %s:%s via instance %s (%s) via profile %s to %s:%s",
target_host,
target_port,
instance_id,
region,
profile,
)
with SSMPortForwardSession(
instance_id,
region_name=region,
profile_name=profile,
ssm=ssm,
target_host=target_host,
target_port=target_port,
local_port=local_port,
) as sess:
sess.open()
159 changes: 159 additions & 0 deletions tests/unit/test_port_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest

from aws_gate.port_forward import port_forward, SSMPortForwardSession


def test_create_ssm_forward_session(ssm_mock, instance_id):
sess = SSMPortForwardSession(
instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234
)
sess.create()

assert ssm_mock.start_session.called


def test_terminate_ssm_forward_session(ssm_mock, instance_id):
sess = SSMPortForwardSession(
instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234
)

sess.create()
sess.terminate()

assert ssm_mock.terminate_session.called


@pytest.mark.parametrize(
"target_host",
[
None,
"my-fun-host",
],
ids=["Target is None, Local Forward", "Target is not None, Remote Forward"],
)
def test_open_ssm_forward_session(mocker, instance_id, ssm_mock, target_host):
m = mocker.patch("aws_gate.session_common.execute_plugin", return_value="output")

sess = SSMPortForwardSession(
instance_id=instance_id, ssm=ssm_mock, target_host=target_host, target_port=1234
)
sess.open()

if target_host:
expected_doc_name = "AWS-StartPortForwardingSessionToRemoteHost"
else:
expected_doc_name = "AWS-StartPortForwardingSession"

assert sess._session_parameters.get("DocumentName") == expected_doc_name
assert m.called


def test_ssm_forward_session_context_manager(ssm_mock, instance_id):
with SSMPortForwardSession(
instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234
):
pass

assert ssm_mock.start_session.called
assert ssm_mock.terminate_session.called


def test_port_forward(mocker, instance_id, config):
mocker.patch("aws_gate.port_forward.get_aws_client")
mocker.patch("aws_gate.port_forward.get_aws_resource")
mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id)
port_forward_mock = mocker.patch(
"aws_gate.port_forward.SSMPortForwardSession", return_value=mocker.MagicMock()
)
mocker.patch("aws_gate.decorators.is_existing_region", return_value=True)
mocker.patch("aws_gate.decorators._plugin_exists", return_value=True)
mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0")

port_forward(
config=config,
instance_name="instance_name",
target_host="target_host",
target_port=22,
profile_name="default",
region_name="eu-west-1",
)

assert port_forward_mock.called


def test_port_forward_exception_invalid_profile(mocker, instance_id, config):
mocker.patch("aws_gate.port_forward.get_aws_client")
mocker.patch("aws_gate.port_forward.get_aws_resource")
mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id)
mocker.patch("aws_gate.decorators.is_existing_region", return_value=True)
mocker.patch("aws_gate.decorators._plugin_exists", return_value=True)
mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0")

with pytest.raises(ValueError):
port_forward(
config=config,
instance_name="instance_name",
target_host="target_host",
target_port=22,
profile_name="invalid-default",
region_name="eu-west-1",
)


def test_port_forward_exception_invalid_region(mocker, instance_id, config):
mocker.patch("aws_gate.port_forward.get_aws_client")
mocker.patch("aws_gate.port_forward.get_aws_resource")
mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id)
mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True)
mocker.patch("aws_gate.decorators._plugin_exists", return_value=True)
mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0")
mocker.patch(
"aws_gate.port_forward.SSMPortForwardSession", return_value=mocker.MagicMock()
)
with pytest.raises(ValueError):
port_forward(
config=config,
region_name="not-a-region",
instance_name="instance_name",
target_port=22,
profile_name="default",
target_host="target_host",
)


def test_port_forward_exception_unknown_instance_id(mocker, instance_id, config):
mocker.patch("aws_gate.port_forward.get_aws_client")
mocker.patch("aws_gate.port_forward.get_aws_resource")
mocker.patch("aws_gate.port_forward.query_instance", return_value=None)
mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True)
mocker.patch("aws_gate.decorators.is_existing_region", return_value=True)
mocker.patch("aws_gate.decorators._plugin_exists", return_value=True)
mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0")
with pytest.raises(ValueError):
port_forward(
config=config,
region_name="ap-southeast-2",
instance_name=instance_id,
target_port=22,
profile_name="default",
target_host="target_host",
)


def test_port_forward_exception_without_config(mocker, instance_id, empty_config):
mocker.patch("aws_gate.port_forward.get_aws_client")
mocker.patch("aws_gate.port_forward.get_aws_resource")
mocker.patch("aws_gate.port_forward.query_instance", return_value=None)
mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True)
mocker.patch("aws_gate.decorators.is_existing_region", return_value=True)
mocker.patch("aws_gate.decorators._plugin_exists", return_value=True)
mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0")
with pytest.raises(ValueError):
port_forward(
config=empty_config,
region_name="ap-southeast-2",
instance_name=instance_id,
target_port=22,
profile_name="default",
target_host="target_host",
)