Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Check if AWS S3 bucket is public before generate signed url (M2-8020) #1640

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ name = "pypi"
[packages]
aio-pika = "==9.4.3"
aiofiles = "==24.1.0"
aiohttp = "==3.10.9"
aiohttp = "==3.10.10"
alembic = "==1.13.3"
asyncpg = "==0.29.0"
azure-storage-blob = "==12.22.0"
asyncpg = "==0.30.0"
azure-storage-blob = "==12.23.1"
bcrypt = "==4.2.0"
boto3 = "==1.26.10"
fastapi = "==0.110.3"
boto3 = "==1.35.47"
fastapi = "==0.115.3"
# The latest version of the fastapi is not taken because of the issue
# starlette version for those deps ==0.21.0
# with fastapi-mail that requires 0.21 < starlette < 0.22
Expand Down Expand Up @@ -45,7 +45,7 @@ pyOpenSSL = "==24.2.1"
pydantic = { extras = ["email"], version = "==1.10.18" }
pymongo = "*"
python-multipart = "==0.0.12"
redis = "==5.1.0"
redis = "==5.1.1"
sentry-sdk = "~=2.13"
sqlalchemy = { extras = ["asyncio"], version = "==1.4.53" }
sqlalchemy-utils = "==0.41.2"
Expand All @@ -54,14 +54,14 @@ taskiq-aio-pika = "==0.4.1"
taskiq-fastapi = "==0.3.2"
taskiq-redis = "==1.0.2"
typer = "==0.12.5"
uvicorn = { extras = ["standard"], version = "==0.30.6" }
uvicorn = { extras = ["standard"], version = "==0.32.0" }
pyjwt = "==2.9.0"

[dev-packages]
ipdb = "==0.13.13"
pudb = "==2024.1.2"
pre-commit = "==3.8.0"
ruff = "==0.6.8"
pudb = "==2024.1.3"
pre-commit = "==4.0.1"
ruff = "==0.7.0"
allure-pytest = "==2.13.5"
pydantic-factories = "==1.17.3"
pytest = "==8.3.3"
Expand All @@ -72,11 +72,11 @@ pytest-lazy-fixtures = "==1.1.1"
pytest-mock = "==3.14.0"
nest-asyncio = "==1.6.0"
gevent = "==24.2.1"
mypy = "==1.11.2"
types-python-dateutil = "==2.9.0.20240906"
mypy = "==1.13.0"
types-python-dateutil = "==2.9.0.20241003"
typing-extensions = "==4.12.2"
types-requests = "==2.32.0.20240914"
types-pytz = "==2024.2.0.20240913"
types-requests = "==2.32.0.20241016"
types-pytz = "==2024.2.0.20241003"
types-aiofiles = "==24.1.0.20240626"
types-cachetools = "==5.5.0.20240820"
greenlet = "==3.1.0"
Expand Down
1,539 changes: 763 additions & 776 deletions Pipfile.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ env = [
"DATABASE__USER=postgres",
"DATABASE__DB=test",
"ARBITRARY_DB=test_arbitrary",
"TASK_ANSWER_ENCRYPTION__BATCH_LIMIT=1"
"TASK_ANSWER_ENCRYPTION__BATCH_LIMIT=1",
"CDN__LEGACY_REGION=us-east-1",
"CDN__LEGACY_BUCKET=testing"
]

