diff --git a/app/utils/security.py b/app/utils/security.py index 098f59499..f73679150 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -40,7 +40,7 @@ def is_safe_path(base_path: Path, user_path: Path, return False @staticmethod - def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = True) -> bool: + def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = False) -> bool: """ 验证URL是否在允许的域名列表中,包括带有端口的域名。 @@ -53,7 +53,7 @@ def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: b # 解析URL parsed_url = urlparse(url) - # 检查URL的scheme和netloc + # 如果 URL 没有包含有效的 scheme,或者无法从中提取到有效的 netloc,则认为该 URL 是无效的 if not parsed_url.scheme or not parsed_url.netloc: return False @@ -63,22 +63,22 @@ def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: b # 获取完整的 netloc(包括 IP 和端口)并转换为小写 netloc = parsed_url.netloc.lower() - allowed_domains = {d.lower() for d in allowed_domains} - if not netloc: return False - if strict: - # 严格匹配一级域名,要求完全匹配或者子域名精确匹配 - domain_parts = netloc.split(".") - for allowed_domain in allowed_domains: - allowed_parts = allowed_domain.split(".") - if domain_parts[-len(allowed_parts):] == allowed_parts: + # 检查每个允许的域名 + allowed_domains = {d.lower() for d in allowed_domains} + for domain in allowed_domains: + parsed_allowed_url = urlparse(domain) + allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path + + if strict: + # 严格模式下,要求完全匹配域名和端口 + if netloc == allowed_netloc: return True - else: - # 允许匹配多级域名,或者完全匹配的 netloc(包括 IP:port) - for allowed_domain in allowed_domains: - if netloc == allowed_domain or netloc.endswith(f".{allowed_domain}"): + else: + # 非严格模式下,允许子域名匹配 + if netloc == allowed_netloc or netloc.endswith('.' + allowed_netloc): return True return False