diff --git a/mirai/misc.py b/mirai/misc.py index 843852c..2aae500 100644 --- a/mirai/misc.py +++ b/mirai/misc.py @@ -4,6 +4,8 @@ import typing as T import random from .logger import Protocol +import os +import re def assertOperatorSuccess(result, raise_exception=False, return_as_is=False): if "code" in result: @@ -46,6 +48,21 @@ class ImageType(Enum): "friend": r"(?<=/)([0-9a-z]{8})\-([0-9a-z]{4})-([0-9a-z]{4})-([0-9a-z]{4})-([0-9a-z]{12})" } +_windows_device_files = ( + "CON", + "AUX", + "COM1", + "COM2", + "COM3", + "COM4", + "LPT1", + "LPT2", + "LPT3", + "PRN", + "NUL", +) +_filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") + def getMatchedString(regex_result): if regex_result: return regex_result.string[slice(*regex_result.span())] @@ -84,4 +101,26 @@ async def warpper(*args, **kwargs): except Exception as e: Protocol.error(f"protocol method {func.__name__} raised a error: {e.__class__.__name__}") raise e - return warpper \ No newline at end of file + return warpper + +def secure_filename(filename): + if isinstance(filename, str): + from unicodedata import normalize + + filename = normalize("NFKD", filename).encode("ascii", "ignore") + filename = filename.decode("ascii") + + for sep in os.path.sep, os.path.altsep: + if sep: + filename = filename.replace(sep, " ") + + filename = \ + str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip("._") + + if ( + os.name == "nt" and filename and \ + filename.split(".")[0].upper() in _windows_device_files + ): + filename = "_" + filename + + return filename \ No newline at end of file diff --git a/mirai/session.py b/mirai/session.py index d56bf6d..763d1ee 100644 --- a/mirai/session.py +++ b/mirai/session.py @@ -288,10 +288,10 @@ async def main_entrance(self, run_body, event_context, queue): try: if isinstance(run_body, dict): middlewares = run_body.get("middlewares") - - async_middlewares = [] - normal_middlewares = [] if middlewares: + async_middlewares = [] + normal_middlewares = [] + for middleware in middlewares: if all([ hasattr(middleware, "__aenter__"),