Skip to content

Commit

Permalink
add: secure_filename for handle path safely
Browse files Browse the repository at this point in the history
  • Loading branch information
Chenwe_i_lin committed Mar 2, 2020
1 parent 069ad7b commit 51bbe48
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
41 changes: 40 additions & 1 deletion mirai/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import typing as T
import random
from .logger import Protocol
import os
import re

def assertOperatorSuccess(result, raise_exception=False, return_as_is=False):
if "code" in result:
Expand Down Expand Up @@ -46,6 +48,21 @@ class ImageType(Enum):
"friend": r"(?<=/)([0-9a-z]{8})\-([0-9a-z]{4})-([0-9a-z]{4})-([0-9a-z]{4})-([0-9a-z]{12})"
}

_windows_device_files = (
"CON",
"AUX",
"COM1",
"COM2",
"COM3",
"COM4",
"LPT1",
"LPT2",
"LPT3",
"PRN",
"NUL",
)
_filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]")

def getMatchedString(regex_result):
if regex_result:
return regex_result.string[slice(*regex_result.span())]
Expand Down Expand Up @@ -84,4 +101,26 @@ async def warpper(*args, **kwargs):
except Exception as e:
Protocol.error(f"protocol method {func.__name__} raised a error: {e.__class__.__name__}")
raise e
return warpper
return warpper

def secure_filename(filename):
if isinstance(filename, str):
from unicodedata import normalize

filename = normalize("NFKD", filename).encode("ascii", "ignore")
filename = filename.decode("ascii")

for sep in os.path.sep, os.path.altsep:
if sep:
filename = filename.replace(sep, " ")

filename = \
str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip("._")

if (
os.name == "nt" and filename and \
filename.split(".")[0].upper() in _windows_device_files
):
filename = "_" + filename

return filename
6 changes: 3 additions & 3 deletions mirai/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,10 @@ async def main_entrance(self, run_body, event_context, queue):
try:
if isinstance(run_body, dict):
middlewares = run_body.get("middlewares")

async_middlewares = []
normal_middlewares = []
if middlewares:
async_middlewares = []
normal_middlewares = []

for middleware in middlewares:
if all([
hasattr(middleware, "__aenter__"),
Expand Down

0 comments on commit 51bbe48

Please sign in to comment.