From abbf8e2c4c5ac9e878451c0589d7f4223761ab64 Mon Sep 17 00:00:00 2001 From: Michael Ketchel Date: Thu, 7 Nov 2024 16:13:30 -0500 Subject: [PATCH] More cleanup, and armored up the network API against command injection via interface name --- wlanpi_core/__main__.py | 7 +- .../api/api_v1/endpoints/network_api.py | 91 ++++++++++++++----- wlanpi_core/models/network/wlan/wlan_dbus.py | 9 +- .../services/network_ethernet_service.py | 6 +- wlanpi_core/services/network_info_service.py | 2 +- wlanpi_core/services/system_service.py | 6 +- wlanpi_core/utils/general.py | 40 ++++++-- wlanpi_core/utils/network.py | 15 ++- 8 files changed, 129 insertions(+), 47 deletions(-) diff --git a/wlanpi_core/__main__.py b/wlanpi_core/__main__.py index 3e16003..1cf69f3 100644 --- a/wlanpi_core/__main__.py +++ b/wlanpi_core/__main__.py @@ -18,6 +18,7 @@ import os import platform import sys +from typing import Union # third party imports import uvicorn @@ -26,7 +27,7 @@ from .__version__ import __version__ -def port(port) -> int: +def check_port(port: Union[int, str]) -> int: """Check if the provided port is valid""" try: # make sure port is an int @@ -52,7 +53,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--reload", dest="livereload", action="store_true", default=False ) - parser.add_argument("--port", "-p", dest="port", type=port, default=8000) + parser.add_argument("--port", "-p", dest="port", type=check_port, default=8000) parser.add_argument( "--version", "-V", "-v", action="version", version=f"{__version__}" @@ -111,7 +112,7 @@ def init() -> None: ) if __name__ == "__main__": - return sys.exit(main()) + sys.exit(main()) init() diff --git a/wlanpi_core/api/api_v1/endpoints/network_api.py b/wlanpi_core/api/api_v1/endpoints/network_api.py index b1ded7f..ae532bf 100644 --- a/wlanpi_core/api/api_v1/endpoints/network_api.py +++ b/wlanpi_core/api/api_v1/endpoints/network_api.py @@ -14,25 +14,49 @@ SupplicantNetwork, ) from wlanpi_core.services import network_ethernet_service, network_service +from wlanpi_core.utils.network import list_ethernet_interfaces, list_wlan_interfaces router = APIRouter() log = logging.getLogger("uvicorn") +def validate_wlan_interface(interface: Optional[str], required: bool = True) -> None: + if (required or interface is not None) and interface not in list_wlan_interfaces(): + raise ValidationError( + f"Invalid/unavailable interface specified: #{interface}", status_code=400 + ) + + +def validate_ethernet_interface( + interface: Optional[str], required: bool = True +) -> None: + if ( + required or interface is not None + ) and interface not in list_ethernet_interfaces(): + raise ValidationError( + f"Invalid/unavailable interface specified: #{interface}", status_code=400 + ) + + ################################ # General Network Management # ################################ @router.get("/interfaces", response_model=dict[str, list[IPInterface]]) @router.get("/interfaces/{interface}", response_model=dict[str, list[IPInterface]]) -async def show_all_interfaces(interface: Optional[str] = None): +async def show_all_interfaces( + interface: Optional[str] = None, +): """ Returns all network interfaces. """ - if interface and interface.lower() == "all": - interface = None try: + if interface and interface.lower() == "all": + interface = None + else: + validate_ethernet_interface(interface, required=False) + return await network_ethernet_service.get_interfaces(interface=interface) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) @@ -52,10 +76,12 @@ async def show_all_ethernet_interfaces(interface: Optional[str] = None): """ Returns all ethernet interfaces. """ - if interface and interface.lower() == "all": - interface = None try: + if interface and interface.lower() == "all": + interface = None + else: + validate_ethernet_interface(interface, required=False) def filterfunc(i): iface_obj = i.model_dump() @@ -96,24 +122,29 @@ async def show_all_ethernet_vlans( """ Returns all VLANS for a given ethernet interface. """ - custom_filter = lambda i: True - if not interface or interface.lower() == "all": - interface = None - if vlan and vlan.lower() == "all": - vlan = None - if vlan and vlan.lower() != "all": - - def filterfunc(i): - return i.model_dump().get("linkinfo", {}).get( - "info_kind" - ) == "vlan" and i.model_dump().get("linkinfo", {}).get("info_data", {}).get( - "id" - ) == int( - vlan - ) - - custom_filter = filterfunc try: + custom_filter = lambda i: True + if not interface or interface.lower() == "all": + interface = None + else: + validate_ethernet_interface(interface, required=False) + if vlan and vlan.lower() == "all": + vlan = None + if vlan and vlan.lower() != "all": + + def filterfunc(i): + return i.model_dump().get("linkinfo", {}).get( + "info_kind" + ) == "vlan" and i.model_dump().get("linkinfo", {}).get( + "info_data", {} + ).get( + "id" + ) == int( + vlan + ) + + custom_filter = filterfunc + return await network_ethernet_service.get_vlans( interface=interface, custom_filter=custom_filter ) @@ -147,6 +178,7 @@ async def create_ethernet_vlan( return Response(content=ve.error_msg, status_code=ve.status_code) try: + validate_ethernet_interface(interface, required=True) await network_ethernet_service.remove_vlan( interface=interface, vlan_id=vlan, allow_missing=True ) @@ -187,6 +219,7 @@ async def delete_ethernet_vlan( return Response(content=ve.error_msg, status_code=ve.status_code) try: + validate_ethernet_interface(interface, required=True) await network_ethernet_service.remove_vlan( interface=interface, vlan_id=vlan, allow_missing=allow_missing ) @@ -210,7 +243,7 @@ async def delete_ethernet_vlan( @router.get("/wlan/interfaces", response_model=network.Interfaces) -async def get_wireless_interfaces(timeout: int = API_DEFAULT_TIMEOUT): +async def get_all_wireless_interfaces(timeout: int = API_DEFAULT_TIMEOUT): """ Queries wpa_supplicant via dbus to get all interfaces known to the supplicant. """ @@ -237,6 +270,7 @@ async def do_wireless_network_scan( """ try: + validate_wlan_interface(interface) return await network_service.get_wireless_network_scan_async( scan_type, interface, timeout ) @@ -258,6 +292,7 @@ async def add_wireless_network( """ try: + validate_wlan_interface(interface) return await network_service.add_wireless_network( interface, setup.netConfig, setup.removeAllFirst, timeout ) @@ -281,6 +316,7 @@ async def get_current_wireless_network_details( """ try: + validate_wlan_interface(interface) return await network_service.get_current_wireless_network_details( interface, timeout ) @@ -302,6 +338,7 @@ async def get_all_wireless_networks(interface: str, timeout: int = API_DEFAULT_T """ try: + validate_wlan_interface(interface) return await network_service.networks(interface) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) @@ -321,6 +358,7 @@ async def get_current_network(interface: str, timeout: int = API_DEFAULT_TIMEOUT """ try: + validate_wlan_interface(interface) return await network_service.current_network(interface) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) @@ -340,6 +378,7 @@ async def get_wireless_network(interface: str, network_id: int): """ try: + validate_wlan_interface(interface) return await network_service.get_network(interface, network_id) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) @@ -359,6 +398,7 @@ async def remove_all_wireless_networks(interface: str): """ try: + validate_wlan_interface(interface) return await network_service.remove_all_networks(interface) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) @@ -380,6 +420,7 @@ async def disconnect_wireless_network( """ try: + validate_wlan_interface(interface) return await network_service.disconnect_wireless_network(interface, timeout) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) @@ -399,6 +440,7 @@ async def remove_wireless_network(interface: str, network_id: int): """ try: + validate_wlan_interface(interface) return await network_service.remove_network( interface, network_id, @@ -426,7 +468,10 @@ async def get_interface_details(interface: Optional[str] = None): """ Gets interface details via iw. """ + try: + validate_wlan_interface(interface, required=False) + return await network_service.interface_details(interface) except ValidationError as ve: return Response(content=ve.error_msg, status_code=ve.status_code) diff --git a/wlanpi_core/models/network/wlan/wlan_dbus.py b/wlanpi_core/models/network/wlan/wlan_dbus.py index f63f5e0..27c99db 100644 --- a/wlanpi_core/models/network/wlan/wlan_dbus.py +++ b/wlanpi_core/models/network/wlan/wlan_dbus.py @@ -12,6 +12,7 @@ from wlanpi_core.models.network.wlan.exceptions import WlanDBUSInterfaceException from wlanpi_core.models.network.wlan.wlan_dbus_interface import WlanDBUSInterface from wlanpi_core.utils.general import run_command +from wlanpi_core.utils.network import list_wlan_interfaces class WlanDBUS: @@ -48,15 +49,9 @@ def get_interface(self, interface) -> WlanDBUSInterface: self.interfaces[interface] = new_interface return self.interfaces[interface] - @staticmethod - def _fetch_system_interfaces() -> list[str]: - return run_command( - "ls /sys/class/ieee80211/*/device/net/", shell=True - ).grep_stdout_for_string("/", negate=True, split=True) - def fetch_interfaces(self, wpas_obj): available_interfaces = [] - for system_interface in self._fetch_system_interfaces(): + for system_interface in list_wlan_interfaces(): try: self.get_interface(system_interface) except WlanDBUSInterfaceException as e: diff --git a/wlanpi_core/services/network_ethernet_service.py b/wlanpi_core/services/network_ethernet_service.py index 9e84f45..0f13771 100644 --- a/wlanpi_core/services/network_ethernet_service.py +++ b/wlanpi_core/services/network_ethernet_service.py @@ -3,7 +3,7 @@ from ..models.network import common from ..models.network.vlan import LiveVLANs from ..models.network.vlan.vlan_file import VLANFile -from ..schemas.network.network import IPInterfaceAddress +from ..schemas.network.network import IPInterface, IPInterfaceAddress from ..schemas.network.types import CustomIPInterfaceFilter # https://man.cx/interfaces(5) @@ -53,10 +53,10 @@ async def remove_vlan(interface: str, vlan_id: Union[str, int], allow_missing=Fa async def get_interfaces( - interface: str, + interface: Optional[str], allow_missing=False, custom_filter: Optional[CustomIPInterfaceFilter] = None, -): +) -> dict[str, list[IPInterface]]: """ Returns definitions for all network interfaces known by the `ip` command. """ diff --git a/wlanpi_core/services/network_info_service.py b/wlanpi_core/services/network_info_service.py index 2f7418a..54e05e1 100644 --- a/wlanpi_core/services/network_info_service.py +++ b/wlanpi_core/services/network_info_service.py @@ -123,7 +123,7 @@ def show_wlan_interfaces(): try: interfaces = run_command( - f"{IW_FILE} dev 2>&1", shell=True + f"{IW_FILE} dev 2>&1", shell=True, use_shlex=False ).grep_stdout_for_pattern(r"interface", flags=re.I, split=True) interfaces = map(lambda x: x.strip().split(" ")[1], interfaces) except Exception as e: diff --git a/wlanpi_core/services/system_service.py b/wlanpi_core/services/system_service.py index 7daeec4..5175265 100644 --- a/wlanpi_core/services/system_service.py +++ b/wlanpi_core/services/system_service.py @@ -176,14 +176,14 @@ def get_stats(): # determine mem useage cmd = "free -m | awk 'NR==2{printf \"%s/%sMB %.2f%%\", $3,$2,$3*100/$2 }'" try: - MemUsage = run_command(cmd, shell=True).stdout.strip() + MemUsage = run_command(cmd, shell=True, use_shlex=False).stdout.strip() except Exception: MemUsage = "unknown" # determine disk util cmd = 'df -h | awk \'$NF=="/"{printf "%d/%dGB %s", $3,$2,$5}\'' try: - Disk = run_command(cmd, shell=True).stdout.strip() + Disk = run_command(cmd, shell=True, use_shlex=False).stdout.strip() except Exception: Disk = "unknown" @@ -200,7 +200,7 @@ def get_stats(): # determine uptime cmd = "uptime -p | sed -r 's/up|,//g' | sed -r 's/\s*week[s]?/w/g' | sed -r 's/\s*day[s]?/d/g' | sed -r 's/\s*hour[s]?/h/g' | sed -r 's/\s*minute[s]?/m/g'" try: - uptime = run_command(cmd, shell=True).stdout.strip() + uptime = run_command(cmd, shell=True, use_shlex=False).stdout.strip() except Exception: uptime = "unknown" diff --git a/wlanpi_core/utils/general.py b/wlanpi_core/utils/general.py index 9f01ff1..e40fc78 100644 --- a/wlanpi_core/utils/general.py +++ b/wlanpi_core/utils/general.py @@ -18,6 +18,7 @@ def run_command( stdin: Optional[TextIO] = None, shell=False, raise_on_fail=True, + use_shlex=True, ) -> CommandResult: """Run a single CLI command with subprocess and returns the output""" """ @@ -35,6 +36,8 @@ def run_command( If True, then the entire command string will be executed in a shell. Otherwise, the command and its arguments are executed separately. raise_on_fail: Whether to raise an error if the command fails or not. Default is True. + shlex: If shlex should be used to protect input. Set to false if you need support + for some shell features like wildcards. Returns: A CommandResult object containing the output of the command, along with a boolean indicating @@ -50,13 +53,15 @@ def run_command( error_msg="You cannot use both 'input' and 'stdin' on the same call.", return_code=-1, ) - - # Todo: explore using shlex to always split to protect against injections + if not use_shlex: + logging.getLogger().warning( + f"shlex protection disabled for command--make sure this command is otherwise protected from injections:\n {cmd}" + ) if shell: # If a list was passed in shell mode, safely join using shlex to protect against injection. if isinstance(cmd, list): cmd: list - cmd: str = shlex.join(cmd) + cmd: str = shlex.join(cmd) if use_shlex else " ".join(cmd) cmd: str logging.getLogger().warning( f"Command {cmd} being run as a shell script. This could present " @@ -66,7 +71,7 @@ def run_command( # If a string was passed in non-shell mode, safely split it using shlex to protect against injection. if isinstance(cmd, str): cmd: str - cmd: list[str] = shlex.split(cmd) + cmd: list[str] = shlex.split(cmd) if use_shlex else cmd.split() cmd: list[str] with subprocess.Popen( cmd, @@ -94,6 +99,7 @@ async def run_command_async( stdin: Optional[TextIO] = None, shell=False, raise_on_fail=True, + use_shlex=True, ) -> CommandResult: """Run a single CLI command with subprocess and returns the output""" """ @@ -111,6 +117,8 @@ async def run_command_async( If True, then the entire command string will be executed in a shell. Otherwise, the command and its arguments are executed separately. raise_on_fail: Whether to raise an error if the command fails or not. Default is True. + shlex: If shlex should be used to protect input. Set to false if you need support + for some shell features like wildcards. Returns: A CommandResult object containing the output of the command, along with a boolean indicating @@ -126,6 +134,26 @@ async def run_command_async( error_msg="You cannot use both 'input' and 'stdin' on the same call.", return_code=-1, ) + if not use_shlex: + logging.getLogger().warning( + f"shlex protection disabled for command--make sure this command is otherwise protected from injections:\n {cmd}" + ) + if shell: + # If a list was passed in shell mode, safely join using shlex to protect against injection. + if isinstance(cmd, list): + cmd: list + cmd: str = shlex.join(cmd) if use_shlex else " ".join(cmd) + cmd: str + logging.getLogger().warning( + f"Command {cmd} being run as a shell script. This could present " + f"an injection vulnerability. Consider whether you really need to do this." + ) + else: + # If a string was passed in non-shell mode, safely split it using shlex to protect against injection. + if isinstance(cmd, str): + cmd: str + cmd: list[str] = shlex.split(cmd) if use_shlex else cmd.split() + cmd: list[str] # Prepare input data for communicate if input: @@ -143,7 +171,7 @@ async def run_command_async( # If a list was passed in shell mode, safely join using shlex to protect against injection. if isinstance(cmd, list): cmd: list - cmd: str = shlex.join(cmd) + cmd: str = use_shlex.join(cmd) cmd: str logging.getLogger().warning( f"Command {cmd} being run as a shell script. This could present " @@ -162,7 +190,7 @@ async def run_command_async( # If a string was passed in non-shell mode, safely split it using shlex to protect against injection. if isinstance(cmd, str): cmd: str - cmd: list[str] = shlex.split(cmd) + cmd: list[str] = use_shlex.split(cmd) cmd: list[str] proc = await asyncio.subprocess.create_subprocess_exec( cmd[0], diff --git a/wlanpi_core/utils/network.py b/wlanpi_core/utils/network.py index 7b9a2df..ab13ccd 100644 --- a/wlanpi_core/utils/network.py +++ b/wlanpi_core/utils/network.py @@ -186,6 +186,19 @@ def get_phy_interface_name(phy_num: int) -> Optional[str]: return None +def list_wlan_interfaces() -> list[str]: + return run_command( # type: ignore + ["ls", "-1", "/sys/class/ieee80211/*/device/net/"], use_shlex=False, shell=True + ).grep_stdout_for_pattern(r"^$|/", negate=True, split=True) + + +def list_ethernet_interfaces() -> list[str]: + res = run_command( # type: ignore + ["ls", "-1", "/sys/class/net/*/device/net/"], use_shlex=False, shell=True + ).grep_stdout_for_pattern(r"^$|/", negate=True, split=True) + return [x for x in res if "eth" in x] + + def get_wlan_channels(interface: str) -> list[WlanChannelInfo]: phy = get_interface_phy_num(interface) if phy is None: @@ -316,4 +329,4 @@ def get_interface_mac(interface: str) -> str: if __name__ == "__main__": - print(json.dumps(get_interface_details())) + print(list_wlan_interfaces())