diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 686a297b9e..004d91c444 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,12 +4,14 @@ on: pull_request: branches: - develop - - "feat*" + - "feat/*" + - "feat-*" merge_group: types: [ checks_requested ] branches: - develop - - "feat*" + - "feat/*" + - "feat-*" jobs: run-workflow: diff --git a/appveyor-ubuntu.yml b/appveyor-ubuntu.yml index d290ba0da3..80e6a3689a 100644 --- a/appveyor-ubuntu.yml +++ b/appveyor-ubuntu.yml @@ -99,7 +99,7 @@ install: - sh: "terraform -version" # install Rust - - sh: "curl --proto '=https' --tlsv1.2 --retry 10 --retry-connrefused -fsSL https://sh.rustup.rs | sh -s -- --default-toolchain none -y" + - sh: "curl --proto '=https' --tlsv1.2 --retry 10 --retry-connrefused -fsSL https://sh.rustup.rs | sh -s -- --default-toolchain none -y > /dev/null 2>&1" - sh: "source $HOME/.cargo/env" - sh: "rustup toolchain install stable --profile minimal --no-self-update" - sh: "rustup default stable" @@ -264,7 +264,7 @@ for: test_script: - "pip install -e \".[dev]\"" - - sh: "pytest -vv tests/integration/deploy -n 4 --reruns 4 --json-report --json-report-file=TEST_REPORT-integration-deploy.json" + - sh: "pytest -vv tests/integration/deploy -n 4 --reruns 4 --dist=loadgroup --json-report --json-report-file=TEST_REPORT-integration-deploy.json" # Integ testing package - diff --git a/appveyor-windows.yml b/appveyor-windows.yml index 73133d2693..c29bff14ba 100644 --- a/appveyor-windows.yml +++ b/appveyor-windows.yml @@ -254,7 +254,7 @@ for: - "git --version" - "venv\\Scripts\\activate" - "docker system prune -a -f" - - ps: "pytest -vv tests/integration/deploy -n 4 --reruns 4 --json-report --json-report-file=TEST_REPORT-integration-deploy.json" + - ps: "pytest -vv tests/integration/deploy -n 4 --reruns 4 --dist=loadgroup --json-report --json-report-file=TEST_REPORT-integration-deploy.json" # Integ testing package - matrix: diff --git a/requirements/base.txt b/requirements/base.txt index 6b3a171ca2..5c5f5c2e9a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -13,7 +13,7 @@ docker~=4.2.0 dateparser~=1.1 requests==2.28.2 serverlessrepo==0.1.10 -aws_lambda_builders==1.28.0 +aws_lambda_builders==1.29.0 tomlkit==0.11.7 watchdog==2.1.2 pyopenssl==23.0.0 diff --git a/requirements/pre-dev.txt b/requirements/pre-dev.txt index 70028040ca..2bc0c6f01f 100644 --- a/requirements/pre-dev.txt +++ b/requirements/pre-dev.txt @@ -1 +1 @@ -ruff==0.0.251 +ruff==0.0.261 diff --git a/requirements/reproducible-linux.txt b/requirements/reproducible-linux.txt index 9f16518c58..baeadccec7 100644 --- a/requirements/reproducible-linux.txt +++ b/requirements/reproducible-linux.txt @@ -15,9 +15,9 @@ attrs==22.2.0 \ # jschema-to-python # jsonschema # sarif-om -aws-lambda-builders==1.28.0 \ - --hash=sha256:6ea2fb607057436f03e2a8a857b5c5cbd18f7b2b907c53c2b461e65f843a4f38 \ - --hash=sha256:bd6566772e7c5d887d05f32cf7e61a57293658388eef4fe8301f65bef432fe39 +aws-lambda-builders==1.29.0 \ + --hash=sha256:292e4a52550a27a80a46f66b3d3256840f252df343ef82c38f5d89f3073e9820 \ + --hash=sha256:dca37c6beb1fc88958a02aea20bc529ccea5694d489541e91199b49fcbd0bc0a # via aws-sam-cli (setup.py) aws-sam-translator==1.64.0 \ --hash=sha256:0cc5b07dd6ef1de3525d887a3b9557168e04cb44327706a43661653bad30687f \ @@ -54,9 +54,9 @@ boto3==1.26.99 \ # aws-sam-cli (setup.py) # aws-sam-translator # serverlessrepo -botocore==1.29.99 \ - --hash=sha256:15c205e4578253da1e8cc247b9d4755042f5f873f68ac6e5fed48f4bd6f008c6 \ - --hash=sha256:d1770b4fe5531870af7a81e9897b2092d2f89e4ba8cb7abbbaf3ab952f6b8a6f +botocore==1.29.109 \ + --hash=sha256:2e449525f0ccedb31fdb962a77caac48b4c486c23515b84c5989a39a1823a024 \ + --hash=sha256:cf43dddb7e2ba5425fe19fad68b10043307b61d9103d06566f1ab6034e38b8db # via # boto3 # s3transfer diff --git a/samcli/__init__.py b/samcli/__init__.py index b9a0387c06..db57f5f6fe 100644 --- a/samcli/__init__.py +++ b/samcli/__init__.py @@ -2,4 +2,4 @@ SAM CLI version """ -__version__ = "1.79.0" +__version__ = "1.80.0" diff --git a/samcli/cli/types.py b/samcli/cli/types.py index bdce397be2..bf3f0169c0 100644 --- a/samcli/cli/types.py +++ b/samcli/cli/types.py @@ -94,9 +94,8 @@ def convert(self, value, param, ctx): value = (value,) if isinstance(value, str) else value for val in value: - val.strip() # Add empty string to start of the string to help match `_pattern2` - val = " " + val + normalized_val = " " + val.strip() try: # NOTE(TheSriram): find the first regex that matched. @@ -105,7 +104,7 @@ def convert(self, value, param, ctx): pattern = next( i for i in filter( - lambda item: re.findall(item, val), self.ordered_pattern_match + lambda item: re.findall(item, normalized_val), self.ordered_pattern_match ) # pylint: disable=cell-var-from-loop ) except StopIteration: @@ -117,7 +116,7 @@ def convert(self, value, param, ctx): ctx, ) - groups = re.findall(pattern, val) + groups = re.findall(pattern, normalized_val) # 'groups' variable is a list of tuples ex: [(key1, value1), (key2, value2)] for key, param_value in groups: @@ -320,11 +319,10 @@ def convert(self, value, param, ctx): value = (value,) if isinstance(value, str) else value for val in value: - val.strip() # Add empty string to start of the string to help match `_pattern2` - val = " " + val + normalized_val = " " + val.strip() - signing_profiles = re.findall(self.pattern, val) + signing_profiles = re.findall(self.pattern, normalized_val) # if no signing profiles found by regex, then fail if not signing_profiles: diff --git a/samcli/commands/_utils/click_mutex.py b/samcli/commands/_utils/click_mutex.py index cf3c0e2566..397e43acba 100644 --- a/samcli/commands/_utils/click_mutex.py +++ b/samcli/commands/_utils/click_mutex.py @@ -5,8 +5,10 @@ import click +from samcli.commands._utils.custom_options.replace_help_option import ReplaceHelpSummaryOption -class ClickMutex(click.Option): + +class ClickMutex(ReplaceHelpSummaryOption): """ Preprocessing checks for mutually exclusive or required parameters as supported by click api. diff --git a/samcli/commands/_utils/custom_options/replace_help_option.py b/samcli/commands/_utils/custom_options/replace_help_option.py index c4d2e30a6c..ff538c7a25 100644 --- a/samcli/commands/_utils/custom_options/replace_help_option.py +++ b/samcli/commands/_utils/custom_options/replace_help_option.py @@ -10,5 +10,5 @@ def __init__(self, *args, **kwargs): super(ReplaceHelpSummaryOption, self).__init__(*args, **kwargs) def get_help_record(self, ctx): - _, help_text = super(ReplaceHelpSummaryOption, self).get_help_record(ctx=ctx) - return self.replace_help_option, help_text + help_record, help_text = super(ReplaceHelpSummaryOption, self).get_help_record(ctx=ctx) + return self.replace_help_option if self.replace_help_option else help_record, help_text diff --git a/samcli/commands/_utils/options.py b/samcli/commands/_utils/options.py index d4699f2f00..e3f620cdef 100644 --- a/samcli/commands/_utils/options.py +++ b/samcli/commands/_utils/options.py @@ -310,7 +310,7 @@ def no_progressbar_click_option(): default=False, required=False, is_flag=True, - help="Does not showcase a progress bar when uploading artifacts to s3 ", + help="Does not showcase a progress bar when uploading artifacts to s3 and pushing docker images to ECR", ) diff --git a/samcli/commands/deploy/core/command.py b/samcli/commands/deploy/core/command.py index 95edb7fdf8..392afaadad 100644 --- a/samcli/commands/deploy/core/command.py +++ b/samcli/commands/deploy/core/command.py @@ -7,7 +7,7 @@ from samcli.commands.deploy.core.formatters import DeployCommandHelpTextFormatter from samcli.commands.deploy.core.options import OPTIONS_INFO -COL_SIZE_MODIFIER = 50 +COL_SIZE_MODIFIER = 38 class DeployCommand(CoreCommand): diff --git a/samcli/commands/deploy/guided_config.py b/samcli/commands/deploy/guided_config.py index 5dc3332b22..b9d8ea59b5 100644 --- a/samcli/commands/deploy/guided_config.py +++ b/samcli/commands/deploy/guided_config.py @@ -56,10 +56,11 @@ def save_config( cmd_names = get_cmd_names(ctx.info_name, ctx) for key, value in kwargs.items(): - if isinstance(value, (list, tuple)): - value = " ".join(val for val in value) - if value: - samconfig.put(cmd_names, self.section, key, value, env=config_env) + v = value + if isinstance(v, (list, tuple)): + v = " ".join(val for val in v) + if v: + samconfig.put(cmd_names, self.section, key, v, env=config_env) self._save_parameter_overrides(cmd_names, config_env, parameter_overrides, samconfig) self._save_image_repositories(cmd_names, config_env, samconfig, image_repositories) diff --git a/samcli/commands/init/command.py b/samcli/commands/init/command.py index 9a79eb6c0e..8702b3f9a9 100644 --- a/samcli/commands/init/command.py +++ b/samcli/commands/init/command.py @@ -10,6 +10,7 @@ from samcli.cli.cli_config_file import TomlProvider, configuration_option from samcli.cli.main import common_options, pass_context, print_cmdline_args from samcli.commands._utils.click_mutex import ClickMutex +from samcli.commands.init.core.command import InitCommand from samcli.commands.init.init_flow_helpers import _get_runtime_from_image, get_architectures, get_sorted_runtimes from samcli.lib.build.constants import DEPRECATED_RUNTIMES from samcli.lib.telemetry.metric import track_command @@ -20,47 +21,16 @@ LOG = logging.getLogger(__name__) -HELP_TEXT = """ \b - Initialize a serverless application with a SAM template, folder - structure for your Lambda functions, connected to an event source such as APIs, - S3 Buckets or DynamoDB Tables. This application includes everything you need to - get started with serverless and eventually grow into a production scale application. - \b - This command can initialize a boilerplate serverless app. If you want to create your own - template as well as use a custom location please take a look at our official documentation. -\b -Common usage: - \b - Starts an interactive prompt process to initialize a new project: - \b - $ sam init - \b - Initializes a new SAM project using project templates without an interactive workflow: - \b - $ sam init --name sam-app --runtime nodejs14.x --dependency-manager npm --app-template hello-world - \b - $ sam init --name sam-app --runtime nodejs14.x --architecture arm64 - \b - $ sam init --name sam-app --package-type image --base-image nodejs14.x-base - \b - Initializes a new SAM project using custom template in a Git/Mercurial repository - \b - # gh being expanded to github url - $ sam init --location gh:aws-samples/cookiecutter-aws-sam-python - \b - $ sam init --location git+ssh://git@github.com/aws-samples/cookiecutter-aws-sam-python.git - \b - $ sam init --location hg+ssh://hg@bitbucket.org/repo/template-name - \b - Initializes a new SAM project using custom template in a Zipfile - \b - $ sam init --location /path/to/template.zip - \b - $ sam init --location https://example.com/path/to/template.zip - \b - Initializes a new SAM project using custom template in a local path - \b - $ sam init --location /path/to/template/folder +HELP_TEXT = "Initialize an AWS SAM application." + +DESCRIPTION = """ \b + Initialize a serverless application with an AWS SAM template, source code and + structure for serverless abstractions which connect to event source(s) such as APIs, + S3 Buckets or DynamoDB Tables. This application includes everything one needs to + get started with serverless and eventually grow into a production scale application. + \b + To explore initializing with your own template and/or using a custom location, + please take a look at our official documentation. """ INCOMPATIBLE_PARAMS_HINT = """You can run 'sam init' without any options for an interactive initialization flow, \ @@ -136,16 +106,18 @@ def wrapped(*args, **kwargs): @click.command( "init", - help=HELP_TEXT, - short_help="Init an AWS SAM application.", - context_settings=dict(help_option_names=["-h", "--help"]), + short_help=HELP_TEXT, + context_settings={"max_content_width": 120}, + cls=InitCommand, + description=DESCRIPTION, + requires_credentials=False, ) @configuration_option(provider=TomlProvider(section="parameters")) @click.option( "--no-interactive", is_flag=True, default=False, - help="Disable interactive prompting for init parameters, and fail if any required values are missing.", + help="Disable interactive prompting for init parameters. (fail if any required values are missing)", cls=ClickMutex, required_param_lists=[ ["name", "location"], @@ -159,13 +131,14 @@ def wrapped(*args, **kwargs): "-a", "--architecture", type=click.Choice([ARM64, X86_64]), - help="Architectures your Lambda function will run on", + replace_help_option="--architecture ARCHITECTURE", + help="Architectures for Lambda functions." + click.style(f"\n\nArchitectures: {[ARM64, X86_64]}", bold=True), cls=ClickMutex, ) @click.option( "-l", "--location", - help="Template location (git, mercurial, http(s), zip, path)", + help="Template location (git, mercurial, http(s), zip, path).", cls=ClickMutex, incompatible_params=["package_type", "runtime", "base_image", "dependency_manager", "app_template"], incompatible_params_hint=INCOMPATIBLE_PARAMS_HINT, @@ -174,7 +147,9 @@ def wrapped(*args, **kwargs): "-r", "--runtime", type=click.Choice(get_sorted_runtimes(INIT_RUNTIMES)), - help="Lambda Runtime of your app", + replace_help_option="--runtime RUNTIME", + help="Lambda runtime for application." + + click.style(f"\n\nRuntimes: {', '.join(get_sorted_runtimes(INIT_RUNTIMES))}", bold=True), cls=ClickMutex, incompatible_params=["location", "base_image"], incompatible_params_hint=INCOMPATIBLE_PARAMS_HINT, @@ -183,7 +158,8 @@ def wrapped(*args, **kwargs): "-p", "--package-type", type=click.Choice([ZIP, IMAGE]), - help="Package type for your app", + help="Lambda deployment package type." + click.style(f"\n\nPackage Types: {', '.join([ZIP, IMAGE])}", bold=True), + replace_help_option="--package-type PACKAGE_TYPE", cls=ClickMutex, callback=PackageType.pt_callback, incompatible_params=["location"], @@ -194,7 +170,9 @@ def wrapped(*args, **kwargs): "--base-image", type=click.Choice(LAMBDA_IMAGES_RUNTIMES), default=None, - help="Lambda Image of your app", + help="Lambda base image for deploying IMAGE based package type." + + click.style(f"\n\nBase images: {', '.join(LAMBDA_IMAGES_RUNTIMES)}", bold=True), + replace_help_option="--base-image BASE_IMAGE", cls=ClickMutex, incompatible_params=["location", "runtime"], incompatible_params_hint=INCOMPATIBLE_PARAMS_HINT, @@ -204,18 +182,20 @@ def wrapped(*args, **kwargs): "--dependency-manager", type=click.Choice(SUPPORTED_DEP_MANAGERS), default=None, - help="Dependency manager of your Lambda runtime", + help="Dependency manager for Lambda runtime." + + click.style(f"\n\nDependency managers: {', '.join(SUPPORTED_DEP_MANAGERS)}", bold=True), required=False, cls=ClickMutex, + replace_help_option="--dependency-manager DEPENDENCY_MANAGER", incompatible_params=["location"], incompatible_params_hint=INCOMPATIBLE_PARAMS_HINT, ) -@click.option("-o", "--output-dir", type=click.Path(), help="Where to output the initialized app into", default=".") -@click.option("-n", "--name", help="Name of your project to be generated as a folder") +@click.option("-o", "--output-dir", type=click.Path(), help="Directory to initialize AWS SAM application.", default=".") +@click.option("-n", "--name", help="Name of AWS SAM Application.") @click.option( "--app-template", - help="Identifier of the managed application template you want to use. " - "If not sure, call 'sam init' without options for an interactive workflow.", + help="Identifier of the managed application template to be used. " + "Alternatively, run '$sam init' without options for an interactive workflow.", cls=ClickMutex, incompatible_params=["location"], incompatible_params_hint=INCOMPATIBLE_PARAMS_HINT, @@ -224,12 +204,12 @@ def wrapped(*args, **kwargs): "--no-input", is_flag=True, default=False, - help="Disable Cookiecutter prompting and accept default values defined template config", + help="Disable Cookiecutter prompting and accept default values defined in the cookiecutter config.", ) @click.option( "--extra-context", default=None, - help="Override any custom parameters in the template's cookiecutter.json configuration e.g. " + help="Override custom parameters in the template's cookiecutter.json configuration e.g. " "" '{"customParam1": "customValue1", "customParam2":"customValue2"}' """ """, @@ -238,12 +218,12 @@ def wrapped(*args, **kwargs): @click.option( "--tracing/--no-tracing", default=None, - help="Enable AWS X-Ray tracing for your lambda functions", + help="Enable AWS X-Ray tracing for application.", ) @click.option( "--application-insights/--no-application-insights", default=None, - help="Enable CloudWatch Application Insights monitoring for your application", + help="Enable CloudWatch Application Insights monitoring for application.", ) @common_options @non_interactive_validation diff --git a/samcli/commands/init/core/__init__.py b/samcli/commands/init/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/commands/init/core/command.py b/samcli/commands/init/core/command.py new file mode 100644 index 0000000000..87b9f503aa --- /dev/null +++ b/samcli/commands/init/core/command.py @@ -0,0 +1,95 @@ +from click import Context, style + +from samcli.cli.core.command import CoreCommand +from samcli.cli.row_modifiers import RowDefinition, ShowcaseRowModifier +from samcli.commands.init.core.formatters import InitCommandHelpTextFormatter +from samcli.commands.init.core.options import OPTIONS_INFO + + +class InitCommand(CoreCommand): + class CustomFormatterContext(Context): + formatter_class = InitCommandHelpTextFormatter + + context_class = CustomFormatterContext + + @staticmethod + def format_examples(ctx: Context, formatter: InitCommandHelpTextFormatter): + with formatter.indented_section(name="Examples", extra_indents=1): + with formatter.indented_section(name="Interactive Mode", extra_indents=1): + formatter.write_rd( + [ + RowDefinition( + text="\n", + ), + RowDefinition(name=style(f"$ {ctx.command_path}"), extra_row_modifiers=[ShowcaseRowModifier()]), + ] + ) + with formatter.indented_section(name="Customized Interactive Mode", extra_indents=1): + formatter.write_rd( + [ + RowDefinition( + text="\n", + ), + RowDefinition( + name=style( + f"$ {ctx.command_path} --name sam-app --runtime nodejs18.x --architecture arm64" + ), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + RowDefinition( + name=style( + f"$ {ctx.command_path} --name sam-app --runtime nodejs18.x --dependency-manager " + f"npm --app-template hello-world" + ), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + RowDefinition( + name=style( + f"$ {ctx.command_path} --name sam-app --package-type image --architecture arm64" + ), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + ] + ) + with formatter.indented_section(name="Direct Initialization", extra_indents=1): + formatter.write_rd( + [ + RowDefinition( + text="\n", + ), + RowDefinition( + name=style(f"$ {ctx.command_path} --location gh:aws-samples/cookiecutter-aws-sam-python"), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + RowDefinition( + name=style( + f"$ {ctx.command_path} --location " + f"git+ssh://git@github.com/aws-samples/cookiecutter-aws-sam-python.git" + ), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + RowDefinition( + name=style(f"$ {ctx.command_path} --location /path/to/template.zip"), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + RowDefinition( + name=style(f"$ {ctx.command_path} --location /path/to/template/directory"), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + RowDefinition( + name=style(f"$ {ctx.command_path} --location https://example.com/path/to/template.zip"), + extra_row_modifiers=[ShowcaseRowModifier()], + ), + ], + ) + + def format_options(self, ctx: Context, formatter: InitCommandHelpTextFormatter) -> None: # type:ignore + # `ignore` is put in place here for mypy even though it is the correct behavior, + # as the `formatter_class` can be set in subclass of Command. If ignore is not set, + # mypy raises argument needs to be HelpFormatter as super class defines it. + + self.format_description(formatter) + InitCommand.format_examples(ctx, formatter) + CoreCommand._format_options( + ctx=ctx, params=self.get_params(ctx), formatter=formatter, formatting_options=OPTIONS_INFO + ) diff --git a/samcli/commands/init/core/formatters.py b/samcli/commands/init/core/formatters.py new file mode 100644 index 0000000000..2ee8ac2f8c --- /dev/null +++ b/samcli/commands/init/core/formatters.py @@ -0,0 +1,19 @@ +from samcli.cli.formatters import RootCommandHelpTextFormatter +from samcli.cli.row_modifiers import BaseLineRowModifier +from samcli.commands.init.core.options import ALL_OPTIONS + + +class InitCommandHelpTextFormatter(RootCommandHelpTextFormatter): + # Picked an additive constant that gives an aesthetically pleasing look. + ADDITIVE_JUSTIFICATION = 10 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Add Additional space after determining the longest option. + # However, do not justify with padding for more than half the width of + # the terminal to retain aesthetics. + self.left_justification_length = min( + max([len(option) for option in ALL_OPTIONS]) + self.ADDITIVE_JUSTIFICATION, + self.width // 2 - self.indent_increment, + ) + self.modifiers = [BaseLineRowModifier()] diff --git a/samcli/commands/init/core/options.py b/samcli/commands/init/core/options.py new file mode 100644 index 0000000000..5d6cda6c2c --- /dev/null +++ b/samcli/commands/init/core/options.py @@ -0,0 +1,58 @@ +""" +Init Command Options related Datastructures for formatting. +""" +from typing import Dict, List + +from samcli.cli.row_modifiers import RowDefinition + +# The ordering of the option lists matter, they are the order in which options will be displayed. + +APPLICATION_OPTIONS: List[str] = [ + "name", + "architecture", + "runtime", + "dependency_manager", + "location", + "package_type", + "base_image", + "app_template", + "output_dir", +] + +# Can be used instead of the options in the first list +NON_INTERACTIVE_OPTIONS: List[str] = ["no_interactive", "no_input", "extra_context"] + +CONFIGURATION_OPTION_NAMES: List[str] = ["config_env", "config_file"] + +ADDITIONAL_OPTIONS: List[str] = [ + "tracing", + "application_insights", +] + +OTHER_OPTIONS: List[str] = ["debug"] + +ALL_OPTIONS: List[str] = ( + APPLICATION_OPTIONS + NON_INTERACTIVE_OPTIONS + CONFIGURATION_OPTION_NAMES + ADDITIONAL_OPTIONS + OTHER_OPTIONS +) + +OPTIONS_INFO: Dict[str, Dict] = { + "Application Options": { + "option_names": {opt: {"rank": idx} for idx, opt in enumerate(APPLICATION_OPTIONS)}, + "extras": [RowDefinition(name="")], + }, + "Non Interactive Options": { + "option_names": {opt: {"rank": idx} for idx, opt in enumerate(NON_INTERACTIVE_OPTIONS)} + }, + "Configuration Options": { + "option_names": {opt: {"rank": idx} for idx, opt in enumerate(CONFIGURATION_OPTION_NAMES)}, + "extras": [ + RowDefinition(name="Learn more about configuration files at:"), + RowDefinition( + name="https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-cli" + "-config.html. " + ), + ], + }, + "Additional Options": {"option_names": {opt: {"rank": idx} for idx, opt in enumerate(ADDITIONAL_OPTIONS)}}, + "Other Options": {"option_names": {opt: {"rank": idx} for idx, opt in enumerate(OTHER_OPTIONS)}}, +} diff --git a/samcli/commands/init/init_flow_helpers.py b/samcli/commands/init/init_flow_helpers.py index ccaff48fc2..6ca7352b1b 100644 --- a/samcli/commands/init/init_flow_helpers.py +++ b/samcli/commands/init/init_flow_helpers.py @@ -158,7 +158,7 @@ def _get_runtime_from_image(image: str) -> Optional[str]: if match is None: return None runtime, base = match.groups() - if base != "": + if base: return f"{runtime} ({base})" return runtime diff --git a/samcli/commands/local/lib/swagger/parser.py b/samcli/commands/local/lib/swagger/parser.py index 4e36d1caff..e4ee0d5960 100644 --- a/samcli/commands/local/lib/swagger/parser.py +++ b/samcli/commands/local/lib/swagger/parser.py @@ -1,19 +1,47 @@ """Handles Swagger Parsing""" import logging +from typing import Dict, List, Union from samcli.commands.local.lib.swagger.integration_uri import IntegrationType, LambdaUri -from samcli.local.apigw.local_apigw_service import Route +from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator +from samcli.local.apigw.authorizers.authorizer import Authorizer +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer +from samcli.local.apigw.exceptions import ( + IncorrectOasWithDefaultAuthorizerException, + InvalidOasVersion, + InvalidSecurityDefinition, + MultipleAuthorizerException, +) +from samcli.local.apigw.route import Route LOG = logging.getLogger(__name__) class SwaggerParser: + _AUTHORIZER_KEY = "x-amazon-apigateway-authorizer" _INTEGRATION_KEY = "x-amazon-apigateway-integration" _ANY_METHOD_EXTENSION_KEY = "x-amazon-apigateway-any-method" _BINARY_MEDIA_TYPES_EXTENSION_KEY = "x-amazon-apigateway-binary-media-types" # pylint: disable=C0103 _ANY_METHOD = "ANY" + _SWAGGER = "swagger" + _OPENAPI = "openapi" + _2_X_VERSION = "2." + _3_X_VERSION = "3." + _SWAGGER_COMPONENTS = "components" + _SWAGGER_SECURITY = "security" + _SWAGGER_SECURITY_SCHEMES = "securitySchemes" + _SWAGGER_SECURITY_DEFINITIONS = "securityDefinitions" + _AUTHORIZER_TYPE = "type" + _AUTHORIZER_PAYLOAD_VERSION = "authorizerPayloadFormatVersion" + _AUTHORIZER_LAMBDA_URI = "authorizerUri" + _AUTHORIZER_LAMBDA_VALIDATION = "identityValidationExpression" + _AUTHORIZER_NAME = "name" + _AUTHORIZER_IN = "in" + _AUTHORIZER_IDENTITY_SOURCE = "identitySource" + _AUTHORIZER_SIMPLE_RESPONSES = "enableSimpleResponses" + def __init__(self, stack_path: str, swagger): """ Constructs an Swagger Parser object @@ -36,7 +64,233 @@ def get_binary_media_types(self): """ return self.swagger.get(self._BINARY_MEDIA_TYPES_EXTENSION_KEY) or [] - def get_routes(self, event_type=Route.API): + def get_authorizers(self, event_type: str = Route.API) -> Dict[str, Authorizer]: + """ + Parse Swagger document and returns a list of Authorizer objects + + Parameters + ---------- + event_type: str + String indicating what type of API Gateway this is + + Returns + ------- + dict[str, Authorizer] + A map of authorizer names and Authorizer objects found in the body definition + """ + authorizers: Dict[str, Authorizer] = {} + + authorizer_dict = {} + document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or "" + + if document_version.startswith(SwaggerParser._2_X_VERSION): + LOG.debug("Parsing Swagger document using 2.0 specification") + authorizer_dict = self.swagger.get(SwaggerParser._SWAGGER_SECURITY_DEFINITIONS, {}) + elif document_version.startswith(SwaggerParser._3_X_VERSION): + LOG.debug("Parsing Swagger document using 3.0 specification") + authorizer_dict = self.swagger.get(SwaggerParser._SWAGGER_COMPONENTS, {}).get( + SwaggerParser._SWAGGER_SECURITY_SCHEMES, {} + ) + else: + raise InvalidOasVersion( + f"An invalid OpenApi version was detected: '{document_version}', must be one of 2.x or 3.x", + ) + + for auth_name, properties in authorizer_dict.items(): + authorizer_object = properties.get(self._AUTHORIZER_KEY) + + if not authorizer_object: + LOG.warning("Skip parsing unsupported authorizer '%s'", auth_name) + continue + + authorizer_type = authorizer_object.get(SwaggerParser._AUTHORIZER_TYPE, "").lower() + payload_version = authorizer_object.get(SwaggerParser._AUTHORIZER_PAYLOAD_VERSION) + + if event_type == Route.HTTP and payload_version not in LambdaAuthorizer.PAYLOAD_VERSIONS: + raise InvalidSecurityDefinition(f"Authorizer '{auth_name}' contains an invalid payload version") + + if event_type == Route.API: + payload_version = LambdaAuthorizer.PAYLOAD_V1 + + lambda_name = LambdaUri.get_function_name(authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_URI)) + + if not lambda_name: + LOG.warning("Unable to parse authorizerUri '%s' for authorizer '%s', skipping", lambda_name, auth_name) + continue + + # only add authorizer if it is Lambda token or request based (not jwt) + if authorizer_type not in LambdaAuthorizer.VALID_TYPES: + LOG.warning("Lambda authorizer '%s' type '%s' is unsupported, skipping", auth_name, authorizer_type) + continue + + identity_sources = self._get_lambda_identity_sources( + auth_name, authorizer_type, event_type, properties, authorizer_object + ) + + validation_expression = authorizer_object.get(SwaggerParser._AUTHORIZER_LAMBDA_VALIDATION) + if event_type == Route.HTTP and validation_expression: + validation_expression = None + + LOG.warning( + "Validation expressions is only available on REST APIs, ignoring for Lambda authorizer '%s'", + auth_name, + ) + + enable_simple_response = authorizer_object.get(SwaggerParser._AUTHORIZER_SIMPLE_RESPONSES, False) + + if ( + event_type != Route.HTTP + or payload_version != LambdaAuthorizer.PAYLOAD_V2 + or not isinstance(enable_simple_response, bool) + ): + enable_simple_response = False + + if authorizer_object.get(SwaggerParser._AUTHORIZER_SIMPLE_RESPONSES) is not None: + LOG.warning( + "Simple responses are only available on HTTP APIs with payload version " + "2.0, ignoring for Lambda authorizer '%s'", + auth_name, + ) + + if not identity_sources: + LOG.warning( + "Skip parsing Lambda authorizer '%s', must contain at least one valid identity source", + auth_name, + ) + continue + + lambda_authorizer = LambdaAuthorizer( + authorizer_name=auth_name, + type=authorizer_type, + payload_version=payload_version, + lambda_name=lambda_name, + identity_sources=identity_sources, + validation_string=validation_expression, + use_simple_response=enable_simple_response, + ) + + authorizers[auth_name] = lambda_authorizer + + LOG.debug("Parsing Lambda authorizer '%s' type '%s'", auth_name, authorizer_type) + + return authorizers + + @staticmethod + def _get_lambda_identity_sources( + auth_name: str, auth_type: str, event_type: str, properties: dict, authorizer_object: dict + ) -> List[str]: + """ + Parses the properties depending on the Lambda Authorizer type (token or request) and retrieves identity sources + + Parameters + ---------- + auth_name: str + Name of the authorizer used for logging + auth_type: str + Type of authorizer (token, request) + event_type: str + API Gateway type (API, HTTP API) + properties: dict + Swagger Lambda Authorizer properties + authorizer_object: dict + Lambda Authorizer integration properties + Returns + ------- + List[str] + A list of identity sources + """ + identity_sources: List[str] = [] + + if auth_type == LambdaAuthorizer.TOKEN: + header_name = properties.get(SwaggerParser._AUTHORIZER_NAME) + + if not properties.get(SwaggerParser._AUTHORIZER_IN) == "header" or not header_name: + LOG.warning( + "Missing properties for Lambda Authorizer '%s', " + "property 'in' must be set to 'header' and " + "property 'name' must be provided", + auth_name, + ) + elif event_type == Route.HTTP: + LOG.info("Type 'token' for Lambda Authorizer '%s' is unsupported ", auth_name) + else: + identity_sources.append(f"method.request.header.{header_name}") + else: + identity_source_string = authorizer_object.get(SwaggerParser._AUTHORIZER_IDENTITY_SOURCE) + + if not identity_source_string: + LOG.warning( + "Missing property 'identitySource' in the authorizer integration for Lambda Authorizer '%s'", + auth_name, + ) + else: + # split the identity sources, remove any trailing spaces, and validate + split_identity_source: List[str] = identity_source_string.split(",") + + for identity in split_identity_source: + trimmed_identity = identity.strip() + is_valid_format = IdentitySourceValidator.validate_identity_source(trimmed_identity, event_type) + + if not is_valid_format: + raise InvalidSecurityDefinition( + f"Identity source '{trimmed_identity}' for Lambda Authorizer '{auth_name}' " + "is not a valid identity source, check the spelling/format." + ) + + identity_sources.append(trimmed_identity) + + return identity_sources + + def get_default_authorizer(self, event_type: str) -> Union[str, None]: + """ + Parses the body definition to find root level Authorizer definitions + + Parameters + ---------- + event_type: str + String representing the type of API the definition body is defined as + + Returns + ------- + Union[str, None] + Returns the name of the authorizer, if there is one defined, otherwise None + """ + document_version = self.swagger.get(SwaggerParser._SWAGGER) or self.swagger.get(SwaggerParser._OPENAPI) or "" + authorizers = self.swagger.get(SwaggerParser._SWAGGER_SECURITY, []) + + if not authorizers: + return None + + if not document_version.startswith(SwaggerParser._3_X_VERSION) or not event_type == Route.HTTP: + raise IncorrectOasWithDefaultAuthorizerException( + "Root level definition of default authorizers are only supported for OpenApi 3.0" + ) + + if len(authorizers) > 1: + raise MultipleAuthorizerException( + f"There must only be a single authorizer defined for a single route, found '{len(authorizers)}'" + ) + + if len(authorizers) == 1: + # user has authorizer defined + authorizer_object = authorizers[0] + authorizer_object = list(authorizers[0]) + + # make sure that authorizer actually has keys + if len(authorizer_object) != 1: + raise InvalidSecurityDefinition( + "Invalid default security definition found, there must be an authorizer defined." + ) + + authorizer_name = str(authorizer_object[0]) + + LOG.debug("Found default authorizer: %s", authorizer_name) + + return authorizer_name + + return None + + def get_routes(self, event_type=Route.API) -> List[Route]: """ Parses a swagger document and returns a list of APIs configured in the document. @@ -82,20 +336,60 @@ def get_routes(self, event_type=Route.API): ) continue - if method.lower() == self._ANY_METHOD_EXTENSION_KEY: + normalized_method = method + if normalized_method.lower() == self._ANY_METHOD_EXTENSION_KEY: # Convert to a more commonly used method notation - method = self._ANY_METHOD + normalized_method = self._ANY_METHOD payload_format_version = self._get_payload_format_version(method_config) + + authorizers = method_config.get(SwaggerParser._SWAGGER_SECURITY, None) + + authorizer_name = None + use_default_authorizer = True + + if authorizers is not None: + if not isinstance(authorizers, list): + raise InvalidSecurityDefinition( + "Invalid security definition found, authorizers for " + f"path='{full_path}' method='{method}' must be a list" + ) + + if len(authorizers) > 1: + raise MultipleAuthorizerException( + "There must only be a single authorizer defined " + f"for path='{full_path}' method='{method}', found '{len(authorizers)}'" + ) + + if len(authorizers) == 1: + # user has authorizer defined + authorizer_object = authorizers[0] + authorizer_object = list(authorizers[0]) + + # make sure that authorizer actually has keys + if len(authorizer_object) != 1: + raise InvalidSecurityDefinition( + "Invalid security definition found, authorizers for " + f"path='{full_path}' method='{method}' must contain an authorizer" + ) + + authorizer_name = str(authorizer_object[0]) + else: + # customer provided empty list, do not use default authorizer + use_default_authorizer = False + route = Route( function_name, full_path, - methods=[method], + methods=[normalized_method], event_type=event_type, payload_format_version=payload_format_version, operation_name=method_config.get("operationId"), stack_path=self.stack_path, + authorizer_name=authorizer_name, + use_default_authorizer=use_default_authorizer, ) result.append(route) + return result def _get_integration(self, method_config): diff --git a/samcli/commands/local/lib/validators/__init__.py b/samcli/commands/local/lib/validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/commands/local/lib/validators/identity_source_validator.py b/samcli/commands/local/lib/validators/identity_source_validator.py new file mode 100644 index 0000000000..e95cc6a316 --- /dev/null +++ b/samcli/commands/local/lib/validators/identity_source_validator.py @@ -0,0 +1,57 @@ +""" +Handles the validation of identity sources +""" +import re + +from samcli.local.apigw.route import Route + + +class IdentitySourceValidator: + # match lowercase + uppercase + numbers + those 3 symbols, until the end of string + API_GATEWAY_V1_QUERY_REGEX = re.compile(r"method\.request\.querystring\.[a-zA-Z0-9._-]+$") + API_GATEWAY_V1_HEADER_REGEX = re.compile(r"method\.request\.header\.[a-zA-Z0-9._-]+$") + API_GATEWAY_V1_CONTEXT_REGEX = re.compile(r"context\.[a-zA-Z0-9._-]+$") + API_GATEWAY_V1_STAGE_REGEX = re.compile(r"stageVariables\.[a-zA-Z0-9._-]+$") + + API_GATEWAY_V2_QUERY_REGEX = re.compile(r"\$request\.querystring\.[a-zA-Z0-9._-]+$") + API_GATEWAY_V2_HEADER_REGEX = re.compile(r"\$request\.header\.[a-zA-Z0-9._-]+$") + API_GATEWAY_V2_CONTEXT_REGEX = re.compile(r"\$context\.[a-zA-Z0-9._-]+$") + API_GATEWAY_V2_STAGE_REGEX = re.compile(r"\$stageVariables\.[a-zA-Z0-9._-]+$") + + API_GATEWAY_VALIDATION_LIST = { + Route.API: [ + API_GATEWAY_V1_QUERY_REGEX, + API_GATEWAY_V1_HEADER_REGEX, + API_GATEWAY_V1_CONTEXT_REGEX, + API_GATEWAY_V1_STAGE_REGEX, + ], + Route.HTTP: [ + API_GATEWAY_V2_QUERY_REGEX, + API_GATEWAY_V2_HEADER_REGEX, + API_GATEWAY_V2_CONTEXT_REGEX, + API_GATEWAY_V2_STAGE_REGEX, + ], + } + + @staticmethod + def validate_identity_source(identity_source: str, event_type: str = Route.API) -> bool: + """ + Validates if the identity source is valid for the provided event type + + Parameters + ---------- + identity_source: str + The identity source to validate + event_type: str + The type of API Gateway to validate against (API or HTTP) + + Returns + ------- + bool + True if the identity source is valid + """ + for regex in IdentitySourceValidator.API_GATEWAY_VALIDATION_LIST[event_type]: + if regex.match(identity_source): + return True + + return False diff --git a/samcli/commands/local/lib/validators/lambda_auth_props.py b/samcli/commands/local/lib/validators/lambda_auth_props.py new file mode 100644 index 0000000000..96a9ce05b1 --- /dev/null +++ b/samcli/commands/local/lib/validators/lambda_auth_props.py @@ -0,0 +1,266 @@ +""" +Module to help validate Lambda Authorizer properties +""" +import logging +from abc import ABC, abstractmethod + +from samcli.commands.local.cli_common.user_exceptions import InvalidSamTemplateException +from samcli.commands.local.lib.swagger.integration_uri import LambdaUri +from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer +from samcli.local.apigw.route import Route + +LOG = logging.getLogger(__name__) + + +class BaseLambdaAuthorizerValidator(ABC): + AUTHORIZER_TYPE = "Type" + AUTHORIZER_REST_API = "RestApiId" + AUTHORIZER_NAME = "Name" + AUTHORIZER_IDENTITY_SOURCE = "IdentitySource" + AUTHORIZER_VALIDATION = "IdentityValidationExpression" + AUTHORIZER_AUTHORIZER_URI = "AuthorizerUri" + + @staticmethod + @abstractmethod + def validate(logical_id: str, resource: dict) -> bool: + """ + Validates if all the required properties for a Lambda Authorizer are present and valid. + + Parameters + ---------- + logical_id: str + The logical ID of the authorizer + resource: dict + The resource dictionary for the authorizer containing the `Properties` + + Returns + ------- + bool + True if the `Properties` contains all the required key values + """ + + @staticmethod + def _validate_common_properties(logical_id: str, properties: dict, type_key: str, api_key: str): + """ + Validates if the common required properties are present and valid, will raise an exception + if they are missing or invalid. + + Parameters + ---------- + logical_id: str + The logical ID of the authorizer + properties: dict + The `Properties` dictionary for the authorizer + type_key: str + They authorizer type key to search for + api_key: str + The API Gateway reference key to search for + """ + authorizer_type = properties.get(type_key) + api_id = properties.get(api_key) + name = properties.get(BaseLambdaAuthorizerValidator.AUTHORIZER_NAME) + + if not authorizer_type: + raise InvalidSamTemplateException( + f"Authorizer '{logical_id}' is missing the '{type_key}' " + "property, an Authorizer type must be defined." + ) + + if not api_id: + raise InvalidSamTemplateException( + f"Authorizer '{logical_id}' is missing the '{api_key}' " "property, this must be defined." + ) + + if not name: + raise InvalidSamTemplateException( + f"Authorizer '{logical_id}' is missing the '{BaseLambdaAuthorizerValidator.AUTHORIZER_NAME}' " + "property, the Name must be defined." + ) + + +class LambdaAuthorizerV1Validator(BaseLambdaAuthorizerValidator): + @staticmethod + def validate( + logical_id: str, + resource: dict, + ): + """ + Validates if all the required properties for a Lambda Authorizer V1 are present and valid. + + Parameters + ---------- + logical_id: str + The logical ID of the authorizer + resource: dict + The resource dictionary for the authorizer containing the `Properties` + + Returns + ------- + bool + True if the `Properties` contains all the required key values + """ + properties = resource.get("Properties", {}) + authorizer_type = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_TYPE, "") + authorizer_uri = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_AUTHORIZER_URI) + + LambdaAuthorizerV1Validator._validate_common_properties( + logical_id, + properties, + LambdaAuthorizerV1Validator.AUTHORIZER_TYPE, + LambdaAuthorizerV1Validator.AUTHORIZER_REST_API, + ) + + # (lucashuy) AWS SAM CLI keeps references to types as lowercase strings + # while they are defined as uppercase strings in CFN + # this is to just validate that they are provided as upper case strings + if authorizer_type not in [type.upper() for type in LambdaAuthorizer.VALID_TYPES]: + LOG.warning( + "Authorizer '%s' with type '%s' is currently not supported. " + "Only Lambda Authorizers of type TOKEN and REQUEST are supported.", + logical_id, + authorizer_type, + ) + return False + + if not authorizer_uri: + raise InvalidSamTemplateException( + f"Authorizer '{logical_id}' is missing the '{LambdaAuthorizerV1Validator.AUTHORIZER_AUTHORIZER_URI}' " + "property, a valid Lambda ARN must be provided." + ) + + function_name = LambdaUri.get_function_name(authorizer_uri) + if not function_name: + LOG.warning( + "Was not able to resolve Lambda function ARN for Authorizer '%s'. " + "Double check the ARN format, or use more simple intrinsics.", + logical_id, + ) + return False + + identity_source_template = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_IDENTITY_SOURCE, None) + + if identity_source_template is None and authorizer_type == LambdaAuthorizer.TOKEN.upper(): + raise InvalidSamTemplateException( + f"Lambda Authorizer '{logical_id}' of type TOKEN, must have " + f"'{LambdaAuthorizerV1Validator.AUTHORIZER_IDENTITY_SOURCE}' of type string defined." + ) + + # (lucashuy) (regarding this if statement and the one below this) + # For API Gateway V1, an authorizer of type REQUEST can omit the identity sources + # if caching is enabled. Made the decision to not test this behaviour, and instead + # test if the it is a string. + if identity_source_template is not None and not isinstance(identity_source_template, str): + raise InvalidSamTemplateException( + f"Lambda Authorizer '{logical_id}' contains an invalid " + f"'{LambdaAuthorizerV1Validator.AUTHORIZER_IDENTITY_SOURCE}', " + "it must be a comma-separated string." + ) + + validation_expression = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_VALIDATION) + + if authorizer_type == LambdaAuthorizer.REQUEST.upper() and validation_expression: + raise InvalidSamTemplateException( + "Lambda Authorizer '%s' has '%s' property defined, but validation is only " + "supported on TOKEN type authorizers." % (logical_id, LambdaAuthorizerV1Validator.AUTHORIZER_VALIDATION) + ) + + return True + + +class LambdaAuthorizerV2Validator(BaseLambdaAuthorizerValidator): + AUTHORIZER_V2_TYPE = "AuthorizerType" + AUTHORIZER_V2_API = "ApiId" + AUTHORIZER_V2_PAYLOAD = "AuthorizerPayloadFormatVersion" + AUTHORIZER_V2_SIMPLE_RESPONSE = "EnableSimpleResponses" + + @staticmethod + def validate( + logical_id: str, + resource: dict, + ): + """ + Validates if all the required properties for a Lambda Authorizer V2 are present and valid. + + Parameters + ---------- + logical_id: str + The logical ID of the authorizer + resource: dict + The resource dictionary for the authorizer containing the `Properties` + + Returns + ------- + bool + True if the `Properties` contains all the required key values + """ + properties = resource.get("Properties", {}) + authorizer_type = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_TYPE, "") + authorizer_uri = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_AUTHORIZER_URI) + + LambdaAuthorizerV2Validator._validate_common_properties( + logical_id, + properties, + LambdaAuthorizerV2Validator.AUTHORIZER_V2_TYPE, + LambdaAuthorizerV2Validator.AUTHORIZER_V2_API, + ) + + # (lucashuy) AWS SAM CLI keeps references to types as lowercase strings + # while they are defined as uppercase strings in CFN + # this is to just validate that they are provided as upper case strings + if authorizer_type != LambdaAuthorizer.REQUEST.upper(): + LOG.warning( + "Authorizer '%s' with type '%s' is currently not supported. " + "Only Lambda Authorizers of type REQUEST are supported for API Gateway V2.", + logical_id, + authorizer_type, + ) + return False + + if not authorizer_uri: + raise InvalidSamTemplateException( + f"Authorizer '{logical_id}' is missing the '{LambdaAuthorizerV2Validator.AUTHORIZER_AUTHORIZER_URI}' " + "property, a valid Lambda ARN must be provided." + ) + + function_name = LambdaUri.get_function_name(authorizer_uri) + if not function_name: + LOG.warning( + "Was not able to resolve Lambda function ARN for Authorizer '%s'. " + "Double check the ARN format, or use more simple intrinsics.", + logical_id, + ) + return False + + identity_sources = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_IDENTITY_SOURCE, None) + + if not isinstance(identity_sources, list): + raise InvalidSamTemplateException( + f"Lambda Authorizer '{logical_id}' must have " + f"'{LambdaAuthorizerV2Validator.AUTHORIZER_IDENTITY_SOURCE}' of type list defined." + ) + + for identity_source in identity_sources: + if not IdentitySourceValidator.validate_identity_source(identity_source, Route.HTTP): + raise InvalidSamTemplateException( + f"Lambda Authorizer {logical_id} does not contain valid identity sources.", Route.HTTP + ) + + payload_version = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_PAYLOAD) + + if payload_version not in [None, *LambdaAuthorizer.PAYLOAD_VERSIONS]: + raise InvalidSamTemplateException( + f"Lambda Authorizer '{logical_id}' contains an invalid " + f"'{LambdaAuthorizerV2Validator.AUTHORIZER_V2_PAYLOAD}'" + ", it must be set to '1.0' or '2.0'" + ) + + simple_responses = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_SIMPLE_RESPONSE, False) + + if payload_version == LambdaAuthorizer.PAYLOAD_V1 and simple_responses: + raise InvalidSamTemplateException( + f"'{LambdaAuthorizerV2Validator.AUTHORIZER_V2_SIMPLE_RESPONSE}' is only supported for '2.0' " + f"payload format versions for Lambda Authorizer '{logical_id}'." + ) + + return True diff --git a/samcli/commands/validate/validate.py b/samcli/commands/validate/validate.py index da588e5dbe..8a3663153b 100644 --- a/samcli/commands/validate/validate.py +++ b/samcli/commands/validate/validate.py @@ -104,8 +104,8 @@ def _read_sam_file(template): click.secho("SAM Template Not Found", bg="red") raise SamTemplateNotFoundException("Template at {} is not found".format(template)) - with click.open_file(template, "r", encoding="utf-8") as sam_template: - sam_template = yaml_parse(sam_template.read()) + with click.open_file(template, "r", encoding="utf-8") as sam_file: + sam_template = yaml_parse(sam_file.read()) return sam_template diff --git a/samcli/hook_packages/terraform/hooks/prepare/exceptions.py b/samcli/hook_packages/terraform/hooks/prepare/exceptions.py index b185d3f585..f1c3fbf4fe 100644 --- a/samcli/hook_packages/terraform/hooks/prepare/exceptions.py +++ b/samcli/hook_packages/terraform/hooks/prepare/exceptions.py @@ -53,3 +53,17 @@ def __init__(self, local_variable_reference, function_id): class InvalidSamMetadataPropertiesException(UserException): pass + + +class OpenAPIBodyNotSupportedException(UserException): + fmt = ( + "AWS SAM CLI is unable to process a Terraform project that uses an OpenAPI specification to " + "define the API Gateway resource. AWS SAM CLI does not currently support this " + "functionality. Affected resource: {api_id}." + ) + + def __init__(self, api_id): + msg = self.fmt.format( + api_id=api_id, + ) + UserException.__init__(self, msg) diff --git a/samcli/hook_packages/terraform/hooks/prepare/makefile_generator.py b/samcli/hook_packages/terraform/hooks/prepare/makefile_generator.py index 1c3ff08dd6..f9606ffbd3 100644 --- a/samcli/hook_packages/terraform/hooks/prepare/makefile_generator.py +++ b/samcli/hook_packages/terraform/hooks/prepare/makefile_generator.py @@ -263,7 +263,7 @@ def _get_parent_modules(module_address: Optional[str]) -> List[str]: previous_module = modules[0] full_path_modules = [previous_module] for module in modules[1:]: - module = previous_module + "." + module - previous_module = module - full_path_modules.append(module) + norm_module = previous_module + "." + module + previous_module = norm_module + full_path_modules.append(norm_module) return full_path_modules diff --git a/samcli/hook_packages/terraform/hooks/prepare/property_builder.py b/samcli/hook_packages/terraform/hooks/prepare/property_builder.py index 744ba61987..fc33a9ba02 100644 --- a/samcli/hook_packages/terraform/hooks/prepare/property_builder.py +++ b/samcli/hook_packages/terraform/hooks/prepare/property_builder.py @@ -12,6 +12,9 @@ ) from samcli.lib.hook.exceptions import PrepareHookException from samcli.lib.utils.packagetype import IMAGE, ZIP +from samcli.lib.utils.resources import AWS_APIGATEWAY_RESOURCE as CFN_AWS_APIGATEWAY_RESOURCE +from samcli.lib.utils.resources import AWS_APIGATEWAY_RESTAPI as CFN_AWS_APIGATEWAY_RESTAPI +from samcli.lib.utils.resources import AWS_APIGATEWAY_STAGE as CFN_AWS_APIGATEWAY_STAGE from samcli.lib.utils.resources import AWS_LAMBDA_FUNCTION as CFN_AWS_LAMBDA_FUNCTION from samcli.lib.utils.resources import AWS_LAMBDA_LAYERVERSION as CFN_AWS_LAMBDA_LAYER_VERSION @@ -19,6 +22,10 @@ TF_AWS_LAMBDA_FUNCTION = "aws_lambda_function" TF_AWS_LAMBDA_LAYER_VERSION = "aws_lambda_layer_version" +TF_AWS_API_GATEWAY_RESOURCE = "aws_api_gateway_resource" +TF_AWS_API_GATEWAY_REST_API = "aws_api_gateway_rest_api" +TF_AWS_API_GATEWAY_STAGE = "aws_api_gateway_stage" + def _build_code_property(tf_properties: dict, resource: TFResource) -> Any: """ @@ -215,9 +222,37 @@ def _check_image_config_value(image_config: Any) -> bool: "Content": _build_code_property, } +AWS_API_GATEWAY_REST_API_PROPERTY_BUILDER_MAPPING: PropertyBuilderMapping = { + "Name": _get_property_extractor("name"), + "Body": _get_property_extractor("body"), + "Parameters": _get_property_extractor("parameters"), + "BinaryMediaTypes": _get_property_extractor("binary_media_types"), +} + +AWS_API_GATEWAY_STAGE_PROPERTY_BUILDER_MAPPING: PropertyBuilderMapping = { + "RestApiId": _get_property_extractor("rest_api_id"), + "StageName": _get_property_extractor("stage_name"), + "Variables": _get_property_extractor("variables"), +} + +AWS_API_GATEWAY_RESOURCE_PROPERTY_BUILDER_MAPPING: PropertyBuilderMapping = { + "RestApiId": _get_property_extractor("rest_api_id"), + "ParentId": _get_property_extractor("parent_id"), + "PathPart": _get_property_extractor("path_part"), +} + RESOURCE_TRANSLATOR_MAPPING: Dict[str, ResourceTranslator] = { TF_AWS_LAMBDA_FUNCTION: ResourceTranslator(CFN_AWS_LAMBDA_FUNCTION, AWS_LAMBDA_FUNCTION_PROPERTY_BUILDER_MAPPING), TF_AWS_LAMBDA_LAYER_VERSION: ResourceTranslator( CFN_AWS_LAMBDA_LAYER_VERSION, AWS_LAMBDA_LAYER_VERSION_PROPERTY_BUILDER_MAPPING ), + TF_AWS_API_GATEWAY_REST_API: ResourceTranslator( + CFN_AWS_APIGATEWAY_RESTAPI, AWS_API_GATEWAY_REST_API_PROPERTY_BUILDER_MAPPING + ), + TF_AWS_API_GATEWAY_STAGE: ResourceTranslator( + CFN_AWS_APIGATEWAY_STAGE, AWS_API_GATEWAY_STAGE_PROPERTY_BUILDER_MAPPING + ), + TF_AWS_API_GATEWAY_RESOURCE: ResourceTranslator( + CFN_AWS_APIGATEWAY_RESOURCE, AWS_API_GATEWAY_RESOURCE_PROPERTY_BUILDER_MAPPING + ), } diff --git a/samcli/hook_packages/terraform/hooks/prepare/resources/__init__.py b/samcli/hook_packages/terraform/hooks/prepare/resources/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/hook_packages/terraform/hooks/prepare/resources/apigw.py b/samcli/hook_packages/terraform/hooks/prepare/resources/apigw.py new file mode 100644 index 0000000000..3d8dfc541b --- /dev/null +++ b/samcli/hook_packages/terraform/hooks/prepare/resources/apigw.py @@ -0,0 +1,52 @@ +""" +Module for API Gateway-related Terraform translation logic +""" + +from typing import Dict + +from samcli.hook_packages.terraform.hooks.prepare.exceptions import OpenAPIBodyNotSupportedException +from samcli.hook_packages.terraform.hooks.prepare.types import References, ResourceTranslationValidator, TFResource + + +class RESTAPITranslationValidator(ResourceTranslationValidator): + def validate(self): + """ + Validation function to check if the API Gateway REST API resource can be + translated and used by AWS SAM CLI + + Raises + ------- + OpenAPIBodyNotSupportedException if the given api_gateway_rest_api resource contains + an OpenAPI spec with a reference to a computed value not parsable by AWS SAM CLI + """ + if _unsupported_reference_field("body", self.resource, self.config_resource): + raise OpenAPIBodyNotSupportedException(self.config_resource.full_address) + + +def _unsupported_reference_field(field: str, resource: Dict, config_resource: TFResource) -> bool: + """ + Check if a field in a resource is a reference to a computed value that is unknown until + apply-time. These fields are not visible to AWS SAM CLI until the Terraform application + is applied, meaning that the field isn't parsable by `sam local` commands and isn't supported + with the current hook implementation. + + Parameters + ---------- + field: str + String representation of the field to looks for + resource: Dict + Dict containing the resource properties to look in + config_resource + The configuration resource that will contain possible references + + Returns + ------- + bool + True if the resource contains an field with a reference not parsable by AWS SAM CLI, + False otherwise + """ + return bool( + not resource.get(field) + and config_resource.attributes.get(field) + and isinstance(config_resource.attributes.get(field), References) + ) diff --git a/samcli/hook_packages/terraform/hooks/prepare/translate.py b/samcli/hook_packages/terraform/hooks/prepare/translate.py index 1c6fc77ab1..3bb5e06e7b 100644 --- a/samcli/hook_packages/terraform/hooks/prepare/translate.py +++ b/samcli/hook_packages/terraform/hooks/prepare/translate.py @@ -3,9 +3,10 @@ This method contains the logic required to translate the `terraform show` JSON output into a Cloudformation template """ +# ruff: noqa: PLR0915 import hashlib import logging -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Type, Union from samcli.hook_packages.terraform.hooks.prepare.constants import ( CFN_CODE_PROPERTIES, @@ -15,6 +16,7 @@ from samcli.hook_packages.terraform.hooks.prepare.property_builder import ( REMOTE_DUMMY_VALUE, RESOURCE_TRANSLATOR_MAPPING, + TF_AWS_API_GATEWAY_REST_API, TF_AWS_LAMBDA_FUNCTION, TF_AWS_LAMBDA_LAYER_VERSION, PropertyBuilderMapping, @@ -25,10 +27,12 @@ _link_lambda_function_to_layer, _resolve_resource_attribute, ) +from samcli.hook_packages.terraform.hooks.prepare.resources.apigw import RESTAPITranslationValidator from samcli.hook_packages.terraform.hooks.prepare.types import ( ConstantValue, References, ResolvedReference, + ResourceTranslationValidator, SamMetadataResource, TFModule, TFResource, @@ -50,6 +54,10 @@ LOG = logging.getLogger(__name__) +TRANSLATION_VALIDATORS: Dict[str, Type[ResourceTranslationValidator]] = { + TF_AWS_API_GATEWAY_REST_API: RESTAPITranslationValidator, +} + def translate_to_cfn(tf_json: dict, output_directory_path: str, terraform_application_dir: str) -> dict: """ @@ -176,12 +184,16 @@ def translate_to_cfn(tf_json: dict, output_directory_path: str, terraform_applic translated_properties = _translate_properties( resource_values, resource_translator.property_builder_mapping, config_resource ) - translated_resource = { + translated_resource: Dict = { "Type": resource_translator.cfn_name, "Properties": translated_properties, - "Metadata": {"SamResourceId": resource_full_address, "SkipBuild": True}, + "Metadata": {"SamResourceId": resource_full_address}, } + # Only set the SkipBuild metadata if it's a resource that can be built + if resource_translator.cfn_name in CFN_CODE_PROPERTIES: + translated_resource["Metadata"]["SkipBuild"] = True + # build CFN logical ID from resource address logical_id = build_cfn_logical_id(resource_full_address) @@ -227,6 +239,10 @@ def translate_to_cfn(tf_json: dict, output_directory_path: str, terraform_applic translated_resource, ) + if resource_type in TRANSLATION_VALIDATORS: + validator = TRANSLATION_VALIDATORS[resource_type](resource=resource, config_resource=config_resource) + validator.validate() + # map s3 object sources to corresponding functions LOG.debug("Mapping S3 object sources to corresponding functions") _map_s3_sources_to_functions(s3_hash_to_source, cfn_dict.get("Resources", {}), lambda_resources_to_code_map) diff --git a/samcli/hook_packages/terraform/hooks/prepare/types.py b/samcli/hook_packages/terraform/hooks/prepare/types.py index 0accdff586..baf1f3f5bf 100644 --- a/samcli/hook_packages/terraform/hooks/prepare/types.py +++ b/samcli/hook_packages/terraform/hooks/prepare/types.py @@ -77,3 +77,23 @@ class SamMetadataResource: current_module_address: Optional[str] resource: Dict config_resource: TFResource + + +class ResourceTranslationValidator: + """ + Base class for a validation class to be used when translating Terraform resources to a metadata file + """ + + resource: Dict + config_resource: TFResource + + def __init__(self, resource, config_resource): + self.resource = resource + self.config_resource = config_resource + + def validate(self): + """ + Function to be called for resources of a given type used for validating + the AWS SAM CLI transformation logic for the given resource + """ + raise NotImplementedError diff --git a/samcli/lib/list/endpoints/endpoints_producer.py b/samcli/lib/list/endpoints/endpoints_producer.py index 13efc821ea..b2d32649b0 100644 --- a/samcli/lib/list/endpoints/endpoints_producer.py +++ b/samcli/lib/list/endpoints/endpoints_producer.py @@ -23,13 +23,13 @@ from samcli.lib.utils.boto_utils import get_client_error_code from samcli.lib.utils.resources import ( AWS_APIGATEWAY_BASE_PATH_MAPPING, + AWS_APIGATEWAY_DOMAIN_NAME, AWS_APIGATEWAY_RESTAPI, AWS_APIGATEWAY_V2_API, + AWS_APIGATEWAY_V2_BASE_PATH_MAPPING, AWS_APIGATEWAY_V2_DOMAIN_NAME, - AWS_APIGATWAY_DOMAIN_NAME, AWS_LAMBDA_FUNCTION, AWS_LAMBDA_FUNCTION_URL, - AWS_APIGATEWAY_v2_BASE_PATH_MAPPING, ) ENDPOINT_RESOURCE_TYPES = {AWS_LAMBDA_FUNCTION, AWS_APIGATEWAY_RESTAPI, AWS_APIGATEWAY_V2_API} @@ -425,7 +425,7 @@ def get_custom_domain_substitute_list( custom_domain_substitute_dict[rest_api_id].append(response_domain_dict.get(domain_id, None)) # Collect custom domain data for APIGW V2 resources - elif resource.get(RESOURCE_TYPE, "") == AWS_APIGATEWAY_v2_BASE_PATH_MAPPING: + elif resource.get(RESOURCE_TYPE, "") == AWS_APIGATEWAY_V2_BASE_PATH_MAPPING: local_mapping = local_stack_resources.get(resource.get(LOGICAL_RESOURCE_ID, ""), {}).get(PROPERTIES, {}) rest_api_id = local_mapping.get(API_ID, "") domain_id = local_mapping.get(DOMAIN_NAME, "") @@ -454,7 +454,7 @@ def get_response_domain_dict(response: Dict[Any, Any]) -> Dict[str, str]: response_domain_dict = {} for resource in response.get(STACK_RESOURCES, {}): if ( - resource.get(RESOURCE_TYPE, "") == AWS_APIGATWAY_DOMAIN_NAME + resource.get(RESOURCE_TYPE, "") == AWS_APIGATEWAY_DOMAIN_NAME or resource.get(RESOURCE_TYPE, "") == AWS_APIGATEWAY_V2_DOMAIN_NAME ): response_domain_dict[ diff --git a/samcli/lib/providers/api_collector.py b/samcli/lib/providers/api_collector.py index d8cfd43105..90a9673514 100644 --- a/samcli/lib/providers/api_collector.py +++ b/samcli/lib/providers/api_collector.py @@ -4,11 +4,14 @@ """ import logging +import os from collections import defaultdict from typing import Dict, Iterator, List, Optional, Set, Tuple, Union from samcli.lib.providers.provider import Api, Cors -from samcli.local.apigw.local_apigw_service import Route +from samcli.lib.utils.colors import Colored +from samcli.local.apigw.authorizers.authorizer import Authorizer +from samcli.local.apigw.route import Route LOG = logging.getLogger(__name__) @@ -18,6 +21,10 @@ def __init__(self) -> None: # Route properties stored per resource. self._route_per_resource: Dict[str, List[Route]] = defaultdict(list) + # Authorizer definitions and default authorizers for each api + self._authorizers_per_resources: Dict[str, Dict[str, Authorizer]] = defaultdict(dict) + self._default_authorizer_per_resource: Dict[str, str] = {} + # processed values to be set before creating the api self._routes: List[Route] = [] self.binary_media_types_set: Set[str] = set() @@ -40,6 +47,75 @@ def __iter__(self) -> Iterator[Tuple[str, List[Route]]]: for logical_id, _ in self._route_per_resource.items(): yield logical_id, self._get_routes(logical_id) + def add_authorizers(self, logical_id: str, authorizers: Dict[str, Authorizer]) -> None: + """ + Adds Authorizers to a API Gateway resource + + Parameters + ---------- + logical_id: str + Logical ID of API Gateway resource + authorizers: Dict[str, Authorizer] + Dictionary with key as authorizer name, and value as Authorizer object + """ + self._authorizers_per_resources[logical_id].update(authorizers) + + def set_default_authorizer(self, logical_id: str, authorizer_name: str) -> None: + """ + Sets the default authorizer used for the API Gateway resource + + Parameters + ---------- + logical_id: str + Logical ID of API Gateway resource + authorizer_name: str + Name of the authorizer to reference + """ + self._default_authorizer_per_resource[logical_id] = authorizer_name + + def _link_authorizers(self) -> None: + """ + Links the routes to the correct authorizer object + """ + for apigw_id, routes in self._route_per_resource.items(): + authorizers = self._authorizers_per_resources.get(apigw_id, {}) + + default_authorizer = self._default_authorizer_per_resource.get(apigw_id, None) + + for route in routes: + if route.authorizer_name is None and not route.use_default_authorizer: + LOG.debug( + "Linking authorizer skipped, route '%s' is set to not use any authorizer.", + route.path, + ) + + continue + + # determine the name of the authorizer object we want to search for in our dict + authorizer_name_lookup = route.authorizer_name or default_authorizer or "" + authorizer_object = authorizers.get(authorizer_name_lookup, None) + + if authorizer_object: + route.authorizer_name = authorizer_name_lookup + route.authorizer_object = authorizer_object + + LOG.debug( + "Linking authorizer '%s', for route '%s'", + route.authorizer_name, + route.path, + ) + + continue + + if not authorizer_object and authorizer_name_lookup: + route.authorizer_name = None + + LOG.info( + "Linking authorizer skipped for route '%s', authorizer '%s' is unsupported or not found", + route.path, + route.authorizer_name, + ) + def add_routes(self, logical_id: str, routes: List[Route]) -> None: """ Stores the given routes tagged under the given logicalId @@ -102,13 +178,29 @@ def get_api(self) -> Api: An Api object with all the properties """ api = Api() + + self._link_authorizers() + routes = self.dedupe_function_routes(self.routes) routes = self.normalize_cors_methods(routes, self.cors) + api.routes = routes api.binary_media_types_set = self.binary_media_types_set api.stage_name = self.stage_name api.stage_variables = self.stage_variables api.cors = self.cors + + for authorizers in self._authorizers_per_resources.values(): + if len(authorizers): + message = f"""{os.linesep}AWS SAM CLI does not guarantee 100% fidelity between authorizers locally +and authorizers deployed on AWS. Any application critical behavior should +be validated thoroughly before deploying to production. + +Testing application behaviour against authorizers deployed on AWS can be done using the sam sync command.{os.linesep}""" + LOG.warning(Colored().yellow(message)) + + break + return api @staticmethod @@ -165,6 +257,8 @@ def dedupe_function_routes(routes: List[Route]) -> List[Route]: payload_format_version=route.payload_format_version, operation_name=route.operation_name, stack_path=route.stack_path, + authorizer_name=route.authorizer_name, + authorizer_object=route.authorizer_object, ) return list(grouped_routes.values()) diff --git a/samcli/lib/providers/cfn_api_provider.py b/samcli/lib/providers/cfn_api_provider.py index dc1b866c7b..081ffca2ed 100644 --- a/samcli/lib/providers/cfn_api_provider.py +++ b/samcli/lib/providers/cfn_api_provider.py @@ -4,20 +4,28 @@ from samcli.commands.local.cli_common.user_exceptions import InvalidSamTemplateException from samcli.commands.local.lib.swagger.integration_uri import LambdaUri +from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator +from samcli.commands.local.lib.validators.lambda_auth_props import ( + LambdaAuthorizerV1Validator, + LambdaAuthorizerV2Validator, +) from samcli.lib.providers.api_collector import ApiCollector from samcli.lib.providers.cfn_base_api_provider import CfnBaseApiProvider from samcli.lib.providers.provider import Stack from samcli.lib.utils.resources import ( + AWS_APIGATEWAY_AUTHORIZER, AWS_APIGATEWAY_METHOD, AWS_APIGATEWAY_RESOURCE, AWS_APIGATEWAY_RESTAPI, AWS_APIGATEWAY_STAGE, AWS_APIGATEWAY_V2_API, + AWS_APIGATEWAY_V2_AUTHORIZER, AWS_APIGATEWAY_V2_INTEGRATION, AWS_APIGATEWAY_V2_ROUTE, AWS_APIGATEWAY_V2_STAGE, ) -from samcli.local.apigw.local_apigw_service import Route +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer +from samcli.local.apigw.route import Route LOG = logging.getLogger(__name__) @@ -30,12 +38,17 @@ class CfnApiProvider(CfnBaseApiProvider): AWS_APIGATEWAY_STAGE, AWS_APIGATEWAY_RESOURCE, AWS_APIGATEWAY_METHOD, + AWS_APIGATEWAY_AUTHORIZER, AWS_APIGATEWAY_V2_API, AWS_APIGATEWAY_V2_INTEGRATION, AWS_APIGATEWAY_V2_ROUTE, AWS_APIGATEWAY_V2_STAGE, + AWS_APIGATEWAY_V2_AUTHORIZER, ] + _METHOD_AUTHORIZER_ID = "AuthorizerId" + _ROUTE_AUTHORIZER_ID = "AuthorizerId" + def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: Optional[str] = None) -> None: """ Extract the Route Object from a given resource and adds it to the RouteCollector. @@ -65,6 +78,9 @@ def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: O if resource_type == AWS_APIGATEWAY_METHOD: self._extract_cloud_formation_method(stack.stack_path, resources, logical_id, resource, collector) + if resource_type == AWS_APIGATEWAY_AUTHORIZER: + self._extract_cloud_formation_authorizer(logical_id, resource, collector) + if resource_type == AWS_APIGATEWAY_V2_API: self._extract_cfn_gateway_v2_api(stack.stack_path, logical_id, resource, collector, cwd=cwd) @@ -74,6 +90,103 @@ def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: O if resource_type == AWS_APIGATEWAY_V2_STAGE: self._extract_cfn_gateway_v2_stage(resources, resource, collector) + if resource_type == AWS_APIGATEWAY_V2_AUTHORIZER: + self._extract_cfn_gateway_v2_authorizer(logical_id, resource, collector) + + @staticmethod + def _extract_cloud_formation_authorizer(logical_id: str, resource: dict, collector: ApiCollector) -> None: + """ + Extract Authorizers from AWS::ApiGateway::Authorizer and add them to the collector. + + Parameters + ---------- + logical_id: str + The logical ID of the Authorizer + resource: dict + The attributes for the Authorizer + collector: ApiCollector + ApiCollector to save Authorizers into + """ + if not LambdaAuthorizerV1Validator.validate(logical_id, resource): + return + + properties = resource.get("Properties", {}) + authorizer_type = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_TYPE, "").lower() + rest_api_id = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_REST_API) + name = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_NAME) + authorizer_uri = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_AUTHORIZER_URI) + identity_source_template = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_IDENTITY_SOURCE, []) + + # this will always return a string since we have already validated above + function_name = cast(str, LambdaUri.get_function_name(authorizer_uri)) + + # split and parse out identity sources + identity_source_list = [] + + if identity_source_template: + for identity_source in identity_source_template.split(","): + trimmed_id_source = identity_source.strip() + + if not IdentitySourceValidator.validate_identity_source(trimmed_id_source): + raise InvalidSamTemplateException( + f"Lambda Authorizer {logical_id} does not contain valid identity sources.", Route.API + ) + + identity_source_list.append(trimmed_id_source) + + validation_expression = properties.get(LambdaAuthorizerV1Validator.AUTHORIZER_VALIDATION) + + lambda_authorizer = LambdaAuthorizer( + payload_version="1.0", + authorizer_name=name, + type=authorizer_type, + lambda_name=function_name, + identity_sources=identity_source_list, + validation_string=validation_expression, + ) + + collector.add_authorizers(rest_api_id, {name: lambda_authorizer}) + + @staticmethod + def _extract_cfn_gateway_v2_authorizer(logical_id: str, resource: dict, collector: ApiCollector) -> None: + """ + Extract Authorizers from AWS::ApiGatewayV2::Authorizer and add them to the collector. + + Parameters + ---------- + logical_id: str + The logical ID of the Authorizer + resource: dict + The attributes for the Authorizer + collector: ApiCollector + ApiCollector to save Authorizers into + """ + if not LambdaAuthorizerV2Validator.validate(logical_id, resource): + return + + properties = resource.get("Properties", {}) + authorizer_type = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_TYPE, "").lower() + api_id = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_API) + name = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_NAME) + authorizer_uri = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_AUTHORIZER_URI) + identity_sources = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_IDENTITY_SOURCE, []) + payload_version = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_PAYLOAD, LambdaAuthorizer.PAYLOAD_V2) + simple_responses = properties.get(LambdaAuthorizerV2Validator.AUTHORIZER_V2_SIMPLE_RESPONSE, False) + + # this will always return a string since we have already validated above + function_name = cast(str, LambdaUri.get_function_name(authorizer_uri)) + + lambda_authorizer = LambdaAuthorizer( + payload_version=payload_version, + authorizer_name=name, + type=authorizer_type, + lambda_name=function_name, + identity_sources=identity_sources, + use_simple_response=simple_responses, + ) + + collector.add_authorizers(api_id, {name: lambda_authorizer}) + @staticmethod def _extract_cloud_formation_route( stack_path: str, @@ -213,12 +326,15 @@ def _extract_cloud_formation_method( if content_handling == CfnApiProvider.METHOD_BINARY_TYPE and content_type: collector.add_binary_media_types(logical_id, [content_type]) + authorizer_name = properties.get(CfnApiProvider._METHOD_AUTHORIZER_ID) + routes = Route( methods=[method], function_name=self._get_integration_function_name(integration), path=resource_path, operation_name=operation_name, stack_path=stack_path, + authorizer_name=authorizer_name, ) collector.add_routes(rest_api_id, [routes]) @@ -331,6 +447,8 @@ def _extract_cfn_gateway_v2_route( "The AWS::ApiGatewayV2::Route {} does not have a correct route key {}".format(logical_id, route_key) ) + authorizer_name = properties.get(CfnApiProvider._ROUTE_AUTHORIZER_ID) + routes = Route( methods=[method], path=path, @@ -339,6 +457,7 @@ def _extract_cfn_gateway_v2_route( payload_format_version=payload_format_version, operation_name=operation_name, stack_path=stack_path, + authorizer_name=authorizer_name, ) collector.add_routes(api_id, [routes]) diff --git a/samcli/lib/providers/cfn_base_api_provider.py b/samcli/lib/providers/cfn_base_api_provider.py index 1fec717b5b..a2ac0a2efe 100644 --- a/samcli/lib/providers/cfn_base_api_provider.py +++ b/samcli/lib/providers/cfn_base_api_provider.py @@ -15,7 +15,7 @@ Cors, Stack, ) -from samcli.local.apigw.local_apigw_service import Route +from samcli.local.apigw.route import Route LOG = logging.getLogger(__name__) @@ -62,7 +62,6 @@ def extract_swagger_route( ---------- stack_path : str Path of the stack the resource is located - logical_id : str Logical ID of the resource body : dict @@ -81,10 +80,20 @@ def extract_swagger_route( reader = SwaggerReader(definition_body=body, definition_uri=uri, working_dir=cwd) swagger = reader.read() parser = SwaggerParser(stack_path, swagger) + + authorizers = parser.get_authorizers(event_type) + default_authorizer = parser.get_default_authorizer(event_type) + routes = parser.get_routes(event_type) + LOG.debug("Found '%s' APIs in resource '%s'", len(routes), logical_id) + LOG.debug("Found '%s' authorizers in resource '%s'", len(authorizers), logical_id) collector.add_routes(logical_id, routes) + collector.add_authorizers(logical_id, authorizers) + + if default_authorizer: + collector.set_default_authorizer(logical_id, default_authorizer) collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template diff --git a/samcli/lib/providers/provider.py b/samcli/lib/providers/provider.py index 11c5fcd1ce..3a2a9039ac 100644 --- a/samcli/lib/providers/provider.py +++ b/samcli/lib/providers/provider.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: # pragma: no cover # avoid circular import, https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING - from samcli.local.apigw.local_apigw_service import Route + from samcli.local.apigw.route import Route LOG = logging.getLogger(__name__) diff --git a/samcli/lib/providers/sam_api_provider.py b/samcli/lib/providers/sam_api_provider.py index e47688a7cf..9684ffa316 100644 --- a/samcli/lib/providers/sam_api_provider.py +++ b/samcli/lib/providers/sam_api_provider.py @@ -3,13 +3,16 @@ import logging from typing import Dict, List, Optional, Tuple, Union, cast +from samcli.commands.local.lib.swagger.integration_uri import LambdaUri from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.lib.providers.api_collector import ApiCollector from samcli.lib.providers.cfn_base_api_provider import CfnBaseApiProvider from samcli.lib.providers.provider import Stack from samcli.lib.utils.colors import Colored from samcli.lib.utils.resources import AWS_SERVERLESS_API, AWS_SERVERLESS_FUNCTION, AWS_SERVERLESS_HTTPAPI -from samcli.local.apigw.local_apigw_service import Route +from samcli.local.apigw.authorizers.authorizer import Authorizer +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer +from samcli.local.apigw.route import Route LOG = logging.getLogger(__name__) @@ -25,6 +28,24 @@ class SamApiProvider(CfnBaseApiProvider): IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi" IMPLICIT_HTTP_API_RESOURCE_ID = "ServerlessHttpApi" + _AUTH = "Auth" + _AUTH_HEADER = "Header" + _AUTH_SIMPLE_RESPONSES = "EnableSimpleResponses" + _AUTHORIZER = "Authorizer" + _AUTHORIZERS = "Authorizers" + _DEFAULT_AUTHORIZER = "DefaultAuthorizer" + _FUNCTION_TYPE = "FunctionPayloadType" + _AUTHORIZER_PAYLOAD = "AuthorizerPayloadFormatVersion" + _FUNCTION_ARN = "FunctionArn" + _VALIDATION_EXPRESSION = "ValidationExpression" + _IDENTITY = "Identity" + _IDENTITY_QUERY = "QueryStrings" + _IDENTITY_HEADERS = "Headers" + _IDENTITY_CONTEXT = "Context" + _IDENTITY_STAGE = "StageVariables" + _API_IDENTITY_SOURCE_PREFIX = "method." + _HTTP_IDENTITY_SOURCE_PREFIX = "$" + def extract_resources(self, stacks: List[Stack], collector: ApiCollector, cwd: Optional[str] = None) -> None: """ Extract the Route Object from a given resource and adds it to the RouteCollector. @@ -98,6 +119,181 @@ def _extract_from_serverless_api( collector.stage_variables = stage_variables collector.cors = cors + auth = properties.get(SamApiProvider._AUTH, {}) + if not auth: + return + + default_authorizer = auth.get(SamApiProvider._DEFAULT_AUTHORIZER) + if default_authorizer: + collector.set_default_authorizer(logical_id, default_authorizer) + + self._extract_authorizers_from_props(logical_id, auth, collector, Route.API) + + @staticmethod + def _extract_request_lambda_authorizer( + auth_name: str, function_name: str, prefix: str, properties: dict, event_type: str + ) -> LambdaAuthorizer: + """ + Generates a request Lambda Authorizer from the given identity object + + Parameters + ---------- + auth_name: str + Name of the authorizer + function_name: str + Name of the Lambda function this authorizer uses + prefix: str + The prefix to prepend to identity sources + properties: dict + The authorizer properties that contains identity sources and authorizer specific properties + event_type: str + The type of API this is (API or HTTP API) + + Returns + ------- + LambdaAuthorizer + The request based Lambda Authorizer object + """ + payload_version = properties.get(SamApiProvider._AUTHORIZER_PAYLOAD) + + if payload_version is not None and not isinstance(payload_version, str): + raise InvalidSamDocumentException( + f"'{SamApiProvider._AUTHORIZER_PAYLOAD}' must be of type string for Lambda Authorizer '{auth_name}'." + ) + + if payload_version not in LambdaAuthorizer.PAYLOAD_VERSIONS and event_type == Route.HTTP: + raise InvalidSamDocumentException( + f"Lambda Authorizer '{auth_name}' must contain a valid " + f"'{SamApiProvider._AUTHORIZER_PAYLOAD}' for HTTP APIs." + ) + + simple_responses = properties.get(SamApiProvider._AUTH_SIMPLE_RESPONSES, False) + if simple_responses and payload_version == LambdaAuthorizer.PAYLOAD_V1: + raise InvalidSamDocumentException( + f"{SamApiProvider._AUTH_SIMPLE_RESPONSES} must be used with the 2.0 " + f"payload format version in Lambda Authorizer '{auth_name}'." + ) + + identity_sources = [] + identity_object = properties.get(SamApiProvider._IDENTITY, {}) + + for query_string in identity_object.get(SamApiProvider._IDENTITY_QUERY, []): + identity_sources.append(f"{prefix}request.querystring.{query_string}") + + for header in identity_object.get(SamApiProvider._IDENTITY_HEADERS, []): + identity_sources.append(f"{prefix}request.header.{header}") + + # context and stageVariables do not have "method." for V1 APIGW + # but the V2 still expects "$" + prefix = SamApiProvider._HTTP_IDENTITY_SOURCE_PREFIX if event_type == Route.HTTP else "" + + for context in identity_object.get(SamApiProvider._IDENTITY_CONTEXT, []): + identity_sources.append(f"{prefix}context.{context}") + + for stage_variable in identity_object.get(SamApiProvider._IDENTITY_STAGE, []): + identity_sources.append(f"{prefix}stageVariables.{stage_variable}") + + return LambdaAuthorizer( + payload_version=payload_version if payload_version else "1.0", + authorizer_name=auth_name, + type=LambdaAuthorizer.REQUEST, + lambda_name=function_name, + identity_sources=identity_sources, + use_simple_response=simple_responses, + ) + + @staticmethod + def _extract_token_lambda_authorizer( + auth_name: str, function_name: str, prefix: str, identity_object: dict + ) -> LambdaAuthorizer: + """ + Generates a token Lambda Authorizer from the given identity object + + Parameters + ---------- + auth_name: str + Name of the authorizer + function_name: str + Name of the Lambda function this authorizer uses + prefix: str + The prefix to prepend to identity sources + identity_object: dict + The identity source object that contains the various identity sources + + Returns + ------- + LambdaAuthorizer + The token based Lambda Authorizer object + """ + validation_expression = identity_object.get(SamApiProvider._VALIDATION_EXPRESSION) + + header = identity_object.get(SamApiProvider._AUTH_HEADER, "Authorization") + header = f"{prefix}request.header.{header}" + + return LambdaAuthorizer( + payload_version=LambdaAuthorizer.PAYLOAD_V1, + authorizer_name=auth_name, + type=LambdaAuthorizer.TOKEN, + lambda_name=function_name, + identity_sources=[header], + validation_string=validation_expression, + ) + + @staticmethod + def _extract_authorizers_from_props(logical_id: str, auth: dict, collector: ApiCollector, event_type: str) -> None: + """ + Extracts Authorizers from the Auth properties section of Serverless resources + + Parameters + ---------- + logical_id: str + The logical ID of the Serverless resource + auth: dict + The Auth property dictionary + collector: ApiCollector + The Api Collector to send the Authorizers to + event_type: str + What kind of API this is (API, HTTP API) + """ + prefix = ( + SamApiProvider._API_IDENTITY_SOURCE_PREFIX + if event_type == Route.API + else SamApiProvider._HTTP_IDENTITY_SOURCE_PREFIX + ) + + authorizers: Dict[str, Authorizer] = {} + + for auth_name, auth_props in auth.get(SamApiProvider._AUTHORIZERS, {}).items(): + authorizer_type = auth_props.get(SamApiProvider._FUNCTION_TYPE, LambdaAuthorizer.TOKEN) + identity_object = auth_props.get(SamApiProvider._IDENTITY, {}) + + function_arn = auth_props.get(SamApiProvider._FUNCTION_ARN) + + if not function_arn: + LOG.debug("Authorizer '%s' is currently unsupported (must be a Lambda Authorizer), skipping", auth_name) + continue + + function_name = LambdaUri.get_function_name(function_arn) + + if not function_name: + LOG.warning("Unable to parse the Lambda ARN for Authorizer '%s', skipping", auth_name) + continue + + if authorizer_type == LambdaAuthorizer.REQUEST.upper() or event_type == Route.HTTP: + authorizers[auth_name] = SamApiProvider._extract_request_lambda_authorizer( + auth_name, function_name, prefix, auth_props, event_type + ) + elif authorizer_type == LambdaAuthorizer.TOKEN.upper(): + authorizers[auth_name] = SamApiProvider._extract_token_lambda_authorizer( + auth_name, function_name, prefix, identity_object + ) + else: + LOG.debug( + "Authorizer '%s' is currently unsupported (not of type TOKEN or REQUEST), skipping", auth_name + ) + + collector.add_authorizers(logical_id, authorizers) + def _extract_from_serverless_http( self, stack_path: str, logical_id: str, api_resource: Dict, collector: ApiCollector, cwd: Optional[str] = None ) -> None: @@ -143,6 +339,16 @@ def _extract_from_serverless_http( collector.stage_variables = stage_variables collector.cors = cors + auth = properties.get(SamApiProvider._AUTH, {}) + if not auth: + return + + default_authorizer = auth.get(SamApiProvider._DEFAULT_AUTHORIZER) + if default_authorizer: + collector.set_default_authorizer(logical_id, default_authorizer) + + self._extract_authorizers_from_props(logical_id, auth, collector, Route.HTTP) + def _extract_routes_from_function( self, stack_path: str, logical_id: str, function_resource: Dict, collector: ApiCollector ) -> None: @@ -241,6 +447,15 @@ def _convert_event_route( "It should either be a LogicalId string or a Ref of a Logical Id string".format(lambda_logical_id) ) + use_default_authorizer = True + + # Find Authorizer + authorizer_name = event_properties.get(SamApiProvider._AUTH, {}).get(SamApiProvider._AUTHORIZER, None) + if authorizer_name == "NONE": + # do not use any authorizers + use_default_authorizer = False + authorizer_name = None + return ( api_resource_id, Route( @@ -250,6 +465,8 @@ def _convert_event_route( event_type=event_type, payload_format_version=payload_format_version, stack_path=stack_path, + authorizer_name=authorizer_name, + use_default_authorizer=use_default_authorizer, ), ) diff --git a/samcli/lib/schemas/schemas_code_manager.py b/samcli/lib/schemas/schemas_code_manager.py index 54a9ca8b3e..112902ed5e 100644 --- a/samcli/lib/schemas/schemas_code_manager.py +++ b/samcli/lib/schemas/schemas_code_manager.py @@ -56,8 +56,8 @@ def do_extract_and_merge_schemas_code(download_location, output_dir, project_nam """ click.echo("Merging code bindings...") cookiecutter_json_path = os.path.join(template_location, "cookiecutter.json") - with open(cookiecutter_json_path, "r") as cookiecutter_json: - cookiecutter_json_data = cookiecutter_json.read() + with open(cookiecutter_json_path, "r") as cookiecutter_file: + cookiecutter_json_data = cookiecutter_file.read() cookiecutter_json = json.loads(cookiecutter_json_data) function_name = cookiecutter_json["function_name"] copy_location = os.path.join(output_dir, project_name, function_name) diff --git a/samcli/lib/telemetry/event.py b/samcli/lib/telemetry/event.py index cff129adaf..ef21a00d1e 100644 --- a/samcli/lib/telemetry/event.py +++ b/samcli/lib/telemetry/event.py @@ -34,6 +34,7 @@ class UsedFeature(Enum): CDK = "CDK" INIT_WITH_APPLICATION_INSIGHTS = "InitWithApplicationInsights" CFNLint = "CFNLint" + INVOKED_CUSTOM_LAMBDA_AUTHORIZERS = "InvokedLambdaAuthorizers" class EventType: @@ -84,22 +85,29 @@ class Event: event_value: str # Validated by EventType.get_accepted_values to never be an arbitrary string thread_id = threading.get_ident() # The thread ID; used to group Events from the same command run time_stamp: str + exception_name: Optional[str] - def __init__(self, event_name: str, event_value: str): + def __init__(self, event_name: str, event_value: str, exception_name: Optional[str] = None): Event._verify_event(event_name, event_value) self.event_name = EventName(event_name) self.event_value = event_value self.time_stamp = str(datetime.utcnow())[:-3] # format microseconds from 6 -> 3 figures to allow SQL casting + self.exception_name = exception_name def __eq__(self, other): - return self.event_name == other.event_name and self.event_value == other.event_value + return ( + self.event_name == other.event_name + and self.event_value == other.event_value + and self.exception_name == other.exception_name + ) def __repr__(self): return ( f"Event(event_name={self.event_name.value}, " f"event_value={self.event_value}, " f"thread_id={self.thread_id}, " - f"time_stamp={self.time_stamp})" + f"time_stamp={self.time_stamp})", + f"exception_name={self.exception_name})", ) def to_json(self): @@ -108,6 +116,7 @@ def to_json(self): "event_value": self.event_value, "thread_id": self.thread_id, "time_stamp": self.time_stamp, + "exception_name": self.exception_name, } @staticmethod @@ -134,7 +143,9 @@ class EventTracker: MAX_EVENTS: int = 50 # Maximum number of events to store before sending @staticmethod - def track_event(event_name: str, event_value: str): + def track_event( + event_name: str, event_value: str, session_id: Optional[str] = None, exception_name: Optional[str] = None + ): """Method to track an event where and when it occurs. Place this method in the codepath of the event that you would @@ -149,6 +160,10 @@ def track_event(event_name: str, event_value: str): event_value: str The value of the Event. Must be a valid EventType value for the passed event_name, or an EventCreationError will be thrown. + session_id: Optional[str] + The session ID to set to link back to the original command run + exception_name: Optional[str] + The name of the exception that this event encountered when tracking a feature Examples -------- @@ -161,18 +176,18 @@ def track_event(event_name: str, event_value: str): EventTracker.track_event("UsedFeature", "FeatureY") return some_value """ + + if session_id: + EventTracker._session_id = session_id + try: should_send: bool = False with EventTracker._event_lock: - EventTracker._events.append(Event(event_name, event_value)) + EventTracker._events.append(Event(event_name, event_value, exception_name=exception_name)) + # Get the session ID (needed for multithreading sending) - if not EventTracker._session_id: - try: - ctx = Context.get_current_context() - if ctx: - EventTracker._session_id = ctx.session_id - except RuntimeError: - LOG.debug("EventTracker: Unable to obtain session ID") + EventTracker._set_session_id() + if len(EventTracker._events) >= EventTracker.MAX_EVENTS: should_send = True if should_send: @@ -199,6 +214,19 @@ def send_events() -> threading.Thread: send_thread.start() return send_thread + @staticmethod + def _set_session_id() -> None: + """ + Get the session ID from click and save it locally. + """ + if not EventTracker._session_id: + try: + ctx = Context.get_current_context() + if ctx: + EventTracker._session_id = ctx.session_id + except RuntimeError: + LOG.debug("EventTracker: Unable to obtain session ID") + @staticmethod def _send_events_in_thread(): """Send the current list of Events via Telemetry.""" diff --git a/samcli/lib/telemetry/metric.py b/samcli/lib/telemetry/metric.py index 1cdb22613d..681d2adc64 100644 --- a/samcli/lib/telemetry/metric.py +++ b/samcli/lib/telemetry/metric.py @@ -27,6 +27,7 @@ from samcli.lib.telemetry.event import EventTracker from samcli.lib.telemetry.project_metadata import get_git_remote_origin_url, get_initial_commit_hash, get_project_name from samcli.lib.telemetry.telemetry import Telemetry +from samcli.lib.telemetry.user_agent import get_user_agent_string from samcli.lib.warnings.sam_cli_warning import TemplateWarningsChecker LOG = logging.getLogger(__name__) @@ -438,6 +439,10 @@ def _add_common_metric_attributes(self): self._data["pyversion"] = platform.python_version() self._data["samcliVersion"] = samcli_version + user_agent = get_user_agent_string() + if user_agent: + self._data["userAgent"] = user_agent + @staticmethod def _default_session_id() -> Optional[str]: """ diff --git a/samcli/lib/telemetry/user_agent.py b/samcli/lib/telemetry/user_agent.py new file mode 100644 index 0000000000..875a4d8788 --- /dev/null +++ b/samcli/lib/telemetry/user_agent.py @@ -0,0 +1,21 @@ +""" +Reads user agent information from environment and returns it for telemetry consumption +""" +import os +import re +from typing import Optional + +USER_AGENT_ENV_VAR = "AWS_TOOLING_USER_AGENT" + +# Should accept format: ${AGENT_NAME}/${AGENT_VERSION} +# AWS_Toolkit-For-VSCode/1.62.0 +# AWS-Toolkit-For-JetBrains/1.60-223 +# AWS-Toolkit-For-JetBrains/1.60.0-223 +ACCEPTED_USER_AGENT_FORMAT = re.compile(r"^[A-Za-z0-9\-_]{1,64}/\d+\.\d+(\.\d+)?(\-[A-Za-z0-9]{0,16})?$") + + +def get_user_agent_string() -> Optional[str]: + user_agent = os.environ.get(USER_AGENT_ENV_VAR, "").strip() + if user_agent and ACCEPTED_USER_AGENT_FORMAT.match(user_agent): + return user_agent + return None diff --git a/samcli/lib/utils/managed_cloudformation_stack.py b/samcli/lib/utils/managed_cloudformation_stack.py index 9410bcef39..064abaf8e0 100644 --- a/samcli/lib/utils/managed_cloudformation_stack.py +++ b/samcli/lib/utils/managed_cloudformation_stack.py @@ -300,8 +300,9 @@ def _generate_stack_parameters( parameters = [] if parameter_overrides: for key, value in parameter_overrides.items(): - if isinstance(value, Collection) and not isinstance(value, str): + norm_value = value + if isinstance(norm_value, Collection) and not isinstance(norm_value, str): # Assumption: values don't include commas or spaces. Need to refactor to handle such a case if needed. - value = ",".join(value) - parameters.append({"ParameterKey": key, "ParameterValue": value}) + norm_value = ",".join(norm_value) + parameters.append({"ParameterKey": key, "ParameterValue": norm_value}) return parameters diff --git a/samcli/lib/utils/preview_runtimes.py b/samcli/lib/utils/preview_runtimes.py index 3d6a5b9662..c17ae95cf8 100644 --- a/samcli/lib/utils/preview_runtimes.py +++ b/samcli/lib/utils/preview_runtimes.py @@ -4,4 +4,4 @@ """ from typing import Set -PREVIEW_RUNTIMES: Set[str] = {"python3.10"} +PREVIEW_RUNTIMES: Set[str] = set() diff --git a/samcli/lib/utils/resources.py b/samcli/lib/utils/resources.py index 9d6f95fca8..875f3bd997 100644 --- a/samcli/lib/utils/resources.py +++ b/samcli/lib/utils/resources.py @@ -21,14 +21,16 @@ AWS_APIGATEWAY_METHOD = "AWS::ApiGateway::Method" AWS_APIGATEWAY_DEPLOYMENT = "AWS::ApiGateway::Deployment" AWS_APIGATEWAY_BASE_PATH_MAPPING = "AWS::ApiGateway::BasePathMapping" -AWS_APIGATWAY_DOMAIN_NAME = "AWS::ApiGateway::DomainName" +AWS_APIGATEWAY_DOMAIN_NAME = "AWS::ApiGateway::DomainName" +AWS_APIGATEWAY_AUTHORIZER = "AWS::ApiGateway::Authorizer" AWS_APIGATEWAY_V2_API = "AWS::ApiGatewayV2::Api" AWS_APIGATEWAY_V2_INTEGRATION = "AWS::ApiGatewayV2::Integration" AWS_APIGATEWAY_V2_ROUTE = "AWS::ApiGatewayV2::Route" AWS_APIGATEWAY_V2_STAGE = "AWS::ApiGatewayV2::Stage" -AWS_APIGATEWAY_v2_BASE_PATH_MAPPING = "AWS::ApiGatewayV2::ApiMapping" +AWS_APIGATEWAY_V2_BASE_PATH_MAPPING = "AWS::ApiGatewayV2::ApiMapping" AWS_APIGATEWAY_V2_DOMAIN_NAME = "AWS::ApiGatewayV2::DomainName" +AWS_APIGATEWAY_V2_AUTHORIZER = "AWS::ApiGatewayV2::Authorizer" # SFN AWS_SERVERLESS_STATEMACHINE = "AWS::Serverless::StateMachine" diff --git a/samcli/local/apigw/authorizers/__init__.py b/samcli/local/apigw/authorizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/samcli/local/apigw/authorizers/authorizer.py b/samcli/local/apigw/authorizers/authorizer.py new file mode 100644 index 0000000000..17101b06d2 --- /dev/null +++ b/samcli/local/apigw/authorizers/authorizer.py @@ -0,0 +1,11 @@ +""" +Base Authorizer class definition +""" +from dataclasses import dataclass + + +@dataclass +class Authorizer: + payload_version: str + authorizer_name: str + type: str diff --git a/samcli/local/apigw/authorizers/lambda_authorizer.py b/samcli/local/apigw/authorizers/lambda_authorizer.py new file mode 100644 index 0000000000..cc08061b87 --- /dev/null +++ b/samcli/local/apigw/authorizers/lambda_authorizer.py @@ -0,0 +1,536 @@ +""" +Custom Lambda Authorizer class definition +""" +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from json import JSONDecodeError, loads +from typing import Any, Dict, List, Optional, Tuple, Type +from urllib.parse import parse_qsl + +from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator +from samcli.local.apigw.authorizers.authorizer import Authorizer +from samcli.local.apigw.exceptions import InvalidLambdaAuthorizerResponse, InvalidSecurityDefinition +from samcli.local.apigw.route import Route + +_RESPONSE_PRINCIPAL_ID = "principalId" +_RESPONSE_CONTEXT = "context" +_RESPONSE_POLICY_DOCUMENT = "policyDocument" +_RESPONSE_IAM_STATEMENT = "Statement" +_RESPONSE_IAM_EFFECT = "Effect" +_RESPONSE_IAM_EFFECT_ALLOW = "Allow" +_RESPONSE_IAM_ACTION = "Action" +_RESPONSE_IAM_RESOURCE = "Resource" +_SIMPLE_RESPONSE_IS_AUTH = "isAuthorized" +_IAM_INVOKE_ACTION = "execute-api:Invoke" + + +class IdentitySource(ABC): + def __init__(self, identity_source: str): + """ + Abstract class representing an identity source validator + + Paramters + --------- + identity_source: str + The identity source without any prefix + """ + self.identity_source = identity_source + + def is_valid(self, **kwargs) -> bool: + """ + Validates if the identity source is present + + Parameters + ---------- + kwargs: dict + Key word arguments to search in + + Returns + ------- + bool: + True if the identity source is present + """ + return self.find_identity_value(**kwargs) is not None + + @abstractmethod + def find_identity_value(self, **kwargs) -> Any: + """ + Returns the identity value, if found + """ + + def __eq__(self, other) -> bool: + return ( + isinstance(other, IdentitySource) + and self.identity_source == other.identity_source + and self.__class__ == other.__class__ + ) + + +class HeaderIdentitySource(IdentitySource): + def find_identity_value(self, **kwargs) -> Optional[str]: + """ + Finds the header value that the identity source corresponds to + + Parameters + ---------- + kwargs + Keyword arguments that should contain `headers` + + Returns + ------- + Optional[str] + The string value of the header if it is found, otherwise None + """ + headers = kwargs.get("headers", {}) + value = headers.get(self.identity_source) + + return str(value) if value else None + + def is_valid(self, **kwargs) -> bool: + """ + Validates whether the required header is present and matches the + validation expression, if defined. + + Parameters + ---------- + kwargs: dict + Keyword arugments containing the incoming sources and validation expression + + Returns + ------- + bool + True if present and valid + """ + identity_source = self.find_identity_value(**kwargs) + validation_expression = kwargs.get("validation_expression") + + if validation_expression and identity_source is not None: + return re.match(validation_expression, identity_source) is not None + + return identity_source is not None + + +class QueryIdentitySource(IdentitySource): + def find_identity_value(self, **kwargs) -> Optional[str]: + """ + Finds the query string value that the identity source corresponds to + + Parameters + ---------- + kwargs + Keyword arguments that should contain `querystring` + + Returns + ------- + Optional[str] + The string value of the query parameter if one is found, otherwise None + """ + query_string = kwargs.get("querystring", "") + + if not query_string: + return None + + query_string_list: List[Tuple[str, str]] = parse_qsl(query_string) + + for key, value in query_string_list: + if key == self.identity_source and value: + return value + + return None + + +class ContextIdentitySource(IdentitySource): + def find_identity_value(self, **kwargs) -> Optional[str]: + """ + Finds the context value that the identity source corresponds to + + Parameters + ---------- + kwargs + Keyword arguments that should contain `context` + + Returns + ------- + Optional[str] + The string value of the context variable if it is found, otherwise None + """ + context = kwargs.get("context", {}) + value = context.get(self.identity_source) + + return str(value) if value else None + + +class StageVariableIdentitySource(IdentitySource): + def find_identity_value(self, **kwargs) -> Optional[str]: + """ + Finds the stage variable value that the identity source corresponds to + + Parameters + ---------- + kwargs + Keyword arguments that should contain `stageVariables` + + Returns + ------- + Optional[str] + The stage variable if it is found, otherwise None + """ + stage_variables = kwargs.get("stageVariables", {}) + value = stage_variables.get(self.identity_source) + + return str(value) if value else None + + +@dataclass +class LambdaAuthorizer(Authorizer): + TOKEN = "token" + REQUEST = "request" + VALID_TYPES = [TOKEN, REQUEST] + + PAYLOAD_V1 = "1.0" + PAYLOAD_V2 = "2.0" + PAYLOAD_VERSIONS = [PAYLOAD_V1, PAYLOAD_V2] + + def __init__( + self, + authorizer_name: str, + type: str, + lambda_name: str, + identity_sources: List[str], + payload_version: str, + validation_string: Optional[str] = None, + use_simple_response: bool = False, + ): + """ + Creates a Lambda Authorizer class + + Parameters + ---------- + authorizer_name: str + The name of the Lambda Authorizer + type: str + The type of authorizer this is (token or request) + lambda_name: str + The name of the Lambda function this authorizer invokes + identity_sources: List[str] + A list of strings that this authorizer uses + payload_version: str + The payload format version (1.0 or 2.0) + validation_string: Optional[str] = None + The regular expression that can be used to validate headers + use_simple_responses: bool = False + Boolean representing whether to return a simple response or not + """ + self.authorizer_name = authorizer_name + self.lambda_name = lambda_name + self.type = type + self.validation_string = validation_string + self.payload_version = payload_version + self.use_simple_response = use_simple_response + + self._parse_identity_sources(identity_sources) + + def __eq__(self, other): + return ( + isinstance(other, LambdaAuthorizer) + and self.lambda_name == other.lambda_name + and sorted(self._identity_sources_raw) == sorted(other._identity_sources_raw) + and self.validation_string == other.validation_string + and self.use_simple_response == other.use_simple_response + and self.payload_version == other.payload_version + and self.authorizer_name == other.authorizer_name + and self.type == other.type + ) + + @property + def identity_sources(self) -> List[IdentitySource]: + """ + The list of identity source validation objects + + Returns + ------- + List[IdentitySource] + A list of concrete identity source validation objects + """ + return self._identity_sources + + @identity_sources.setter + def identity_sources(self, identity_sources: List[str]) -> None: + """ + Parses and sets the identity source validation objects + + Parameters + ---------- + identity_sources: List[str] + A list of strings of identity sources + """ + self._parse_identity_sources(identity_sources) + + def _parse_identity_sources(self, identity_sources: List[str]) -> None: + """ + Helper function to create identity source validation objects + + Parameters + ---------- + identity_sources: List[str] + A list of identity sources to parse + """ + + # validate incoming identity sources first + for source in identity_sources: + is_valid = IdentitySourceValidator.validate_identity_source( + source, Route.API + ) or IdentitySourceValidator.validate_identity_source(source, Route.HTTP) + + if not is_valid: + raise InvalidSecurityDefinition( + f"Invalid identity source '{source}' for Lambda authorizer '{self.authorizer_name}" + ) + + identity_source_type = { + "method.request.header.": HeaderIdentitySource, + "$request.header.": HeaderIdentitySource, + "method.request.querystring.": QueryIdentitySource, + "$request.querystring.": QueryIdentitySource, + "context.": ContextIdentitySource, + "$context.": ContextIdentitySource, + "stageVariables.": StageVariableIdentitySource, + "$stageVariables.": StageVariableIdentitySource, + } + + self._identity_sources_raw = identity_sources + self._identity_sources = [] + + for identity_source in self._identity_sources_raw: + for prefix, identity_source_object in identity_source_type.items(): + if identity_source.startswith(prefix): + # get the stuff after the prefix + # and create the corresponding identity source object + property = identity_source[len(prefix) :] + + # NOTE (lucashuy): + # need to ignore the typing here so that mypy doesn't complain + # about instantiating an abstract class + # + # `identity_source_object` (which comes from `identity_source_type`) + # is always a concrete class + identity_source_validator = identity_source_object(identity_source=property) # type: ignore + + self._identity_sources.append(identity_source_validator) + + break + + def is_valid_response(self, response: str, method_arn: str) -> bool: + """ + Validates whether a Lambda authorizer request is authenticated or not. + + Parameters + ---------- + response: str + JSON string containing the output from a Lambda authorizer + method_arn: str + The method ARN of the route that invoked the Lambda authorizer + + Returns + ------- + bool + True if the request is properly authenticated + """ + try: + json_response = loads(response) + except (ValueError, JSONDecodeError): + raise InvalidLambdaAuthorizerResponse( + f"Authorizer {self.authorizer_name} return an invalid response payload" + ) + + if self.payload_version == LambdaAuthorizer.PAYLOAD_V2 and self.use_simple_response: + return self._validate_simple_response(json_response) + + # validate IAM policy document + LambdaAuthorizerIAMPolicyValidator.validate_policy_document(self.authorizer_name, json_response) + LambdaAuthorizerIAMPolicyValidator.validate_statement(self.authorizer_name, json_response) + + return self._is_resource_authorized(json_response, method_arn) + + def _is_resource_authorized(self, response: dict, method_arn: str) -> bool: + """ + Validate the if the current method ARN is actually authorized + + Parameters + ---------- + response: dict + The response output from the Lambda authorizer (should be in IAM format) + method_arn: str + The route's method ARN + + Returns + ------- + bool + True if authorized + """ + policy_document = response.get(_RESPONSE_POLICY_DOCUMENT, {}) + all_statements = policy_document.get(_RESPONSE_IAM_STATEMENT, []) + + for statement in all_statements: + if ( + statement.get(_RESPONSE_IAM_ACTION) != _IAM_INVOKE_ACTION + or statement.get(_RESPONSE_IAM_EFFECT) != _RESPONSE_IAM_EFFECT_ALLOW + ): + continue + + for resource_arn in statement.get(_RESPONSE_IAM_RESOURCE, []): + # form a regular expression from the possible wildcard resource ARN + regex_method_arn = resource_arn.replace("*", ".+").replace("?", ".") + regex_method_arn += "$" + + if re.match(regex_method_arn, method_arn): + return True + + return False + + def _validate_simple_response(self, response: dict) -> bool: + """ + Helper method to validate if a Lambda authorizer response using simple responses is valid and authorized + + Parameters + ---------- + response: dict + JSON object containing required simple response paramters + + Returns + ------- + bool + True if the request is authorized + """ + is_authorized = response.get(_SIMPLE_RESPONSE_IS_AUTH) + + if is_authorized is None or not isinstance(is_authorized, bool): + raise InvalidLambdaAuthorizerResponse( + f"Authorizer {self.authorizer_name} is missing or contains an invalid " f"{_SIMPLE_RESPONSE_IS_AUTH}" + ) + + return is_authorized + + def get_context(self, response: str) -> Dict[str, Any]: + """ + Returns the context (if set) from the authorizer response and appends the principalId to it. + + Parameters + ---------- + response: str + Output from Lambda authorizer + + Returns + ------- + Dict[str, Any] + The built authorizer context object + """ + invalid_message = f"Authorizer {self.authorizer_name} return an invalid response payload" + + try: + json_response = loads(response) + except (ValueError, JSONDecodeError) as ex: + raise InvalidLambdaAuthorizerResponse(invalid_message) from ex + + if not isinstance(json_response, dict): + raise InvalidLambdaAuthorizerResponse(invalid_message) + + built_context = json_response.get(_RESPONSE_CONTEXT, {}) + + if not isinstance(built_context, dict): + raise InvalidLambdaAuthorizerResponse(invalid_message) + + principal_id = json_response.get(_RESPONSE_PRINCIPAL_ID) + if principal_id: + # only V1 response contains this ID in the output + built_context[_RESPONSE_PRINCIPAL_ID] = principal_id + + return built_context + + +@dataclass +class LambdaAuthorizerIAMPolicyPropertyValidator: + property_key: str + property_type: Type + + def is_valid(self, response: dict) -> bool: + """ + Validates whether the property is present and of the correct type + + Parameters + ---------- + response: dict + The response output from the Lambda authorizer (should be in IAM format) + + Returns + ------- + bool + True if present and of correct type + """ + value = response.get(self.property_key) + + return value is not None and isinstance(value, self.property_type) + + +class LambdaAuthorizerIAMPolicyValidator: + @staticmethod + def validate_policy_document(auth_name: str, response: dict) -> None: + """ + Validate the properties of a Lambda authorizer response at the root level + + Parameters + ---------- + auth_name: str + Name of the authorizer + response: dict + The response output from the Lambda authorizer (should be in IAM format) + """ + validators = { + _RESPONSE_PRINCIPAL_ID: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_PRINCIPAL_ID, str), + _RESPONSE_POLICY_DOCUMENT: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_POLICY_DOCUMENT, dict), + } + + for prop_name, validator in validators.items(): + if not validator.is_valid(response): + raise InvalidLambdaAuthorizerResponse( + f"Authorizer '{auth_name}' contains an invalid or " f"missing '{prop_name}' from response" + ) + + @staticmethod + def validate_statement(auth_name: str, response: dict) -> None: + """ + Validate the Statement(s) of a Lambda authorizer response's policy document + + Parameters + ---------- + auth_name: str + Name of the authorizer + response: dict + The response output from the Lambda authorizer (should be in IAM format) + """ + policy_document = response.get(_RESPONSE_POLICY_DOCUMENT, {}) + + all_statements = policy_document.get(_RESPONSE_IAM_STATEMENT) + if not all_statements or not isinstance(all_statements, list) or not len(all_statements) > 0: + raise InvalidLambdaAuthorizerResponse( + f"Authorizer '{auth_name}' contains an invalid or " f"missing '{_RESPONSE_IAM_STATEMENT}' from response" + ) + + validators = { + _RESPONSE_IAM_ACTION: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_ACTION, str), + _RESPONSE_IAM_EFFECT: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_EFFECT, str), + _RESPONSE_IAM_RESOURCE: LambdaAuthorizerIAMPolicyPropertyValidator(_RESPONSE_IAM_RESOURCE, list), + } + + for statement in all_statements: + if not isinstance(statement, dict): + raise InvalidLambdaAuthorizerResponse( + f"Authorizer '{auth_name}' policy document must be a list of objects" + ) + + for prop_name, validator in validators.items(): + if not validator.is_valid(statement): + raise InvalidLambdaAuthorizerResponse( + f"Authorizer '{auth_name}' policy document contains an invalid '{prop_name}'" + ) diff --git a/samcli/local/apigw/event_constructor.py b/samcli/local/apigw/event_constructor.py new file mode 100644 index 0000000000..441d94d20e --- /dev/null +++ b/samcli/local/apigw/event_constructor.py @@ -0,0 +1,327 @@ +""" +Lambda event construction and generation +""" + +import base64 +import logging +from datetime import datetime +from time import time +from typing import Any, Dict + +from samcli.local.apigw.path_converter import PathConverter +from samcli.local.events.api_event import ( + ApiGatewayLambdaEvent, + ApiGatewayV2LambdaEvent, + ContextHTTP, + ContextIdentity, + RequestContext, + RequestContextV2, +) + +LOG = logging.getLogger(__name__) + + +def construct_v1_event( + flask_request, port, binary_types, stage_name=None, stage_variables=None, operation_name=None +) -> Dict[str, Any]: + """ + Helper method that constructs the Event to be passed to Lambda + + :param request flask_request: Flask Request + :param port: the port number + :param binary_types: list of binary types + :param stage_name: Optional, the stage name string + :param stage_variables: Optional, API Gateway Stage Variables + :return: JSON object + """ + + identity = ContextIdentity(source_ip=flask_request.remote_addr) + + endpoint = PathConverter.convert_path_to_api_gateway(flask_request.endpoint) + method = flask_request.method + protocol = flask_request.environ.get("SERVER_PROTOCOL", "HTTP/1.1") + host = flask_request.host + + request_data = flask_request.get_data() + + request_mimetype = flask_request.mimetype + + is_base_64 = _should_base64_encode(binary_types, request_mimetype) + + if is_base_64: + LOG.debug("Incoming Request seems to be binary. Base64 encoding the request data before sending to Lambda.") + request_data = base64.b64encode(request_data) + + if request_data: + # Flask does not parse/decode the request data. We should do it ourselves + # Note(xinhol): here we change request_data's type from bytes to str and confused mypy + # We might want to consider to use a new variable here. + request_data = request_data.decode("utf-8") + + query_string_dict, multi_value_query_string_dict = _query_string_params(flask_request) + + context = RequestContext( + resource_path=endpoint, + http_method=method, + stage=stage_name, + identity=identity, + path=endpoint, + protocol=protocol, + domain_name=host, + operation_name=operation_name, + ) + + headers_dict, multi_value_headers_dict = _event_headers(flask_request, port) + + event = ApiGatewayLambdaEvent( + http_method=method, + body=request_data, + resource=endpoint, + request_context=context, + query_string_params=query_string_dict, + multi_value_query_string_params=multi_value_query_string_dict, + headers=headers_dict, + multi_value_headers=multi_value_headers_dict, + path_parameters=flask_request.view_args, + path=flask_request.path, + is_base_64_encoded=is_base_64, + stage_variables=stage_variables, + ) + + event_dict = event.to_dict() + LOG.debug("Constructed Event 1.0 to invoke Lambda. Event: %s", event_dict) + return event_dict + + +def construct_v2_event_http( + flask_request, + port, + binary_types, + stage_name=None, + stage_variables=None, + route_key=None, + request_time_epoch=int(time()), + request_time=datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000"), +) -> Dict[str, Any]: + """ + Helper method that constructs the Event 2.0 to be passed to Lambda + + https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html + + :param request flask_request: Flask Request + :param port: the port number + :param binary_types: list of binary types + :param stage_name: Optional, the stage name string + :param stage_variables: Optional, API Gateway Stage Variables + :param route_key: Optional, the route key for the route + :return: JSON object + """ + method = flask_request.method + + request_data = flask_request.get_data() + + request_mimetype = flask_request.mimetype + + is_base_64 = _should_base64_encode(binary_types, request_mimetype) + + if is_base_64: + LOG.debug("Incoming Request seems to be binary. Base64 encoding the request data before sending to Lambda.") + request_data = base64.b64encode(request_data) + + if request_data is not None: + # Flask does not parse/decode the request data. We should do it ourselves + request_data = request_data.decode("utf-8") + + query_string_dict = _query_string_params_v_2_0(flask_request) + + cookies = _event_http_cookies(flask_request) + headers = _event_http_headers(flask_request, port) + context_http = ContextHTTP(method=method, path=flask_request.path, source_ip=flask_request.remote_addr) + context = RequestContextV2( + http=context_http, + route_key=route_key, + stage=stage_name, + request_time_epoch=request_time_epoch, + request_time=request_time, + ) + + event = ApiGatewayV2LambdaEvent( + route_key=route_key, + raw_path=flask_request.path, + raw_query_string=flask_request.query_string.decode("utf-8"), + cookies=cookies, + headers=headers, + query_string_params=query_string_dict, + request_context=context, + body=request_data, + path_parameters=flask_request.view_args, + is_base_64_encoded=is_base_64, + stage_variables=stage_variables, + ) + + event_dict = event.to_dict() + LOG.debug("Constructed Event Version 2.0 to invoke Lambda. Event: %s", event_dict) + return event_dict + + +def _query_string_params(flask_request): + """ + Constructs an APIGW equivalent query string dictionary + + Parameters + ---------- + flask_request request + Request from Flask + + Returns dict (str: str), dict (str: list of str) + ------- + Empty dict if no query params where in the request otherwise returns a dictionary of key to value + + """ + query_string_dict = {} + multi_value_query_string_dict = {} + + # Flask returns an ImmutableMultiDict so convert to a dictionary that becomes + # a dict(str: list) then iterate over + for query_string_key, query_string_list in flask_request.args.lists(): + query_string_value_length = len(query_string_list) + + # if the list is empty, default to empty string + if not query_string_value_length: + query_string_dict[query_string_key] = "" + multi_value_query_string_dict[query_string_key] = [""] + else: + query_string_dict[query_string_key] = query_string_list[-1] + multi_value_query_string_dict[query_string_key] = query_string_list + + return query_string_dict, multi_value_query_string_dict + + +def _query_string_params_v_2_0(flask_request): + """ + Constructs an APIGW equivalent query string dictionary using the 2.0 format + https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#2.0 + + Parameters + ---------- + flask_request request + Request from Flask + + Returns dict (str: str) + ------- + Empty dict if no query params where in the request otherwise returns a dictionary of key to value + + """ + query_string_dict = {} + + # Flask returns an ImmutableMultiDict so convert to a dictionary that becomes + # a dict(str: list) then iterate over + query_string_dict = { + query_string_key: ",".join(query_string_list) + for query_string_key, query_string_list in flask_request.args.lists() + } + + return query_string_dict + + +def _event_headers(flask_request, port): + """ + Constructs an APIGW equivalent headers dictionary + + Parameters + ---------- + flask_request request + Request from Flask + int port + Forwarded Port + cors_headers dict + Dict of the Cors properties + + Returns dict (str: str), dict (str: list of str) + ------- + Returns a dictionary of key to list of strings + + """ + headers_dict = {} + multi_value_headers_dict = {} + + # Multi-value request headers is not really supported by Flask. + # See https://github.com/pallets/flask/issues/850 + for header_key in flask_request.headers.keys(): + headers_dict[header_key] = flask_request.headers.get(header_key) + multi_value_headers_dict[header_key] = flask_request.headers.getlist(header_key) + + headers_dict["X-Forwarded-Proto"] = flask_request.scheme + multi_value_headers_dict["X-Forwarded-Proto"] = [flask_request.scheme] + + headers_dict["X-Forwarded-Port"] = str(port) + multi_value_headers_dict["X-Forwarded-Port"] = [str(port)] + return headers_dict, multi_value_headers_dict + + +def _event_http_cookies(flask_request): + """ + All cookie headers in the request are combined with commas. + + https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html + + Parameters + ---------- + flask_request request + Request from Flask + + Returns list + ------- + Returns a list of cookies + + """ + cookies = [] + for cookie_key in flask_request.cookies.keys(): + cookies.append(f"{cookie_key}={flask_request.cookies.get(cookie_key)}") + return cookies + + +def _event_http_headers(flask_request, port): + """ + Duplicate headers are combined with commas. + + https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html + + Parameters + ---------- + flask_request request + Request from Flask + + Returns list + ------- + Returns a list of cookies + + """ + headers = {} + # Multi-value request headers is not really supported by Flask. + # See https://github.com/pallets/flask/issues/850 + for header_key in flask_request.headers.keys(): + headers[header_key] = flask_request.headers.get(header_key) + + headers["X-Forwarded-Proto"] = flask_request.scheme + headers["X-Forwarded-Port"] = str(port) + return headers + + +def _should_base64_encode(binary_types, request_mimetype): + """ + Whether or not to encode the data from the request to Base64 + + Parameters + ---------- + binary_types list(basestring) + Corresponds to self.binary_types (aka. what is parsed from SAM Template + request_mimetype str + Mimetype for the request + + Returns + ------- + True if the data should be encoded to Base64 otherwise False + + """ + return request_mimetype in binary_types or "*/*" in binary_types diff --git a/samcli/local/apigw/exceptions.py b/samcli/local/apigw/exceptions.py new file mode 100644 index 0000000000..474e51560b --- /dev/null +++ b/samcli/local/apigw/exceptions.py @@ -0,0 +1,52 @@ +""" +Exceptions used by API Gateway service +""" +from samcli.commands.exceptions import UserException + + +class LambdaResponseParseException(Exception): + """ + An exception raised when we fail to parse the response for Lambda + """ + + +class PayloadFormatVersionValidateException(Exception): + """ + An exception raised when validation of payload format version fails + """ + + +class MultipleAuthorizerException(UserException): + """ + An exception raised when user lists more than one Authorizer + """ + + +class IncorrectOasWithDefaultAuthorizerException(UserException): + """ + An exception raised when the user provides root level Authorizers using the wrong OpenAPI Specification versions + """ + + +class InvalidOasVersion(UserException): + """ + An exception raised when the user provides an invalid OpenAPI Specificaion version + """ + + +class InvalidSecurityDefinition(UserException): + """ + An exception raised when the user provides an invalid security definition + """ + + +class InvalidLambdaAuthorizerResponse(UserException): + """ + An exception raised when a Lambda authorizer returns an invalid response format + """ + + +class AuthorizerUnauthorizedRequest(UserException): + """ + An exception raised when the request is not authorized by the authorizer + """ diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index d81bfdf980..af3befbc70 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -1,23 +1,36 @@ """API Gateway Local Service""" + import base64 -import io import json import logging from datetime import datetime +from io import BytesIO from time import time -from typing import List, Optional +from typing import Any, Dict, List, Optional -from flask import Flask, request +from flask import Flask, Request, request from werkzeug.datastructures import Headers from werkzeug.routing import BaseConverter from werkzeug.serving import WSGIRequestHandler from samcli.commands.local.lib.exceptions import UnsupportedInlineCodeError -from samcli.lib.providers.provider import Cors +from samcli.commands.local.lib.local_lambda import LocalLambdaRunner +from samcli.lib.providers.provider import Api, Cors +from samcli.lib.telemetry.event import EventName, EventTracker, UsedFeature from samcli.lib.utils.stream_writer import StreamWriter +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer +from samcli.local.apigw.event_constructor import construct_v1_event, construct_v2_event_http +from samcli.local.apigw.exceptions import ( + AuthorizerUnauthorizedRequest, + InvalidLambdaAuthorizerResponse, + InvalidSecurityDefinition, + LambdaResponseParseException, + PayloadFormatVersionValidateException, +) +from samcli.local.apigw.path_converter import PathConverter +from samcli.local.apigw.route import Route +from samcli.local.apigw.service_error_responses import ServiceErrorResponses from samcli.local.events.api_event import ( - ApiGatewayLambdaEvent, - ApiGatewayV2LambdaEvent, ContextHTTP, ContextIdentity, RequestContext, @@ -26,91 +39,9 @@ from samcli.local.lambdafn.exceptions import FunctionNotFound from samcli.local.services.base_local_service import BaseLocalService, LambdaOutputParser -from .path_converter import PathConverter -from .service_error_responses import ServiceErrorResponses - LOG = logging.getLogger(__name__) -class LambdaResponseParseException(Exception): - """ - An exception raised when we fail to parse the response for Lambda - """ - - -class PayloadFormatVersionValidateException(Exception): - """ - An exception raised when validation of payload format version fails - """ - - -class Route: - API = "Api" - HTTP = "HttpApi" - ANY_HTTP_METHODS = ["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"] - - def __init__( - self, - function_name: Optional[str], - path: str, - methods: List[str], - event_type: str = API, - payload_format_version: Optional[str] = None, - is_default_route: bool = False, - operation_name=None, - stack_path: str = "", - ): - """ - Creates an ApiGatewayRoute - - :param list(str) methods: http method - :param function_name: Name of the Lambda function this API is connected to - :param str path: Path off the base url - :param str event_type: Type of the event. "Api" or "HttpApi" - :param str payload_format_version: version of payload format - :param bool is_default_route: determines if the default route or not - :param string operation_name: Swagger operationId for the route - :param str stack_path: path of the stack the route is located - """ - self.methods = self.normalize_method(methods) - self.function_name = function_name - self.path = path - self.event_type = event_type - self.payload_format_version = payload_format_version - self.is_default_route = is_default_route - self.operation_name = operation_name - self.stack_path = stack_path - - def __eq__(self, other): - return ( - isinstance(other, Route) - and sorted(self.methods) == sorted(other.methods) - and self.function_name == other.function_name - and self.path == other.path - and self.operation_name == other.operation_name - and self.stack_path == other.stack_path - ) - - def __hash__(self): - route_hash = hash(f"{self.stack_path}-{self.function_name}-{self.path}") - for method in sorted(self.methods): - route_hash *= hash(method) - return route_hash - - def normalize_method(self, methods): - """ - Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all - supported Http Methods on Api Gateway. - - :param list methods: Http methods - :return list: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) - """ - methods = [method.upper() for method in methods] - if "ANY" in methods: - return self.ANY_HTTP_METHODS - return methods - - class CatchAllPathConverter(BaseConverter): regex = ".+" weight = 300 @@ -127,7 +58,15 @@ class LocalApigwService(BaseLocalService): _DEFAULT_PORT = 3000 _DEFAULT_HOST = "127.0.0.1" - def __init__(self, api, lambda_runner, static_dir=None, port=None, host=None, stderr=None): + def __init__( + self, + api: Api, + lambda_runner: LocalLambdaRunner, + static_dir: Optional[str] = None, + port: Optional[int] = None, + host: Optional[str] = None, + stderr: Optional[StreamWriter] = None, + ): """ Creates an ApiGatewayService @@ -152,9 +91,22 @@ def __init__(self, api, lambda_runner, static_dir=None, port=None, host=None, st self.api = api self.lambda_runner = lambda_runner self.static_dir = static_dir - self._dict_of_routes = {} + self._dict_of_routes: Dict[str, Route] = {} self.stderr = stderr + self._click_session_id = None + + try: + # save the session ID for telemetry event sending + from samcli.cli.context import Context + + ctx = Context.get_current_context() + + if ctx: + self._click_session_id = ctx.session_id + except RuntimeError: + LOG.debug("Not able to get click context in APIGW service") + def create(self): """ Creates a Flask Application that can be started. @@ -209,7 +161,7 @@ def create(self): self._construct_error_handling() - def _add_catch_all_path(self, methods, path, route): + def _add_catch_all_path(self, methods: List[str], path: str, route: Route): """ Add the catch all route to the _app and the dictionary of routes. @@ -234,6 +186,9 @@ def _add_catch_all_path(self, methods, path, route): payload_format_version=route.payload_format_version, is_default_route=True, stack_path=route.stack_path, + authorizer_name=route.authorizer_name, + authorizer_object=route.authorizer_object, + use_default_authorizer=route.use_default_authorizer, ) def _generate_route_keys(self, methods, path): @@ -277,6 +232,386 @@ def _construct_error_handling(self): # Something went wrong self._app.register_error_handler(500, ServiceErrorResponses.lambda_failure_response) + def _create_method_arn(self, flask_request: Request, event_type: str) -> str: + """ + Creates a method ARN with fake AWS values + + Parameters + ---------- + flask_request: Request + Flask request object to get method and endpoint + event_type: str + Type of event (API or HTTP) + + Returns + ------- + str + A built method ARN with fake values + """ + context = RequestContext() if event_type == Route.API else RequestContextV2() + method, endpoint = self.get_request_methods_endpoints(flask_request) + + return ( + f"arn:aws:execute-api:us-east-1:{context.account_id}:" # type: ignore + f"{context.api_id}/{self.api.stage_name}/{method}{endpoint}" + ) + + def _generate_lambda_token_authorizer_event( + self, flask_request: Request, route: Route, lambda_authorizer: LambdaAuthorizer + ) -> dict: + """ + Creates a Lambda authorizer token event + + Parameters + ---------- + flask_request: Request + Flask request object to get method and endpoint + route: Route + Route object representing the endpoint to be invoked later + lambda_authorizer: LambdaAuthorizer + The Lambda authorizer the route is using + + Returns + ------- + dict + Basic dictionary containing a type and authorizationToken + """ + method_arn = self._create_method_arn(flask_request, route.event_type) + + headers = {"headers": flask_request.headers} + + # V1 token based authorizers should always have a single identity source + if len(lambda_authorizer.identity_sources) != 1: + raise InvalidSecurityDefinition( + "An invalid token based Lambda Authorizer was found, there should be one header identity source" + ) + + identity_source = lambda_authorizer.identity_sources[0] + authorization_token = identity_source.find_identity_value(**headers) + + return { + "type": LambdaAuthorizer.TOKEN.upper(), + "authorizationToken": str(authorization_token), + "methodArn": method_arn, + } + + def _generate_lambda_request_authorizer_event_http( + self, lambda_authorizer_payload: str, identity_values: list, method_arn: str + ) -> dict: + """ + Helper method to generate part of the event required for different payload versions + for API Gateway V2 + + Parameters + ---------- + lambda_authorizer_payload: str + The payload version of the Lambda authorizer + identity_values: list + A list of string identity values + method_arn: str + The method ARN for the endpoint + + Returns + ------- + dict + Dictionary containing partial Lambda authorizer event + """ + if lambda_authorizer_payload == LambdaAuthorizer.PAYLOAD_V2: + # payload 2.0 expects a list of strings + return {"identitySource": identity_values, "routeArn": method_arn} + else: + # payload 1.0 expects a comma deliminated string that is the same + # for both identitySource and authorizationToken + all_identity_values_string = ",".join(identity_values) + + return { + "identitySource": all_identity_values_string, + "authorizationToken": all_identity_values_string, + "methodArn": method_arn, + } + + def _generate_lambda_request_authorizer_event( + self, flask_request: Request, route: Route, lambda_authorizer: LambdaAuthorizer + ) -> dict: + """ + Creates a Lambda authorizer request event + + Parameters + ---------- + flask_request: Request + Flask request object to get method and endpoint + route: Route + Route object representing the endpoint to be invoked later + lambda_authorizer: LambdaAuthorizer + The Lambda authorizer the route is using + + Returns + ------- + dict + A Lambda authorizer event + """ + method_arn = self._create_method_arn(flask_request, route.event_type) + method, endpoint = self.get_request_methods_endpoints(flask_request) + + # generate base lambda event and load it into a dict + lambda_event = self._generate_lambda_event(flask_request, route, method, endpoint) + lambda_event.update({"type": LambdaAuthorizer.REQUEST.upper()}) + + # build context to form identity values + context = ( + self._build_v1_context(route) + if lambda_authorizer.payload_version == LambdaAuthorizer.PAYLOAD_V1 + else self._build_v2_context(route) + ) + + if route.event_type == Route.API: + # v1 requests only add method ARN + lambda_event.update({"methodArn": method_arn}) + else: + # kwargs to pass into identity value finder + kwargs = { + "headers": flask_request.headers, + "querystring": flask_request.query_string.decode("utf-8"), + "context": context, + "stageVariables": self.api.stage_variables, + } + + # find and build all identity sources + all_identity_values = [] + for identity_source in lambda_authorizer.identity_sources: + value = identity_source.find_identity_value(**kwargs) + + if value: + # all identity values must be a string + all_identity_values.append(str(value)) + + lambda_event.update( + self._generate_lambda_request_authorizer_event_http( + lambda_authorizer.payload_version, all_identity_values, method_arn + ) + ) + + return lambda_event + + def _generate_lambda_authorizer_event( + self, flask_request: Request, route: Route, lambda_authorizer: LambdaAuthorizer + ) -> dict: + """ + Generate a Lambda authorizer event + + Parameters + ---------- + flask_request: Request + Flask request object to get method and endpoint + route: Route + Route object representing the endpoint to be invoked later + lambda_authorizer: LambdaAuthorizer + The Lambda authorizer the route is using + + Returns + ------- + str + A JSON string containing event properties + """ + authorizer_events = { + LambdaAuthorizer.TOKEN: self._generate_lambda_token_authorizer_event, + LambdaAuthorizer.REQUEST: self._generate_lambda_request_authorizer_event, + } + + kwargs: Dict[str, Any] = { + "flask_request": flask_request, + "route": route, + "lambda_authorizer": lambda_authorizer, + } + + return authorizer_events[lambda_authorizer.type](**kwargs) + + def _generate_lambda_event(self, flask_request: Request, route: Route, method: str, endpoint: str) -> dict: + """ + Helper function to generate the correct Lambda event + + Parameters + ---------- + flask_request: Request + The global Flask Request object + route: Route + The Route that was called + method: str + The method of the request (eg. GET, POST) from the Flask request + endpoint: str + The endpoint of the request from the Flask request + + Returns + ------- + str + JSON string of event properties + """ + # TODO: Rewrite the logic below to use version 2.0 when an invalid value is provided + # the Lambda Event 2.0 is only used for the HTTP API gateway with defined payload format version equal 2.0 + # or none, as the default value to be used is 2.0 + # https://docs.aws.amazon.com/apigatewayv2/latest/api-reference/apis-apiid-integrations.html#apis-apiid-integrations-prop-createintegrationinput-payloadformatversion + if route.event_type == Route.HTTP and route.payload_format_version in [None, "2.0"]: + apigw_endpoint = PathConverter.convert_path_to_api_gateway(endpoint) + route_key = self._v2_route_key(method, apigw_endpoint, route.is_default_route) + + return construct_v2_event_http( + flask_request=flask_request, + port=self.port, + binary_types=self.api.binary_media_types, + stage_name=self.api.stage_name, + stage_variables=self.api.stage_variables, + route_key=route_key, + ) + + # For Http Apis with payload version 1.0, API Gateway never sends the OperationName. + route_key = route.operation_name if route.event_type == Route.API else None + + return construct_v1_event( + flask_request=flask_request, + port=self.port, + binary_types=self.api.binary_media_types, + stage_name=self.api.stage_name, + stage_variables=self.api.stage_variables, + operation_name=route_key, + ) + + def _build_v1_context(self, route: Route) -> Dict[str, Any]: + """ + Helper function to a 1.0 request context + + Parameters + ---------- + route: Route + The Route object that was invoked + + Returns + ------- + dict + JSON object containing context variables + """ + identity = ContextIdentity(source_ip=request.remote_addr) + + protocol = request.environ.get("SERVER_PROTOCOL", "HTTP/1.1") + host = request.host + + operation_name = route.operation_name if route.event_type == Route.API else None + + endpoint = PathConverter.convert_path_to_api_gateway(request.endpoint) + method = request.method + + context = RequestContext( + resource_path=endpoint, + http_method=method, + stage=self.api.stage_name, + identity=identity, + path=endpoint, + protocol=protocol, + domain_name=host, + operation_name=operation_name, + ) + + return context.to_dict() + + def _build_v2_context(self, route: Route) -> Dict[str, Any]: + """ + Helper function to a 2.0 request context + + Parameters + ---------- + route: Route + The Route object that was invoked + + Returns + ------- + dict + JSON object containing context variables + """ + endpoint = PathConverter.convert_path_to_api_gateway(request.endpoint) + method = request.method + + apigw_endpoint = PathConverter.convert_path_to_api_gateway(endpoint) + route_key = self._v2_route_key(method, apigw_endpoint, route.is_default_route) + + request_time_epoch = int(time()) + request_time = datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000") + + context_http = ContextHTTP(method=method, path=request.path, source_ip=request.remote_addr) + context = RequestContextV2( + http=context_http, + route_key=route_key, + stage=self.api.stage_name, + request_time_epoch=request_time_epoch, + request_time=request_time, + ) + + return context.to_dict() + + def _valid_identity_sources(self, route: Route) -> bool: + """ + Validates if the route contains all the valid identity sources defined in the route's Lambda Authorizer + + Parameters + ---------- + route: Route + the Route object that contains the Lambda Authorizer definition + + Returns + ------- + bool + true if all the identity sources are present and valid + """ + lambda_auth = route.authorizer_object + + if not isinstance(lambda_auth, LambdaAuthorizer): + return False + + identity_sources = lambda_auth.identity_sources + + context = ( + self._build_v1_context(route) + if lambda_auth.payload_version == LambdaAuthorizer.PAYLOAD_V1 + else self._build_v2_context(route) + ) + + kwargs = { + "headers": request.headers, + "querystring": request.query_string.decode("utf-8"), + "context": context, + "stageVariables": self.api.stage_variables, + "validation_expression": lambda_auth.validation_string, + } + + for validator in identity_sources: + if not validator.is_valid(**kwargs): + return False + + return True + + def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str: + """ + Helper method to invoke a function and setup stdout+stderr + + Parameters + ---------- + lambda_function_name: str + The name of the Lambda function to invoke + event: dict + The event object to pass into the Lambda function + + Returns + ------- + str + A string containing the output from the Lambda function + """ + with BytesIO() as stdout: + event_str = json.dumps(event, sort_keys=True) + stdout_writer = StreamWriter(stdout, auto_flush=True) + + self.lambda_runner.invoke(lambda_function_name, event_str, stdout=stdout_writer, stderr=self.stderr) + lambda_response, _ = LambdaOutputParser.get_lambda_output(stdout) + + return lambda_response + def _request_handler(self, **kwargs): """ We handle all requests to the host:port. The general flow of handling a request is as follows @@ -301,8 +636,9 @@ def _request_handler(self, **kwargs): Response object """ - route = self._get_current_route(request) + route: Route = self._get_current_route(request) cors_headers = Cors.cors_to_headers(self.api.cors) + lambda_authorizer = route.authorizer_object # payloadFormatVersion can only support 2 values: "1.0" and "2.0" # so we want to do strict validation to make sure it has proper value if provided @@ -317,59 +653,60 @@ def _request_handler(self, **kwargs): headers = Headers(cors_headers) return self.service_response("", headers, 200) + # check for LambdaAuthorizer since that is the only authorizer we currently support + if isinstance(lambda_authorizer, LambdaAuthorizer) and not self._valid_identity_sources(route): + return ServiceErrorResponses.missing_lambda_auth_identity_sources() + try: - # TODO: Rewrite the logic below to use version 2.0 when an invalid value is provided - # the Lambda Event 2.0 is only used for the HTTP API gateway with defined payload format version equal 2.0 - # or none, as the default value to be used is 2.0 - # https://docs.aws.amazon.com/apigatewayv2/latest/api-reference/apis-apiid-integrations.html#apis-apiid-integrations-prop-createintegrationinput-payloadformatversion - if route.event_type == Route.HTTP and route.payload_format_version in [None, "2.0"]: - apigw_endpoint = PathConverter.convert_path_to_api_gateway(endpoint) - route_key = self._v2_route_key(method, apigw_endpoint, route.is_default_route) - event = self._construct_v_2_0_event_http( - request, - self.port, - self.api.binary_media_types, - self.api.stage_name, - self.api.stage_variables, - route_key, - ) - elif route.event_type == Route.API: - # The OperationName is only sent to the Lambda Function from API Gateway V1(Rest API). - event = self._construct_v_1_0_event( - request, - self.port, - self.api.binary_media_types, - self.api.stage_name, - self.api.stage_variables, - route.operation_name, - ) - else: - # For Http Apis with payload version 1.0, API Gateway never sends the OperationName. - event = self._construct_v_1_0_event( - request, - self.port, - self.api.binary_media_types, - self.api.stage_name, - self.api.stage_variables, - None, - ) + route_lambda_event = self._generate_lambda_event(request, route, method, endpoint) + auth_lambda_event = None + + if lambda_authorizer: + auth_lambda_event = self._generate_lambda_authorizer_event(request, route, lambda_authorizer) except UnicodeDecodeError as error: LOG.error("UnicodeDecodeError while processing HTTP request: %s", error) return ServiceErrorResponses.lambda_failure_response() - stdout_stream = io.BytesIO() - stdout_stream_writer = StreamWriter(stdout_stream, auto_flush=True) + try: + lambda_authorizer_exception = None + auth_service_error = None + + if lambda_authorizer: + self._invoke_parse_lambda_authorizer(lambda_authorizer, auth_lambda_event, route_lambda_event, route) + except AuthorizerUnauthorizedRequest as ex: + auth_service_error = ServiceErrorResponses.lambda_authorizer_unauthorized() + lambda_authorizer_exception = ex + except InvalidLambdaAuthorizerResponse as ex: + auth_service_error = ServiceErrorResponses.lambda_failure_response() + lambda_authorizer_exception = ex + finally: + exception_name = type(lambda_authorizer_exception).__name__ if lambda_authorizer_exception else None + + EventTracker.track_event( + event_name=EventName.USED_FEATURE.value, + event_value=UsedFeature.INVOKED_CUSTOM_LAMBDA_AUTHORIZERS.value, + session_id=self._click_session_id, + exception_name=exception_name, + ) + + if lambda_authorizer_exception: + LOG.error("Lambda authorizer failed to invoke successfully: %s", lambda_authorizer_exception.message) + return auth_service_error + + endpoint_service_error = None try: - self.lambda_runner.invoke(route.function_name, event, stdout=stdout_stream_writer, stderr=self.stderr) + # invoke the route's Lambda function + lambda_response = self._invoke_lambda_function(route.function_name, route_lambda_event) except FunctionNotFound: - return ServiceErrorResponses.lambda_not_found_response() + endpoint_service_error = ServiceErrorResponses.lambda_not_found_response() except UnsupportedInlineCodeError: - return ServiceErrorResponses.not_implemented_locally( + endpoint_service_error = ServiceErrorResponses.not_implemented_locally( "Inline code is not supported for sam local commands. Please write your code in a separate file." ) - lambda_response, _ = LambdaOutputParser.get_lambda_output(stdout_stream) + if endpoint_service_error: + return endpoint_service_error try: if route.event_type == Route.HTTP and ( @@ -388,6 +725,42 @@ def _request_handler(self, **kwargs): return self.service_response(body, headers, status_code) + def _invoke_parse_lambda_authorizer( + self, lambda_authorizer: LambdaAuthorizer, auth_lambda_event: dict, route_lambda_event: dict, route: Route + ) -> None: + """ + Helper method to invoke and parse the output of a Lambda authorizer + + Parameters + ---------- + lambda_authorizer: LambdaAuthorizer + The route's Lambda authorizer + auth_lambda_event: dict + The event to pass to the Lambda authorizer + route_lambda_event: dict + The event to pass into the route + route: Route + The route that is being called + """ + lambda_auth_response = self._invoke_lambda_function(lambda_authorizer.lambda_name, auth_lambda_event) + method_arn = self._create_method_arn(request, route.event_type) + + if not lambda_authorizer.is_valid_response(lambda_auth_response, method_arn): + raise AuthorizerUnauthorizedRequest(f"Request is not authorized for {method_arn}") + + # update route context to include any context that may have been passed from authorizer + original_context = route_lambda_event.get("requestContext", {}) + + context = lambda_authorizer.get_context(lambda_auth_response) + + # payload V2 responses have the passed context under the "lambda" key + if route.event_type == Route.HTTP and route.payload_format_version in [None, "2.0"]: + original_context.update({"authorizer": {"lambda": context}}) + else: + original_context.update({"authorizer": context}) + + route_lambda_event.update({"requestContext": original_context}) + def _get_current_route(self, flask_request): """ Get the route (Route) based on the current request @@ -671,311 +1044,3 @@ def _merge_response_headers(headers, multi_headers): processed_headers.add(header, headers[header]) return processed_headers - - @staticmethod - def _construct_v_1_0_event( - flask_request, port, binary_types, stage_name=None, stage_variables=None, operation_name=None - ): - """ - Helper method that constructs the Event to be passed to Lambda - - :param request flask_request: Flask Request - :param port: the port number - :param binary_types: list of binary types - :param stage_name: Optional, the stage name string - :param stage_variables: Optional, API Gateway Stage Variables - :return: String representing the event - """ - # pylint: disable-msg=too-many-locals - - identity = ContextIdentity(source_ip=flask_request.remote_addr) - - endpoint = PathConverter.convert_path_to_api_gateway(flask_request.endpoint) - method = flask_request.method - protocol = flask_request.environ.get("SERVER_PROTOCOL", "HTTP/1.1") - host = flask_request.host - - request_data = flask_request.get_data() - - request_mimetype = flask_request.mimetype - - is_base_64 = LocalApigwService._should_base64_encode(binary_types, request_mimetype) - - if is_base_64: - LOG.debug("Incoming Request seems to be binary. Base64 encoding the request data before sending to Lambda.") - request_data = base64.b64encode(request_data) - - if request_data: - # Flask does not parse/decode the request data. We should do it ourselves - # Note(xinhol): here we change request_data's type from bytes to str and confused mypy - # We might want to consider to use a new variable here. - request_data = request_data.decode("utf-8") - - query_string_dict, multi_value_query_string_dict = LocalApigwService._query_string_params(flask_request) - - context = RequestContext( - resource_path=endpoint, - http_method=method, - stage=stage_name, - identity=identity, - path=endpoint, - protocol=protocol, - domain_name=host, - operation_name=operation_name, - ) - - headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port) - - event = ApiGatewayLambdaEvent( - http_method=method, - body=request_data, - resource=endpoint, - request_context=context, - query_string_params=query_string_dict, - multi_value_query_string_params=multi_value_query_string_dict, - headers=headers_dict, - multi_value_headers=multi_value_headers_dict, - path_parameters=flask_request.view_args, - path=flask_request.path, - is_base_64_encoded=is_base_64, - stage_variables=stage_variables, - ) - - event_str = json.dumps(event.to_dict(), sort_keys=True) - LOG.debug("Constructed String representation of Event to invoke Lambda. Event: %s", event_str) - return event_str - - @staticmethod - def _construct_v_2_0_event_http( - flask_request, - port, - binary_types, - stage_name=None, - stage_variables=None, - route_key=None, - request_time_epoch=int(time()), - request_time=datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000"), - ): - """ - Helper method that constructs the Event 2.0 to be passed to Lambda - - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html - - :param request flask_request: Flask Request - :param port: the port number - :param binary_types: list of binary types - :param stage_name: Optional, the stage name string - :param stage_variables: Optional, API Gateway Stage Variables - :param route_key: Optional, the route key for the route - :return: String representing the event - """ - # pylint: disable-msg=too-many-locals - method = flask_request.method - - request_data = flask_request.get_data() - - request_mimetype = flask_request.mimetype - - is_base_64 = LocalApigwService._should_base64_encode(binary_types, request_mimetype) - - if is_base_64: - LOG.debug("Incoming Request seems to be binary. Base64 encoding the request data before sending to Lambda.") - request_data = base64.b64encode(request_data) - - if request_data is not None: - # Flask does not parse/decode the request data. We should do it ourselves - request_data = request_data.decode("utf-8") - - query_string_dict = LocalApigwService._query_string_params_v_2_0(flask_request) - - cookies = LocalApigwService._event_http_cookies(flask_request) - headers = LocalApigwService._event_http_headers(flask_request, port) - context_http = ContextHTTP(method=method, path=flask_request.path, source_ip=flask_request.remote_addr) - context = RequestContextV2( - http=context_http, - route_key=route_key, - stage=stage_name, - request_time_epoch=request_time_epoch, - request_time=request_time, - ) - - event = ApiGatewayV2LambdaEvent( - route_key=route_key, - raw_path=flask_request.path, - raw_query_string=flask_request.query_string.decode("utf-8"), - cookies=cookies, - headers=headers, - query_string_params=query_string_dict, - request_context=context, - body=request_data, - path_parameters=flask_request.view_args, - is_base_64_encoded=is_base_64, - stage_variables=stage_variables, - ) - - event_str = json.dumps(event.to_dict()) - LOG.debug("Constructed String representation of Event Version 2.0 to invoke Lambda. Event: %s", event_str) - return event_str - - @staticmethod - def _query_string_params(flask_request): - """ - Constructs an APIGW equivalent query string dictionary - - Parameters - ---------- - flask_request request - Request from Flask - - Returns dict (str: str), dict (str: list of str) - ------- - Empty dict if no query params where in the request otherwise returns a dictionary of key to value - - """ - query_string_dict = {} - multi_value_query_string_dict = {} - - # Flask returns an ImmutableMultiDict so convert to a dictionary that becomes - # a dict(str: list) then iterate over - for query_string_key, query_string_list in flask_request.args.lists(): - query_string_value_length = len(query_string_list) - - # if the list is empty, default to empty string - if not query_string_value_length: - query_string_dict[query_string_key] = "" - multi_value_query_string_dict[query_string_key] = [""] - else: - query_string_dict[query_string_key] = query_string_list[-1] - multi_value_query_string_dict[query_string_key] = query_string_list - - return query_string_dict, multi_value_query_string_dict - - @staticmethod - def _query_string_params_v_2_0(flask_request): - """ - Constructs an APIGW equivalent query string dictionary using the 2.0 format - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#2.0 - - Parameters - ---------- - flask_request request - Request from Flask - - Returns dict (str: str) - ------- - Empty dict if no query params where in the request otherwise returns a dictionary of key to value - - """ - query_string_dict = {} - - # Flask returns an ImmutableMultiDict so convert to a dictionary that becomes - # a dict(str: list) then iterate over - query_string_dict = { - query_string_key: ",".join(query_string_list) - for query_string_key, query_string_list in flask_request.args.lists() - } - - return query_string_dict - - @staticmethod - def _event_headers(flask_request, port): - """ - Constructs an APIGW equivalent headers dictionary - - Parameters - ---------- - flask_request request - Request from Flask - int port - Forwarded Port - cors_headers dict - Dict of the Cors properties - - Returns dict (str: str), dict (str: list of str) - ------- - Returns a dictionary of key to list of strings - - """ - headers_dict = {} - multi_value_headers_dict = {} - - # Multi-value request headers is not really supported by Flask. - # See https://github.com/pallets/flask/issues/850 - for header_key in flask_request.headers.keys(): - headers_dict[header_key] = flask_request.headers.get(header_key) - multi_value_headers_dict[header_key] = flask_request.headers.getlist(header_key) - - headers_dict["X-Forwarded-Proto"] = flask_request.scheme - multi_value_headers_dict["X-Forwarded-Proto"] = [flask_request.scheme] - - headers_dict["X-Forwarded-Port"] = str(port) - multi_value_headers_dict["X-Forwarded-Port"] = [str(port)] - return headers_dict, multi_value_headers_dict - - @staticmethod - def _event_http_cookies(flask_request): - """ - All cookie headers in the request are combined with commas. - - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html - - Parameters - ---------- - flask_request request - Request from Flask - - Returns list - ------- - Returns a list of cookies - - """ - cookies = [] - for cookie_key in flask_request.cookies.keys(): - cookies.append("{}={}".format(cookie_key, flask_request.cookies.get(cookie_key))) - return cookies - - @staticmethod - def _event_http_headers(flask_request, port): - """ - Duplicate headers are combined with commas. - - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html - - Parameters - ---------- - flask_request request - Request from Flask - - Returns list - ------- - Returns a list of cookies - - """ - headers = {} - # Multi-value request headers is not really supported by Flask. - # See https://github.com/pallets/flask/issues/850 - for header_key in flask_request.headers.keys(): - headers[header_key] = flask_request.headers.get(header_key) - - headers["X-Forwarded-Proto"] = flask_request.scheme - headers["X-Forwarded-Port"] = str(port) - return headers - - @staticmethod - def _should_base64_encode(binary_types, request_mimetype): - """ - Whether or not to encode the data from the request to Base64 - - Parameters - ---------- - binary_types list(basestring) - Corresponds to self.binary_types (aka. what is parsed from SAM Template - request_mimetype str - Mimetype for the request - - Returns - ------- - True if the data should be encoded to Base64 otherwise False - - """ - return request_mimetype in binary_types or "*/*" in binary_types diff --git a/samcli/local/apigw/route.py b/samcli/local/apigw/route.py new file mode 100644 index 0000000000..20858265d6 --- /dev/null +++ b/samcli/local/apigw/route.py @@ -0,0 +1,85 @@ +""" +Route definition for local start-api +""" +from typing import List, Optional + +from samcli.local.apigw.authorizers.authorizer import Authorizer + + +class Route: + API = "Api" + HTTP = "HttpApi" + ANY_HTTP_METHODS = ["GET", "DELETE", "PUT", "POST", "HEAD", "OPTIONS", "PATCH"] + + def __init__( + self, + function_name: Optional[str], + path: str, + methods: List[str], + event_type: str = API, + payload_format_version: Optional[str] = None, + is_default_route: bool = False, + operation_name=None, + stack_path: str = "", + authorizer_name: Optional[str] = None, + authorizer_object: Optional[Authorizer] = None, + use_default_authorizer: bool = True, + ): + """ + Creates an ApiGatewayRoute + + :param list(str) methods: http method + :param function_name: Name of the Lambda function this API is connected to + :param str path: Path off the base url + :param str event_type: Type of the event. "Api" or "HttpApi" + :param str payload_format_version: version of payload format + :param bool is_default_route: determines if the default route or not + :param string operation_name: Swagger operationId for the route + :param str stack_path: path of the stack the route is located + :param str authorizer_name: the authorizer this route is using, if any + :param Authorizer authorizer_object: the authorizer object this route is using, if any + :param bool use_default_authorizer: whether or not to use a default authorizer (if defined) + """ + self.methods = self.normalize_method(methods) + self.function_name = function_name + self.path = path + self.event_type = event_type + self.payload_format_version = payload_format_version + self.is_default_route = is_default_route + self.operation_name = operation_name + self.stack_path = stack_path + self.authorizer_name = authorizer_name + self.authorizer_object = authorizer_object + self.use_default_authorizer = use_default_authorizer + + def __eq__(self, other): + return ( + isinstance(other, Route) + and sorted(self.methods) == sorted(other.methods) + and self.function_name == other.function_name + and self.path == other.path + and self.operation_name == other.operation_name + and self.stack_path == other.stack_path + and self.authorizer_name == other.authorizer_name + and self.authorizer_object == other.authorizer_object + and self.use_default_authorizer == other.use_default_authorizer + ) + + def __hash__(self): + route_hash = hash(f"{self.stack_path}-{self.function_name}-{self.path}") + for method in sorted(self.methods): + route_hash *= hash(method) + return route_hash + + def normalize_method(self, methods): + """ + Normalizes Http Methods. Api Gateway allows a Http Methods of ANY. This is a special verb to denote all + supported Http Methods on Api Gateway. + + :param list methods: Http methods + :return list: Either the input http_method or one of the _ANY_HTTP_METHODS (normalized Http Methods) + """ + methods = [method.upper() for method in methods] + if "ANY" in methods: + return self.ANY_HTTP_METHODS + return methods diff --git a/samcli/local/apigw/service_error_responses.py b/samcli/local/apigw/service_error_responses.py index 9603aeca51..689b9172e9 100644 --- a/samcli/local/apigw/service_error_responses.py +++ b/samcli/local/apigw/service_error_responses.py @@ -1,16 +1,47 @@ """Class container to hold common Service Responses""" -from flask import jsonify, make_response +from flask import Response, jsonify, make_response class ServiceErrorResponses: _NO_LAMBDA_INTEGRATION = {"message": "No function defined for resource method"} _MISSING_AUTHENTICATION = {"message": "Missing Authentication Token"} _LAMBDA_FAILURE = {"message": "Internal server error"} + _MISSING_LAMBDA_AUTH_IDENTITY_SOURCES = {"message": "Unauthorized"} + _LAMBDA_AUTHORIZER_NOT_AUTHORIZED = {"message": "User is not authorized to access this resource"} HTTP_STATUS_CODE_501 = 501 HTTP_STATUS_CODE_502 = 502 HTTP_STATUS_CODE_403 = 403 + HTTP_STATUS_CODE_401 = 401 + + @staticmethod + def lambda_authorizer_unauthorized() -> Response: + """ + Constructs a Flask response for when a route invokes a Lambda Authorizer, but + is the identity sources provided are not authorized for that method + + Returns + ------- + Response + A Flask Response object + """ + response_data = jsonify(ServiceErrorResponses._LAMBDA_AUTHORIZER_NOT_AUTHORIZED) + return make_response(response_data, ServiceErrorResponses.HTTP_STATUS_CODE_403) + + @staticmethod + def missing_lambda_auth_identity_sources() -> Response: + """ + Constructs a Flask response for when a route contains a Lambda Authorizer + but is missing the required identity services + + Returns + ------- + Response + A Flask Response object + """ + response_data = jsonify(ServiceErrorResponses._MISSING_LAMBDA_AUTH_IDENTITY_SOURCES) + return make_response(response_data, ServiceErrorResponses.HTTP_STATUS_CODE_401) @staticmethod def lambda_failure_response(*args): diff --git a/samcli/local/docker/manager.py b/samcli/local/docker/manager.py index 5780852d02..50f7178021 100644 --- a/samcli/local/docker/manager.py +++ b/samcli/local/docker/manager.py @@ -142,7 +142,12 @@ def pull_image(self, image_name, tag=None, stream=None): If the Docker image was not available in the server """ if tag is None: - tag = image_name.split(":")[1] if ":" in image_name else "latest" + _image_name_split = image_name.split(":") + # Separate the image_name from the tag so less forgiving docker clones + # (podman) get the image name as the URL they expect. Official docker seems + # to clean this up internally. + tag = _image_name_split[1] if len(_image_name_split) > 1 else "latest" + image_name = _image_name_split[0] # use a global lock to get the image lock with self._lock: image_lock = self._lock_per_image.get(image_name) @@ -162,7 +167,7 @@ def pull_image(self, image_name, tag=None, stream=None): raise DockerImagePullFailedException(str(ex)) from ex # io streams, especially StringIO, work only with unicode strings - stream_writer.write("\nFetching {} Docker container image...".format(image_name)) + stream_writer.write("\nFetching {}:{} Docker container image...".format(image_name, tag)) # Each line contains information on progress of the pull. Each line is a JSON string for _ in result_itr: diff --git a/samcli/local/events/api_event.py b/samcli/local/events/api_event.py index 8779f6509f..1b82c7caea 100644 --- a/samcli/local/events/api_event.py +++ b/samcli/local/events/api_event.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from time import time +from typing import Any, Dict class ContextIdentity: @@ -120,7 +121,7 @@ def __init__( self.request_time = request_time self.operation_name = operation_name - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Constructs an dictionary representation of the RequestContext Object to be used in serializing to JSON @@ -218,11 +219,14 @@ def __init__( self.path = path self.is_base_64_encoded = is_base_64_encoded - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Constructs an dictionary representation of the ApiGatewayLambdaEvent Object to be used in serializing to JSON - :return: dict representing the object + Returns + ------- + Dict[str, Any] + Dict representing the object """ request_context_dict = {} if self.request_context: @@ -326,7 +330,7 @@ def __init__( self.domain_name = domain_name self.domain_prefix = domain_prefix - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Constructs an dictionary representation of the RequestContext Version 2 Object to be used in serializing to JSON @@ -427,12 +431,15 @@ def __init__( self.is_base_64_encoded = is_base_64_encoded self.stage_variables = stage_variables - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Constructs an dictionary representation of the ApiGatewayLambdaEvent Version 2 Object to be used in serializing to JSON - :return: dict representing the object + Returns + ------- + Dict[str, Any] + Dict representing the object """ request_context_dict = {} if self.request_context: diff --git a/samcli/local/lambdafn/env_vars.py b/samcli/local/lambdafn/env_vars.py index 0aa9237cf7..4aefb76bef 100644 --- a/samcli/local/lambdafn/env_vars.py +++ b/samcli/local/lambdafn/env_vars.py @@ -94,17 +94,19 @@ def resolve(self): # Default value for the variable gets lowest priority for name, value in self.variables.items(): + override_value = value + # Shell environment values, second priority if name in self.shell_env_values: - value = self.shell_env_values[name] + override_value = self.shell_env_values[name] # Overridden values, highest priority if name in self.override_values: - value = self.override_values[name] + override_value = self.override_values[name] # Any value must be a string when passed to Lambda runtime. # Runtime expects a Map for environment variables - result[name] = self._stringify_value(value) + result[name] = self._stringify_value(override_value) return result diff --git a/samcli/runtime_config.json b/samcli/runtime_config.json index 7920d35dd4..3d0c75ae2b 100644 --- a/samcli/runtime_config.json +++ b/samcli/runtime_config.json @@ -1,3 +1,3 @@ { - "app_template_repo_commit": "8bdd0c3897ada824175a53f7762ce0711a0596a8" + "app_template_repo_commit": "b4a2b2ee5d0dc2d03d1f65385fa8c21bafd097f3" } diff --git a/tests/integration/deploy/test_managed_stack_deploy.py b/tests/integration/deploy/test_managed_stack_deploy.py index 923353c8b3..5a74af639a 100644 --- a/tests/integration/deploy/test_managed_stack_deploy.py +++ b/tests/integration/deploy/test_managed_stack_deploy.py @@ -2,6 +2,7 @@ from unittest import skipIf import boto3 +import pytest from botocore.exceptions import ClientError from parameterized import parameterized @@ -16,7 +17,7 @@ # This is to restrict package tests to run outside of CI/CD, when the branch is not master or tests are not run by Canary SKIP_MANAGED_STACK_TESTS = RUNNING_ON_CI and RUNNING_TEST_FOR_MASTER_ON_CI and not RUN_BY_CANARY # Limits the managed stack tests to be run on a single python version to avoid CI race conditions -IS_TARGETTED_PYTHON_VERSION = PYTHON_VERSION.startswith("3.7") +IS_TARGETTED_PYTHON_VERSION = PYTHON_VERSION.startswith("3.8") CFN_PYTHON_VERSION_SUFFIX = PYTHON_VERSION.replace(".", "-") # Set region for managed stacks to be in a different region than the ones in deploy @@ -24,6 +25,7 @@ @skipIf(SKIP_MANAGED_STACK_TESTS or not IS_TARGETTED_PYTHON_VERSION, "Skip managed stack tests in CI/CD only") +@pytest.mark.xdist_group(name="managed_stack") class TestManagedStackDeploy(DeployIntegBase): def setUp(self): super().setUp() diff --git a/tests/integration/local/common_utils.py b/tests/integration/local/common_utils.py index 1d6fd14547..9ed98ed361 100644 --- a/tests/integration/local/common_utils.py +++ b/tests/integration/local/common_utils.py @@ -28,8 +28,13 @@ def wait_for_local_process(process, port, collect_output=False) -> str: if "Address already in use" in line_as_str: LOG.info(f"Attempted to start port on {port} but it is already in use, restarting on a new port.") raise InvalidAddressException() - if "Press CTRL+C to quit" in line_as_str or "Terraform Support beta feature is not enabled." in line_as_str: + if ( + "Press CTRL+C to quit" in line_as_str + or "Terraform Support beta feature is not enabled." in line_as_str + or "Error: " in line_as_str + ): break + return output diff --git a/tests/integration/local/start_api/lambda_authorizers/__init__.py b/tests/integration/local/start_api/lambda_authorizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/local/start_api/lambda_authorizers/test_cfn_authorizer_definitions.py b/tests/integration/local/start_api/lambda_authorizers/test_cfn_authorizer_definitions.py new file mode 100644 index 0000000000..e97e27ef1b --- /dev/null +++ b/tests/integration/local/start_api/lambda_authorizers/test_cfn_authorizer_definitions.py @@ -0,0 +1,331 @@ +import pytest +import requests +from tests.integration.local.start_api.start_api_integ_base import ( + StartApiIntegBaseClass, + WritableStartApiIntegBaseClass, +) +from parameterized import parameterized_class + + +@parameterized_class( + ("template_path", "endpoint", "parameter_overrides"), + [ + ("/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml", "requestauthorizertoken", {}), + ("/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml", "requestauthorizerrequest", {}), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2", + {"RoutePayloadFormatVersion": "2.0"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2", + {"RoutePayloadFormatVersion": "1.0"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2simple", + {"AuthHandler": "app.simple_handler", "RoutePayloadFormatVersion": "2.0"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2simple", + {"AuthHandler": "app.simple_handler", "RoutePayloadFormatVersion": "1.0"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv1", + {"RoutePayloadFormatVersion": "2.0"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv1", + {"RoutePayloadFormatVersion": "1.0"}, + ), + ], +) +class TestCfnLambdaAuthorizerResources(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_invokes_authorizer(self): + headers = {"header": "myheader"} + query_string = {"query": "myquery"} + + response = requests.get(f"{self.url}/{self.endpoint}", headers=headers, params=query_string, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 200) + # check if the authorizer passes along a message + self.assertEqual(response_json, {"message": "from authorizer"}) + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_missing_identity_sources(self): + response = requests.get(f"{self.url}/{self.endpoint}", timeout=300) + + response_json = response.json() + self.assertEqual(response.status_code, 401) + self.assertEqual(response_json, {"message": "Unauthorized"}) + + +@parameterized_class( + ("template_path", "endpoint", "parameter_overrides"), + [ + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml", + "requestauthorizertoken", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml", + "requestauthorizerrequest", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2simple", + {"AuthHandler": "app.unauthv2"}, + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv1", + {"AuthHandler": "app.unauth"}, + ), + ], +) +class TestCfnLambdaAuthorizersUnauthorized(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_unauthorized_request(self): + headers = {"header": "myheader"} + query_string = {"query": "myquery"} + + response = requests.get(f"{self.url}/{self.endpoint}", headers=headers, params=query_string, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 403) + self.assertEqual(response_json, {"message": "User is not authorized to access this resource"}) + + +@parameterized_class( + ("template_path", "endpoint"), + [ + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml", + "requestauthorizertoken", + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml", + "requestauthorizerrequest", + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2", + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv2simple", + ), + ( + "/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml", + "requestauthorizerv1", + ), + ], +) +class TestCfnLambdaAuthorizer500(StartApiIntegBaseClass): + parameter_overrides = {"AuthHandler": "app.throws_exception"} + + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_authorizer_raises_exception(self): + headers = {"header": "myheader"} + query_string = {"query": "myquery"} + + response = requests.get(f"{self.url}/{self.endpoint}", headers=headers, params=query_string, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 502) + self.assertEqual(response_json, {"message": "Internal server error"}) + + +class TestInvalidApiTemplateUsingUnsupportedType(WritableStartApiIntegBaseClass): + """ + Test using an invalid Type for an Authorizer + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + RequestAuthorizer: + Type: AWS::ApiGateway::Authorizer + Properties: + AuthorizerUri: arn:aws:apigateway:123:lambda:path/2015-03-31/functions/arn/invocations + Type: notvalid + IdentitySource: "method.request.header.header, method.request.querystring.query" + Name: RequestAuthorizer + RestApiId: !Ref RestApiLambdaAuth +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Authorizer 'RequestAuthorizer' with type 'notvalid' is currently not supported. " + "Only Lambda Authorizers of type TOKEN and REQUEST are supported.", + self.start_api_process_output, + ) + + +class TestInvalidHttpTemplateUsingIncorrectPayloadVersion(WritableStartApiIntegBaseClass): + """ + Test using an invalid AuthorizerPayloadFormatVersion for an Authorizer + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + RequestAuthorizerV2Simple: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "3.0" + EnableSimpleResponses: false + AuthorizerType: REQUEST + AuthorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations + IdentitySource: + - "$request.header.header" + - "$request.querystring.query" + Name: RequestAuthorizerV2Simple + ApiId: !Ref HttpLambdaAuth +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Error: Lambda Authorizer 'RequestAuthorizerV2Simple' contains an " + "invalid 'AuthorizerPayloadFormatVersion', it must be set to '1.0' or '2.0'", + self.start_api_process_output, + ) + + +class TestInvalidHttpTemplateSimpleResponseWithV1(WritableStartApiIntegBaseClass): + """ + Test using simple responses with V1 format version + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + RequestAuthorizerV2Simple: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "1.0" + EnableSimpleResponses: true + AuthorizerType: REQUEST + AuthorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations + IdentitySource: + - "$request.header.header" + - "$request.querystring.query" + Name: RequestAuthorizerV2Simple + ApiId: !Ref HttpLambdaAuth +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Error: 'EnableSimpleResponses' is only supported for '2.0' " + "payload format versions for Lambda Authorizer 'RequestAuthorizerV2Simple'.", + self.start_api_process_output, + ) + + +class TestInvalidHttpTemplateUnsupportedType(WritableStartApiIntegBaseClass): + """ + Test using an invalid Type for HttpApi + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + RequestAuthorizerV2Simple: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "1.0" + EnableSimpleResponses: false + AuthorizerType: unsupportedtype + AuthorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations + IdentitySource: + - "$request.header.header" + - "$request.querystring.query" + Name: RequestAuthorizerV2Simple + ApiId: !Ref HttpLambdaAuth +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Authorizer 'RequestAuthorizerV2Simple' with type 'unsupportedtype' is currently " + "not supported. Only Lambda Authorizers of type REQUEST are supported for API Gateway V2.", + self.start_api_process_output, + ) + + +class TestInvalidHttpTemplateInvalidIdentitySources(WritableStartApiIntegBaseClass): + """ + Test using an invalid identity source + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + RequestAuthorizerV2Simple: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "1.0" + EnableSimpleResponses: false + AuthorizerType: REQUEST + AuthorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations + IdentitySource: + - "hello.world.this.is.invalid" + Name: RequestAuthorizerV2Simple + ApiId: !Ref HttpLambdaAuth +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Error: Lambda Authorizer RequestAuthorizerV2Simple does not contain valid identity sources.", + self.start_api_process_output, + ) diff --git a/tests/integration/local/start_api/lambda_authorizers/test_sfn_props_lambda_authorizers.py b/tests/integration/local/start_api/lambda_authorizers/test_sfn_props_lambda_authorizers.py new file mode 100644 index 0000000000..7e5c10f3e4 --- /dev/null +++ b/tests/integration/local/start_api/lambda_authorizers/test_sfn_props_lambda_authorizers.py @@ -0,0 +1,269 @@ +import pytest +import requests +from tests.integration.local.start_api.start_api_integ_base import ( + StartApiIntegBaseClass, + WritableStartApiIntegBaseClass, +) +from parameterized import parameterized_class + + +@parameterized_class( + ("parameter_overrides", "template_path"), + [ + ({"AuthOverride": "RequestAuthorizerV2"}, "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml"), + ( + {"AuthOverride": "RequestAuthorizerV2Simple"}, + "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml", + ), + ({"AuthOverride": "RequestAuthorizerV1"}, "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml"), + ({"AuthOverride": "Token"}, "/testdata/start_api/lambda_authorizers/serverless-api-props.yaml"), + ({"AuthOverride": "Request"}, "/testdata/start_api/lambda_authorizers/serverless-api-props.yaml"), + ], +) +class TestSfnPropertiesLambdaAuthorizers(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_invokes_authorizer(self): + headers = {"header": "myheader"} + query = {"query": "myquery"} + response = requests.get(f"{self.url}/requestauthorizer", headers=headers, params=query, timeout=300) + + response_json = response.json() + self.assertEqual(response.status_code, 200) + # check if the authorizer passes along a message + self.assertEqual(response_json, {"message": "from authorizer"}) + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_missing_identity_sources(self): + response = requests.get(f"{self.url}/requestauthorizer", timeout=300) + + response_json = response.json() + self.assertEqual(response.status_code, 401) + self.assertEqual(response_json, {"message": "Unauthorized"}) + + +@parameterized_class( + ("parameter_overrides", "template_path"), + [ + ( + {"AuthHandler": "app.unauth", "AuthOverride": "RequestAuthorizerV1"}, + "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml", + ), + ( + {"AuthSimpleHandler": "app.unauthv2", "AuthOverride": "RequestAuthorizerV2Simple"}, + "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml", + ), + ( + {"AuthHandler": "app.unauth", "AuthOverride": "Token"}, + "/testdata/start_api/lambda_authorizers/serverless-api-props.yaml", + ), + ( + {"AuthHandler": "app.unauth", "AuthOverride": "Request"}, + "/testdata/start_api/lambda_authorizers/serverless-api-props.yaml", + ), + ], +) +class TestSfnPropertiesLambdaAuthorizersUnauthorized(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_unauthorized_request(self): + headers = {"header": "myheader"} + query = {"query": "myquery"} + + response = requests.get(f"{self.url}/requestauthorizer", headers=headers, params=query, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 403) + self.assertEqual(response_json, {"message": "User is not authorized to access this resource"}) + + +@parameterized_class( + ("parameter_overrides", "template_path"), + [ + ( + {"AuthHandler": "app.throws_exception", "AuthOverride": "RequestAuthorizerV1"}, + "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml", + ), + ( + {"AuthSimpleHandler": "app.throws_exception", "AuthOverride": "RequestAuthorizerV2Simple"}, + "/testdata/start_api/lambda_authorizers/serverless-http-props.yaml", + ), + ( + {"AuthHandler": "app.throws_exception", "AuthOverride": "Token"}, + "/testdata/start_api/lambda_authorizers/serverless-api-props.yaml", + ), + ( + {"AuthHandler": "app.throws_exception", "AuthOverride": "Request"}, + "/testdata/start_api/lambda_authorizers/serverless-api-props.yaml", + ), + ], +) +class TestSfnPropertiesLambdaAuthorizer500(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_unauthorized_request(self): + headers = {"header": "myheader"} + query = {"query": "myquery"} + + response = requests.get(f"{self.url}/requestauthorizer", headers=headers, params=query, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 502) + self.assertEqual(response_json, {"message": "Internal server error"}) + + +class TestUsingSimpleResponseWithV1HttpApi(WritableStartApiIntegBaseClass): + do_collect_cmd_init_output = True + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + TestServerlessHttpApi: + Type: AWS::Serverless::HttpApi + Properties: + StageName: http + Auth: + DefaultAuthorizer: RequestAuthorizerV2 + Authorizers: + RequestAuthorizerV2: + AuthorizerPayloadFormatVersion: "1.0" + EnableSimpleResponses: true + FunctionArn: !GetAtt AuthorizerFunction.Arn + Identity: + Headers: + - header + QueryStrings: + - query + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Events: + ApiEvent: + Type: HttpApi + Properties: + Path: /requestauthorizer + Method: get + ApiId: !Ref TestServerlessHttpApi +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "EnableSimpleResponses must be used with the 2.0 payload " + "format version in Lambda Authorizer 'RequestAuthorizerV2'.", + self.start_api_process_output, + ) + + +class TestInvalidInvalidVersionHttpApi(WritableStartApiIntegBaseClass): + """ + Test using an invalid AuthorizerPayloadFormatVersion property value + when defining a Lambda Authorizer in the a Serverless resource properties. + """ + + do_collect_cmd_init_output = True + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + TestServerlessHttpApi: + Type: AWS::Serverless::HttpApi + Properties: + StageName: http + Auth: + DefaultAuthorizer: RequestAuthorizerV2 + Authorizers: + RequestAuthorizerV2: + AuthorizerPayloadFormatVersion: "3.0" + EnableSimpleResponses: false + FunctionArn: !GetAtt AuthorizerFunction.Arn + Identity: + Headers: + - header + QueryStrings: + - query + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Events: + ApiEvent: + Type: HttpApi + Properties: + Path: /requestauthorizer + Method: get + ApiId: !Ref TestServerlessHttpApi +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Error: Lambda Authorizer 'RequestAuthorizerV2' must contain " + "a valid 'AuthorizerPayloadFormatVersion' for HTTP APIs.", + self.start_api_process_output, + ) + + +class TestUsingInvalidFunctionArnHttpApi(WritableStartApiIntegBaseClass): + """ + Test using an invalid FunctionArn property value when defining + a Lambda Authorizer in the a Serverless resource properties. + """ + + do_collect_cmd_init_output = True + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + TestServerlessHttpApi: + Type: AWS::Serverless::HttpApi + Properties: + StageName: http + Auth: + DefaultAuthorizer: RequestAuthorizerV2 + Authorizers: + RequestAuthorizerV2: + AuthorizerPayloadFormatVersion: "2.0" + EnableSimpleResponses: false + FunctionArn: iofaqio'hfw;iqauh + Identity: + Headers: + - header + QueryStrings: + - query + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Events: + ApiEvent: + Type: HttpApi + Properties: + Path: /requestauthorizer + Method: get + ApiId: !Ref TestServerlessHttpApi +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Unable to parse the Lambda ARN for Authorizer 'RequestAuthorizerV2', skipping", + self.start_api_process_output, + ) diff --git a/tests/integration/local/start_api/lambda_authorizers/test_swagger_authorizer_definitions.py b/tests/integration/local/start_api/lambda_authorizers/test_swagger_authorizer_definitions.py new file mode 100644 index 0000000000..befe814a40 --- /dev/null +++ b/tests/integration/local/start_api/lambda_authorizers/test_swagger_authorizer_definitions.py @@ -0,0 +1,332 @@ +import pytest +import requests +from tests.integration.local.start_api.start_api_integ_base import ( + StartApiIntegBaseClass, + WritableStartApiIntegBaseClass, +) +from parameterized import parameterized_class + + +@parameterized_class( + ("template_path", "endpoint", "parameter_overrides"), + [ + ("/testdata/start_api/lambda_authorizers/swagger-api.yaml", "requestauthorizerswaggertoken", {}), + ( + "/testdata/start_api/lambda_authorizers/swagger-api.yaml", + "requestauthorizerswaggertoken", + {"ValidationString": "^myheader$"}, + ), + ("/testdata/start_api/lambda_authorizers/swagger-api.yaml", "requestauthorizerswaggerrequest", {}), + ("/testdata/start_api/lambda_authorizers/swagger-api.yaml", "requestauthorizeropenapi", {}), + ("/testdata/start_api/lambda_authorizers/swagger-http.yaml", "requestauthorizer", {}), + ( + "/testdata/start_api/lambda_authorizers/swagger-http.yaml", + "requestauthorizersimple", + {"AuthHandler": "app.simple_handler"}, + ), + ], +) +class TestSwaggerLambdaAuthorizerResources(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_invokes_authorizer(self): + headers = {"header": "myheader"} + query_string = {"query": "myquery"} + + response = requests.get(f"{self.url}/{self.endpoint}", headers=headers, params=query_string, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 200) + # check if the authorizer passes along a message + self.assertEqual(response_json, {"message": "from authorizer"}) + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_missing_identity_sources(self): + response = requests.get(f"{self.url}/{self.endpoint}", timeout=300) + + response_json = response.json() + self.assertEqual(response.status_code, 401) + self.assertEqual(response_json, {"message": "Unauthorized"}) + + +@parameterized_class( + ("template_path", "endpoint", "parameter_overrides"), + [ + ( + "/testdata/start_api/lambda_authorizers/swagger-api.yaml", + "requestauthorizerswaggertoken", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/swagger-api.yaml", + "requestauthorizerswaggerrequest", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/swagger-api.yaml", + "requestauthorizeropenapi", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/swagger-http.yaml", + "requestauthorizer", + {"AuthHandler": "app.unauth"}, + ), + ( + "/testdata/start_api/lambda_authorizers/swagger-http.yaml", + "requestauthorizersimple", + {"AuthHandler": "app.unauthv2"}, + ), + ], +) +class TestSwaggerLambdaAuthorizersUnauthorized(StartApiIntegBaseClass): + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_unauthorized_request(self): + headers = {"header": "myheader"} + query_string = {"query": "myquery"} + + response = requests.get(f"{self.url}/{self.endpoint}", headers=headers, params=query_string, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 403) + self.assertEqual(response_json, {"message": "User is not authorized to access this resource"}) + + +@parameterized_class( + ("template_path", "endpoint"), + [ + ("/testdata/start_api/lambda_authorizers/swagger-api.yaml", "requestauthorizerswaggertoken"), + ("/testdata/start_api/lambda_authorizers/swagger-api.yaml", "requestauthorizerswaggerrequest"), + ("/testdata/start_api/lambda_authorizers/swagger-api.yaml", "requestauthorizeropenapi"), + ("/testdata/start_api/lambda_authorizers/swagger-http.yaml", "requestauthorizer"), + ("/testdata/start_api/lambda_authorizers/swagger-http.yaml", "requestauthorizersimple"), + ], +) +class TestSwaggerLambdaAuthorizer500(StartApiIntegBaseClass): + parameter_overrides = {"AuthHandler": "app.throws_exception"} + + def setUp(self): + self.url = f"http://127.0.0.1:{self.port}" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=600, method="thread") + def test_authorizer_raises_exception(self): + headers = {"header": "myheader"} + query_string = {"query": "myquery"} + + response = requests.get(f"{self.url}/{self.endpoint}", headers=headers, params=query_string, timeout=300) + response_json = response.json() + + self.assertEqual(response.status_code, 502) + self.assertEqual(response_json, {"message": "Internal server error"}) + + +class TestInvalidSwaggerTemplateUsingUnsupportedType(WritableStartApiIntegBaseClass): + """ + Test using an invalid Lambda authorizer type + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + HttpApiOpenApi: + Type: AWS::ApiGatewayV2::Api + Properties: + Body: + openapi: "3.0" + info: + title: HttpApiOpenApi + components: + securitySchemes: + Authorizer: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "2.0" + type: "bad type" + identitySource: "$request.header.header, $request.querystring.query" + authorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Lambda authorizer 'Authorizer' type 'bad type' is unsupported, skipping", + self.start_api_process_output, + ) + + +class TestInvalidSwaggerTemplateUsingSimpleResponseWithPayloadV1(WritableStartApiIntegBaseClass): + """ + Test using simple response with wrong payload version + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + HttpApiOpenApi: + Type: AWS::ApiGatewayV2::Api + Properties: + Body: + openapi: "3.0" + info: + title: HttpApiOpenApi + components: + securitySchemes: + Authorizer: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "1.0" + type: "request" + enableSimpleResponses: True + identitySource: "$request.header.header, $request.querystring.query" + authorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Simple responses are only available on HTTP APIs with " + "payload version 2.0, ignoring for Lambda authorizer 'Authorizer'", + self.start_api_process_output, + ) + + +class TestInvalidSwaggerTemplateUsingUnsupportedPayloadVersion(WritableStartApiIntegBaseClass): + """ + Test using an incorrect payload format version + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + HttpApiOpenApi: + Type: AWS::ApiGatewayV2::Api + Properties: + Body: + openapi: "3.0" + info: + title: HttpApiOpenApi + components: + securitySchemes: + Authorizer: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "1.2.3" + type: "request" + identitySource: "$request.header.header, $request.querystring.query" + authorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Error: Authorizer 'Authorizer' contains an invalid payload version", + self.start_api_process_output, + ) + + +class TestInvalidSwaggerTemplateUsingInvalidIdentitySources(WritableStartApiIntegBaseClass): + """ + Test using an invalid identity source (a.b.c.d.e) + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + HttpApiOpenApi: + Type: AWS::ApiGatewayV2::Api + Properties: + Body: + openapi: "3.0" + info: + title: HttpApiOpenApi + components: + securitySchemes: + Authorizer: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "2.0" + type: "request" + identitySource: "a.b.c.d.e" + authorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Error: Identity source 'a.b.c.d.e' for Lambda Authorizer " + "'Authorizer' is not a valid identity source, check the spelling/format.", + self.start_api_process_output, + ) + + +class TestInvalidSwaggerTemplateUsingTokenWithHttpApi(WritableStartApiIntegBaseClass): + """ + Test using token authorizer with HTTP API + """ + + do_collect_cmd_init_output = True + + template_content = """AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Resources: + HttpApiOpenApi: + Type: AWS::ApiGatewayV2::Api + Properties: + Body: + openapi: "3.0" + info: + title: HttpApiOpenApi + components: + securitySchemes: + Authorizer: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "2.0" + type: "token" + identitySource: "$request.header.header" + authorizerUri: arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:AuthorizerFunction/invocations +""" + + @pytest.mark.flaky(reruns=3) + @pytest.mark.timeout(timeout=10, method="thread") + def test_invalid_template(self): + self.assertIn( + "Type 'token' for Lambda Authorizer 'Authorizer' is unsupported", + self.start_api_process_output, + ) diff --git a/tests/integration/local/start_api/start_api_integ_base.py b/tests/integration/local/start_api/start_api_integ_base.py index a74de20c15..b1f9a6a785 100644 --- a/tests/integration/local/start_api/start_api_integ_base.py +++ b/tests/integration/local/start_api/start_api_integ_base.py @@ -31,6 +31,8 @@ class StartApiIntegBaseClass(TestCase): build_before_invoke = False build_overrides: Optional[Dict[str, str]] = None + do_collect_cmd_init_output: bool = False + @classmethod def setUpClass(cls): # This is the directory for tests/integration which will be used to file the testdata @@ -97,9 +99,10 @@ def start_api(cls): for image in cls.invoke_image: command_list += ["--invoke-image", image] - cls.start_api_process = Popen(command_list, stderr=PIPE) - - wait_for_local_process(cls.start_api_process, cls.port) + cls.start_api_process = Popen(command_list, stderr=PIPE, stdout=PIPE) + cls.start_api_process_output = wait_for_local_process( + cls.start_api_process, cls.port, collect_output=cls.do_collect_cmd_init_output + ) cls.stop_reading_thread = False @@ -129,12 +132,16 @@ def get_binary_data(filename): return fp.read() -class WatchWarmContainersIntegBaseClass(StartApiIntegBaseClass): +class WritableStartApiIntegBaseClass(StartApiIntegBaseClass): temp_path: Optional[str] = None template_path: Optional[str] = None code_path: Optional[str] = None docker_file_path: Optional[str] = None + template_content: Optional[str] = None + code_content: Optional[str] = None + docker_file_content: Optional[str] = None + @classmethod def setUpClass(cls): cls.temp_path = str(uuid.uuid4()).replace("-", "")[:10] diff --git a/tests/integration/local/start_api/test_start_api.py b/tests/integration/local/start_api/test_start_api.py index 3f0242565a..225cfb96a5 100644 --- a/tests/integration/local/start_api/test_start_api.py +++ b/tests/integration/local/start_api/test_start_api.py @@ -14,8 +14,8 @@ from parameterized import parameterized_class from samcli.commands.local.cli_common.invoke_context import ContainersInitializationMode -from samcli.local.apigw.local_apigw_service import Route -from .start_api_integ_base import StartApiIntegBaseClass, WatchWarmContainersIntegBaseClass +from samcli.local.apigw.route import Route +from .start_api_integ_base import StartApiIntegBaseClass, WritableStartApiIntegBaseClass from ..invoke.layer_utils import LayerUtils @@ -2225,7 +2225,7 @@ def test_can_invoke_lambda_function_successfully(self): self.assertEqual(response.json(), {"hello": "world"}) -class TestWatchingZipWarmContainers(WatchWarmContainersIntegBaseClass): +class TestWatchingZipWarmContainers(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Resources: @@ -2274,7 +2274,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingTemplateChangesLambdaFunctionHandlerChanged(WatchWarmContainersIntegBaseClass): +class TestWatchingTemplateChangesLambdaFunctionHandlerChanged(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Resources: @@ -2340,7 +2340,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingTemplateChangesLambdaFunctionCodeUriChanged(WatchWarmContainersIntegBaseClass): +class TestWatchingTemplateChangesLambdaFunctionCodeUriChanged(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Resources: @@ -2409,7 +2409,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingImageWarmContainers(WatchWarmContainersIntegBaseClass): +class TestWatchingImageWarmContainers(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Parameters: @@ -2473,7 +2473,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingTemplateChangesImageDockerFileChangedLocation(WatchWarmContainersIntegBaseClass): +class TestWatchingTemplateChangesImageDockerFileChangedLocation(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Parameters: @@ -2567,7 +2567,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingZipLazyContainers(WatchWarmContainersIntegBaseClass): +class TestWatchingZipLazyContainers(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Resources: @@ -2616,7 +2616,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingImageLazyContainers(WatchWarmContainersIntegBaseClass): +class TestWatchingImageLazyContainers(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Parameters: @@ -2680,7 +2680,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingTemplateChangesLambdaFunctionHandlerChangedLazyContainer(WatchWarmContainersIntegBaseClass): +class TestWatchingTemplateChangesLambdaFunctionHandlerChangedLazyContainer(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Resources: @@ -2746,7 +2746,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingTemplateChangesLambdaFunctionCodeUriChangedLazyContainers(WatchWarmContainersIntegBaseClass): +class TestWatchingTemplateChangesLambdaFunctionCodeUriChangedLazyContainers(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Resources: @@ -2815,7 +2815,7 @@ def test_changed_code_got_observed_and_loaded(self): self.assertEqual(response.json(), {"hello": "world2"}) -class TestWatchingTemplateChangesImageDockerFileChangedLocationLazyContainers(WatchWarmContainersIntegBaseClass): +class TestWatchingTemplateChangesImageDockerFileChangedLocationLazyContainers(WritableStartApiIntegBaseClass): template_content = """AWSTemplateFormatVersion : '2010-09-09' Transform: AWS::Serverless-2016-10-31 Parameters: diff --git a/tests/integration/testdata/start_api/lambda_authorizers/app.py b/tests/integration/testdata/start_api/lambda_authorizers/app.py new file mode 100644 index 0000000000..bdd5936302 --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/app.py @@ -0,0 +1,265 @@ +# https://github.com/awslabs/aws-apigateway-lambda-authorizer-blueprints/blob/1e79ad02a4dcbbd0fe2951cf9a5de4aff7915823/blueprints/python/api-gateway-authorizer-python.py +# parts of this file is sourced from the above link + +""" +Copyright 2015-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at + + http://aws.amazon.com/apache2.0/ + +or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +""" + +import json +import re + +def lambda_handler(event, context): + authorizer_context = event.get("requestContext", {}).get("authorizer", {}) + + # assume APIGW V1, search for passed message under "authorizer" + message = authorizer_context.get("passed") + + if not message: + # this may be V2, search under "authorizer" -> "lambda" + message = authorizer_context.get("lambda", {}).get("passed") + + return { + "statusCode": 200, + "body": json.dumps({ + "message": message, + }), + } + +def simple_handler(event, context): + return { + "isAuthorized": True, + "context": { + 'number' : 1, + 'bool' : True, + 'passed': 'from authorizer' + } + } + +def throws_exception(event, context): + raise Exception() + +def unauth(event, context): + principalId = "user|a1b2c3d4" + arn = event.get("methodArn") or event.get("routeArn") + tmp = arn.split(':') + awsAccountId = tmp[4] + + policy = AuthPolicy(principalId, awsAccountId) + policy.denyAllMethods() + + authResponse = policy.build() + return authResponse + +def unauthv2(event, context): + return { + "isAuthorized": False, + "context": { + 'number' : 1, + 'bool' : True + } + } + +def auth_handler(event, context): + principalId = "user|a1b2c3d4" + arn = event.get("methodArn") or event.get("routeArn") + tmp = arn.split(':') + apiGatewayArnTmp = tmp[5].split('/') + awsAccountId = tmp[4] + + policy = AuthPolicy(principalId, awsAccountId) + policy.restApiId = apiGatewayArnTmp[0] + policy.region = tmp[3] + policy.stage = apiGatewayArnTmp[1] + # policy.denyAllMethods() + policy.allowAllMethods() + # policy.allowMethod(HttpVerb.GET, "/hello/world/*") + + # Finally, build the policy + authResponse = policy.build() + + # new! -- add additional key-value pairs associated with the authenticated principal + # these are made available by APIGW like so: $context.authorizer. + # additional context is cached + context = { + # "key": str(auth_source), # $context.authorizer.key -> value + 'number' : 1, + 'bool' : True, + 'passed': 'from authorizer' + } + # context['arr'] = ['foo'] <- this is invalid, APIGW will not accept it + # context['obj'] = {'foo':'bar'} <- also invalid + + authResponse['context'] = context + return authResponse + +class HttpVerb: + GET = "GET" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + HEAD = "HEAD" + DELETE = "DELETE" + OPTIONS = "OPTIONS" + ALL = "*" + +class AuthPolicy(object): + awsAccountId = "" + """The AWS account id the policy will be generated for. This is used to create the method ARNs.""" + principalId = "" + """The principal used for the policy, this should be a unique identifier for the end user.""" + version = "2012-10-17" + """The policy version used for the evaluation. This should always be '2012-10-17'""" + pathRegex = "^[/.a-zA-Z0-9-\*]+$" + """The regular expression used to validate resource paths for the policy""" + + """these are the internal lists of allowed and denied methods. These are lists + of objects and each object has 2 properties: A resource ARN and a nullable + conditions statement. + the build method processes these lists and generates the approriate + statements for the final policy""" + allowMethods = [] + denyMethods = [] + + + restApiId = "<>" + """ Replace the placeholder value with a default API Gateway API id to be used in the policy. + Beware of using '*' since it will not simply mean any API Gateway API id, because stars will greedily expand over '/' or other separators. + See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html for more details. """ + + region = "<>" + """ Replace the placeholder value with a default region to be used in the policy. + Beware of using '*' since it will not simply mean any region, because stars will greedily expand over '/' or other separators. + See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html for more details. """ + + stage = "<>" + """ Replace the placeholder value with a default stage to be used in the policy. + Beware of using '*' since it will not simply mean any stage, because stars will greedily expand over '/' or other separators. + See https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_resource.html for more details. """ + + def __init__(self, principal, awsAccountId): + self.awsAccountId = awsAccountId + self.principalId = principal + self.allowMethods = [] + self.denyMethods = [] + + def _addMethod(self, effect, verb, resource, conditions): + """Adds a method to the internal lists of allowed or denied methods. Each object in + the internal list contains a resource ARN and a condition statement. The condition + statement can be null.""" + if verb != "*" and not hasattr(HttpVerb, verb): + raise NameError("Invalid HTTP verb " + verb + ". Allowed verbs in HttpVerb class") + resourcePattern = re.compile(self.pathRegex) + if not resourcePattern.match(resource): + raise NameError("Invalid resource path: " + resource + ". Path should match " + self.pathRegex) + + if resource[:1] == "/": + resource = resource[1:] + + resourceArn = ("arn:aws:execute-api:" + + self.region + ":" + + self.awsAccountId + ":" + + self.restApiId + "/" + + self.stage + "/" + + verb + "/" + + resource) + + if effect.lower() == "allow": + self.allowMethods.append({ + 'resourceArn' : resourceArn, + 'conditions' : conditions + }) + elif effect.lower() == "deny": + self.denyMethods.append({ + 'resourceArn' : resourceArn, + 'conditions' : conditions + }) + + def _getEmptyStatement(self, effect): + """Returns an empty statement object prepopulated with the correct action and the + desired effect.""" + statement = { + 'Action': 'execute-api:Invoke', + 'Effect': effect[:1].upper() + effect[1:].lower(), + 'Resource': [] + } + + return statement + + def _getStatementForEffect(self, effect, methods): + """This function loops over an array of objects containing a resourceArn and + conditions statement and generates the array of statements for the policy.""" + statements = [] + + if len(methods) > 0: + statement = self._getEmptyStatement(effect) + + for curMethod in methods: + if curMethod['conditions'] is None or len(curMethod['conditions']) == 0: + statement['Resource'].append(curMethod['resourceArn']) + else: + conditionalStatement = self._getEmptyStatement(effect) + conditionalStatement['Resource'].append(curMethod['resourceArn']) + conditionalStatement['Condition'] = curMethod['conditions'] + statements.append(conditionalStatement) + + statements.append(statement) + + return statements + + def allowAllMethods(self): + """Adds a '*' allow to the policy to authorize access to all methods of an API""" + self._addMethod("Allow", HttpVerb.ALL, "*", []) + + def denyAllMethods(self): + """Adds a '*' allow to the policy to deny access to all methods of an API""" + self._addMethod("Deny", HttpVerb.ALL, "*", []) + + def allowMethod(self, verb, resource): + """Adds an API Gateway method (Http verb + Resource path) to the list of allowed + methods for the policy""" + self._addMethod("Allow", verb, resource, []) + + def denyMethod(self, verb, resource): + """Adds an API Gateway method (Http verb + Resource path) to the list of denied + methods for the policy""" + self._addMethod("Deny", verb, resource, []) + + def allowMethodWithConditions(self, verb, resource, conditions): + """Adds an API Gateway method (Http verb + Resource path) to the list of allowed + methods and includes a condition for the policy statement. More on AWS policy + conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition""" + self._addMethod("Allow", verb, resource, conditions) + + def denyMethodWithConditions(self, verb, resource, conditions): + """Adds an API Gateway method (Http verb + Resource path) to the list of denied + methods and includes a condition for the policy statement. More on AWS policy + conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition""" + self._addMethod("Deny", verb, resource, conditions) + + def build(self): + """Generates the policy document based on the internal lists of allowed and denied + conditions. This will generate a policy with two main statements for the effect: + one statement for Allow and one statement for Deny. + Methods that includes conditions will have their own statement in the policy.""" + if ((self.allowMethods is None or len(self.allowMethods) == 0) and + (self.denyMethods is None or len(self.denyMethods) == 0)): + raise NameError("No statements defined for the policy") + + policy = { + 'principalId' : self.principalId, + 'policyDocument' : { + 'Version' : self.version, + 'Statement' : [] + } + } + + policy['policyDocument']['Statement'].extend(self._getStatementForEffect("Allow", self.allowMethods)) + policy['policyDocument']['Statement'].extend(self._getStatementForEffect("Deny", self.denyMethods)) + + return policy diff --git a/tests/integration/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml b/tests/integration/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml new file mode 100644 index 0000000000..6ac6fcd834 --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/cfn-apigw-v1.yaml @@ -0,0 +1,130 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Parameters: + AuthHandler: + Type: String + Default: app.auth_handler + +Globals: + Function: + Timeout: 3 + MemorySize: 128 + +Resources: + ## + # APIGW + # + RestApiLambdaAuth: + Type: AWS::ApiGateway::RestApi + Properties: + Name: restapi + ## + # hello world lambda function + # + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Runtime: python3.8 + Architectures: + - x86_64 + HelloWorldFunctionPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref HelloWorldFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + ## + # authorizer lambda function + # + AuthorizerFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthHandler + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + ## + # token based authorizer definition + # + TokenAuthorizer: + Type: AWS::ApiGateway::Authorizer + Properties: + AuthorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + Type: TOKEN + IdentitySource: method.request.header.header + Name: TokenAuthorizer + RestApiId: !Ref RestApiLambdaAuth + ## + # request based authorizer definition + # + RequestAuthorizer: + Type: AWS::ApiGateway::Authorizer + Properties: + AuthorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + Type: REQUEST + IdentitySource: "method.request.header.header, method.request.querystring.query" + Name: RequestAuthorizer + RestApiId: !Ref RestApiLambdaAuth + ## + # hello world endpoint using token auth + # + HelloWorldResourceToken: + Type: AWS::ApiGateway::Resource + Properties: + RestApiId: !Ref RestApiLambdaAuth + ParentId: !GetAtt RestApiLambdaAuth.RootResourceId + PathPart: requestauthorizertoken + HelloWorldMethodToken: + Type: AWS::ApiGateway::Method + Properties: + RestApiId: !Ref RestApiLambdaAuth + ResourceId: !Ref HelloWorldResourceToken + HttpMethod: GET + AuthorizationType: CUSTOM + AuthorizerId: !Ref TokenAuthorizer + Integration: + Type: AWS_PROXY + IntegrationHttpMethod: POST + Uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + ## + # hello world endpoint using request auth + # + HelloWorldResourceRequest: + Type: AWS::ApiGateway::Resource + Properties: + RestApiId: !Ref RestApiLambdaAuth + ParentId: !GetAtt RestApiLambdaAuth.RootResourceId + PathPart: requestauthorizerrequest + HelloWorldMethodRequest: + Type: AWS::ApiGateway::Method + Properties: + RestApiId: !Ref RestApiLambdaAuth + ResourceId: !Ref HelloWorldResourceRequest + HttpMethod: GET + AuthorizationType: CUSTOM + AuthorizerId: !Ref RequestAuthorizer + Integration: + Type: AWS_PROXY + IntegrationHttpMethod: POST + Uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + ## + # APIGW deployment + # + Deployment: + DependsOn: + - HelloWorldMethodToken + - HelloWorldMethodRequest + Type: AWS::ApiGateway::Deployment + Properties: + RestApiId: !Ref RestApiLambdaAuth + StageName: prod \ No newline at end of file diff --git a/tests/integration/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml b/tests/integration/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml new file mode 100644 index 0000000000..09b1611471 --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/cfn-apigw-v2.yaml @@ -0,0 +1,174 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Parameters: + AuthHandler: + Type: String + Default: app.auth_handler + RoutePayloadFormatVersion: + Type: String + Default: "2.0" + +Globals: + Function: + Timeout: 3 + MemorySize: 128 + +Resources: + ## + # APIGW + # + HttpLambdaAuth: + Type: AWS::ApiGatewayV2::Api + Properties: + Name: http + ProtocolType: HTTP + ## + # hello world lambda function + # + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Runtime: python3.8 + Architectures: + - x86_64 + HelloWorldFunctionPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref HelloWorldFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + ## + # authorizer lambda function + # + AuthorizerFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthHandler + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + ## + # request based authorizer definition v2 + # + RequestAuthorizerV2: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "2.0" + EnableSimpleResponses: false + AuthorizerType: REQUEST + AuthorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + IdentitySource: + - "$request.header.header" + - "$request.querystring.query" + Name: RequestAuthorizerV2 + ApiId: !Ref HttpLambdaAuth + ## + # request based authorizer definition v2 (simple response) + # + RequestAuthorizerV2Simple: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "2.0" + EnableSimpleResponses: true + AuthorizerType: REQUEST + AuthorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + IdentitySource: + - "$request.header.header" + - "$request.querystring.query" + Name: RequestAuthorizerV2Simple + ApiId: !Ref HttpLambdaAuth + ## + # request based authorizer definition v1 + # + RequestAuthorizerV1: + Type: AWS::ApiGatewayV2::Authorizer + Properties: + AuthorizerPayloadFormatVersion: "1.0" + EnableSimpleResponses: false + AuthorizerType: REQUEST + AuthorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + IdentitySource: + - "$request.header.header" + - "$request.querystring.query" + Name: RequestAuthorizerV1 + ApiId: !Ref HttpLambdaAuth + ## + # route definition using v2 authorizer + # + HelloWorldRouteV2: + Type: AWS::ApiGatewayV2::Route + Properties: + ApiId: !Ref HttpLambdaAuth + RouteKey: "GET /requestauthorizerv2" + AuthorizationType: CUSTOM + AuthorizerId: !Ref RequestAuthorizerV2 + Target: !Join + - / + - - integrations + - !Ref HelloWorldIntegration + ## + # route definition using v2 simple authorizer + # + HelloWorldRouteV2Simple: + Type: AWS::ApiGatewayV2::Route + Properties: + ApiId: !Ref HttpLambdaAuth + RouteKey: "GET /requestauthorizerv2simple" + AuthorizationType: CUSTOM + AuthorizerId: !Ref RequestAuthorizerV2Simple + Target: !Join + - / + - - integrations + - !Ref HelloWorldIntegration + ## + # route definition using v1 authorizer + # + HelloWorldRouteV1: + Type: AWS::ApiGatewayV2::Route + Properties: + ApiId: !Ref HttpLambdaAuth + RouteKey: "GET /requestauthorizerv1" + AuthorizationType: CUSTOM + AuthorizerId: !Ref RequestAuthorizerV1 + Target: !Join + - / + - - integrations + - !Ref HelloWorldIntegration + ## + # deployment + # + Stage: + Type: AWS::ApiGatewayV2::Stage + Properties: + StageName: prod + DeploymentId: !Ref Deployment + ApiId: !Ref HttpLambdaAuth + Deployment: + Type: AWS::ApiGatewayV2::Deployment + DependsOn: + - HelloWorldRouteV2 + - HelloWorldRouteV2Simple + - HelloWorldRouteV1 + Properties: + ApiId: !Ref HttpLambdaAuth + ## + # lambda integration + # + HelloWorldIntegration: + Type: AWS::ApiGatewayV2::Integration + Properties: + PayloadFormatVersion: !Ref RoutePayloadFormatVersion + ApiId: !Ref HttpLambdaAuth + IntegrationType: AWS_PROXY + IntegrationMethod: POST + IntegrationUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations \ No newline at end of file diff --git a/tests/integration/testdata/start_api/lambda_authorizers/serverless-api-props.yaml b/tests/integration/testdata/start_api/lambda_authorizers/serverless-api-props.yaml new file mode 100644 index 0000000000..721d53c383 --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/serverless-api-props.yaml @@ -0,0 +1,68 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Globals: + Function: + Timeout: 3 + MemorySize: 128 + +Parameters: + AuthOverride: + Default: Token + Type: String + AuthHandler: + Default: app.auth_handler + Type: String + +Resources: + TestServerlessRestApi: + Type: AWS::Serverless::Api + Properties: + StageName: api + Auth: + DefaultAuthorizer: Token + Authorizers: + Token: + FunctionPayloadType: TOKEN + FunctionArn: !GetAtt AuthorizerFunction.Arn + Identity: + Header: header + Request: + FunctionPayloadType: REQUEST + FunctionArn: !GetAtt AuthorizerFunction.Arn + Identity: + Headers: + - header + QueryStrings: + - query + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Runtime: python3.8 + Architectures: + - x86_64 + Events: + ApiEvent: + Type: Api + Properties: + Auth: + Authorizer: !Ref AuthOverride + Path: /requestauthorizer + Method: get + RestApiId: !Ref TestServerlessRestApi + AuthorizerFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthHandler + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com \ No newline at end of file diff --git a/tests/integration/testdata/start_api/lambda_authorizers/serverless-http-props.yaml b/tests/integration/testdata/start_api/lambda_authorizers/serverless-http-props.yaml new file mode 100644 index 0000000000..95a0331d49 --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/serverless-http-props.yaml @@ -0,0 +1,99 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Globals: + Function: + Timeout: 3 + MemorySize: 128 + +Parameters: + AuthOverride: + Default: RequestAuthorizerV2 + Type: String + AuthHandler: + Default: app.auth_handler + Type: String + AuthSimpleHandler: + Default: app.simple_handler + Type: String + +Resources: + TestServerlessHttpApi: + Type: AWS::Serverless::HttpApi + Properties: + StageName: http + Auth: + DefaultAuthorizer: RequestAuthorizerV2 + Authorizers: + RequestAuthorizerV2: + AuthorizerPayloadFormatVersion: "2.0" + EnableSimpleResponses: false + FunctionArn: !GetAtt AuthorizerFunction.Arn + Identity: + Headers: + - header + QueryStrings: + - query + RequestAuthorizerV1: + AuthorizerPayloadFormatVersion: "1.0" + EnableSimpleResponses: false + FunctionArn: !GetAtt AuthorizerFunction.Arn + Identity: + Headers: + - header + QueryStrings: + - query + RequestAuthorizerV2Simple: + AuthorizerPayloadFormatVersion: "2.0" + EnableSimpleResponses: true + FunctionArn: !GetAtt AuthorizerFunctionSimple.Arn + Identity: + Headers: + - header + QueryStrings: + - query + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Runtime: python3.8 + Architectures: + - x86_64 + Events: + ApiEvent: + Type: HttpApi + Properties: + Auth: + Authorizer: !Ref AuthOverride + Path: /requestauthorizer + Method: get + ApiId: !Ref TestServerlessHttpApi + AuthorizerFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthHandler + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + AuthorizerFunctionSimple: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthSimpleHandler + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerSimplePermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunctionSimple + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com \ No newline at end of file diff --git a/tests/integration/testdata/start_api/lambda_authorizers/swagger-api.yaml b/tests/integration/testdata/start_api/lambda_authorizers/swagger-api.yaml new file mode 100644 index 0000000000..b838cc371f --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/swagger-api.yaml @@ -0,0 +1,131 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Parameters: + AuthHandler: + Type: String + Default: app.auth_handler + ValidationString: + Type: String + Default: "" + +Resources: + ## + # Swagger definition within a REST API + # + RestApiSwagger: + Type: AWS::ApiGateway::RestApi + Properties: + Body: + swagger: "2.0" + info: + title: RestApiSwagger + securityDefinitions: + ## + # request based authorizer + # + ApiKeyAuthRequest: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authtype": "custom" + "x-amazon-apigateway-authorizer": + type: "request" + identitySource: "method.request.header.header, method.request.querystring.query" + authorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + ## + # token based authorizer + # + ApiKeyAuthToken: + type: apiKey + in: header + name: header + "x-amazon-apigateway-authtype": "custom" + "x-amazon-apigateway-authorizer": + type: "token" + identityValidationExpression: !Ref ValidationString + authorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + paths: + "/requestauthorizerswaggerrequest": + get: + security: + - ApiKeyAuthRequest: [] + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + "/requestauthorizerswaggertoken": + get: + security: + - ApiKeyAuthToken: [] + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + Name: RestApiSwagger + ## + # OpenAPI definitino with a REST API + # + RestApiOpenApi: + Type: AWS::ApiGateway::RestApi + Properties: + Body: + openapi: "3.0" + info: + title: RestApiOpenApi + components: + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: Auth + "x-amazon-apigateway-authtype": "custom" + "x-amazon-apigateway-authorizer": + type: "request" + identitySource: "method.request.header.header, method.request.querystring.query" + authorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + paths: + "/requestauthorizeropenapi": + get: + security: + - ApiKeyAuth: [] + x-amazon-apigateway-integration: + httpMethod: POST + type: aws_proxy + uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + Name: RestApiOpenApi + ## + # Hello world function an execute permission + # + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Runtime: python3.8 + Architectures: + - x86_64 + HelloWorldFunctionPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref HelloWorldFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + ## + # Authorizer function and execute permission + # + AuthorizerFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthHandler + # Handler: app.unauth + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerFunctionPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com \ No newline at end of file diff --git a/tests/integration/testdata/start_api/lambda_authorizers/swagger-http.yaml b/tests/integration/testdata/start_api/lambda_authorizers/swagger-http.yaml new file mode 100644 index 0000000000..52dbaee3e2 --- /dev/null +++ b/tests/integration/testdata/start_api/lambda_authorizers/swagger-http.yaml @@ -0,0 +1,93 @@ +AWSTemplateFormatVersion: '2010-09-09' +Transform: AWS::Serverless-2016-10-31 + +Parameters: + AuthHandler: + Type: String + Default: app.auth_handler + +Resources: + ## + # OpenAPI definition with a HTTP API + # + HttpApiOpenApi: + Type: AWS::ApiGatewayV2::Api + Properties: + Body: + openapi: "3.0" + info: + title: HttpApiOpenApi + components: + securitySchemes: + RegularAuth: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "2.0" + type: "request" + identitySource: "$request.header.header, $request.querystring.query" + authorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + SimpleAuth: + type: apiKey + in: header + name: notused + "x-amazon-apigateway-authorizer": + authorizerPayloadFormatVersion: "2.0" + enableSimpleResponses: True + type: "request" + identitySource: "$request.header.header, $request.querystring.query" + authorizerUri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthorizerFunction.Arn}/invocations + paths: + "/requestauthorizer": + get: + security: + - RegularAuth: [] + x-amazon-apigateway-integration: + payloadFormatVersion: "2.0" + httpMethod: POST + type: aws_proxy + uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + "/requestauthorizersimple": + get: + security: + - SimpleAuth: [] + x-amazon-apigateway-integration: + payloadFormatVersion: "2.0" + httpMethod: POST + type: aws_proxy + uri: !Sub arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HelloWorldFunction.Arn}/invocations + ## + # Hello world function an execute permission + # + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: app.lambda_handler + Runtime: python3.8 + Architectures: + - x86_64 + HelloWorldFunctionPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref HelloWorldFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com + ## + # Authorizer function and execute permission + # + AuthorizerFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./ + Handler: !Ref AuthHandler + Runtime: python3.8 + Architectures: + - x86_64 + AuthorizerFunctionPermission: + Type: AWS::Lambda::Permission + Properties: + FunctionName: !Ref AuthorizerFunction + Action: lambda:InvokeFunction + Principal: apigateway.amazonaws.com \ No newline at end of file diff --git a/tests/unit/commands/init/core/__init__.py b/tests/unit/commands/init/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/commands/init/core/test_command.py b/tests/unit/commands/init/core/test_command.py new file mode 100644 index 0000000000..6a65394145 --- /dev/null +++ b/tests/unit/commands/init/core/test_command.py @@ -0,0 +1,69 @@ +import unittest +from unittest.mock import Mock, patch +from samcli.commands.init.core.command import InitCommand +from samcli.commands.init.command import DESCRIPTION +from tests.unit.cli.test_command import MockFormatter + + +class MockParams: + def __init__(self, rv, name): + self.rv = rv + self.name = name + + def get_help_record(self, ctx): + return self.rv + + +class TestInitCommand(unittest.TestCase): + @patch.object(InitCommand, "get_params") + def test_get_options_init_command_text(self, mock_get_params): + ctx = Mock() + ctx.command_path = "sam init" + ctx.parent.command_path = "sam" + formatter = MockFormatter(scrub_text=True) + # NOTE(sriram-mv): One option per option section. + mock_get_params.return_value = [ + MockParams(rv=("--name", "Application"), name="name"), + MockParams(rv=("--no-interactive", ""), name="no_interactive"), + MockParams(rv=("--config-file", ""), name="config_file"), + MockParams(rv=("--tracing", ""), name="tracing"), + MockParams(rv=("--debug", ""), name="debug"), + ] + + cmd = InitCommand(name="init", requires_credentials=False, description=DESCRIPTION) + expected_output = { + "Additional Options": [("", ""), ("--tracing", ""), ("", "")], + "Application Options": [("", ""), ("--name", ""), ("", "")], + "Configuration Options": [("", ""), ("--config-file", ""), ("", "")], + "Customized Interactive Mode": [ + ("", ""), + ("$ sam init --name sam-app --runtime " "nodejs18.x --architecture arm64\x1b[0m", ""), + ( + "$ sam init --name sam-app --runtime " + "nodejs18.x --dependency-manager npm " + "--app-template hello-world\x1b[0m", + "", + ), + ("$ sam init --name sam-app --package-type " "image --architecture arm64\x1b[0m", ""), + ], + "Description": [(cmd.description + cmd.description_addendum, "")], + "Direct Initialization": [ + ("", ""), + ("$ sam init --location " "gh:aws-samples/cookiecutter-aws-sam-python\x1b[0m", ""), + ( + "$ sam init --location " + "git+ssh://git@github.com/aws-samples/cookiecutter-aws-sam-python.git\x1b[0m", + "", + ), + ("$ sam init --location " "/path/to/template.zip\x1b[0m", ""), + ("$ sam init --location " "/path/to/template/directory\x1b[0m", ""), + ("$ sam init --location " "https://example.com/path/to/template.zip\x1b[0m", ""), + ], + "Examples": [], + "Interactive Mode": [("", ""), ("$ sam init\x1b[0m", "")], + "Non Interactive Options": [("", ""), ("--no-interactive", ""), ("", "")], + "Other Options": [("", ""), ("--debug", ""), ("", "")], + } + + cmd.format_options(ctx, formatter) + self.assertEqual(formatter.data, expected_output) diff --git a/tests/unit/commands/init/core/test_formatter.py b/tests/unit/commands/init/core/test_formatter.py new file mode 100644 index 0000000000..c518cdf016 --- /dev/null +++ b/tests/unit/commands/init/core/test_formatter.py @@ -0,0 +1,12 @@ +from shutil import get_terminal_size +from unittest import TestCase + +from samcli.cli.row_modifiers import BaseLineRowModifier +from samcli.commands.init.core.formatters import InitCommandHelpTextFormatter + + +class TestInitCommandHelpTextFormatter(TestCase): + def test_init_formatter(self): + self.formatter = InitCommandHelpTextFormatter() + self.assertTrue(self.formatter.left_justification_length <= get_terminal_size().columns // 2) + self.assertIsInstance(self.formatter.modifiers[0], BaseLineRowModifier) diff --git a/tests/unit/commands/init/core/test_options.py b/tests/unit/commands/init/core/test_options.py new file mode 100644 index 0000000000..69c740f536 --- /dev/null +++ b/tests/unit/commands/init/core/test_options.py @@ -0,0 +1,11 @@ +from unittest import TestCase + +from samcli.commands.init.command import cli +from samcli.commands.init.core.options import ALL_OPTIONS + + +class TestOptions(TestCase): + def test_all_options_formatted(self): + command_options = [param.human_readable_name for param in cli.params] + command_options = [command_option for command_option in command_options if command_option is not None] + self.assertEqual(sorted(ALL_OPTIONS), sorted(command_options)) diff --git a/tests/unit/commands/local/lib/swagger/test_parser.py b/tests/unit/commands/local/lib/swagger/test_parser.py index 7b9ea5758d..84f194923e 100644 --- a/tests/unit/commands/local/lib/swagger/test_parser.py +++ b/tests/unit/commands/local/lib/swagger/test_parser.py @@ -3,11 +3,18 @@ """ from unittest import TestCase -from unittest.mock import patch, Mock +from unittest.mock import ANY, patch, Mock from parameterized import parameterized, param from samcli.commands.local.lib.swagger.parser import SwaggerParser -from samcli.local.apigw.local_apigw_service import Route +from samcli.local.apigw.exceptions import ( + IncorrectOasWithDefaultAuthorizerException, + InvalidOasVersion, + InvalidSecurityDefinition, + MultipleAuthorizerException, +) +from samcli.local.apigw.route import Route +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer class TestSwaggerParser_get_apis(TestCase): @@ -24,7 +31,14 @@ def test_with_one_path_method(self): parser._get_integration_function_name = Mock() parser._get_integration_function_name.return_value = function_name - expected = [Route(path="/path1", methods=["get"], function_name=function_name, stack_path=self.stack_path)] + expected = [ + Route( + path="/path1", + methods=["get"], + function_name=function_name, + stack_path=self.stack_path, + ) + ] result = parser.get_routes() self.assertEqual(expected, result) @@ -49,9 +63,24 @@ def test_with_combination_of_paths_methods(self): parser._get_integration_function_name.return_value = function_name expected = { - Route(path="/path1", methods=["get"], function_name=function_name, stack_path=self.stack_path), - Route(path="/path1", methods=["delete"], function_name=function_name, stack_path=self.stack_path), - Route(path="/path2", methods=["post"], function_name=function_name, stack_path=self.stack_path), + Route( + path="/path1", + methods=["get"], + function_name=function_name, + stack_path=self.stack_path, + ), + Route( + path="/path1", + methods=["delete"], + function_name=function_name, + stack_path=self.stack_path, + ), + Route( + path="/path2", + methods=["post"], + function_name=function_name, + stack_path=self.stack_path, + ), } result = parser.get_routes() @@ -73,7 +102,14 @@ def test_with_any_method(self): parser._get_integration_function_name = Mock() parser._get_integration_function_name.return_value = function_name - expected = [Route(methods=["ANY"], path="/path1", function_name=function_name, stack_path=self.stack_path)] + expected = [ + Route( + methods=["ANY"], + path="/path1", + function_name=function_name, + stack_path=self.stack_path, + ) + ] result = parser.get_routes() self.assertEqual(expected, result) @@ -157,6 +193,129 @@ def test_invalid_swagger(self, test_case_name, swagger): expected = [] self.assertEqual(expected, result) + def test_set_no_authorizer(self): + function_name = "function" + payload_version = "1.0" + + swagger = { + "paths": { + "/path1": { + "get": { + "security": [], + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": "someuri", + "payloadFormatVersion": payload_version, + }, + } + } + } + } + + parser = SwaggerParser(self.stack_path, swagger) + parser._get_integration_function_name = Mock(return_value=function_name) + parser._get_payload_format_version = Mock(return_value=payload_version) + + results = parser.get_routes() + expected_result = [ + Route( + path="/path1", + methods=["get"], + function_name=function_name, + payload_format_version=payload_version, + stack_path=self.stack_path, + authorizer_name=None, + authorizer_object=None, + use_default_authorizer=False, + ), + ] + + self.assertEqual(results, expected_result) + + def test_set_defined_authorizer(self): + function_name = "function" + payload_version = "1.0" + authorizer_name = "auth" + + swagger = { + "paths": { + "/path1": { + "get": { + "security": [{authorizer_name: []}], + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": "someuri", + "payloadFormatVersion": payload_version, + }, + } + } + } + } + + parser = SwaggerParser(self.stack_path, swagger) + parser._get_integration_function_name = Mock(return_value=function_name) + parser._get_payload_format_version = Mock(return_value=payload_version) + + results = parser.get_routes() + expected_result = [ + Route( + path="/path1", + methods=["get"], + function_name=function_name, + payload_format_version=payload_version, + stack_path=self.stack_path, + authorizer_name=authorizer_name, + ), + ] + + self.assertEqual(results, expected_result) + + @parameterized.expand( + [ + ( + { + "paths": { + "/path1": { + "get": { + "security": {}, + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": "someuri", + "payloadFormatVersion": "1.0", + }, + } + } + } + }, + InvalidSecurityDefinition, + ), + ( + { + "paths": { + "/path1": { + "get": { + "security": [{"auth1": []}, {"auth2": []}], + "x-amazon-apigateway-integration": { + "type": "aws_proxy", + "uri": "someuri", + "payloadFormatVersion": "1.0", + }, + } + } + } + }, + MultipleAuthorizerException, + ), + ] + ) + def test_invalid_authorizer_definition(self, swagger, expected_exception): + parser = SwaggerParser(self.stack_path, swagger) + parser._get_integration_function_name = Mock(return_value="function") + parser._get_payload_format_version = Mock(return_value="1.0") + + with self.assertRaises(expected_exception): + parser.get_routes() + class TestSwaggerParser_get_integration_function_name(TestCase): def setUp(self) -> None: @@ -214,3 +373,593 @@ def test_binary_media_type_returned(self, test_case_name, swagger, expected_resu parser = SwaggerParser(self.stack_path, swagger) self.assertEqual(parser.get_binary_media_types(), expected_result) + + +class TestSwaggerParser_get_authorizers(TestCase): + @parameterized.expand( + [ + ( # swagger 2.0 with token + request authorizers + { + "swagger": "2.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + }, + }, + "QueryAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "request", + "identitySource": "method.request.querystring.Auth", + "authorizerUri": "arn", + }, + }, + }, + }, + { + "TokenAuth": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="TokenAuth", + type="token", + lambda_name="arn", + identity_sources=["method.request.header.Auth"], + validation_string=None, + use_simple_response=False, + ), + "QueryAuth": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="QueryAuth", + type="request", + lambda_name="arn", + identity_sources=["method.request.querystring.Auth"], + validation_string=None, + use_simple_response=False, + ), + }, + Route.API, + ), + ( # openapi 3.0 with token authorizer + { + "openapi": "3.0", + "components": { + "securitySchemes": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "authorizerPayloadFormatVersion": "2.0", + "type": "request", + "identitySource": "$request.header.Auth", + "authorizerUri": "arn", + }, + }, + }, + }, + }, + { + "TokenAuth": LambdaAuthorizer( + payload_version="2.0", + authorizer_name="TokenAuth", + type="request", + lambda_name="arn", + identity_sources=["$request.header.Auth"], + validation_string=None, + use_simple_response=False, + ), + }, + Route.HTTP, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + def test_with_valid_lambda_auth_definition(self, swagger_doc, expected_authorizers, api_type, mock_lambda_uri): + mock_lambda_uri.get_function_name.return_value = "arn" + + parser = SwaggerParser(Mock(), swagger_doc) + + self.assertEqual(parser.get_authorizers(event_type=api_type), expected_authorizers) + + @parameterized.expand( + [ + ( # test unsupported type (jwt) + { + "openapi": "3.0", + "components": { + "securitySchemes": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "jwt", + "identitySource": "method.request.header.Auth", + "authorizerUri": "arn", + }, + }, + }, + }, + }, + ), + ( # test invalid integration key + { + "openapi": "3.0", + "components": { + "securitySchemes": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "invalid-key-goes-here": { + "type": "request", + "identitySource": "$request.header.Auth", + "authorizerUri": "arn", + }, + }, + }, + }, + }, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_unsupported_lambda_authorizers(self, swagger_doc, get_id_sources_mock, mock_lambda_uri): + parser = SwaggerParser(Mock(), swagger_doc) + + self.assertEqual(parser.get_authorizers(), {}) + + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_invalid_lambda_auth_arn(self, get_id_sources_mock, mock_lambda_uri): + mock_lambda_uri.get_function_name.return_value = None + + swagger_doc = { + "swagger": "2.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + }, + } + }, + } + + parser = SwaggerParser(Mock(), swagger_doc) + + self.assertEqual(parser.get_authorizers(), {}) + + @parameterized.expand( + [ + ( + { + "swagger": "4.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + }, + } + }, + }, + ), + ( + { + "openapi": "1.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + }, + } + }, + }, + ), + ( + { + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + }, + } + } + }, + ), + ] + ) + def test_invalid_oas_version(self, swagger_doc): + parser = SwaggerParser(Mock(), swagger_doc) + + with self.assertRaises(InvalidOasVersion): + parser.get_authorizers() + + @parameterized.expand( + [ + ( # API event with a defined validation string (123), expect lambda auth obj property populated + { + "swagger": "2.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "identityValidationExpression": "123", + "authorizerUri": "arn", + }, + }, + }, + }, + "123", + Route.API, + ), + ( # HTTP event with a defined validation string (123), expect lambda auth obj property NOT populated + { + "openapi": "3.0", + "components": { + "securitySchemes": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "unused", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "authorizerPayloadFormatVersion": "2.0", + "type": "request", + "identityValidationExpression": "123", + "authorizerUri": "arn", + "identitySource": "$request.header.header", + }, + }, + } + }, + }, + None, + Route.HTTP, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_defining_validation_expression( + self, swagger_doc, expected_validation_string, event_type, get_id_sources_mock, mock_lambda_uri + ): + mock_lambda_uri.get_function_name.return_value = "arn" + + parser = SwaggerParser(Mock(), swagger_doc) + + lambda_authorizers = parser.get_authorizers(event_type) + + self.assertEqual(lambda_authorizers["TokenAuth"].validation_string, expected_validation_string) + + @parameterized.expand( + [ + ## + # testing API events + # + ( # using 2.0 payload and no simple response, expect it to be set as False + "2.0", + False, + Route.API, + False, + ), + ( # using 1.0 payload and no simple response, expect it to be set as False + "1.0", + False, + Route.API, + False, + ), + ( # using 1.0 payload and simple response IS set, expect it to be set as False + "1.0", + True, + Route.API, + False, + ), + ( # using 2.0 payload and simple response IS set, expect it to be set as False + "2.0", + True, + Route.API, + False, + ), + ## + # testing HTTP events + # + ( # using 2.0 payload and no simple response, expect it to be set as False + "2.0", + False, + Route.HTTP, + False, + ), + ( # using 1.0 payload and no simple response, expect it to be set as False + "1.0", + False, + Route.HTTP, + False, + ), + ( # using 1.0 payload and simple response IS set, expect it to be set as False + "1.0", + True, + Route.HTTP, + False, + ), + ( # using 2.0 payload and simple response IS set, expect it to be set as True + "2.0", + True, + Route.HTTP, + True, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_defining_simple_responses( + self, + payload_version, + enabled_simple_response, + event_type, + expected_response, + get_id_sources_mock, + mock_lambda_uri, + ): + mock_lambda_uri.get_function_name.return_value = "arn" + + swagger_doc = { + "openapi": "3.0", + "components": { + "securitySchemes": { + "Authorizer": { + "type": "apiKey", + "in": "header", + "name": "notused", + "x-amazon-apigateway-authorizer": { + "authorizerPayloadFormatVersion": payload_version, + "enableSimpleResponses": enabled_simple_response, + "type": "request", + "authorizerUri": "arn", + "identitySource": "$request.header.header", + }, + }, + }, + }, + } + + parser = SwaggerParser(Mock(), swagger_doc) + + lambda_authorizers = parser.get_authorizers(event_type) + + self.assertEqual(lambda_authorizers["Authorizer"].use_simple_response, expected_response) + + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_defining_invalid_payload_versions(self, get_id_sources_mock, mock_lambda_uri): + mock_lambda_uri.get_function_name.return_value = "arn" + + swagger_doc = { + "openapi": "3.0", + "components": { + "securitySchemes": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "authorizerPayloadFormatVersion": "1.2.3", + "type": "request", + "authorizerUri": "arn", + "identitySource": "$request.header.header", + }, + }, + }, + }, + } + + parser = SwaggerParser(Mock(), swagger_doc) + + with self.assertRaisesRegex( + InvalidSecurityDefinition, "^Authorizer 'TokenAuth' contains an invalid payload version$" + ): + parser.get_authorizers(Route.HTTP) + + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_undefined_payload_api_event(self, get_id_sources_mock, mock_lambda_uri): + """ + Tests if the payload version is set to 1.0 if it is not defined for API events + """ + mock_lambda_uri.get_function_name.return_value = "arn" + + swagger_doc = { + "swagger": "2.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + }, + }, + }, + } + + parser = SwaggerParser(Mock(), swagger_doc) + + self.assertEqual(parser.get_authorizers(Route.API)["TokenAuth"].payload_version, LambdaAuthorizer.PAYLOAD_V1) + + @patch("samcli.commands.local.lib.swagger.parser.LambdaUri") + @patch("samcli.commands.local.lib.swagger.parser.SwaggerParser._get_lambda_identity_sources") + def test_simple_response_override_using_rest_api(self, get_id_sources_mock, mock_lambda_uri): + """ + Tests the the Lambda authorizer's simple response property is set to False + if it is provided in a Swagger 2.0 document. + """ + mock_lambda_uri.get_function_name.return_value = "arn" + + swagger_doc = { + "swagger": "2.0", + "securityDefinitions": { + "TokenAuth": { + "type": "apiKey", + "in": "header", + "name": "Auth", + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": { + "type": "token", + "authorizerUri": "arn", + "enableSimpleResponses": True, + }, + }, + }, + } + + parser = SwaggerParser(Mock(), swagger_doc) + + self.assertEqual(parser.get_authorizers(Route.API)["TokenAuth"].use_simple_response, False) + + +class TestSwaggerParser_get_default_authorizer(TestCase): + def test_valid_default_authorizers(self): + authorizer_name = "authorizer" + + swagger_doc = {"openapi": "3.0", "security": [{authorizer_name: []}]} + + parser = SwaggerParser(Mock(), swagger_doc) + result = parser.get_default_authorizer(Route.HTTP) + + self.assertEqual(result, authorizer_name) + + @parameterized.expand( + [ + ({"openapi": "3.0", "security": []},), + ({"swagger": "2.0", "security": []},), + ({"openapi": "3.0"},), + ({"swagger": "2.0"},), + ] + ) + def test_no_default_authorizer_defined(self, swagger): + parser = SwaggerParser(Mock(), swagger) + + result = parser.get_default_authorizer(Route.HTTP) + self.assertIsNone(result) + + result = parser.get_default_authorizer(Route.API) + self.assertIsNone(result) + + @parameterized.expand( + [ + ({"swagger": "2.0", "security": [{"auth": []}]}, IncorrectOasWithDefaultAuthorizerException), + ({"openapi": "3.0", "security": [{"auth": []}, {"auth2": []}]}, MultipleAuthorizerException), + ] + ) + def test_invalid_default_authorizer_definition(self, swagger, expected_exception): + parser = SwaggerParser(Mock(), swagger) + + with self.assertRaises(expected_exception): + parser.get_default_authorizer(Route.HTTP) + + +class TestSwaggerParser_get_lambda_identity_sources(TestCase): + @parameterized.expand( + [ + ( + "token", + Route.API, + {"name": "Authentication", "in": "header"}, + {}, + ["method.request.header.Authentication"], + ), + ( + "request", + Route.API, + {"name": "unused", "in": "header"}, + {"identitySource": "method.request.header.Authentication, method.request.header.otherheader"}, + ["method.request.header.Authentication", "method.request.header.otherheader"], + ), + ] + ) + def test_valid_identity_sources(self, type, event_type, properties, authorizer_object, expected_result): + parser = SwaggerParser(Mock(), Mock()) + + result = parser._get_lambda_identity_sources("myauth", type, event_type, properties, authorizer_object) + self.assertEqual(result, expected_result) + + @parameterized.expand( + [ + ( # missing 'in' property + "token", + Route.API, + {"name": "Authentication"}, + {}, + ), + ( # missing 'name' property + "token", + Route.API, + {"in": "header"}, + {}, + ), + ( # token type for HTTP API + "token", + Route.HTTP, + {"name": "auth", "in": "header"}, + {"identitySource": "method.request.header.Authentication, method.request.header.otherheader"}, + ), + ( # missing 'identitySource' for request + "request", + Route.HTTP, + {"name": "unused", "in": "header"}, + {}, + ), + ] + ) + def test_invalid_authorizer_definitions(self, type, event_type, properties, authorizer_object): + parser = SwaggerParser(Mock(), Mock()) + + result = parser._get_lambda_identity_sources("myauth", type, event_type, properties, authorizer_object) + self.assertEqual(result, []) + + def test_invalid_identity_source_throws_exception(self): + parser = SwaggerParser(Mock(), Mock()) + + properties = {"name": "Authentication", "in": "header"} + auth_properties = {"identitySource": "invalid string goes here"} + + with self.assertRaises(InvalidSecurityDefinition): + parser._get_lambda_identity_sources(Mock(), "request", Route.API, properties, auth_properties) diff --git a/tests/unit/commands/local/lib/test_api_collector.py b/tests/unit/commands/local/lib/test_api_collector.py new file mode 100644 index 0000000000..fdda3db3d4 --- /dev/null +++ b/tests/unit/commands/local/lib/test_api_collector.py @@ -0,0 +1,162 @@ +from unittest import TestCase +from parameterized import parameterized + +from samcli.lib.providers.api_collector import ApiCollector +from samcli.local.apigw.route import Route +from samcli.local.apigw.authorizers.authorizer import Authorizer + + +class TestApiCollector_linking_authorizer(TestCase): + def setUp(self): + self.apigw_id = "apigw1" + + self.api_collector = ApiCollector() + + @parameterized.expand( + [ + ( # test link default authorizer + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name=None, + use_default_authorizer=True, + ) + ], + { + "auth1": Authorizer(authorizer_name="auth1", type="token1", payload_version="1.0"), + "auth2": Authorizer(authorizer_name="auth2", type="token2", payload_version="1.0"), + }, + "auth1", + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name="auth1", + authorizer_object=Authorizer(authorizer_name="auth1", type="token1", payload_version="1.0"), + use_default_authorizer=True, + ) + ], + ), + ( # test link non existant default authorizer + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name=None, + use_default_authorizer=True, + ) + ], + { + "auth1": Authorizer(authorizer_name="auth1", type="token1", payload_version="1.0"), + "auth2": Authorizer(authorizer_name="auth2", type="token2", payload_version="1.0"), + }, + None, + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name=None, + authorizer_object=None, + use_default_authorizer=True, + ) + ], + ), + ( # test no authorizer defined in route + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name=None, + use_default_authorizer=False, + ) + ], + { + "auth1": Authorizer(authorizer_name="auth1", type="token1", payload_version="1.0"), + "auth2": Authorizer(authorizer_name="auth2", type="token2", payload_version="1.0"), + }, + "auth1", + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name=None, + authorizer_object=None, + use_default_authorizer=False, + ) + ], + ), + ( # test linking defined authorizer + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name="auth2", + ) + ], + { + "auth1": Authorizer(authorizer_name="auth1", type="token1", payload_version="1.0"), + "auth2": Authorizer(authorizer_name="auth2", type="token2", payload_version="1.0"), + }, + "auth1", + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name="auth2", + authorizer_object=Authorizer(authorizer_name="auth2", type="token2", payload_version="1.0"), + ) + ], + ), + ( # test linking unsupported authorizer + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name="unsupported", + ) + ], + { + "auth1": Authorizer(authorizer_name="auth1", type="token1", payload_version="1.0"), + "auth2": Authorizer(authorizer_name="auth2", type="token2", payload_version="1.0"), + }, + "auth1", + [ + Route( + function_name="func1", + path="path1", + methods=["get"], + stack_path="path1", + authorizer_name=None, + authorizer_object=None, + ) + ], + ), + ] + ) + def test_link_authorizers(self, routes, authorizers, default_authorizer, expected_routes): + self.api_collector._route_per_resource[self.apigw_id] = routes + self.api_collector._authorizers_per_resources[self.apigw_id] = authorizers + self.api_collector._default_authorizer_per_resource[self.apigw_id] = default_authorizer + + self.api_collector._link_authorizers() + + self.assertEqual(self.api_collector._route_per_resource, {self.apigw_id: expected_routes}) diff --git a/tests/unit/commands/local/lib/test_cfn_api_provider.py b/tests/unit/commands/local/lib/test_cfn_api_provider.py index 57ac53927e..b8ac3a5e01 100644 --- a/tests/unit/commands/local/lib/test_cfn_api_provider.py +++ b/tests/unit/commands/local/lib/test_cfn_api_provider.py @@ -4,10 +4,12 @@ from unittest import TestCase from unittest.mock import patch, Mock +from parameterized import parameterized from samcli.lib.providers.api_provider import ApiProvider from samcli.lib.providers.cfn_api_provider import CfnApiProvider -from samcli.local.apigw.local_apigw_service import Route +from samcli.local.apigw.route import Route +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer from tests.unit.commands.local.lib.test_sam_api_provider import make_swagger from samcli.lib.providers.provider import Cors, Stack @@ -169,6 +171,7 @@ def test_provider_parse_stage_name(self): "Type": "AWS::ApiGateway::RestApi", "Properties": { "Body": { + "swagger": "2.0", "paths": { "/path": { "get": { @@ -183,7 +186,7 @@ def test_provider_parse_stage_name(self): } } } - } + }, } }, }, @@ -211,6 +214,7 @@ def test_provider_stage_variables(self): "Type": "AWS::ApiGateway::RestApi", "Properties": { "Body": { + "swagger": "2.0", "paths": { "/path": { "get": { @@ -225,7 +229,7 @@ def test_provider_stage_variables(self): } } } - } + }, } }, }, @@ -244,6 +248,7 @@ def test_multi_stage_get_all(self): "Type": "AWS::ApiGateway::RestApi", "Properties": { "Body": { + "swagger": "2.0", "paths": { "/path": { "get": { @@ -271,7 +276,7 @@ def test_multi_stage_get_all(self): } } }, - } + }, } }, } @@ -964,6 +969,7 @@ def test_provider_parse_stage_name(self): "Type": "AWS::ApiGatewayV2::Api", "Properties": { "Body": { + "openapi": "3.0", "paths": { "/path": { "get": { @@ -978,7 +984,7 @@ def test_provider_parse_stage_name(self): } } } - } + }, } }, }, @@ -1006,6 +1012,7 @@ def test_provider_stage_variables(self): "Type": "AWS::ApiGatewayV2::Api", "Properties": { "Body": { + "openapi": "3.0", "paths": { "/path": { "get": { @@ -1020,7 +1027,7 @@ def test_provider_stage_variables(self): } } } - } + }, } }, }, @@ -1039,6 +1046,7 @@ def test_multi_stage_get_all(self): "Type": "AWS::ApiGatewayV2::Api", "Properties": { "Body": { + "openapi": "3.0", "paths": { "/path": { "get": { @@ -1066,7 +1074,7 @@ def test_multi_stage_get_all(self): } } }, - } + }, } }, } @@ -1177,3 +1185,132 @@ def test_empty_integration_array(self): } provider = ApiProvider(make_mock_stacks_from_template(template)) self.assertIsNone(provider.api.cors) + + +class TestCollectLambdaAuthorizersWithApiGatewayV1Resources(TestCase): + @parameterized.expand( + [ + ( # test token auth WITHOUT validation + { + "Properties": { + "Type": "TOKEN", + "RestApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": "method.request.header.auth", + } + }, + { + "my-auth-name": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="my-auth-name", + type=LambdaAuthorizer.TOKEN, + lambda_name="my-lambda", + identity_sources=["method.request.header.auth"], + ) + }, + ), + ( # test token auth WITH validation + { + "Properties": { + "Type": "TOKEN", + "RestApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": "method.request.header.auth", + "IdentityValidationExpression": "*", + } + }, + { + "my-auth-name": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="my-auth-name", + type=LambdaAuthorizer.TOKEN, + lambda_name="my-lambda", + identity_sources=["method.request.header.auth"], + validation_string="*", + ) + }, + ), + ( # test request auth + { + "Properties": { + "Type": "REQUEST", + "RestApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": "method.request.header.auth, method.request.querystring.abc", + } + }, + { + "my-auth-name": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="my-auth-name", + type=LambdaAuthorizer.REQUEST, + lambda_name="my-lambda", + identity_sources=["method.request.header.auth", "method.request.querystring.abc"], + ) + }, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + @patch("samcli.commands.local.lib.validators.lambda_auth_props.LambdaAuthorizerV1Validator.validate") + def test_collect_v1_lambda_authorizer(self, resource, expected_authorizer, validator_mock, get_func_name_mock): + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + get_func_name_mock.return_value = auth_lambda_func_name + + validator_mock.return_value = True + + mock_collector = Mock() + mock_collector.add_authorizers = Mock() + + CfnApiProvider._extract_cloud_formation_authorizer(lambda_auth_logical_id, resource, mock_collector) + + mock_collector.add_authorizers.assert_called_with("my-rest-api", expected_authorizer) + + +class TestCollectLambdaAuthorizersWithApiGatewayV2Resources(TestCase): + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + @patch("samcli.commands.local.lib.validators.lambda_auth_props.LambdaAuthorizerV2Validator.validate") + def test_collect_v2_lambda_authorizer(self, validator_mock, get_func_name_mock): + identity_sources = ["$request.header.auth", "$context.something"] + + properties = { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": identity_sources, + "AuthorizerPayloadFormatVersion": "2.0", + } + } + + expected_authorizers = { + "my-auth-name": LambdaAuthorizer( + payload_version="2.0", + authorizer_name="my-auth-name", + type=LambdaAuthorizer.REQUEST, + lambda_name="my-lambda", + identity_sources=identity_sources, + ) + } + + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + get_func_name_mock.return_value = auth_lambda_func_name + + mock_collector = Mock() + mock_collector.add_authorizers = Mock() + + validator_mock.return_value = True + + CfnApiProvider._extract_cfn_gateway_v2_authorizer(lambda_auth_logical_id, properties, mock_collector) + + mock_collector.add_authorizers.assert_called_with("my-rest-api", expected_authorizers) diff --git a/tests/unit/commands/local/lib/test_identity_source_validator.py b/tests/unit/commands/local/lib/test_identity_source_validator.py new file mode 100644 index 0000000000..dbdc4d13d0 --- /dev/null +++ b/tests/unit/commands/local/lib/test_identity_source_validator.py @@ -0,0 +1,50 @@ +from unittest import TestCase + +from parameterized import parameterized + +from samcli.commands.local.lib.validators.identity_source_validator import IdentitySourceValidator +from samcli.local.apigw.route import Route + + +class TestIdentitySourceValidator(TestCase): + @parameterized.expand( + [ + ("method.request.header.this-is_my.header", Route.API), + ("method.request.querystring.this_is-my_query.string", Route.API), + ("context.this.is.a_cool-context", Route.API), + ("stageVariables.my.stage_vari-ble", Route.API), + ("$request.header.this-is_my.header", Route.HTTP), + ("$request.querystring.this_is-my_query.string", Route.HTTP), + ("$context.this.is.a_cool-context", Route.HTTP), + ("$stageVariables.my.stage_vari-ble", Route.HTTP), + ] + ) + def test_valid_identity_sources(self, identity_source, event_type): + self.assertTrue(IdentitySourceValidator.validate_identity_source(identity_source, event_type)) + + @parameterized.expand( + [ + ("method.request.header.this+is+my~header", Route.API), + ("method.request.querystring.this+is+my~query?string", Route.API), + ("context.this?is~a_cool-context", Route.API), + ("stageVariables.my][stage|vari-ble", Route.API), + ("", Route.API), + ("method.request.querystring", Route.API), + ("method.request.header", Route.API), + ("context", Route.API), + ("stageVariable", Route.API), + ("hello world", Route.API), + ("$request.header.this+is+my~header", Route.HTTP), + ("$request.querystring.this+is+my~query?string", Route.HTTP), + ("$context.this?is~a_cool-context", Route.HTTP), + ("$stageVariables.my][stage|vari-ble", Route.HTTP), + ("", Route.HTTP), + ("$request.querystring", Route.HTTP), + ("$request.header", Route.HTTP), + ("$context", Route.HTTP), + ("$stageVariable", Route.HTTP), + ("hello world", Route.HTTP), + ] + ) + def test_invalid_identity_sources(self, identity_source, event_type): + self.assertFalse(IdentitySourceValidator.validate_identity_source(identity_source, event_type)) diff --git a/tests/unit/commands/local/lib/test_local_api_service.py b/tests/unit/commands/local/lib/test_local_api_service.py index 89d91102d3..b848a61800 100644 --- a/tests/unit/commands/local/lib/test_local_api_service.py +++ b/tests/unit/commands/local/lib/test_local_api_service.py @@ -11,7 +11,7 @@ from samcli.lib.providers.api_provider import ApiProvider from samcli.commands.local.lib.exceptions import NoApisDefined from samcli.commands.local.lib.local_api_service import LocalApiService -from samcli.local.apigw.local_apigw_service import Route +from samcli.local.apigw.route import Route class TestLocalApiService_start(TestCase): diff --git a/tests/unit/commands/local/lib/test_sam_api_provider.py b/tests/unit/commands/local/lib/test_sam_api_provider.py index 705e7bf876..945bdb79dc 100644 --- a/tests/unit/commands/local/lib/test_sam_api_provider.py +++ b/tests/unit/commands/local/lib/test_sam_api_provider.py @@ -3,13 +3,15 @@ from collections import OrderedDict from unittest import TestCase -from unittest.mock import patch, Mock +from unittest.mock import ANY, patch, Mock from parameterized import parameterized from samcli.commands.validate.lib.exceptions import InvalidSamDocumentException from samcli.lib.providers.api_provider import ApiProvider from samcli.lib.providers.provider import Cors, Stack -from samcli.local.apigw.local_apigw_service import Route +from samcli.lib.providers.sam_api_provider import SamApiProvider +from samcli.local.apigw.route import Route +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer def make_mock_stacks_from_template(template): @@ -745,6 +747,7 @@ def test_provider_parse_stage_name(self): "Properties": { "StageName": "dev", "DefinitionBody": { + "swagger": "2.0", "paths": { "/path": { "get": { @@ -759,7 +762,7 @@ def test_provider_parse_stage_name(self): } } } - } + }, }, }, } @@ -781,6 +784,7 @@ def test_provider_stage_variables(self): "StageName": "dev", "Variables": {"vis": "data", "random": "test", "foo": "bar"}, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path": { "get": { @@ -795,7 +799,7 @@ def test_provider_stage_variables(self): } } } - } + }, }, }, } @@ -816,6 +820,7 @@ def test_multi_stage_get_all(self): "StageName": "dev", "Variables": {"vis": "data", "random": "test", "foo": "bar"}, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "get": { @@ -830,7 +835,7 @@ def test_multi_stage_get_all(self): } } } - } + }, }, }, } @@ -841,6 +846,7 @@ def test_multi_stage_get_all(self): "StageName": "Production", "Variables": {"vis": "prod data", "random": "test", "foo": "bar"}, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path": { "get": { @@ -868,7 +874,7 @@ def test_multi_stage_get_all(self): } } }, - } + }, }, }, } @@ -900,6 +906,7 @@ def test_provider_parse_cors_with_unresolved_intrinsic(self): "StageName": "Prod", "Cors": {"AllowOrigin": {"Fn:Sub": "Some string to sub"}}, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "post": { @@ -925,7 +932,7 @@ def test_provider_parse_cors_with_unresolved_intrinsic(self): } } }, - } + }, }, }, } @@ -956,6 +963,7 @@ def test_provider_parse_cors_string(self): "StageName": "Prod", "Cors": "'*'", "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "post": { @@ -981,7 +989,7 @@ def test_provider_parse_cors_string(self): } } }, - } + }, }, }, } @@ -1017,6 +1025,7 @@ def test_provider_parse_cors_dict(self): "MaxAge": "'600'", }, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "post": { @@ -1042,7 +1051,7 @@ def test_provider_parse_cors_dict(self): } } }, - } + }, }, }, } @@ -1080,6 +1089,7 @@ def test_provider_parse_cors_dict_star_allow(self): "MaxAge": "'600'", }, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "post": { @@ -1105,7 +1115,7 @@ def test_provider_parse_cors_dict_star_allow(self): } } }, - } + }, }, }, } @@ -1283,6 +1293,7 @@ def test_default_cors_dict_prop(self): "StageName": "Prod", "Cors": {"AllowOrigin": "'www.domain.com'"}, "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "get": { @@ -1297,7 +1308,7 @@ def test_default_cors_dict_prop(self): } } } - } + }, }, }, } @@ -1331,6 +1342,7 @@ def test_global_cors(self): "Properties": { "StageName": "Prod", "DefinitionBody": { + "swagger": "2.0", "paths": { "/path2": { "get": { @@ -1356,7 +1368,7 @@ def test_global_cors(self): } } }, - } + }, }, }, } @@ -1391,6 +1403,7 @@ def test_provider_parse_cors_with_unresolved_intrinsic(self): "StageName": "Prod", "CorsConfiguration": {"AllowOrigins": {"Fn:Sub": "Some string to sub"}}, "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "post": { @@ -1416,7 +1429,7 @@ def test_provider_parse_cors_with_unresolved_intrinsic(self): } } }, - } + }, }, }, } @@ -1447,6 +1460,7 @@ def test_provider_parse_cors_true(self): "StageName": "Prod", "CorsConfiguration": True, "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "post": { @@ -1472,7 +1486,7 @@ def test_provider_parse_cors_true(self): } } }, - } + }, }, }, } @@ -1503,6 +1517,7 @@ def test_provider_parse_cors_false(self): "StageName": "Prod", "CorsConfiguration": False, "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "post": { @@ -1528,7 +1543,7 @@ def test_provider_parse_cors_false(self): } } }, - } + }, }, }, } @@ -1561,6 +1576,7 @@ def test_provider_parse_cors_dict(self): "MaxAge": 600, }, "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "post": { @@ -1586,7 +1602,7 @@ def test_provider_parse_cors_dict(self): } } }, - } + }, }, }, } @@ -1624,6 +1640,7 @@ def test_provider_parse_cors_dict_star_allow(self): "MaxAge": 600, }, "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "post": { @@ -1649,7 +1666,7 @@ def test_provider_parse_cors_dict_star_allow(self): } } }, - } + }, }, }, } @@ -1732,6 +1749,7 @@ def test_default_cors_dict_prop(self): "StageName": "Prod", "CorsConfiguration": {"AllowOrigins": ["www.domain.com"]}, "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "get": { @@ -1746,7 +1764,7 @@ def test_default_cors_dict_prop(self): } } } - } + }, }, }, } @@ -1780,6 +1798,7 @@ def test_global_cors(self): "Properties": { "StageName": "Prod", "DefinitionBody": { + "openapi": "3.0", "paths": { "/path2": { "get": { @@ -1805,7 +1824,7 @@ def test_global_cors(self): } } }, - } + }, }, }, } @@ -1830,6 +1849,288 @@ def test_global_cors(self): self.assertEqual(provider.api.cors, cors) +class TestSamApiUsingAuthorizers(TestCase): + @parameterized.expand( + [(SamApiProvider()._extract_from_serverless_api,), (SamApiProvider()._extract_from_serverless_http,)] + ) + @patch("samcli.lib.providers.cfn_base_api_provider.CfnBaseApiProvider.extract_swagger_route") + @patch("samcli.lib.providers.sam_api_provider.SamApiProvider._extract_authorizers_from_props") + def test_extract_serverless_api_extracts_default_authorizer( + self, extraction_method, extract_authorizers_mock, extract_swagger_route_mock + ): + authorizer_name = "myauth" + + properties = { + "Properties": {"DefinitionBody": {"something": "here"}, "Auth": {"DefaultAuthorizer": authorizer_name}} + } + + logical_id_mock = Mock() + api_collector_mock = Mock() + api_collector_mock.set_default_authorizer = Mock() + + extraction_method(Mock(), logical_id_mock, properties, api_collector_mock, Mock()) + + api_collector_mock.set_default_authorizer.assert_called_with(logical_id_mock, authorizer_name) + + @parameterized.expand( + [ + ( # test token + swagger 2.0 + { + "Authorizers": { + "mycoolauthorizer": { + "FunctionPayloadType": "TOKEN", + "Identity": { + "Header": "myheader", + }, + "FunctionArn": "will_be_mocked", + } + } + }, + { + "mycoolauthorizer": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="mycoolauthorizer", + type="token", + lambda_name=ANY, + identity_sources=["method.request.header.myheader"], + ) + }, + Route.API, + ), + ( # test no identity header + token + swagger 2.0 + { + "Authorizers": { + "mycoolauthorizer": { + "FunctionPayloadType": "TOKEN", + "FunctionArn": "will_be_mocked", + } + } + }, + { + "mycoolauthorizer": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="mycoolauthorizer", + type="token", + lambda_name=ANY, + identity_sources=["method.request.header.Authorization"], + ) + }, + Route.API, + ), + ( # test request + swagger 2.0 + { + "Authorizers": { + "mycoolauthorizer": { + "FunctionPayloadType": "REQUEST", + "Identity": { + "QueryStrings": ["query1", "query2"], + "Headers": ["header1", "header2"], + "Context": ["context1", "context2"], + "StageVariables": ["stage1", "stage2"], + }, + "FunctionArn": "will_be_mocked", + "AuthorizerPayloadFormatVersion": "1.0", + } + } + }, + { + "mycoolauthorizer": LambdaAuthorizer( + payload_version="1.0", + authorizer_name="mycoolauthorizer", + type="request", + lambda_name=ANY, + identity_sources=[ + "method.request.header.header1", + "method.request.header.header2", + "method.request.querystring.query1", + "method.request.querystring.query2", + "context.context1", + "context.context2", + "stageVariables.stage1", + "stageVariables.stage2", + ], + ) + }, + Route.API, + ), + ( # test openapi3 (http api event) + { + "Authorizers": { + "mycoolauthorizer": { + "Identity": { + "QueryStrings": ["query1", "query2"], + "Headers": ["header1", "header2"], + "Context": ["context1", "context2"], + "StageVariables": ["stage1", "stage2"], + }, + "AuthorizerPayloadFormatVersion": "2.0", + "EnableSimpleResponses": True, + "FunctionArn": "will_be_mocked", + } + } + }, + { + "mycoolauthorizer": LambdaAuthorizer( + payload_version="2.0", + authorizer_name="mycoolauthorizer", + type="request", + lambda_name=ANY, + use_simple_response=True, + identity_sources=[ + "$request.header.header1", + "$request.header.header2", + "$request.querystring.query1", + "$request.querystring.query2", + "$context.context1", + "$context.context2", + "$stageVariables.stage1", + "$stageVariables.stage2", + ], + ) + }, + Route.HTTP, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_extract_lambda_authorizers_from_properties( + self, properties, expected_authorizers, event_type, function_name_mock + ): + logical_id = Mock() + + function_name_mock.return_value = Mock() + + collector_mock = Mock() + collector_mock.add_authorizers = Mock() + + SamApiProvider._extract_authorizers_from_props(logical_id, properties, collector_mock, event_type) + + collector_mock.add_authorizers.assert_called_with(logical_id, expected_authorizers) + + @parameterized.expand( + [ + ( # missing function arn + { + "Authorizers": { + "mycoolauthorizer": { + "FunctionPayloadType": "TOKEN", + "Identity": { + "Header": "myheader", + }, + } + } + }, + ), + ( # invalid (blank) function arn + { + "Authorizers": { + "mycoolauthorizer": { + "FunctionPayloadType": "TOKEN", + "Identity": { + "Header": "myheader", + }, + "FunctionArn": "", + } + } + }, + ), + ( # not a token or request authorizer + { + "Authorizers": { + "mycoolauthorizer": { + "FunctionPayloadType": "TOKEN", + "Identity": { + "Header": "myheader", + }, + "FunctionArn": "function", + "FunctionPayloadType": "hello world", + } + } + }, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_extract_invalid_authorizers_from_properties(self, properties, function_name_mock): + logical_id = Mock() + + function_name_mock.return_value = Mock() + + collector_mock = Mock() + collector_mock.add_authorizers = Mock() + + SamApiProvider._extract_authorizers_from_props(logical_id, properties, collector_mock, Route.API) + + collector_mock.add_authorizers.assert_called_with(logical_id, {}) + + @parameterized.expand( + [ + ( # wrong payload type + { + "FunctionPayloadType": "REQUEST", + "Identity": { + "Header": "myheader", + }, + "AuthorizerPayloadFormatVersion": True, + }, + "'AuthorizerPayloadFormatVersion' must be of type string for Lambda Authorizer 'auth'.", + ), + ( # missing payload format version + { + "FunctionPayloadType": "REQUEST", + "Identity": { + "Header": "myheader", + }, + }, + "Lambda Authorizer 'auth' must contain a valid 'AuthorizerPayloadFormatVersion' for HTTP APIs.", + ), + ( # invalid payload format version + { + "FunctionPayloadType": "REQUEST", + "Identity": { + "Header": "myheader", + }, + "AuthorizerPayloadFormatVersion": "invalid", + }, + "Lambda Authorizer 'auth' must contain a valid 'AuthorizerPayloadFormatVersion' for HTTP APIs.", + ), + ( # simple responses using wrong format version + { + "FunctionPayloadType": "REQUEST", + "Identity": { + "Header": "myheader", + }, + "AuthorizerPayloadFormatVersion": "1.0", + "EnableSimpleResponses": True, + }, + "EnableSimpleResponses must be used with the 2.0 payload format version in Lambda Authorizer 'auth'.", + ), + ] + ) + def test_extract_invalid_http_authorizer_throws_exception(self, properties, expected_ex): + with self.assertRaisesRegex(InvalidSamDocumentException, expected_ex): + SamApiProvider._extract_request_lambda_authorizer("auth", "lambda", Mock(), properties, Route.HTTP) + + @parameterized.expand( + [ + ({"Auth": {"Authorizer": "myauth"}}, "myauth", True), # defined auth + ({"Auth": {"Authorizer": "NONE"}}, None, False), # explict no authorizers + ({}, None, True), # default auth + ] + ) + def test_add_authorizer_in_serverless_function(self, authorizer_obj, expected_auth_name, use_default): + properties = {"Path": "path", "Method": "method", "RestApiId": "id"} + + if authorizer_obj: + properties.update(authorizer_obj) + + _, route = SamApiProvider._convert_event_route(Mock(), Mock(), properties, Route.API) + + self.assertEqual( + route, Route(ANY, ANY, ["method"], ANY, ANY, ANY, ANY, ANY, expected_auth_name, ANY, use_default) + ) + + def make_swagger(routes, binary_media_types=None): """ Given a list of API configurations named tuples, returns a Swagger document @@ -1845,7 +2146,7 @@ def make_swagger(routes, binary_media_types=None): Swagger document """ - swagger = {"paths": {}} + swagger = {"paths": {}, "swagger": "2.0"} for api in routes: swagger["paths"].setdefault(api.path, {}) diff --git a/tests/unit/commands/local/lib/validators/__init__.py b/tests/unit/commands/local/lib/validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/commands/local/lib/validators/test_lambda_auth_props.py b/tests/unit/commands/local/lib/validators/test_lambda_auth_props.py new file mode 100644 index 0000000000..5be500cfc4 --- /dev/null +++ b/tests/unit/commands/local/lib/validators/test_lambda_auth_props.py @@ -0,0 +1,281 @@ +from unittest import TestCase +from unittest.mock import patch +from parameterized import parameterized + +from samcli.commands.local.cli_common.user_exceptions import InvalidSamTemplateException +from samcli.commands.local.lib.validators.lambda_auth_props import ( + LambdaAuthorizerV1Validator, + LambdaAuthorizerV2Validator, +) + + +class TestLambdaAuthorizerV1Validator(TestCase): + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_valid_v1_properties(self, function_mock): + logical_id = "id" + properties = { + "Properties": { + "Type": "REQUEST", + "RestApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": "method.request.header.auth, method.request.querystring.abc", + } + } + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + function_mock.return_value = auth_lambda_func_name + + self.assertTrue(LambdaAuthorizerV1Validator.validate(logical_id, properties)) + + @parameterized.expand( + [ + ( # test no type + {"Properties": {}}, + "Authorizer 'my-auth-id' is missing the 'Type' property, an Authorizer type must be defined.", + ), + ( # test no rest api id + {"Properties": {"Type": "TOKEN"}}, + "Authorizer 'my-auth-id' is missing the 'RestApiId' property, this must be defined.", + ), + ( # test no name + {"Properties": {"Type": "TOKEN", "RestApiId": "restapiid"}}, + "Authorizer 'my-auth-id' is missing the 'Name' property, the Name must be defined.", + ), + ( # test no authorizer uri + {"Properties": {"Type": "TOKEN", "RestApiId": "restapiid", "Name": "myauth"}}, + "Authorizer 'my-auth-id' is missing the 'AuthorizerUri' property, a valid Lambda ARN must be provided.", + ), + ( # test invalid identity source (missing) + {"Properties": {"Type": "TOKEN", "RestApiId": "restapiid", "Name": "myauth", "AuthorizerUri": "arn"}}, + "Lambda Authorizer 'my-auth-id' of type TOKEN, must have 'IdentitySource' of type string defined.", + ), + ( # test invalid identity source (must be str) + { + "Properties": { + "Type": "TOKEN", + "RestApiId": "restapiid", + "Name": "myauth", + "AuthorizerUri": "arn", + "IdentitySource": {}, + } + }, + "Lambda Authorizer 'my-auth-id' contains an invalid 'IdentitySource', it must be a comma-separated string.", + ), + ( # test request type using validation + { + "Properties": { + "Type": "REQUEST", + "RestApiId": "restapiid", + "Name": "myauth", + "AuthorizerUri": "arn", + "IdentityValidationExpression": "123", + } + }, + "Lambda Authorizer 'my-auth-id' has 'IdentityValidationExpression' property defined, but validation is only supported on TOKEN type authorizers.", + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_invalid_v1_lamabda_authorizers(self, resource, expected_exception_message, get_func_name_mock): + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + get_func_name_mock.return_value = auth_lambda_func_name + + with self.assertRaisesRegex(InvalidSamTemplateException, expected_exception_message): + LambdaAuthorizerV1Validator.validate(lambda_auth_logical_id, resource) + + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_invalid_v1_skip_invalid_type(self, get_func_name_mock): + properties = {"Properties": {"Type": "_-_-_", "RestApiId": "restapiid", "Name": "myauth"}} + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + get_func_name_mock.return_value = auth_lambda_func_name + + self.assertFalse(LambdaAuthorizerV1Validator.validate(lambda_auth_logical_id, properties)) + + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_invalid_v1_skip_invalid_arn(self, get_func_name_mock): + properties = { + "Properties": {"Type": "TOKEN", "RestApiId": "restapiid", "Name": "myauth", "AuthorizerUri": "arn"} + } + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function to return None + get_func_name_mock.return_value = None + + self.assertFalse(LambdaAuthorizerV1Validator.validate(lambda_auth_logical_id, properties)) + + +class TestLambdaAuthorizerV2Validator(TestCase): + @parameterized.expand( + [ + ( # authorizer with 2.0 payload and simple responses + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": ["$request.header.auth", "$context.something"], + "AuthorizerPayloadFormatVersion": "2.0", + "EnableSimpleResponses": True, + } + }, + ), + ( # authorizer with 2.0 payload and NO simple responses + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": ["$request.header.auth", "$context.something"], + "AuthorizerPayloadFormatVersion": "2.0", + "EnableSimpleResponses": False, + } + }, + ), + ( # authorizer with 1.0 payload and NO simple responses + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": ["$request.header.auth", "$context.something"], + "AuthorizerPayloadFormatVersion": "1.0", + } + }, + ), + ( # authorizer with missing payload version + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "my-rest-api", + "Name": "my-auth-name", + "AuthorizerUri": "arn", + "IdentitySource": ["$request.header.auth", "$context.something"], + } + }, + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_valid_v2_properties(self, properties, function_mock): + logical_id = "id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + function_mock.return_value = auth_lambda_func_name + + self.assertTrue(LambdaAuthorizerV2Validator.validate(logical_id, properties)) + + @parameterized.expand( + [ + ( # test no type + {"Properties": {}}, + "Authorizer 'my-auth-id' is missing the 'AuthorizerType' property, an Authorizer type must be defined.", + ), + ( # test no rest api id + {"Properties": {"AuthorizerType": "REQUEST"}}, + "Authorizer 'my-auth-id' is missing the 'ApiId' property, this must be defined.", + ), + ( # test no name + {"Properties": {"AuthorizerType": "REQUEST", "ApiId": "restapiid"}}, + "Authorizer 'my-auth-id' is missing the 'Name' property, the Name must be defined.", + ), + ( # test no authorizer uri + {"Properties": {"AuthorizerType": "REQUEST", "ApiId": "restapiid", "Name": "myauth"}}, + "Authorizer 'my-auth-id' is missing the 'AuthorizerUri' property, a valid Lambda ARN must be provided.", + ), + ( # test invalid identity source (missing) + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "restapiid", + "Name": "myauth", + "AuthorizerUri": "arn", + } + }, + "Lambda Authorizer 'my-auth-id' must have 'IdentitySource' of type list defined.", + ), + ( # test invalid identity source (must be list) + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "restapiid", + "Name": "myauth", + "AuthorizerUri": "arn", + "IdentitySource": "hello world, im not a list", + } + }, + "Lambda Authorizer 'my-auth-id' must have 'IdentitySource' of type list defined.", + ), + ( # test invalid payload version + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "restapiid", + "Name": "myauth", + "AuthorizerUri": "arn", + "IdentitySource": [], + "AuthorizerPayloadFormatVersion": "1.2.3", + } + }, + "Lambda Authorizer 'my-auth-id' contains an invalid 'AuthorizerPayloadFormatVersion', it must be set to '1.0' or '2.0'", + ), + ( # test using simple response but wrong payload version + { + "Properties": { + "AuthorizerType": "REQUEST", + "ApiId": "restapiid", + "Name": "myauth", + "AuthorizerUri": "arn", + "IdentitySource": [], + "AuthorizerPayloadFormatVersion": "1.0", + "EnableSimpleResponses": True, + } + }, + "'EnableSimpleResponses' is only supported for '2.0' payload format versions for Lambda Authorizer 'my-auth-id'.", + ), + ] + ) + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_invalid_v2_lamabda_authorizers(self, resource, expected_exception_message, get_func_name_mock): + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + get_func_name_mock.return_value = auth_lambda_func_name + + with self.assertRaisesRegex(InvalidSamTemplateException, expected_exception_message): + LambdaAuthorizerV2Validator.validate(lambda_auth_logical_id, resource) + + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_invalid_v2_skip_invalid_type(self, get_func_name_mock): + properties = {"Properties": {"AuthorizerType": "TOKEN", "ApiId": "restapiid", "Name": "myauth"}} + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function + auth_lambda_func_name = "my-lambda" + get_func_name_mock.return_value = auth_lambda_func_name + + self.assertFalse(LambdaAuthorizerV2Validator.validate(lambda_auth_logical_id, properties)) + + @patch("samcli.commands.local.lib.swagger.integration_uri.LambdaUri.get_function_name") + def test_invalid_v2_skip_invalid_arn(self, get_func_name_mock): + properties = { + "Properties": {"AuthorizerType": "REQUEST", "ApiId": "restapiid", "Name": "myauth", "AuthorizerUri": "arn"} + } + lambda_auth_logical_id = "my-auth-id" + + # mock ARN resolving function to return None + get_func_name_mock.return_value = None + + self.assertFalse(LambdaAuthorizerV2Validator.validate(lambda_auth_logical_id, properties)) diff --git a/tests/unit/hook_packages/terraform/hooks/prepare/prepare_base.py b/tests/unit/hook_packages/terraform/hooks/prepare/prepare_base.py index 6d6416801b..ee226d532c 100644 --- a/tests/unit/hook_packages/terraform/hooks/prepare/prepare_base.py +++ b/tests/unit/hook_packages/terraform/hooks/prepare/prepare_base.py @@ -7,6 +7,9 @@ from samcli.lib.utils.resources import ( AWS_LAMBDA_FUNCTION as CFN_AWS_LAMBDA_FUNCTION, AWS_LAMBDA_LAYERVERSION, + AWS_APIGATEWAY_RESOURCE, + AWS_APIGATEWAY_RESTAPI, + AWS_APIGATEWAY_STAGE, ) @@ -34,6 +37,10 @@ def setUp(self) -> None: self.image_function_name = "image_func" self.lambda_layer_name = "lambda_layer" + self.apigw_resource_name = "my_resource" + self.apigw_stage_name = "my_stage" + self.apigw_rest_api_name = "my_rest_api" + self.tf_function_common_properties: dict = { "function_name": self.zip_function_name, "architectures": ["x86_64"], @@ -298,6 +305,21 @@ def setUp(self) -> None: "FunctionName": self.zip_function_name_4, } + self.tf_apigw_resource_common_attributes: dict = { + "type": "aws_api_gateway_resource", + "provider_name": AWS_PROVIDER_NAME, + } + + self.tf_apigw_stage_common_attributes: dict = { + "type": "aws_api_gateway_stage", + "provider_name": AWS_PROVIDER_NAME, + } + + self.tf_apigw_rest_api_common_attributes: dict = { + "type": "aws_api_gateway_rest_api", + "provider_name": AWS_PROVIDER_NAME, + } + self.tf_lambda_function_resource_common_attributes: dict = { "type": "aws_lambda_function", "provider_name": AWS_PROVIDER_NAME, @@ -467,6 +489,95 @@ def setUp(self) -> None: "name": "s3_lambda_code_2", } + self.tf_apigw_resource_properties: dict = { + "rest_api_id": "aws_api_gateway_rest_api.MyDemoAPI.id", + "parent_id": "aws_api_gateway_rest_api.MyDemoAPI.root_resource_id", + "path_part": "mydemoresource", + } + + self.expected_cfn_apigw_resource_properties: dict = { + "RestApiId": "aws_api_gateway_rest_api.MyDemoAPI.id", + "ParentId": "aws_api_gateway_rest_api.MyDemoAPI.root_resource_id", + "PathPart": "mydemoresource", + } + + self.tf_apigw_resource_resource: dict = { + **self.tf_apigw_resource_common_attributes, + "values": self.tf_apigw_resource_properties, + "address": f"aws_api_gateway_resource.{self.apigw_resource_name}", + "name": self.apigw_resource_name, + } + + self.expected_cfn_apigw_resource: dict = { + "Type": AWS_APIGATEWAY_RESOURCE, + "Properties": self.expected_cfn_apigw_resource_properties, + "Metadata": {"SamResourceId": f"aws_api_gateway_resource.{self.apigw_resource_name}"}, + } + + self.tf_apigw_stage_properties: dict = { + "rest_api_id": "aws_api_gateway_rest_api.MyDemoAPI.id", + "stage_name": "test", + "variables": {"key1": "value1"}, + } + + self.expected_cfn_apigw_stage_properties: dict = { + "RestApiId": "aws_api_gateway_rest_api.MyDemoAPI.id", + "StageName": "test", + "Variables": {"key1": "value1"}, + } + + self.tf_apigw_stage_resource: dict = { + **self.tf_apigw_stage_common_attributes, + "values": self.tf_apigw_stage_properties, + "address": f"aws_api_gateway_stage.{self.apigw_stage_name}", + "name": self.apigw_stage_name, + } + + self.expected_cfn_apigw_stage_resource: dict = { + "Type": AWS_APIGATEWAY_STAGE, + "Properties": self.expected_cfn_apigw_stage_properties, + "Metadata": {"SamResourceId": f"aws_api_gateway_stage.{self.apigw_stage_name}"}, + } + + self.tf_apigw_rest_api_properties: dict = { + "name": self.apigw_rest_api_name, + "body": { + "openapi": "3.0.1", + "info": { + "title": "example", + "version": "1.0", + }, + }, + "parameters": {"param_a": "value_a"}, + "binary_media_types": ["utf-8"], + } + + self.expected_cfn_apigw_rest_api_properties: dict = { + "Name": self.apigw_rest_api_name, + "Body": { + "openapi": "3.0.1", + "info": { + "title": "example", + "version": "1.0", + }, + }, + "Parameters": {"param_a": "value_a"}, + "BinaryMediaTypes": ["utf-8"], + } + + self.tf_apigw_rest_api_resource: dict = { + **self.tf_apigw_rest_api_common_attributes, + "values": self.tf_apigw_rest_api_properties, + "address": f"aws_api_gateway_rest_api.{self.apigw_rest_api_name}", + "name": self.apigw_rest_api_name, + } + + self.expected_cfn_apigw_rest_api: dict = { + "Type": AWS_APIGATEWAY_RESTAPI, + "Properties": self.expected_cfn_apigw_rest_api_properties, + "Metadata": {"SamResourceId": f"aws_api_gateway_rest_api.{self.apigw_rest_api_name}"}, + } + self.tf_json_with_root_module_only: dict = { "planned_values": { "root_module": { @@ -474,6 +585,9 @@ def setUp(self) -> None: self.tf_lambda_function_resource_zip, self.tf_lambda_function_resource_zip_2, self.tf_image_package_type_lambda_function_resource, + self.tf_apigw_resource_resource, + self.tf_apigw_rest_api_resource, + self.tf_apigw_stage_resource, ] } } @@ -484,6 +598,9 @@ def setUp(self) -> None: f"AwsLambdaFunctionMyfunc{self.mock_logical_id_hash}": self.expected_cfn_lambda_function_resource_zip, f"AwsLambdaFunctionMyfunc2{self.mock_logical_id_hash}": self.expected_cfn_lambda_function_resource_zip_2, f"AwsLambdaFunctionImageFunc{self.mock_logical_id_hash}": self.expected_cfn_image_package_type_lambda_function_resource, + f"AwsApiGatewayResourceMyResource{self.mock_logical_id_hash}": self.expected_cfn_apigw_resource, + f"AwsApiGatewayRestApiMyRestApi{self.mock_logical_id_hash}": self.expected_cfn_apigw_rest_api, + f"AwsApiGatewayStageMyStage{self.mock_logical_id_hash}": self.expected_cfn_apigw_stage_resource, }, } diff --git a/tests/unit/hook_packages/terraform/hooks/prepare/resources/__init__.py b/tests/unit/hook_packages/terraform/hooks/prepare/resources/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/hook_packages/terraform/hooks/prepare/resources/test_apigw.py b/tests/unit/hook_packages/terraform/hooks/prepare/resources/test_apigw.py new file mode 100644 index 0000000000..356c62ea79 --- /dev/null +++ b/tests/unit/hook_packages/terraform/hooks/prepare/resources/test_apigw.py @@ -0,0 +1,42 @@ +from unittest import TestCase +from unittest.mock import Mock, patch + +from parameterized import parameterized + +from samcli.hook_packages.terraform.hooks.prepare.exceptions import OpenAPIBodyNotSupportedException +from samcli.hook_packages.terraform.hooks.prepare.resources.apigw import ( + _unsupported_reference_field, + RESTAPITranslationValidator, +) +from samcli.hook_packages.terraform.hooks.prepare.types import References, TFResource, ConstantValue + + +class TestRESTAPITranslationValidator(TestCase): + @patch("samcli.hook_packages.terraform.hooks.prepare.resources.apigw._unsupported_reference_field") + def test_validate_valid(self, mock_unsupported_reference_field): + mock_unsupported_reference_field.return_value = False + validator = RESTAPITranslationValidator({}, TFResource("address", "", Mock(), {})) + validator.validate() + + @patch("samcli.hook_packages.terraform.hooks.prepare.resources.apigw._unsupported_reference_field") + def test_validate_invalid(self, mock_unsupported_reference_field): + mock_unsupported_reference_field.return_value = True + validator = RESTAPITranslationValidator({}, TFResource("address", "", Mock(), {})) + with self.assertRaises(OpenAPIBodyNotSupportedException) as ex: + validator.validate() + self.assertIn( + "AWS SAM CLI is unable to process a Terraform project that " + "uses an OpenAPI specification to define the API Gateway resource.", + ex.exception.message, + ) + + @parameterized.expand( + [ + ({"field": "a"}, TFResource("address", "", Mock(), {}), False), + ({}, TFResource("address", "", Mock(), {"field": ConstantValue("a")}), False), + ({}, TFResource("address", "", Mock(), {"field": References(["a"])}), True), + ] + ) + def test_unsupported_reference_field(self, resource, config_resource, expected): + result = _unsupported_reference_field("field", resource, config_resource) + self.assertEqual(result, expected) diff --git a/tests/unit/hook_packages/terraform/hooks/prepare/test_translate.py b/tests/unit/hook_packages/terraform/hooks/prepare/test_translate.py index c6253603d1..94f374b3ef 100644 --- a/tests/unit/hook_packages/terraform/hooks/prepare/test_translate.py +++ b/tests/unit/hook_packages/terraform/hooks/prepare/test_translate.py @@ -6,6 +6,10 @@ from samcli.hook_packages.terraform.hooks.prepare.property_builder import ( AWS_LAMBDA_FUNCTION_PROPERTY_BUILDER_MAPPING, REMOTE_DUMMY_VALUE, + AWS_API_GATEWAY_RESOURCE_PROPERTY_BUILDER_MAPPING, + AWS_API_GATEWAY_REST_API_PROPERTY_BUILDER_MAPPING, + AWS_API_GATEWAY_STAGE_PROPERTY_BUILDER_MAPPING, + TF_AWS_API_GATEWAY_REST_API, ) from samcli.hook_packages.terraform.hooks.prepare.types import ( SamMetadataResource, @@ -87,9 +91,16 @@ def test_translate_to_cfn_with_root_module_only( config_resource = Mock() resources_mock.__getitem__.return_value = config_resource resources_mock.__contains__.return_value = True + mock_validator = Mock() mock_build_module.return_value = root_module checksum_mock.return_value = self.mock_logical_id_hash - translated_cfn_dict = translate_to_cfn(self.tf_json_with_root_module_only, self.output_dir, self.project_root) + with patch( + "samcli.hook_packages.terraform.hooks.prepare.translate.TRANSLATION_VALIDATORS", + {TF_AWS_API_GATEWAY_REST_API: mock_validator}, + ): + translated_cfn_dict = translate_to_cfn( + self.tf_json_with_root_module_only, self.output_dir, self.project_root + ) self.assertEqual(translated_cfn_dict, self.expected_cfn_with_root_module_only) mock_enrich_resources_and_generate_makefile.assert_not_called() lambda_functions = dict( @@ -105,6 +116,10 @@ def test_translate_to_cfn_with_root_module_only( ] mock_link_lambda_functions_to_layers.assert_called_once_with(*expected_arguments_in_call) mock_get_configuration_address.assert_called() + mock_validator.assert_called_once_with( + resource=self.tf_apigw_rest_api_resource, config_resource=config_resource + ) + mock_validator.return_value.validate.assert_called_once() @patch("samcli.hook_packages.terraform.hooks.prepare.translate._resolve_resource_attribute") @patch("samcli.hook_packages.terraform.hooks.prepare.translate._check_dummy_remote_values") @@ -1035,3 +1050,21 @@ def test_get_s3_object_hash(self): self.assertNotEqual( _get_s3_object_hash(self.s3_bucket, self.s3_key), _get_s3_object_hash(self.s3_bucket, self.s3_key_2) ) + + def test_translating_apigw_resource(self): + translated_cfn_properties = _translate_properties( + self.tf_apigw_resource_properties, AWS_API_GATEWAY_RESOURCE_PROPERTY_BUILDER_MAPPING, Mock() + ) + self.assertEqual(translated_cfn_properties, self.expected_cfn_apigw_resource_properties) + + def test_translating_apigw_stage_resource(self): + translated_cfn_properties = _translate_properties( + self.tf_apigw_stage_properties, AWS_API_GATEWAY_STAGE_PROPERTY_BUILDER_MAPPING, Mock() + ) + self.assertEqual(translated_cfn_properties, self.expected_cfn_apigw_stage_properties) + + def test_translating_apigw_rest_api(self): + translated_cfn_properties = _translate_properties( + self.tf_apigw_rest_api_properties, AWS_API_GATEWAY_REST_API_PROPERTY_BUILDER_MAPPING, Mock() + ) + self.assertEqual(translated_cfn_properties, self.expected_cfn_apigw_rest_api_properties) diff --git a/tests/unit/lib/telemetry/test_event.py b/tests/unit/lib/telemetry/test_event.py index 853ee63dba..5738743788 100644 --- a/tests/unit/lib/telemetry/test_event.py +++ b/tests/unit/lib/telemetry/test_event.py @@ -7,6 +7,7 @@ from typing import List, Tuple from unittest import TestCase from unittest.mock import ANY, Mock, patch +from samcli.cli.context import Context from samcli.lib.telemetry.event import Event, EventCreationError, EventTracker, track_long_event @@ -64,7 +65,13 @@ def test_event_to_json(self, name_mock, type_mock, verify_mock): self.assertEqual( test_event.to_json(), - {"event_name": "Testing", "event_value": "value1", "thread_id": threading.get_ident(), "time_stamp": ANY}, + { + "event_name": "Testing", + "event_value": "value1", + "thread_id": threading.get_ident(), + "time_stamp": ANY, + "exception_name": None, + }, ) @@ -143,6 +150,7 @@ def test_events_get_sent(self, telemetry_mock): "event_value": "SomeValue", "thread_id": ANY, "time_stamp": ANY, + "exception_name": ANY, } ] }, @@ -177,6 +185,17 @@ def make_mock_event(name, value): send_mock.assert_called() + @patch("samcli.cli.context.Context.get_current_context") + def test_session_id_set(self, context_mock): + mock = Mock() + mock.session_id = "123" + context_mock.return_value = mock + + EventTracker._session_id = None + EventTracker._set_session_id() + + self.assertEqual(EventTracker._session_id, "123") + class TestTrackLongEvent(TestCase): @patch("samcli.lib.telemetry.event.EventTracker.send_events") diff --git a/tests/unit/lib/telemetry/test_metric.py b/tests/unit/lib/telemetry/test_metric.py index 5f14c81ea4..22f695eebe 100644 --- a/tests/unit/lib/telemetry/test_metric.py +++ b/tests/unit/lib/telemetry/test_metric.py @@ -496,20 +496,38 @@ def setUp(self): def tearDown(self): pass - @parameterized.expand([(CICDPlatform.Appveyor, "Appveyor", "ci"), (None, "CLI", False)]) + @parameterized.expand( + [ + (CICDPlatform.Appveyor, "Appveyor", "ci", None), + (None, "CLI", False, None), + (None, "CLI", False, "AWS-Toolkit-For-VSCode/1.62.0"), + ] + ) @patch("samcli.lib.telemetry.metric.CICDDetector.platform") + @patch("samcli.lib.telemetry.metric.get_user_agent_string") @patch("samcli.lib.telemetry.metric.platform") @patch("samcli.lib.telemetry.metric.Context") @patch("samcli.lib.telemetry.metric.GlobalConfig") @patch("samcli.lib.telemetry.metric.uuid") def test_must_add_common_attributes( - self, cicd_platform, execution_env, ci, uuid_mock, gc_mock, context_mock, platform_mock, cicd_platform_mock + self, + cicd_platform, + execution_env, + ci, + user_agent, + uuid_mock, + gc_mock, + context_mock, + platform_mock, + get_user_agent_mock, + cicd_platform_mock, ): request_id = uuid_mock.uuid4.return_value = "fake requestId" installation_id = gc_mock.return_value.installation_id = "fake installation id" session_id = context_mock.get_current_context.return_value.session_id = "fake installation id" python_version = platform_mock.python_version.return_value = "8.8.0" cicd_platform_mock.return_value = cicd_platform + get_user_agent_mock.return_value = user_agent metric = Metric("metric_name") @@ -518,6 +536,8 @@ def test_must_add_common_attributes( assert metric.get_data()["sessionId"] == session_id assert metric.get_data()["executionEnvironment"] == execution_env assert metric.get_data()["ci"] == bool(ci) + if user_agent: + assert metric.get_data()["userAgent"] == user_agent assert metric.get_data()["pyversion"] == python_version assert metric.get_data()["samcliVersion"] == samcli.__version__ diff --git a/tests/unit/lib/telemetry/test_user_agent.py b/tests/unit/lib/telemetry/test_user_agent.py new file mode 100644 index 0000000000..6fb14a83dd --- /dev/null +++ b/tests/unit/lib/telemetry/test_user_agent.py @@ -0,0 +1,45 @@ +from unittest import TestCase +from unittest.mock import patch + +from parameterized import parameterized + +from samcli.lib.telemetry.user_agent import USER_AGENT_ENV_VAR, get_user_agent_string + + +class TestUserAgent(TestCase): + @parameterized.expand( + [ + ("AWS_Toolkit-For-VSCode/1.62.0",), + ("AWS-Toolkit-For-JetBrains/1.60-223",), + ("AWS-Toolkit-For-JetBrains/1.60.0-223",), + ("AWS-Toolkit-For-JetBrains0/1.60.0-223",), + ("AWS-Toolkit-For-JetBrains/1.60.0-2230",), + ] + ) + def test_user_agent(self, agent_value): + with patch("samcli.lib.telemetry.user_agent.os.environ", {USER_AGENT_ENV_VAR: agent_value}): + self.assertEqual(get_user_agent_string(), agent_value) + + @parameterized.expand( + [ + ("invalid_value",), # not matching the format at all + ("AWS_Toolkit-For-VSCode/1",), # not matching semver version + ("AWS_Toolkit-For-V$Code/1.1.0",), # invalid char in the name + ("AWS_Toolkit-For-VSCode/1.1.0-patch$",), # invalid char in the version + # too long product name (> 64) + ("AWS_Toolkit-For-VSCodeAWS_Toolkit-For-VSCodeAWS_Toolkit-For-VSCode/1.1.0-patch$",), + # too long version extension (> 16) + ("AWS_Toolkit-For-VSCode/1.1.0-patchpatchpatchpatch",), + ] + ) + def test_user_agent_with_invalid_value(self, agent_value): + with patch("samcli.lib.telemetry.user_agent.os.environ", {USER_AGENT_ENV_VAR: agent_value}): + self.assertEqual(get_user_agent_string(), None) + + @patch("samcli.lib.telemetry.user_agent.os.environ", {}) + def test_user_agent_without_env_var(self): + self.assertEqual(get_user_agent_string(), None) + + @patch("samcli.lib.telemetry.user_agent.os.environ", {USER_AGENT_ENV_VAR: ""}) + def test_user_agent_with_empty_env_var(self): + self.assertEqual(get_user_agent_string(), None) diff --git a/tests/unit/local/apigw/test_event_constructor.py b/tests/unit/local/apigw/test_event_constructor.py new file mode 100644 index 0000000000..6035a6efc9 --- /dev/null +++ b/tests/unit/local/apigw/test_event_constructor.py @@ -0,0 +1,357 @@ +import base64 +from datetime import datetime +import json +from time import time +from unittest import TestCase +from unittest.mock import Mock, patch +from parameterized import parameterized, param + +from samcli.local.apigw.event_constructor import ( + _event_headers, + _event_http_headers, + _query_string_params, + _query_string_params_v_2_0, + _should_base64_encode, + construct_v1_event, + construct_v2_event_http, +) +from samcli.local.apigw.local_apigw_service import LocalApigwService + + +class TestService_construct_event(TestCase): + def setUp(self): + self.request_mock = Mock() + self.request_mock.endpoint = "endpoint" + self.request_mock.path = "path" + self.request_mock.method = "GET" + self.request_mock.remote_addr = "190.0.0.0" + self.request_mock.host = "190.0.0.1" + self.request_mock.get_data.return_value = b"DATA!!!!" + query_param_args_mock = Mock() + query_param_args_mock.lists.return_value = {"query": ["params"]}.items() + self.request_mock.args = query_param_args_mock + headers_mock = Mock() + headers_mock.keys.return_value = ["Content-Type", "X-Test"] + headers_mock.get.side_effect = ["application/json", "Value"] + headers_mock.getlist.side_effect = [["application/json"], ["Value"]] + self.request_mock.headers = headers_mock + self.request_mock.view_args = {"path": "params"} + self.request_mock.scheme = "http" + environ_dict = {"SERVER_PROTOCOL": "HTTP/1.1"} + self.request_mock.environ = environ_dict + + expected = ( + '{"body": "DATA!!!!", "httpMethod": "GET", ' + '"multiValueQueryStringParameters": {"query": ["params"]}, ' + '"queryStringParameters": {"query": "params"}, "resource": ' + '"endpoint", "requestContext": {"httpMethod": "GET", "requestId": ' + '"c6af9ac6-7b61-11e6-9a41-93e8deadbeef", "path": "endpoint", "extendedRequestId": null, ' + '"resourceId": "123456", "apiId": "1234567890", "stage": null, "resourcePath": "endpoint", ' + '"identity": {"accountId": null, "apiKey": null, "userArn": null, ' + '"cognitoAuthenticationProvider": null, "cognitoIdentityPoolId": null, "userAgent": ' + '"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' + '"190.0.0.0", "user": null}, "accountId": "123456789012", "domainName": "190.0.0.1", ' + '"protocol": "HTTP/1.1"}, "headers": {"Content-Type": ' + '"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' + '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], ' + '"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' + '"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' + '"isBase64Encoded": false}' + ) + + self.expected_dict = json.loads(expected) + + def validate_request_context_and_remove_request_time_data(self, event_json): + request_time = event_json["requestContext"].pop("requestTime", None) + request_time_epoch = event_json["requestContext"].pop("requestTimeEpoch", None) + + self.assertIsInstance(request_time, str) + parsed_request_time = datetime.strptime(request_time, "%d/%b/%Y:%H:%M:%S +0000") + self.assertIsInstance(parsed_request_time, datetime) + + self.assertIsInstance(request_time_epoch, int) + + def test_construct_event_with_data(self): + actual_event_json = construct_v1_event(self.request_mock, 3000, binary_types=[]) + self.validate_request_context_and_remove_request_time_data(actual_event_json) + + self.assertEqual(actual_event_json["body"], self.expected_dict["body"]) + + def test_construct_event_no_data(self): + self.request_mock.get_data.return_value = None + + actual_event_json = construct_v1_event(self.request_mock, 3000, binary_types=[]) + self.validate_request_context_and_remove_request_time_data(actual_event_json) + + self.assertEqual(actual_event_json["body"], None) + + @patch("samcli.local.apigw.event_constructor._should_base64_encode") + def test_construct_event_with_binary_data(self, should_base64_encode_patch): + should_base64_encode_patch.return_value = True + + binary_body = b"011000100110100101101110011000010111001001111001" # binary in binary + base64_body = base64.b64encode(binary_body).decode("utf-8") + + self.request_mock.get_data.return_value = binary_body + + actual_event_json = construct_v1_event(self.request_mock, 3000, binary_types=[]) + self.validate_request_context_and_remove_request_time_data(actual_event_json) + + self.assertEqual(actual_event_json["body"], base64_body) + self.assertEqual(actual_event_json["isBase64Encoded"], True) + + def test_event_headers_with_empty_list(self): + request_mock = Mock() + headers_mock = Mock() + headers_mock.keys.return_value = [] + request_mock.headers = headers_mock + request_mock.scheme = "http" + + actual_query_string = _event_headers(request_mock, "3000") + self.assertEqual( + actual_query_string, + ( + {"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"}, + {"X-Forwarded-Proto": ["http"], "X-Forwarded-Port": ["3000"]}, + ), + ) + + def test_event_headers_with_non_empty_list(self): + request_mock = Mock() + headers_mock = Mock() + headers_mock.keys.return_value = ["Content-Type", "X-Test"] + headers_mock.get.side_effect = ["application/json", "Value"] + headers_mock.getlist.side_effect = [["application/json"], ["Value"]] + request_mock.headers = headers_mock + request_mock.scheme = "http" + + actual_query_string = _event_headers(request_mock, "3000") + self.assertEqual( + actual_query_string, + ( + { + "Content-Type": "application/json", + "X-Test": "Value", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "3000", + }, + { + "Content-Type": ["application/json"], + "X-Test": ["Value"], + "X-Forwarded-Proto": ["http"], + "X-Forwarded-Port": ["3000"], + }, + ), + ) + + def test_query_string_params_with_empty_params(self): + request_mock = Mock() + query_param_args_mock = Mock() + query_param_args_mock.lists.return_value = {}.items() + request_mock.args = query_param_args_mock + + actual_query_string = _query_string_params(request_mock) + self.assertEqual(actual_query_string, ({}, {})) + + def test_query_string_params_with_param_value_being_empty_list(self): + request_mock = Mock() + query_param_args_mock = Mock() + query_param_args_mock.lists.return_value = {"param": []}.items() + request_mock.args = query_param_args_mock + + actual_query_string = _query_string_params(request_mock) + self.assertEqual(actual_query_string, ({"param": ""}, {"param": [""]})) + + def test_query_string_params_with_param_value_being_non_empty_list(self): + request_mock = Mock() + query_param_args_mock = Mock() + query_param_args_mock.lists.return_value = {"param": ["a", "b"]}.items() + request_mock.args = query_param_args_mock + + actual_query_string = _query_string_params(request_mock) + self.assertEqual(actual_query_string, ({"param": "b"}, {"param": ["a", "b"]})) + + def test_query_string_params_v_2_0_with_param_value_being_non_empty_list(self): + request_mock = Mock() + query_param_args_mock = Mock() + query_param_args_mock.lists.return_value = {"param": ["a", "b"]}.items() + request_mock.args = query_param_args_mock + + actual_query_string = _query_string_params_v_2_0(request_mock) + self.assertEqual(actual_query_string, {"param": "a,b"}) + + +class TestService_construct_event_http(TestCase): + def setUp(self): + self.request_mock = Mock() + self.request_mock.endpoint = "endpoint" + self.request_mock.method = "GET" + self.request_mock.path = "/endpoint" + self.request_mock.get_data.return_value = b"DATA!!!!" + self.request_mock.mimetype = "application/json" + query_param_args_mock = Mock() + query_param_args_mock.lists.return_value = {"query": ["param1", "param2"]}.items() + self.request_mock.args = query_param_args_mock + self.request_mock.query_string = b"query=params" + headers_mock = Mock() + headers_mock.keys.return_value = ["Content-Type", "X-Test"] + headers_mock.get.side_effect = ["application/json", "Value"] + headers_mock.getlist.side_effect = [["application/json"], ["Value"]] + self.request_mock.headers = headers_mock + self.request_mock.remote_addr = "190.0.0.0" + self.request_mock.view_args = {"path": "params"} + self.request_mock.scheme = "http" + cookies_mock = Mock() + cookies_mock.keys.return_value = ["cookie1", "cookie2"] + cookies_mock.get.side_effect = ["test", "test"] + self.request_mock.cookies = cookies_mock + self.request_time_epoch = int(time()) + self.request_time = datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000") + + expected = f""" + {{ + "version": "2.0", + "routeKey": "GET /endpoint", + "rawPath": "/endpoint", + "rawQueryString": "query=params", + "cookies": ["cookie1=test", "cookie2=test"], + "headers": {{ + "Content-Type": "application/json", + "X-Test": "Value", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "3000" + }}, + "queryStringParameters": {{"query": "param1,param2"}}, + "requestContext": {{ + "accountId": "123456789012", + "apiId": "1234567890", + "domainName": "localhost", + "domainPrefix": "localhost", + "http": {{ + "method": "GET", + "path": "/endpoint", + "protocol": "HTTP/1.1", + "sourceIp": "190.0.0.0", + "userAgent": "Custom User Agent String" + }}, + "requestId": "", + "routeKey": "GET /endpoint", + "stage": "$default", + "time": \"{self.request_time}\", + "timeEpoch": {self.request_time_epoch} + }}, + "body": "DATA!!!!", + "pathParameters": {{"path": "params"}}, + "stageVariables": null, + "isBase64Encoded": false + }} + """ + + self.expected_dict = json.loads(expected) + + def test_construct_event_with_data(self): + actual_event_dict = construct_v2_event_http( + self.request_mock, + 3000, + binary_types=[], + route_key="GET /endpoint", + request_time_epoch=self.request_time_epoch, + request_time=self.request_time, + ) + self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) + actual_event_dict["requestContext"]["requestId"] = "" + self.assertEqual(actual_event_dict, self.expected_dict) + + def test_construct_event_no_data(self): + self.request_mock.get_data.return_value = None + self.expected_dict["body"] = None + + actual_event_dict = construct_v2_event_http( + self.request_mock, + 3000, + binary_types=[], + route_key="GET /endpoint", + request_time_epoch=self.request_time_epoch, + request_time=self.request_time, + ) + self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) + actual_event_dict["requestContext"]["requestId"] = "" + self.assertEqual(actual_event_dict, self.expected_dict) + + def test_v2_route_key(self): + route_key = LocalApigwService._v2_route_key("GET", "/path", False) + self.assertEqual(route_key, "GET /path") + + def test_v2_default_route_key(self): + route_key = LocalApigwService._v2_route_key("GET", "/path", True) + self.assertEqual(route_key, "$default") + + @patch("samcli.local.apigw.event_constructor._should_base64_encode") + def test_construct_event_with_binary_data(self, should_base64_encode_patch): + should_base64_encode_patch.return_value = True + + binary_body = b"011000100110100101101110011000010111001001111001" # binary in binary + base64_body = base64.b64encode(binary_body).decode("utf-8") + + self.request_mock.get_data.return_value = binary_body + self.expected_dict["body"] = base64_body + self.expected_dict["isBase64Encoded"] = True + self.maxDiff = None + + actual_event_dict = construct_v2_event_http( + self.request_mock, + 3000, + binary_types=[], + route_key="GET /endpoint", + request_time_epoch=self.request_time_epoch, + request_time=self.request_time, + ) + self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) + actual_event_dict["requestContext"]["requestId"] = "" + self.assertEqual(actual_event_dict, self.expected_dict) + + def test_event_headers_with_empty_list(self): + request_mock = Mock() + headers_mock = Mock() + headers_mock.keys.return_value = [] + request_mock.headers = headers_mock + request_mock.scheme = "http" + + actual_query_string = _event_http_headers(request_mock, "3000") + self.assertEqual(actual_query_string, {"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"}) + + def test_event_headers_with_non_empty_list(self): + request_mock = Mock() + headers_mock = Mock() + headers_mock.keys.return_value = ["Content-Type", "X-Test"] + headers_mock.get.side_effect = ["application/json", "Value"] + headers_mock.getlist.side_effect = [["application/json"], ["Value"]] + request_mock.headers = headers_mock + request_mock.scheme = "http" + + actual_query_string = _event_http_headers(request_mock, "3000") + self.assertEqual( + actual_query_string, + { + "Content-Type": "application/json", + "X-Test": "Value", + "X-Forwarded-Proto": "http", + "X-Forwarded-Port": "3000", + }, + ) + + +class TestService_should_base64_encode(TestCase): + @parameterized.expand( + [ + param("Mimeyype is in binary types", ["image/gif"], "image/gif"), + param("Mimetype defined and binary types has */*", ["*/*"], "image/gif"), + param("*/* is in binary types with no mimetype defined", ["*/*"], None), + ] + ) + def test_should_base64_encode_returns_true(self, test_case_name, binary_types, mimetype): + self.assertTrue(_should_base64_encode(binary_types, mimetype)) + + @parameterized.expand([param("Mimetype is not in binary types", ["image/gif"], "application/octet-stream")]) + def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): + self.assertFalse(_should_base64_encode(binary_types, mimetype)) diff --git a/tests/unit/local/apigw/test_lambda_authorizer.py b/tests/unit/local/apigw/test_lambda_authorizer.py new file mode 100644 index 0000000000..804a151521 --- /dev/null +++ b/tests/unit/local/apigw/test_lambda_authorizer.py @@ -0,0 +1,474 @@ +import json +from unittest import TestCase +from unittest.mock import Mock, patch +from parameterized import parameterized +from werkzeug.datastructures import Headers +from samcli.local.apigw.authorizers.lambda_authorizer import ( + ContextIdentitySource, + HeaderIdentitySource, + LambdaAuthorizer, + LambdaAuthorizerIAMPolicyValidator, + QueryIdentitySource, + StageVariableIdentitySource, +) +from samcli.local.apigw.exceptions import InvalidLambdaAuthorizerResponse, InvalidSecurityDefinition + + +class TestHeaderIdentitySource(TestCase): + def test_valid_header_identity_source(self): + id_source = "test" + header_id_source = HeaderIdentitySource(id_source) + + self.assertTrue(header_id_source.is_valid(**{"headers": Headers({id_source: 123})})) + + @parameterized.expand( + [ + ({"headers": Headers({})},), # test empty headers + ({},), # test no headers + ({"headers": Headers({"not here": 123})},), # test missing headers + ({"validation_expression": "^123$"},), # test no headers, but provided validation + ] + ) + def test_invalid_header_identity_source(self, sources_dict): + header_id_source = HeaderIdentitySource("test") + + self.assertFalse(header_id_source.is_valid(**sources_dict)) + + def test_validation_expression_passes(self): + id_source = "myheader" + args = {"headers": Headers({id_source: "123"}), "validation_expression": "^123$"} + + header_id_source = HeaderIdentitySource(id_source) + + self.assertTrue(header_id_source.is_valid(**args)) + + +class TestQueryIdentitySource(TestCase): + @parameterized.expand( + [ + ({"querystring": "foo=bar"}, "foo"), # test single pair + ({"querystring": "foo=bar&hello=world"}, "foo"), # test single pair + ] + ) + def test_valid_query_identity_source(self, sources_dict, id_source): + query_id_source = QueryIdentitySource(id_source) + + self.assertTrue(query_id_source.is_valid(**sources_dict)) + + @parameterized.expand( + [ + ({"querystring": ""}, "foo"), # test empty string + ({}, "foo"), # test missing string + ({"querystring": "hello=world"}, "foo"), # test nonexistant pair + ] + ) + def test_invalid_query_identity_source(self, sources_dict, id_source): + query_id_source = QueryIdentitySource(id_source) + + self.assertFalse(query_id_source.is_valid(**sources_dict)) + + +class TestContextIdentitySource(TestCase): + def test_valid_context_identity_source(self): + id_source = "test" + context_id_source = ContextIdentitySource(id_source) + + self.assertTrue(context_id_source.is_valid(**{"context": {id_source: 123}})) + + @parameterized.expand( + [ + ({"context": {}}, "test"), # test empty context + ({}, "test"), # test no context + ({"headers": {"not here": 123}}, "test"), # test missing context + ] + ) + def test_invalid_context_identity_source(self, sources_dict, id_source): + context_id_source = ContextIdentitySource(id_source) + + self.assertFalse(context_id_source.is_valid(**sources_dict)) + + +class TestStageVariableIdentitySource(TestCase): + def test_valid_stage_identity_source(self): + id_source = "test" + stage_id_source = StageVariableIdentitySource(id_source) + + self.assertTrue(stage_id_source.is_valid(**{"stageVariables": {id_source: 123}})) + + @parameterized.expand( + [ + ({"stageVariables": {}}, "test"), # test empty stageVariables + ({}, "test"), # test no stageVariables + ({"stageVariables": {"not here": 123}}, "test"), # test missing stageVariables + ] + ) + def test_invalid_stage_identity_source(self, sources_dict, id_source): + stage_id_source = StageVariableIdentitySource(id_source) + + self.assertFalse(stage_id_source.is_valid(**sources_dict)) + + +class TestLambdaAuthorizer(TestCase): + def test_parse_identity_sources(self): + identity_sources = [ + "method.request.header.v1header", + "$request.header.v2header", + "method.request.querystring.v1query", + "$request.querystring.v2query", + "context.v1context", + "$context.v2context", + "stageVariables.v1stage", + "$stageVariables.v2stage", + ] + + expected_sources = [ + HeaderIdentitySource("v1header"), + HeaderIdentitySource("v2header"), + QueryIdentitySource("v1query"), + QueryIdentitySource("v2query"), + ContextIdentitySource("v1context"), + ContextIdentitySource("v2context"), + StageVariableIdentitySource("v1stage"), + StageVariableIdentitySource("v2stage"), + ] + + lambda_auth = LambdaAuthorizer( + authorizer_name="auth_name", + type="type", + lambda_name="lambda_name", + identity_sources=identity_sources, + payload_version="version", + validation_string="string", + use_simple_response=True, + ) + + self.assertEqual(sorted(lambda_auth._identity_sources_raw), sorted(identity_sources)) + self.assertEqual(lambda_auth.identity_sources[0], expected_sources[0]) + + def test_parse_invalid_identity_sources_raises(self): + identity_sources = ["this is invalid"] + + with self.assertRaises(InvalidSecurityDefinition): + LambdaAuthorizer( + authorizer_name="auth_name", + type="type", + lambda_name="lambda_name", + identity_sources=identity_sources, + payload_version="version", + validation_string="string", + use_simple_response=True, + ) + + def test_response_validator_raises_exception(self): + auth_name = "my auth" + + with self.assertRaises(InvalidLambdaAuthorizerResponse): + LambdaAuthorizer( + auth_name, + Mock(), + Mock(), + [], + Mock(), + Mock(), + Mock(), + ).is_valid_response("not a valid json string", Mock()) + + @patch.object(LambdaAuthorizer, "_validate_simple_response") + @patch.object(LambdaAuthorizer, "_is_resource_authorized") + def test_response_validator_calls_simple_response(self, resource_mock, simple_mock): + LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + LambdaAuthorizer.PAYLOAD_V2, + Mock(), + True, + ).is_valid_response("{}", Mock()) + + resource_mock.assert_not_called() + simple_mock.assert_called_once() + + @parameterized.expand( + [ + ( # authorizer v2, but not using simple response + LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + LambdaAuthorizer.PAYLOAD_V2, + Mock(), + False, + ), + ), + ( # authorizer v1 + LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + LambdaAuthorizer.PAYLOAD_V1, + Mock(), + False, + ), + ), + ] + ) + @patch.object(LambdaAuthorizer, "_validate_simple_response") + @patch.object(LambdaAuthorizer, "_is_resource_authorized") + @patch.object(LambdaAuthorizerIAMPolicyValidator, "validate_policy_document") + @patch.object(LambdaAuthorizerIAMPolicyValidator, "validate_statement") + def test_response_validator_calls_is_resource_authorized( + self, validate_policy_mock, validate_statement_mock, lambda_auth, resource_mock, simple_mock + ): + LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + LambdaAuthorizer.PAYLOAD_V1, + Mock(), + False, + ).is_valid_response("{}", Mock()) + + resource_mock.assert_called_once() + simple_mock.assert_not_called() + + @parameterized.expand([({"missing": "key"},), ({"isAuthorized": "suppose to be bool"},)]) + def test_validate_simple_response_raises(self, input): + with self.assertRaises(InvalidLambdaAuthorizerResponse): + LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + Mock(), + Mock(), + Mock(), + )._validate_simple_response(input) + + def test_validate_simple_response(self): + result = LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + Mock(), + Mock(), + Mock(), + )._validate_simple_response({"isAuthorized": True}) + + self.assertTrue(result) + + def test_get_context(self): + context = {"key": "value"} + principal_id = "123" + + input = {"context": context, "principalId": principal_id} + + expected = context.copy() + expected["principalId"] = principal_id + + result = LambdaAuthorizer(Mock(), Mock(), Mock(), [], Mock()).get_context(json.dumps(input)) + + self.assertEqual(result, expected) + + @parameterized.expand( + [ + (json.dumps([]),), + ("not valid json",), + (json.dumps({"context": "not dict"}),), + ] + ) + def test_get_context_raises_exception(self, input): + with self.assertRaises(InvalidLambdaAuthorizerResponse): + LambdaAuthorizer("myauth", Mock(), Mock(), [], Mock()).get_context(json.dumps(input)) + + @parameterized.expand( + [ + ( # deny effect + { + "principalId": "123", + "policyDocument": { + "Statement": [{"Action": "execute-api:Invoke", "Effect": "Deny", "Resource": [""]}] + }, + }, + False, + ), + ( # wrong action + { + "principalId": "123", + "policyDocument": {"Statement": [{"Action": "hello world", "Effect": "Deny", "Resource": [""]}]}, + }, + False, + ), + ( # missing arn resource match + { + "principalId": "123", + "policyDocument": { + "Statement": [{"Action": "execute-api:Invoke", "Effect": "Allow", "Resource": ["not the arn"]}] + }, + }, + False, + ), + ( # match wildcard same part + { + "principalId": "123", + "policyDocument": { + "Statement": [ + { + "Action": "execute-api:Invoke", + "Effect": "Allow", + "Resource": ["arn:aws:execute-api:us-east-1:123456789012:1234567890/prod/GET/hel*"], + } + ] + }, + }, + True, + ), + ( # match wildcard any + { + "principalId": "123", + "policyDocument": { + "Statement": [ + { + "Action": "execute-api:Invoke", + "Effect": "Allow", + "Resource": ["arn:aws:execute-api:us-east-1:123456789012:1234567890/prod/GET/*"], + } + ] + }, + }, + True, + ), + ( # match wildcard any path, any method + { + "principalId": "123", + "policyDocument": { + "Statement": [ + { + "Action": "execute-api:Invoke", + "Effect": "Allow", + "Resource": ["arn:aws:execute-api:us-east-1:123456789012:1234567890/prod/*/*"], + } + ] + }, + }, + True, + ), + ( # fail match wildcard second part + { + "principalId": "123", + "policyDocument": { + "Statement": [ + { + "Action": "execute-api:Invoke", + "Effect": "Allow", + "Resource": ["arn:aws:execute-api:us-east-1:123456789012:1234567890/prod/GET/hello/*"], + } + ] + }, + }, + False, + ), + ( # fail match single random character + { + "principalId": "123", + "policyDocument": { + "Statement": [ + { + "Action": "execute-api:Invoke", + "Effect": "Allow", + "Resource": ["arn:aws:execute-api:us-east-1:123456789012:1234567890/prod/GET/he?lo"], + } + ] + }, + }, + True, + ), + ] + ) + def test_validate_is_resource_authorized(self, response, expected_result): + method_arn = "arn:aws:execute-api:us-east-1:123456789012:1234567890/prod/GET/hello" + + auth = LambdaAuthorizer( + "my auth", + Mock(), + Mock(), + [], + Mock(), + Mock(), + Mock(), + ) + + result = auth._is_resource_authorized(response, method_arn) + + self.assertEqual(result, expected_result) + + +class TestLambdaAuthorizerIamPolicyValidator(TestCase): + @parameterized.expand( + [ + ( # missing principalId + {}, + "Authorizer 'my auth' contains an invalid or missing 'principalId' from response", + ), + ( # missing policyDocument + {"principalId": "123"}, + "Authorizer 'my auth' contains an invalid or missing 'policyDocument' from response", + ), + ( # policyDocument not dict + {"principalId": "123", "policyDocument": "not list"}, + "Authorizer 'my auth' contains an invalid or missing 'policyDocument' from response", + ), + ] + ) + def test_validate_validate_policy_document_raises(self, response, message): + with self.assertRaisesRegex(InvalidLambdaAuthorizerResponse, message): + LambdaAuthorizerIAMPolicyValidator.validate_policy_document("my auth", response) + + @parameterized.expand( + [ + ( # policyDocument empty + {"principalId": "123", "policyDocument": {}}, + "Authorizer 'my auth' contains an invalid or missing 'Statement' from response", + ), + ( # missing statement + {"principalId": "123", "policyDocument": {"missing": "statement"}}, + "Authorizer 'my auth' contains an invalid or missing 'Statement'", + ), + ( # statement not list + {"principalId": "123", "policyDocument": {"Statement": "statement"}}, + "Authorizer 'my auth' contains an invalid or missing 'Statement'", + ), + ( # statement empty + {"principalId": "123", "policyDocument": {"Statement": []}}, + "Authorizer 'my auth' contains an invalid or missing 'Statement'", + ), + ( # statement not an object + {"principalId": "123", "policyDocument": {"Statement": ["string"]}}, + "Authorizer 'my auth' policy document must be a list of object", + ), + ( # statement missing action + {"principalId": "123", "policyDocument": {"Statement": [{"no action": "123"}]}}, + "Authorizer 'my auth' policy document contains an invalid 'Action'", + ), + ( # statement missing effect + {"principalId": "123", "policyDocument": {"Statement": [{"Action": "execute-api:Invoke"}]}}, + "Authorizer 'my auth' policy document contains an invalid 'Effect'", + ), + ( # statement resource not a list + { + "principalId": "123", + "policyDocument": { + "Statement": [{"Action": "execute-api:Invoke", "Effect": "Allow", "Resource": "not list"}] + }, + }, + "Authorizer 'my auth' policy document contains an invalid 'Resource'", + ), + ] + ) + def test_validate_validate_statement_raises(self, response, message): + with self.assertRaisesRegex(InvalidLambdaAuthorizerResponse, message): + LambdaAuthorizerIAMPolicyValidator.validate_statement("my auth", response) diff --git a/tests/unit/local/apigw/test_local_apigw_service.py b/tests/unit/local/apigw/test_local_apigw_service.py index 89a62fa58f..1278c4bb30 100644 --- a/tests/unit/local/apigw/test_local_apigw_service.py +++ b/tests/unit/local/apigw/test_local_apigw_service.py @@ -1,8 +1,7 @@ import base64 import copy import json -from time import time -from datetime import datetime +import flask from unittest import TestCase from unittest.mock import Mock, patch, ANY, MagicMock @@ -11,12 +10,19 @@ from samcli.lib.providers.provider import Api from samcli.lib.providers.provider import Cors +from samcli.lib.telemetry.event import EventName, EventTracker, UsedFeature +from samcli.local.apigw.event_constructor import construct_v1_event, construct_v2_event_http +from samcli.local.apigw.authorizers.lambda_authorizer import LambdaAuthorizer +from samcli.local.apigw.route import Route from samcli.local.apigw.local_apigw_service import ( LocalApigwService, - Route, + CatchAllPathConverter, +) +from samcli.local.apigw.exceptions import ( + AuthorizerUnauthorizedRequest, + InvalidSecurityDefinition, LambdaResponseParseException, PayloadFormatVersionValidateException, - CatchAllPathConverter, ) from samcli.local.lambdafn.exceptions import FunctionNotFound from samcli.commands.local.lib.exceptions import UnsupportedInlineCodeError @@ -72,15 +78,19 @@ def setUp(self): self.stderr = Mock() self.api = Api(routes=self.api_list_of_routes) self.http = Api(routes=self.http_list_of_routes) + self.api_service = LocalApigwService( self.api, self.lambda_runner, port=3000, host="127.0.0.1", stderr=self.stderr ) + self.http_service = LocalApigwService( self.http, self.lambda_runner, port=3000, host="127.0.0.1", stderr=self.stderr ) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_api_request_must_invoke_lambda(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + def test_api_request_must_invoke_lambda(self, v2_event_mock, v1_event_mock, request_mock): make_response_mock = Mock() self.api_service.service_response = make_response_mock @@ -88,7 +98,7 @@ def test_api_request_must_invoke_lambda(self, request_mock): self.api_service._get_current_route.return_value = self.api_gateway_route self.api_service._get_current_route.methods = [] self.api_service._get_current_route.return_value.payload_format_version = "2.0" - self.api_service._construct_v_1_0_event = Mock() + v1_event_mock.return_value = {} parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -104,19 +114,27 @@ def test_api_request_must_invoke_lambda(self, request_mock): self.assertEqual(result, make_response_mock) self.lambda_runner.invoke.assert_called_with(ANY, ANY, stdout=ANY, stderr=self.stderr) - self.api_service._construct_v_1_0_event.assert_called_with(ANY, ANY, ANY, ANY, ANY, "getRestApi") + v1_event_mock.assert_called_with( + flask_request=ANY, + port=ANY, + binary_types=ANY, + stage_name=ANY, + stage_variables=ANY, + operation_name="getRestApi", + ) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_http_request_must_invoke_lambda(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + def test_http_request_must_invoke_lambda(self, v2_event_mock, v1_event_mock, request_mock): make_response_mock = Mock() self.http_service.service_response = make_response_mock self.http_service._get_current_route = Mock() self.http_service._get_current_route.return_value = self.http_gateway_route self.http_service._get_current_route.methods = [] - self.http_service._construct_v_1_0_event = Mock() - self.http_service._construct_v_2_0_event_http = MagicMock() + v2_event_mock.return_value = {} parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -132,19 +150,22 @@ def test_http_request_must_invoke_lambda(self, request_mock): self.assertEqual(result, make_response_mock) self.lambda_runner.invoke.assert_called_with(ANY, ANY, stdout=ANY, stderr=self.stderr) - self.http_service._construct_v_2_0_event_http.assert_called_with(ANY, ANY, ANY, ANY, ANY, ANY) + v2_event_mock.assert_called_with( + flask_request=ANY, port=ANY, binary_types=ANY, stage_name=ANY, stage_variables=ANY, route_key="test test" + ) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_http_v1_payload_request_must_invoke_lambda(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + def test_http_v1_payload_request_must_invoke_lambda(self, v2_event_mock, v1_event_mock, request_mock): make_response_mock = Mock() self.http_service.service_response = make_response_mock self.http_service._get_current_route = Mock() self.http_service._get_current_route.return_value = self.http_v1_payload_route self.http_service._get_current_route.methods = [] - self.http_service._construct_v_1_0_event = Mock() - self.http_service._construct_v_2_0_event_http = MagicMock() + v1_event_mock.return_value = {} parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -160,19 +181,27 @@ def test_http_v1_payload_request_must_invoke_lambda(self, request_mock): self.assertEqual(result, make_response_mock) self.lambda_runner.invoke.assert_called_with(ANY, ANY, stdout=ANY, stderr=self.stderr) - self.http_service._construct_v_1_0_event.assert_called_with(ANY, ANY, ANY, ANY, ANY, None) + v1_event_mock.assert_called_with( + flask_request=ANY, + port=ANY, + binary_types=ANY, + stage_name=ANY, + stage_variables=ANY, + operation_name=None, + ) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_http_v2_payload_request_must_invoke_lambda(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + def test_http_v2_payload_request_must_invoke_lambda(self, v2_event_mock, v1_event_mock, request_mock): make_response_mock = Mock() self.http_service.service_response = make_response_mock self.http_service._get_current_route = Mock() self.http_service._get_current_route.return_value = self.http_v2_payload_route self.http_service._get_current_route.methods = [] - self.http_service._construct_v_1_0_event = Mock() - self.http_service._construct_v_2_0_event_http = MagicMock() + v2_event_mock.return_value = {} parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -188,17 +217,21 @@ def test_http_v2_payload_request_must_invoke_lambda(self, request_mock): self.assertEqual(result, make_response_mock) self.lambda_runner.invoke.assert_called_with(ANY, ANY, stdout=ANY, stderr=self.stderr) - self.http_service._construct_v_2_0_event_http.assert_called_with(ANY, ANY, ANY, ANY, ANY, ANY) + v2_event_mock.assert_called_with( + flask_request=ANY, port=ANY, binary_types=ANY, stage_name=ANY, stage_variables=ANY, route_key="test test" + ) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_api_options_request_must_invoke_lambda(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") + def test_api_options_request_must_invoke_lambda(self, generate_mock, request_mock): + generate_mock.return_value = {} make_response_mock = Mock() self.api_service.service_response = make_response_mock self.api_service._get_current_route = MagicMock() self.api_service._get_current_route.return_value.methods = ["OPTIONS"] self.api_service._get_current_route.return_value.payload_format_version = "1.0" - self.api_service._construct_v_1_0_event = Mock() + self.api_service._get_current_route.return_value.authorizer_object = None parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -216,14 +249,16 @@ def test_api_options_request_must_invoke_lambda(self, request_mock): self.lambda_runner.invoke.assert_called_with(ANY, ANY, stdout=ANY, stderr=self.stderr) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_http_options_request_must_invoke_lambda(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") + def test_http_options_request_must_invoke_lambda(self, generate_mock, request_mock): + generate_mock.return_value = {} make_response_mock = Mock() self.http_service.service_response = make_response_mock self.http_service._get_current_route = MagicMock() self.http_service._get_current_route.return_value.methods = ["OPTIONS"] self.http_service._get_current_route.return_value.payload_format_version = "1.0" - self.http_service._construct_v_1_0_event = Mock() + self.http_service._get_current_route.return_value.authorizer_object = None parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -242,19 +277,22 @@ def test_http_options_request_must_invoke_lambda(self, request_mock): @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch("samcli.local.apigw.local_apigw_service.LambdaOutputParser") - def test_request_handler_returns_process_stdout_when_making_response(self, lambda_output_parser_mock, request_mock): + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") + def test_request_handler_returns_process_stdout_when_making_response( + self, generate_mock, lambda_output_parser_mock, request_mock + ): + generate_mock.return_value = {} make_response_mock = Mock() request_mock.return_value = ("test", "test") self.api_service.service_response = make_response_mock current_route = Mock() current_route.payload_format_version = "2.0" + current_route.authorizer_object = None self.api_service._get_current_route = MagicMock() self.api_service._get_current_route.return_value = current_route current_route.methods = [] current_route.event_type = Route.API - self.api_service._construct_v_1_0_event = Mock() - parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") self.api_service._parse_v1_payload_format_lambda_output = parse_output_mock @@ -275,14 +313,16 @@ def test_request_handler_returns_process_stdout_when_making_response(self, lambd parse_output_mock.assert_called_with(lambda_response, ANY, ANY, Route.API) @patch.object(LocalApigwService, "get_request_methods_endpoints") - def test_request_handler_returns_make_response(self, request_mock): + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") + def test_request_handler_returns_make_response(self, generate_mock, request_mock): + generate_mock.return_value = {} make_response_mock = Mock() self.api_service.service_response = make_response_mock self.api_service._get_current_route = MagicMock() - self.api_service._construct_v_1_0_event = Mock() self.api_service._get_current_route.methods = [] self.api_service._get_current_route.return_value.payload_format_version = "1.0" + self.api_service._get_current_route.return_value.authorizer_object = None parse_output_mock = Mock() parse_output_mock.return_value = ("status_code", Headers({"headers": "headers"}), "body") @@ -373,11 +413,15 @@ def test_initalize_with_values(self): @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") - def test_request_handles_error_when_invoke_cant_find_function(self, service_error_responses_patch, request_mock): + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") + def test_request_handles_error_when_invoke_cant_find_function( + self, generate_mock, service_error_responses_patch, request_mock + ): + generate_mock.return_value = {} not_found_response_mock = Mock() - self.api_service._construct_v_1_0_event = Mock() self.api_service._get_current_route = MagicMock() self.api_service._get_current_route.return_value.payload_format_version = "2.0" + self.api_service._get_current_route.return_value.authorizer_object = None self.api_service._get_current_route.methods = [] service_error_responses_patch.lambda_not_found_response.return_value = not_found_response_mock @@ -390,13 +434,15 @@ def test_request_handles_error_when_invoke_cant_find_function(self, service_erro @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") def test_request_handles_error_when_invoke_function_with_inline_code( - self, service_error_responses_patch, request_mock + self, generate_mock, service_error_responses_patch, request_mock ): + generate_mock.return_value = {} not_implemented_response_mock = Mock() - self.api_service._construct_v_1_0_event = Mock() self.api_service._get_current_route = MagicMock() self.api_service._get_current_route.return_value.payload_format_version = "2.0" + self.api_service._get_current_route.return_value.authorizer_object = None self.api_service._get_current_route.methods = [] service_error_responses_patch.not_implemented_locally.return_value = not_implemented_response_mock @@ -411,7 +457,6 @@ def test_request_handles_error_when_invoke_function_with_inline_code( def test_request_throws_when_invoke_fails(self, request_mock): self.lambda_runner.invoke.side_effect = Exception() - self.api_service._construct_v_1_0_event = Mock() self.api_service._get_current_route = Mock() request_mock.return_value = ("test", "test") @@ -420,9 +465,11 @@ def test_request_throws_when_invoke_fails(self, request_mock): @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") def test_request_handler_errors_when_parse_lambda_output_raises_keyerror( - self, service_error_responses_patch, request_mock + self, generate_mock, service_error_responses_patch, request_mock ): + generate_mock.return_value = {} parse_output_mock = Mock() parse_output_mock.side_effect = LambdaResponseParseException() self.api_service._parse_v1_payload_format_lambda_output = parse_output_mock @@ -431,10 +478,10 @@ def test_request_handler_errors_when_parse_lambda_output_raises_keyerror( service_error_responses_patch.lambda_failure_response.return_value = failure_response_mock - self.api_service._construct_v_1_0_event = Mock() self.api_service._get_current_route = MagicMock() self.api_service._get_current_route.methods = [] self.api_service._get_current_route.return_value.payload_format_version = "1.0" + self.api_service._get_current_route.return_value.authorizer_object = None request_mock.return_value = ("test", "test") result = self.api_service._request_handler() @@ -452,14 +499,17 @@ def test_request_handler_errors_when_get_current_route_fails(self, service_error @patch.object(LocalApigwService, "get_request_methods_endpoints") @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") - def test_request_handler_errors_when_unable_to_read_binary_data(self, service_error_responses_patch, request_mock): + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._generate_lambda_event") + def test_request_handler_errors_when_unable_to_read_binary_data( + self, generate_mock, service_error_responses_patch, request_mock + ): + generate_mock.return_value = {} _construct_event = Mock() _construct_event.side_effect = UnicodeDecodeError("utf8", b"obj", 1, 2, "reason") self.api_service._get_current_route = MagicMock() self.api_service._get_current_route.methods = [] self.api_service._get_current_route.return_value.payload_format_version = "1.0" - - self.api_service._construct_v_1_0_event = _construct_event + self.api_service._get_current_route.return_value.authorizer_object = None failure_mock = Mock() service_error_responses_patch.lambda_failure_response.return_value = failure_mock @@ -507,6 +557,410 @@ def test_get_current_route_keyerror(self): with self.assertRaises(KeyError): self.api_service._get_current_route(request_mock) + @patch.object(LocalApigwService, "get_request_methods_endpoints") + @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") + @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._valid_identity_sources") + def test_request_contains_lambda_auth_missing_identity_sources( + self, validate_id_mock, service_error_mock, request_mock + ): + route = self.api_gateway_route + route.authorizer_object = LambdaAuthorizer("", "", "", [], "") + + self.api_service._get_current_route = MagicMock() + self.api_service._get_current_route.return_value = route + self.api_service._get_current_route.methods = [] + self.api_service._get_current_route.return_value.payload_format_version = "2.0" + + mocked_missing_lambda_auth_id = Mock() + service_error_mock.missing_lambda_auth_identity_sources.return_value = mocked_missing_lambda_auth_id + + request_mock.return_value = ("test", "test") + + validate_id_mock.return_value = False + + result = self.api_service._request_handler() + + self.assertEqual(result, mocked_missing_lambda_auth_id) + + def test_valid_identity_sources_not_lambda_auth(self): + route = self.api_gateway_route + route.authorizer_object = None + + self.assertFalse(self.api_service._valid_identity_sources(route)) + + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + @patch("samcli.local.apigw.authorizers.lambda_authorizer.LambdaAuthorizer._parse_identity_sources") + @patch("samcli.local.apigw.authorizers.lambda_authorizer.LambdaAuthorizer.identity_sources") + @patch("samcli.local.apigw.path_converter.PathConverter.convert_path_to_api_gateway") + def test_valid_identity_sources_id_source( + self, is_valid, path_convert_mock, id_source_prop_mock, lambda_auth_parse_mock + ): + route = self.api_gateway_route + route.authorizer_object = LambdaAuthorizer("", "", "", [], "") + + mocked_id_source_obj = Mock() + mocked_id_source_obj.is_valid = Mock(return_value=is_valid) + route.authorizer_object.identity_sources = [mocked_id_source_obj] + + # create a dummy Flask app to populate the request object with testing data + # using Flask's dummy values for request is fine in this context since + # the variables are being passed and not validated + with flask.Flask(__name__).test_request_context(): + self.assertEqual(self.api_service._valid_identity_sources(route), is_valid) + + @patch.object(LocalApigwService, "get_request_methods_endpoints") + def test_create_method_arn(self, method_endpoint_mock): + method_endpoint_mock.return_value = ("method", "/endpoint") + + expected_method_arn = "arn:aws:execute-api:us-east-1:123456789012:1234567890/None/method/endpoint" + + self.assertEqual(self.api_service._create_method_arn(Mock(), Route.API), expected_method_arn) + + @patch.object(LocalApigwService, "_create_method_arn") + def test_generate_lambda_token_authorizer_event_invalid_identity_source(self, method_arn_mock): + method_arn_mock.return_value = "arn" + + authorizer_object = LambdaAuthorizer("", "", "", [], "") + authorizer_object.identity_sources = [] + + with self.assertRaises(InvalidSecurityDefinition): + self.api_service._generate_lambda_token_authorizer_event(Mock(), self.api_gateway_route, authorizer_object) + + @patch.object(LocalApigwService, "_create_method_arn") + def test_generate_lambda_token_authorizer_event(self, method_arn_mock): + method_arn_mock.return_value = "arn" + + authorizer_object = LambdaAuthorizer("", "", "", [], "") + mocked_id_source_obj = Mock() + mocked_id_source_obj.find_identity_value = Mock(return_value="123") + authorizer_object._identity_sources = [mocked_id_source_obj] + + result = self.api_service._generate_lambda_token_authorizer_event( + Mock(), self.api_gateway_route, authorizer_object + ) + + self.assertEqual( + result, + { + "type": "TOKEN", + "authorizationToken": "123", + "methodArn": "arn", + }, + ) + + @parameterized.expand( + [ + ( + LambdaAuthorizer.PAYLOAD_V2, + ["value1", "value2"], + "arn", + {"identitySource": ["value1", "value2"], "routeArn": "arn"}, + ), + ( + LambdaAuthorizer.PAYLOAD_V1, + ["value1", "value2"], + "arn", + { + "identitySource": "value1,value2", + "authorizationToken": "value1,value2", + "methodArn": "arn", + }, + ), + ] + ) + def test_generate_lambda_request_authorizer_event_http(self, payload, id_values, arn, expected_output): + result = self.api_service._generate_lambda_request_authorizer_event_http(payload, id_values, arn) + + self.assertEqual(result, expected_output) + + @patch.object(LocalApigwService, "get_request_methods_endpoints") + @patch.object(LocalApigwService, "_create_method_arn") + @patch.object(LocalApigwService, "_generate_lambda_event") + @patch.object(LocalApigwService, "_build_v1_context") + @patch.object(LocalApigwService, "_build_v2_context") + @patch.object(LocalApigwService, "_generate_lambda_request_authorizer_event_http") + def test_generate_lambda_request_authorizer_event_http_request( + self, + generate_lambda_auth_http_mock, + build_v2_mock, + build_v1_mock, + generate_lambda_mock, + method_arn_mock, + method_endpoints_mock, + ): + original = {"existing": "value"} + payload_version = "2.0" + method_arn = "arn" + + method_arn_mock.return_value = method_arn + method_endpoints_mock.return_value = ("method", "endpoint") + generate_lambda_mock.return_value = original + build_v2_mock.return_value = {} + build_v1_mock.return_value = {} + + authorizer_object = LambdaAuthorizer("", "", "", [], payload_version) + mocked_id_source_obj = Mock() + mocked_id_source_obj.find_identity_value = Mock(return_value="123") + mocked_id_source_obj2 = Mock() + mocked_id_source_obj2.find_identity_value = Mock(return_value="abc") + authorizer_object._identity_sources = [mocked_id_source_obj, mocked_id_source_obj2] + + self.api_service._generate_lambda_request_authorizer_event(Mock(), self.http_gateway_route, authorizer_object) + + generate_lambda_auth_http_mock.assert_called_with(payload_version, ["123", "abc"], method_arn) + + @patch.object(LocalApigwService, "get_request_methods_endpoints") + @patch.object(LocalApigwService, "_create_method_arn") + @patch.object(LocalApigwService, "_generate_lambda_event") + @patch.object(LocalApigwService, "_build_v1_context") + @patch.object(LocalApigwService, "_build_v2_context") + @patch.object(LocalApigwService, "_generate_lambda_request_authorizer_event_http") + def test_generate_lambda_request_authorizer_event_api( + self, + generate_lambda_auth_http_mock, + build_v2_mock, + build_v1_mock, + generate_lambda_mock, + method_arn_mock, + method_endpoints_mock, + ): + payload_version = "1.0" + method_arn = "arn" + original = {"existing": "value"} + + method_arn_mock.return_value = method_arn + method_endpoints_mock.return_value = ("method", "endpoint") + generate_lambda_mock.return_value = original + build_v2_mock.return_value = {} + build_v1_mock.return_value = {} + + authorizer_object = LambdaAuthorizer("", "", "", [], payload_version) + + result = self.api_service._generate_lambda_request_authorizer_event( + Mock(), self.api_gateway_route, authorizer_object + ) + + original.update({"methodArn": method_arn, "type": "REQUEST"}) + + self.assertEqual(result, original) + generate_lambda_auth_http_mock.assert_not_called() + + @patch.object(LocalApigwService, "_generate_lambda_token_authorizer_event") + @patch.object(LocalApigwService, "_generate_lambda_request_authorizer_event") + def test_generate_lambda_authorizer_event_token(self, request_mock, token_mock): + token_auth = LambdaAuthorizer(Mock(), LambdaAuthorizer.TOKEN, Mock(), [], Mock()) + + token_mock.return_value = {} + request_mock.return_value = {} + + self.api_service._generate_lambda_authorizer_event(Mock(), Mock(), token_auth) + token_mock.assert_called() + request_mock.assert_not_called() + + @patch.object(LocalApigwService, "_generate_lambda_token_authorizer_event") + @patch.object(LocalApigwService, "_generate_lambda_request_authorizer_event") + def test_generate_lambda_authorizer_event_request(self, request_mock, token_mock): + request_auth = LambdaAuthorizer(Mock(), LambdaAuthorizer.REQUEST, Mock(), [], Mock()) + + token_mock.return_value = {} + request_mock.return_value = {} + + self.api_service._generate_lambda_authorizer_event(Mock(), Mock(), request_auth) + token_mock.assert_not_called() + request_mock.assert_called() + + @patch.object(LocalApigwService, "get_request_methods_endpoints") + @patch.object(LocalApigwService, "_generate_lambda_authorizer_event") + @patch.object(LocalApigwService, "_valid_identity_sources") + @patch.object(LocalApigwService, "_invoke_lambda_function") + @patch.object(LocalApigwService, "_invoke_parse_lambda_authorizer") + @patch.object(EventTracker, "track_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + def test_lambda_auth_called( + self, + v2_event_mock, + v1_event_mock, + track_mock, + lambda_invoke_mock, + invoke_mock, + validate_id_mock, + gen_auth_event_mock, + request_mock, + ): + make_response_mock = Mock() + validate_id_mock.return_value = True + + # create mock authorizer + auth = LambdaAuthorizer(Mock(), Mock(), "auth_lambda", [], Mock(), Mock(), Mock()) + auth.is_valid_response = Mock(return_value=True) + auth.get_context = Mock(return_value={}) + self.api_gateway_route.authorizer_object = auth + + # get api service to return mocked route containing authorizer + self.api_service.service_response = make_response_mock + self.api_service._get_current_route = MagicMock() + self.api_service._get_current_route.return_value = self.api_gateway_route + self.api_service._get_current_route.methods = [] + self.api_service._get_current_route.return_value.payload_format_version = "2.0" + v1_event_mock.return_value = {} + + parse_output_mock = Mock(return_value=("status_code", Headers({"headers": "headers"}), "body")) + self.api_service._parse_v1_payload_format_lambda_output = parse_output_mock + + service_response_mock = Mock(return_value=make_response_mock) + self.api_service.service_response = service_response_mock + + request_mock.return_value = ("test", "test") + + self.api_service._request_handler() + + # successful invoke + self.api_service._invoke_parse_lambda_authorizer.assert_called_with(auth, ANY, ANY, self.api_gateway_route) + + @patch.object(LocalApigwService, "get_request_methods_endpoints") + @patch.object(LocalApigwService, "_generate_lambda_authorizer_event") + @patch.object(LocalApigwService, "_valid_identity_sources") + @patch.object(LocalApigwService, "_invoke_lambda_function") + @patch.object(LocalApigwService, "_invoke_parse_lambda_authorizer") + @patch.object(EventTracker, "track_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") + def test_lambda_invoke_track_event_exception( + self, + service_mock, + v2_event_mock, + v1_event_mock, + track_mock, + lambda_invoke_mock, + invoke_mock, + validate_id_mock, + gen_auth_event_mock, + request_mock, + ): + make_response_mock = Mock() + validate_id_mock.return_value = True + + # create mock authorizer + auth = LambdaAuthorizer(Mock(), Mock(), "auth_lambda", [], Mock(), Mock(), Mock()) + auth.is_valid_response = Mock(return_value=True) + auth.get_context = Mock(return_value={}) + self.api_gateway_route.authorizer_object = auth + + # get api service to return mocked route containing authorizer + self.api_service.service_response = make_response_mock + self.api_service._get_current_route = MagicMock() + self.api_service._get_current_route.return_value = self.api_gateway_route + self.api_service._get_current_route.methods = [] + self.api_service._get_current_route.return_value.payload_format_version = "2.0" + v1_event_mock.return_value = {} + + parse_output_mock = Mock(return_value=("status_code", Headers({"headers": "headers"}), "body")) + self.api_service._parse_v1_payload_format_lambda_output = parse_output_mock + + service_response_mock = Mock(return_value=make_response_mock) + self.api_service.service_response = service_response_mock + + request_mock.return_value = ("test", "test") + + lambda_invoke_mock.side_effect = AuthorizerUnauthorizedRequest("msg") + service_mock.lambda_authorizer_unauthorized = Mock() + + self.api_service._request_handler() + + track_mock.assert_called_with( + event_name=EventName.USED_FEATURE.value, + event_value=UsedFeature.INVOKED_CUSTOM_LAMBDA_AUTHORIZERS.value, + session_id=ANY, + exception_name=AuthorizerUnauthorizedRequest.__name__, + ) + + @patch.object(LocalApigwService, "get_request_methods_endpoints") + @patch.object(LocalApigwService, "_generate_lambda_authorizer_event") + @patch.object(LocalApigwService, "_valid_identity_sources") + @patch.object(LocalApigwService, "_invoke_lambda_function") + @patch("samcli.local.apigw.local_apigw_service.ServiceErrorResponses") + @patch("samcli.local.apigw.local_apigw_service.construct_v1_event") + @patch("samcli.local.apigw.local_apigw_service.construct_v2_event_http") + def test_lambda_auth_unauthorized_response( + self, + v2_event_mock, + v1_event_mock, + service_err_mock, + invoke_mock, + validate_id_mock, + gen_auth_event_mock, + request_mock, + ): + make_response_mock = Mock() + validate_id_mock.return_value = True + + # create mock authorizer + auth = LambdaAuthorizer(Mock(), Mock(), "auth_lambda", [], Mock(), Mock(), Mock()) + auth.is_valid_response = Mock(return_value=False) + self.api_gateway_route.authorizer_object = auth + + # get api service to return mocked route containing authorizer + self.api_service.service_response = make_response_mock + self.api_service._get_current_route = MagicMock() + self.api_service._get_current_route.return_value = self.api_gateway_route + self.api_service._get_current_route.methods = [] + self.api_service._get_current_route.return_value.payload_format_version = "2.0" + v1_event_mock.return_value = {} + + parse_output_mock = Mock(return_value=("status_code", Headers({"headers": "headers"}), "body")) + self.api_service._parse_v1_payload_format_lambda_output = parse_output_mock + + service_response_mock = Mock(return_value=make_response_mock) + self.api_service.service_response = service_response_mock + + request_mock.return_value = ("test", "test") + + mock_context = {"key": "value"} + invoke_mock.side_effect = [{"context": mock_context}, Mock()] + + unauth_mock = Mock() + service_err_mock.lambda_authorizer_unauthorized.return_value = unauth_mock + + result = self.api_service._request_handler() + self.assertEqual(result, unauth_mock) + + @patch.object(LocalApigwService, "_invoke_lambda_function") + @patch.object(LocalApigwService, "_create_method_arn") + @patch.object(EventTracker, "track_event") + def test_lambda_authorizer_pass_context_http(self, event_mock, method_arn_mock, mock_invoke): + mock_get_context = Mock() + route_event = {} + + auth = LambdaAuthorizer(Mock(), Mock(), "auth_lambda", [], Mock(), Mock(), Mock()) + auth.is_valid_response = Mock(return_value=True) + auth.get_context = Mock(return_value=mock_get_context) + self.http_v2_payload_route.authorizer_object = auth + + self.http_service._invoke_parse_lambda_authorizer(auth, {}, route_event, self.http_v2_payload_route) + self.assertEqual(route_event, {"requestContext": {"authorizer": {"lambda": mock_get_context}}}) + + @patch.object(LocalApigwService, "_invoke_lambda_function") + @patch.object(LocalApigwService, "_create_method_arn") + @patch.object(EventTracker, "track_event") + def test_lambda_authorizer_pass_context_api(self, event_mock, method_arn_mock, mock_invoke): + mock_get_context = Mock() + route_event = {} + + auth = LambdaAuthorizer(Mock(), Mock(), "auth_lambda", [], Mock(), Mock(), Mock()) + auth.is_valid_response = Mock(return_value=True) + auth.get_context = Mock(return_value=mock_get_context) + self.api_gateway_route.authorizer_object = auth + + self.api_service._invoke_parse_lambda_authorizer(auth, {}, route_event, self.api_gateway_route) + self.assertEqual(route_event, {"requestContext": {"authorizer": mock_get_context}}) + class TestApiGatewayModel(TestCase): def setUp(self): @@ -1332,354 +1786,6 @@ def test_lambda_output_json_object_no_status_code(self): self.assertEqual(body, lambda_output) -class TestService_construct_event(TestCase): - def setUp(self): - self.request_mock = Mock() - self.request_mock.endpoint = "endpoint" - self.request_mock.path = "path" - self.request_mock.method = "GET" - self.request_mock.remote_addr = "190.0.0.0" - self.request_mock.host = "190.0.0.1" - self.request_mock.get_data.return_value = b"DATA!!!!" - query_param_args_mock = Mock() - query_param_args_mock.lists.return_value = {"query": ["params"]}.items() - self.request_mock.args = query_param_args_mock - headers_mock = Mock() - headers_mock.keys.return_value = ["Content-Type", "X-Test"] - headers_mock.get.side_effect = ["application/json", "Value"] - headers_mock.getlist.side_effect = [["application/json"], ["Value"]] - self.request_mock.headers = headers_mock - self.request_mock.view_args = {"path": "params"} - self.request_mock.scheme = "http" - environ_dict = {"SERVER_PROTOCOL": "HTTP/1.1"} - self.request_mock.environ = environ_dict - - expected = ( - '{"body": "DATA!!!!", "httpMethod": "GET", ' - '"multiValueQueryStringParameters": {"query": ["params"]}, ' - '"queryStringParameters": {"query": "params"}, "resource": ' - '"endpoint", "requestContext": {"httpMethod": "GET", "requestId": ' - '"c6af9ac6-7b61-11e6-9a41-93e8deadbeef", "path": "endpoint", "extendedRequestId": null, ' - '"resourceId": "123456", "apiId": "1234567890", "stage": null, "resourcePath": "endpoint", ' - '"identity": {"accountId": null, "apiKey": null, "userArn": null, ' - '"cognitoAuthenticationProvider": null, "cognitoIdentityPoolId": null, "userAgent": ' - '"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' - '"190.0.0.0", "user": null}, "accountId": "123456789012", "domainName": "190.0.0.1", ' - '"protocol": "HTTP/1.1"}, "headers": {"Content-Type": ' - '"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' - '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], ' - '"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' - '"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' - '"isBase64Encoded": false}' - ) - - self.expected_dict = json.loads(expected) - - def validate_request_context_and_remove_request_time_data(self, event_json): - request_time = event_json["requestContext"].pop("requestTime", None) - request_time_epoch = event_json["requestContext"].pop("requestTimeEpoch", None) - - self.assertIsInstance(request_time, str) - parsed_request_time = datetime.strptime(request_time, "%d/%b/%Y:%H:%M:%S +0000") - self.assertIsInstance(parsed_request_time, datetime) - - self.assertIsInstance(request_time_epoch, int) - - def test_construct_event_with_data(self): - actual_event_str = LocalApigwService._construct_v_1_0_event(self.request_mock, 3000, binary_types=[]) - - actual_event_json = json.loads(actual_event_str) - self.validate_request_context_and_remove_request_time_data(actual_event_json) - - self.assertEqual(actual_event_json["body"], self.expected_dict["body"]) - - def test_construct_event_no_data(self): - self.request_mock.get_data.return_value = None - - actual_event_str = LocalApigwService._construct_v_1_0_event(self.request_mock, 3000, binary_types=[]) - actual_event_json = json.loads(actual_event_str) - self.validate_request_context_and_remove_request_time_data(actual_event_json) - - self.assertEqual(actual_event_json["body"], None) - - @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._should_base64_encode") - def test_construct_event_with_binary_data(self, should_base64_encode_patch): - should_base64_encode_patch.return_value = True - - binary_body = b"011000100110100101101110011000010111001001111001" # binary in binary - base64_body = base64.b64encode(binary_body).decode("utf-8") - - self.request_mock.get_data.return_value = binary_body - - actual_event_str = LocalApigwService._construct_v_1_0_event(self.request_mock, 3000, binary_types=[]) - actual_event_json = json.loads(actual_event_str) - self.validate_request_context_and_remove_request_time_data(actual_event_json) - - self.assertEqual(actual_event_json["body"], base64_body) - self.assertEqual(actual_event_json["isBase64Encoded"], True) - - def test_event_headers_with_empty_list(self): - request_mock = Mock() - headers_mock = Mock() - headers_mock.keys.return_value = [] - request_mock.headers = headers_mock - request_mock.scheme = "http" - - actual_query_string = LocalApigwService._event_headers(request_mock, "3000") - self.assertEqual( - actual_query_string, - ( - {"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"}, - {"X-Forwarded-Proto": ["http"], "X-Forwarded-Port": ["3000"]}, - ), - ) - - def test_event_headers_with_non_empty_list(self): - request_mock = Mock() - headers_mock = Mock() - headers_mock.keys.return_value = ["Content-Type", "X-Test"] - headers_mock.get.side_effect = ["application/json", "Value"] - headers_mock.getlist.side_effect = [["application/json"], ["Value"]] - request_mock.headers = headers_mock - request_mock.scheme = "http" - - actual_query_string = LocalApigwService._event_headers(request_mock, "3000") - self.assertEqual( - actual_query_string, - ( - { - "Content-Type": "application/json", - "X-Test": "Value", - "X-Forwarded-Proto": "http", - "X-Forwarded-Port": "3000", - }, - { - "Content-Type": ["application/json"], - "X-Test": ["Value"], - "X-Forwarded-Proto": ["http"], - "X-Forwarded-Port": ["3000"], - }, - ), - ) - - def test_query_string_params_with_empty_params(self): - request_mock = Mock() - query_param_args_mock = Mock() - query_param_args_mock.lists.return_value = {}.items() - request_mock.args = query_param_args_mock - - actual_query_string = LocalApigwService._query_string_params(request_mock) - self.assertEqual(actual_query_string, ({}, {})) - - def test_query_string_params_with_param_value_being_empty_list(self): - request_mock = Mock() - query_param_args_mock = Mock() - query_param_args_mock.lists.return_value = {"param": []}.items() - request_mock.args = query_param_args_mock - - actual_query_string = LocalApigwService._query_string_params(request_mock) - self.assertEqual(actual_query_string, ({"param": ""}, {"param": [""]})) - - def test_query_string_params_with_param_value_being_non_empty_list(self): - request_mock = Mock() - query_param_args_mock = Mock() - query_param_args_mock.lists.return_value = {"param": ["a", "b"]}.items() - request_mock.args = query_param_args_mock - - actual_query_string = LocalApigwService._query_string_params(request_mock) - self.assertEqual(actual_query_string, ({"param": "b"}, {"param": ["a", "b"]})) - - def test_query_string_params_v_2_0_with_param_value_being_non_empty_list(self): - request_mock = Mock() - query_param_args_mock = Mock() - query_param_args_mock.lists.return_value = {"param": ["a", "b"]}.items() - request_mock.args = query_param_args_mock - - actual_query_string = LocalApigwService._query_string_params_v_2_0(request_mock) - self.assertEqual(actual_query_string, {"param": "a,b"}) - - -class TestService_construct_event_http(TestCase): - def setUp(self): - self.request_mock = Mock() - self.request_mock.endpoint = "endpoint" - self.request_mock.method = "GET" - self.request_mock.path = "/endpoint" - self.request_mock.get_data.return_value = b"DATA!!!!" - self.request_mock.mimetype = "application/json" - query_param_args_mock = Mock() - query_param_args_mock.lists.return_value = {"query": ["param1", "param2"]}.items() - self.request_mock.args = query_param_args_mock - self.request_mock.query_string = b"query=params" - headers_mock = Mock() - headers_mock.keys.return_value = ["Content-Type", "X-Test"] - headers_mock.get.side_effect = ["application/json", "Value"] - headers_mock.getlist.side_effect = [["application/json"], ["Value"]] - self.request_mock.headers = headers_mock - self.request_mock.remote_addr = "190.0.0.0" - self.request_mock.view_args = {"path": "params"} - self.request_mock.scheme = "http" - cookies_mock = Mock() - cookies_mock.keys.return_value = ["cookie1", "cookie2"] - cookies_mock.get.side_effect = ["test", "test"] - self.request_mock.cookies = cookies_mock - self.request_time_epoch = int(time()) - self.request_time = datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000") - - expected = f""" - {{ - "version": "2.0", - "routeKey": "GET /endpoint", - "rawPath": "/endpoint", - "rawQueryString": "query=params", - "cookies": ["cookie1=test", "cookie2=test"], - "headers": {{ - "Content-Type": "application/json", - "X-Test": "Value", - "X-Forwarded-Proto": "http", - "X-Forwarded-Port": "3000" - }}, - "queryStringParameters": {{"query": "param1,param2"}}, - "requestContext": {{ - "accountId": "123456789012", - "apiId": "1234567890", - "domainName": "localhost", - "domainPrefix": "localhost", - "http": {{ - "method": "GET", - "path": "/endpoint", - "protocol": "HTTP/1.1", - "sourceIp": "190.0.0.0", - "userAgent": "Custom User Agent String" - }}, - "requestId": "", - "routeKey": "GET /endpoint", - "stage": "$default", - "time": \"{self.request_time}\", - "timeEpoch": {self.request_time_epoch} - }}, - "body": "DATA!!!!", - "pathParameters": {{"path": "params"}}, - "stageVariables": null, - "isBase64Encoded": false - }} - """ - - self.expected_dict = json.loads(expected) - - def test_construct_event_with_data(self): - actual_event_str = LocalApigwService._construct_v_2_0_event_http( - self.request_mock, - 3000, - binary_types=[], - route_key="GET /endpoint", - request_time_epoch=self.request_time_epoch, - request_time=self.request_time, - ) - print("DEBUG: json.loads(actual_event_str)", json.loads(actual_event_str)) - print("DEBUG: self.expected_dict", self.expected_dict) - actual_event_dict = json.loads(actual_event_str) - self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) - actual_event_dict["requestContext"]["requestId"] = "" - self.assertEqual(actual_event_dict, self.expected_dict) - - def test_construct_event_no_data(self): - self.request_mock.get_data.return_value = None - self.expected_dict["body"] = None - - actual_event_str = LocalApigwService._construct_v_2_0_event_http( - self.request_mock, - 3000, - binary_types=[], - route_key="GET /endpoint", - request_time_epoch=self.request_time_epoch, - request_time=self.request_time, - ) - actual_event_dict = json.loads(actual_event_str) - self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) - actual_event_dict["requestContext"]["requestId"] = "" - self.assertEqual(actual_event_dict, self.expected_dict) - - def test_v2_route_key(self): - route_key = LocalApigwService._v2_route_key("GET", "/path", False) - self.assertEqual(route_key, "GET /path") - - def test_v2_default_route_key(self): - route_key = LocalApigwService._v2_route_key("GET", "/path", True) - self.assertEqual(route_key, "$default") - - @patch("samcli.local.apigw.local_apigw_service.LocalApigwService._should_base64_encode") - def test_construct_event_with_binary_data(self, should_base64_encode_patch): - should_base64_encode_patch.return_value = True - - binary_body = b"011000100110100101101110011000010111001001111001" # binary in binary - base64_body = base64.b64encode(binary_body).decode("utf-8") - - self.request_mock.get_data.return_value = binary_body - self.expected_dict["body"] = base64_body - self.expected_dict["isBase64Encoded"] = True - self.maxDiff = None - - actual_event_str = LocalApigwService._construct_v_2_0_event_http( - self.request_mock, - 3000, - binary_types=[], - route_key="GET /endpoint", - request_time_epoch=self.request_time_epoch, - request_time=self.request_time, - ) - actual_event_dict = json.loads(actual_event_str) - self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) - actual_event_dict["requestContext"]["requestId"] = "" - self.assertEqual(actual_event_dict, self.expected_dict) - - def test_event_headers_with_empty_list(self): - request_mock = Mock() - headers_mock = Mock() - headers_mock.keys.return_value = [] - request_mock.headers = headers_mock - request_mock.scheme = "http" - - actual_query_string = LocalApigwService._event_http_headers(request_mock, "3000") - self.assertEqual(actual_query_string, {"X-Forwarded-Proto": "http", "X-Forwarded-Port": "3000"}) - - def test_event_headers_with_non_empty_list(self): - request_mock = Mock() - headers_mock = Mock() - headers_mock.keys.return_value = ["Content-Type", "X-Test"] - headers_mock.get.side_effect = ["application/json", "Value"] - headers_mock.getlist.side_effect = [["application/json"], ["Value"]] - request_mock.headers = headers_mock - request_mock.scheme = "http" - - actual_query_string = LocalApigwService._event_http_headers(request_mock, "3000") - self.assertEqual( - actual_query_string, - { - "Content-Type": "application/json", - "X-Test": "Value", - "X-Forwarded-Proto": "http", - "X-Forwarded-Port": "3000", - }, - ) - - -class TestService_should_base64_encode(TestCase): - @parameterized.expand( - [ - param("Mimeyype is in binary types", ["image/gif"], "image/gif"), - param("Mimetype defined and binary types has */*", ["*/*"], "image/gif"), - param("*/* is in binary types with no mimetype defined", ["*/*"], None), - ] - ) - def test_should_base64_encode_returns_true(self, test_case_name, binary_types, mimetype): - self.assertTrue(LocalApigwService._should_base64_encode(binary_types, mimetype)) - - @parameterized.expand([param("Mimetype is not in binary types", ["image/gif"], "application/octet-stream")]) - def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): - self.assertFalse(LocalApigwService._should_base64_encode(binary_types, mimetype)) - - class TestServiceCorsToHeaders(TestCase): def test_basic_conversion(self): cors = Cors( diff --git a/tests/unit/local/docker/test_manager.py b/tests/unit/local/docker/test_manager.py index 2778916634..ada69903ea 100644 --- a/tests/unit/local/docker/test_manager.py +++ b/tests/unit/local/docker/test_manager.py @@ -221,7 +221,7 @@ def test_must_pull_and_print_progress_dots(self): stream = io.StringIO() pull_result = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] self.mock_docker_client.api.pull.return_value = pull_result - expected_stream_output = "\nFetching {} Docker container image...{}\n".format( + expected_stream_output = "\nFetching {}:latest Docker container image...{}\n".format( self.image_name, "." * len(pull_result) # Progress bar will print one dot per response from pull API )