[tool.coverage.run]
Expand Down
4 changes: 2 additions & 2 deletions src/apps/activity_assignments/tests/test_assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,13 @@ async def test_assignment_list_by_applet_with_delete_subject(
assert response.json()["result"]["appletId"] == str(applet_one_with_flow.id)
assignments = response.json()["result"]["assignments"]
assert len(assignments) == 2
assignment = assignments[0]
assignment = [a for a in assignments if a["activityId"] is not None][0]
assert assignment["activityId"] == str(applet_one_with_flow.activities[0].id)
assert assignment["respondentSubjectId"] == str(tom_applet_one_subject.id)
assert assignment["targetSubjectId"] == str(tom_applet_one_subject.id)
assert assignment["activityFlowId"] is None
assert assignment["id"] == assignment_activity["id"]
assignment = assignments[1]
assignment = [a for a in assignments if a["activityFlowId"] is not None][0]
assert assignment["activityId"] is None
assert assignment["respondentSubjectId"] == str(tom_applet_one_subject.id)
assert assignment["targetSubjectId"] == str(lucy_applet_one_subject.id)
Expand Down
2 changes: 1 addition & 1 deletion src/apps/answers/crud/answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ async def get_completed_answers_data_list(

return result_list

async def get_latest_applet_version(self, applet_id: uuid.UUID) -> str:
async def get_latest_applet_version(self, applet_id: uuid.UUID) -> str | None:
query: Query = select(AnswerSchema.applet_history_id)
query = query.where(AnswerSchema.applet_id == applet_id)
query = query.order_by(AnswerSchema.version.desc())
Expand Down
19 changes: 11 additions & 8 deletions src/apps/file/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,31 +59,34 @@ async def get_legacy_client(self, info: WorkspaceArbitrary | None) -> CDNClient:
else:
return await self.get_regular_client()

async def _presign(self, url: str | None):
async def _presign(self, url: str | None, legacy_cdn_client: CDNClient, regular_cdn_client: CDNClient):
if not url:
return

if self._is_legacy_file_url_format(url):
if not await self._check_access_to_legacy_url(url):
return url
key = self._get_key(url)
wsp_service = workspace.WorkspaceService(self.session, self.user_id)
arbitrary_info = await wsp_service.get_arbitrary_info_if_use_arbitrary(self.applet_id)
legacy_cdn_client = await self.get_legacy_client(arbitrary_info)
if legacy_cdn_client.is_bucket_public() or await legacy_cdn_client.is_object_public(key):
return await legacy_cdn_client.generate_public_url(key)
return await legacy_cdn_client.generate_presigned_url(key)
elif self._is_regular_file_url_format(url):
if not await self._check_access_to_regular_url(url):
return url
key = self._get_key(url)
client = await self.get_regular_client()
return await client.generate_presigned_url(key)
return await regular_cdn_client.generate_presigned_url(key)
else:
return url

async def presign(self, urls: List[str | None]) -> List[str]:
c_list = []
wsp_service = workspace.WorkspaceService(self.session, self.user_id)
arbitrary_info = await wsp_service.get_arbitrary_info_if_use_arbitrary(self.applet_id)
legacy_cdn_client = await self.get_legacy_client(arbitrary_info)
regular_cdn_client = await self.get_regular_client()

for url in urls:
c_list.append(self._presign(url))
c_list.append(self._presign(url, legacy_cdn_client, regular_cdn_client))
result = await asyncio.gather(*c_list)
return result

Expand Down Expand Up @@ -136,7 +139,7 @@ class GCPPresignService(S3PresignService):
r"gs:\/\/[a-zA-Z0-9.-]+\/[a-zA-Z0-9-]+\/[a-zA-Z0-9-]+\/[a-f0-9-]+\/[a-f0-9-]+\/[a-zA-Z0-9-]+" # noqa
)

async def _presign(self, url: str | None):
async def _presign(self, url: str | None, *kwargs):
regular_cdn_client = await select_storage(applet_id=self.applet_id, session=self.session)

if self._is_legacy_file_url_format(url):
Expand Down
2 changes: 1 addition & 1 deletion src/apps/workspaces/crud/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def get_arbitraries_map_by_applet_ids(self, applet_ids: list[uuid.UUID]) -
db_result = await self._execute(query)
res = db_result.scalars().all()

user_arb_uri_map: dict[uuid.UUID, str] = dict()
user_arb_uri_map: dict[uuid.UUID, str | None] = dict()
for user_workspace in res:
user_arb_uri_map[user_workspace.user_id] = (
user_workspace.database_uri if user_workspace.use_arbitrary else None
Expand Down
12 changes: 6 additions & 6 deletions src/infrastructure/utility/cdn_arbitrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from typing import BinaryIO

import boto3
import botocore
from azure.storage.blob import BlobSasPermissions, BlobServiceClient, generate_blob_sas
from botocore.config import Config

from infrastructure.utility.cdn_client import CDNClient
from infrastructure.utility.cdn_config import CdnConfig


class ArbitraryS3CdnClient(CDNClient):
def configure_client(self, config: CdnConfig):
client_config = botocore.config.Config(
def configure_client(self, config: CdnConfig, signature_version=None):
client_config = Config(
max_pool_connections=25,
)
return boto3.client(
Expand All @@ -32,8 +32,8 @@ def __init__(self, config: CdnConfig, endpoint_url: str, env: str, *, max_concur
def generate_private_url(self, key):
return f"gs://{self.config.bucket}/{key}"

def configure_client(self, config):
client_config = botocore.config.Config(
def configure_client(self, config, signature_version=None):
client_config = Config(
max_pool_connections=25,
)
return boto3.client(
Expand All @@ -58,7 +58,7 @@ def generate_key(cls, scope, unique, filename):
def generate_private_url(self, key):
return f"https://{self.config.bucket}.blob.core.windows.net/mindlogger/{key}" # noqa

def configure_client(self, _):
def configure_client(self, _, **kwargs):
blob_service_client = BlobServiceClient.from_connection_string(self.sec_key)
with suppress(Exception):
blob_service_client.create_container(self.default_container_name)
Expand Down
74 changes: 71 additions & 3 deletions src/infrastructure/utility/cdn_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
import http
import io
import json
import mimetypes
from concurrent.futures import ThreadPoolExecutor
from typing import BinaryIO

import boto3
import botocore
import httpx
from botocore import UNSIGNED
from botocore.config import Config
from botocore.exceptions import ClientError, EndpointConnectionError

from apps.file.errors import FileNotFoundError
Expand Down Expand Up @@ -35,17 +37,20 @@ def __init__(self, config: CdnConfig, env: str, *, max_concurrent_tasks: int = 1
# semaphore for concurrent calls of urlib3 in boto3
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)

self._is_bucket_public: bool | None = None

@classmethod
def generate_key(cls, scope, unique, filename):
return f"{cls.default_container_name}/{scope}/{unique}/{filename}"

def generate_private_url(self, key):
return f"s3://{self.config.bucket}/{key}"

def configure_client(self, config):
def configure_client(self, config, signature_version=None):
assert config, "set CDN"
client_config = botocore.config.Config(
client_config = Config(
max_pool_connections=25,
signature_version=signature_version,
)

if config.access_key and config.secret_key:
Expand Down Expand Up @@ -109,6 +114,18 @@ def download(self, key, file: BinaryIO | None = None):
media_type = mimetypes.guess_type(key)[0] if mimetypes.guess_type(key)[0] else "application/octet-stream"
return file, media_type

def _generate_public_url(self, key):
client = self.configure_client(config=self.config, signature_version=UNSIGNED)
url = client.generate_presigned_url(
"get_object",
Params={
"Bucket": self.config.bucket,
"Key": key,
},
ExpiresIn=0,
)
return url

def _generate_presigned_url(self, key):
url = self.client.generate_presigned_url(
"get_object",
Expand All @@ -126,6 +143,12 @@ async def generate_presigned_url(self, key):
url = await asyncio.wrap_future(future)
return url

async def generate_public_url(self, key):
with ThreadPoolExecutor() as executor:
future = executor.submit(self._generate_public_url, key)
url = await asyncio.wrap_future(future)
return url

async def delete_object(self, key: str | None):
async with self.semaphore:
with ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -192,3 +215,48 @@ async def check(self):
except httpx.HTTPError as e:
logger.info("File upload error")
raise e

def _check_is_bucket_public(self) -> bool:
# Check the bucket policy
try:
bucket_policy = self.client.get_bucket_policy(Bucket=self.config.bucket)
if policy := bucket_policy.get("Policy"):
policy_statements: list = json.loads(policy)["Statement"]

for statement in policy_statements:
if statement["Effect"] == "Allow" and "Principal" in statement and statement["Principal"] == "*":
return True # Bucket policy allows public access
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchBucketPolicy":
logger.error(f"Error getting bucket policy: {e}")
except Exception as e:
logger.error(f"Error getting bucket policy: {e}")

return False # No public access found

def is_bucket_public(self) -> bool:
if self._is_bucket_public is None:
self._is_bucket_public = self._check_is_bucket_public()

return self._is_bucket_public

def _is_object_public(self, key) -> bool:
# Check the object's ACL
try:
acl = self.client.get_object_acl(Bucket=self.config.bucket, Key=key)
for grant in acl["Grants"]:
if (
grant["Grantee"].get("Type") == "Group"
and grant["Grantee"].get("URI") == "http://acs.amazonaws.com/groups/global/AllUsers"
):
return True # Object is publicly accessible
except (ClientError, Exception) as e:
logger.error(f"Error getting object ACL: {e}")
return False

return False # No public access found

async def is_object_public(self, key) -> bool:
with ThreadPoolExecutor() as executor:
future = executor.submit(self._is_object_public, key)
return await asyncio.wrap_future(future)
Loading