Skip to content

Commit

Permalink
Use from_dict to sanitize XML inputs to Dataclasses (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
shbatm authored Jun 18, 2023
1 parent 0628ec6 commit f7c65c8
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyisyox/events/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def parse_message(self, msg: str) -> None:
return
if event := xml_dict.get("event", {}):
try:
self.route_message(EventData(**event))
self.route_message(EventData.from_dict(event))
except (KeyError, ValueError, NameError):
_LOGGER.error("Could not validate event", exc_info=True)

Expand Down
48 changes: 48 additions & 0 deletions pyisyox/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import suppress
from dataclasses import InitVar, dataclass, field
from datetime import datetime
import inspect
from typing import Any, Generic, TypeVar, cast

from pyisyox.constants import DEFAULT_PRECISION, DEFAULT_UNIT_OF_MEASURE
Expand All @@ -20,6 +21,18 @@
class EntityDetail:
"""Dataclass to hold entity detail info."""

@classmethod
def from_dict(cls: type[EntityDetailT], props: dict) -> EntityDetailT:
"""Create a dataclass from a dictionary.
Class method is used instead of keyword unpacking (**props) to prevent
breaking changes by new parameters being added in the future to the
API XML model.
"""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

parent: str | dict[str, str] | None = None


Expand All @@ -38,6 +51,13 @@ class EntityStatus(Generic[StatusT, EntityDetailT]):
class EventData:
"""Dataclass to represent the event data returned from the stream."""

@classmethod
def from_dict(cls: type[EventData], props: dict) -> EventData:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

seqnum: str = ""
sid: str = ""
control: str = ""
Expand All @@ -61,6 +81,13 @@ class NodeChangedEvent:
class NodeNotes:
"""Dataclass for holding node notes information."""

@classmethod
def from_dict(cls, props: dict) -> NodeNotes:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

spoken: str = ""
is_load: bool = False
description: str = ""
Expand All @@ -71,6 +98,13 @@ class NodeNotes:
class NodeProperty:
"""Class to hold result of a control event or node aux property."""

@classmethod
def from_dict(cls: type[NodeProperty], props: dict) -> NodeProperty:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

id: InitVar[str | None] = ""
control: str = ""
value: OptionalIntT | float = None
Expand All @@ -96,6 +130,13 @@ def __post_init__(self, id: str | None) -> None:
class ZWaveParameter:
"""Class to hold Z-Wave Parameter from a Z-Wave Node."""

@classmethod
def from_dict(cls: type[ZWaveParameter], props: dict) -> ZWaveParameter:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

param_num: int
size: int
value: int | str
Expand All @@ -112,6 +153,13 @@ def __post_init__(self) -> None:
class ZWaveProperties:
"""Class to hold Z-Wave Product Details from a Z-Wave Node."""

@classmethod
def from_dict(cls: type[ZWaveProperties], props: dict) -> ZWaveProperties:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

category: str = "0"
mfg: str = "0.0.0"
gen: str = "0.0.0"
Expand Down
3 changes: 3 additions & 0 deletions pyisyox/helpers/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def post_processor(path: str, key: str, value: Any) -> tuple[str, Any]:
elif key == ATTR_FLAG:
with suppress(ValueError):
value = int(cast(str, value))
elif key == "step":
with suppress(ValueError):
value = int(cast(str, value))
# Convert known dates
if (key.endswith("_time") or key == "ts") and value is not None:
with suppress(ValueError):
Expand Down
14 changes: 11 additions & 3 deletions pyisyox/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
import inspect
from typing import TYPE_CHECKING, Any

from pyisyox.constants import TAG_ID, TAG_NAME, URL_NETWORK, URL_RESOURCES, Protocol
Expand All @@ -21,6 +22,13 @@
class NetworkCommandDetail(EntityDetail):
"""Dataclass to hold entity detail info."""

@classmethod
def from_dict(cls: type[NetworkCommandDetail], props: dict) -> NetworkCommandDetail:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

control_info: dict[str, str | bool] = field(default_factory=dict)
id: str = ""
is_modified: bool = False
Expand All @@ -40,8 +48,8 @@ def __init__(self, isy: ISY) -> None:

def parse(self, xml_dict: dict[str, Any]) -> None:
"""Parse the results from the ISY."""
if not (net_config := xml_dict["net_config"]) or not (
features := net_config["net_rule"]
if not (net_config := xml_dict.get("net_config")) or not (
features := net_config.get("net_rule")
):
return
for feature in features:
Expand All @@ -55,7 +63,7 @@ def parse_entity(self, feature: dict[str, Any]) -> None:
address = feature[TAG_ID]
name = feature[TAG_NAME]
_LOGGER.debug("Parsing %s: %s (%s)", PLATFORM, name, address)
detail = NetworkCommandDetail(**feature)
detail = NetworkCommandDetail.from_dict(feature)
entity = NetworkCommand(self, address, name, detail)
self.add_or_update_entity(address, name, entity)
except (TypeError, KeyError, ValueError) as exc:
Expand Down
69 changes: 54 additions & 15 deletions pyisyox/node_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
from dataclasses import InitVar, asdict, dataclass, field
import inspect
import json
import re
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -45,9 +46,17 @@
class EditorRange:
"""Node Server Editor Range definition."""

@classmethod
def from_dict(cls, props: dict) -> EditorRange:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

uom: str = ""
min: str = ""
max: str = ""
step: int = 0
precision: int = 0
subset: str = ""
nls: str = ""
Expand All @@ -57,6 +66,13 @@ class EditorRange:
class NodeEditor:
"""Node Server Editor definition."""

@classmethod
def from_dict(cls, props: dict) -> NodeEditor:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

editor_id: str = ""
# Ranges are a dict with UoM as the key
ranges: dict[str, EditorRange] = field(default_factory=dict)
Expand All @@ -69,6 +85,13 @@ class NodeEditor:
class NodeServerConnection:
"""Node Server Connection details."""

@classmethod
def from_dict(cls, props: dict) -> NodeServerConnection:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

profile: str = ""
type_: str = ""
enabled: bool = False
Expand All @@ -92,8 +115,15 @@ def configuration_url(self) -> str:
class NodeDef:
"""Node Server Node Definition parsed from the ISY/IoX."""

@classmethod
def from_dict(cls, props: dict) -> NodeDef:
"""Create a dataclass from a dictionary."""
return cls(
**{k: v for k, v in props.items() if k in inspect.signature(cls).parameters}
)

sts: InitVar[dict[str, list | dict]]
cmds: InitVar[dict[str, Any]]
cmds: InitVar[dict[str, Any]] | None
id: str = ""
node_type: str = ""
name: str = ""
Expand All @@ -106,7 +136,9 @@ class NodeDef:
sends: dict[str, Any] = field(init=False, default_factory=dict)
accepts: dict[str, Any] = field(init=False, default_factory=dict)

def __post_init__(self, sts: dict[str, list | dict], cmds: dict[str, Any]) -> None:
def __post_init__(
self, sts: dict[str, list | dict], cmds: dict[str, Any] | None
) -> None:
"""Post-process node server definition."""
statuses = {}
if sts:
Expand All @@ -116,13 +148,15 @@ def __post_init__(self, sts: dict[str, list | dict], cmds: dict[str, Any]) -> No
statuses.update({st[ATTR_ID]: st[ATTR_EDITOR]})
self.statuses = statuses

if cmds_sends := cmds[ATTR_SENDS]:
if isinstance((cmd_list := cmds_sends[ATTR_CMD]), dict):
if cmds is None:
return
if cmds_sends := cmds.get(ATTR_SENDS):
if isinstance((cmd_list := cmds_sends.get(ATTR_CMD)), dict):
cmd_list = [cmd_list]
self.sends = {i[ATTR_ID]: i for i in cmd_list}

if cmds_accepts := cmds[ATTR_ACCEPTS]:
if isinstance((cmd_list := cmds_accepts[ATTR_CMD]), dict):
if cmds_accepts := cmds.get(ATTR_ACCEPTS):
if isinstance((cmd_list := cmds_accepts.get(ATTR_CMD)), dict):
cmd_list = [cmd_list]
self.accepts = {i[ATTR_ID]: i for i in cmd_list}

Expand Down Expand Up @@ -204,7 +238,7 @@ async def get_connection_info(self) -> None:
def parse_connection(self, conn: dict) -> None:
"""Parse the node server connection files from the ISY."""
try:
self._connections.append(NodeServerConnection(**conn))
self._connections.append(NodeServerConnection.from_dict(conn))
except (ValueError, KeyError, NameError) as exc:
_LOGGER.error("Could not parse node server connection: %s", exc)
return
Expand Down Expand Up @@ -316,8 +350,13 @@ async def parse_node_server_file(self, path: str, file_content: str) -> None:
if not line.startswith("#") and line != ""
]
if nls_list:
nls_lookup = dict(re.split(r"\s?=\s?", line) for line in nls_list)
self._node_server_nls[slot] = nls_lookup
try:
nls_lookup = dict(re.split(r"\s+=\s+", line) for line in nls_list)
self._node_server_nls[slot] = nls_lookup
except ValueError:
_LOGGER.error(
"Error parsing language file for node server slot %s, invalid format"
)

if self.isy.args and self.isy.args.file:
filename = "-".join(path.split("/")[-2:]).replace(".txt", ".yaml")
Expand All @@ -336,12 +375,12 @@ async def parse_node_server_file(self, path: str, file_content: str) -> None:
def parse_node_server_defs(self, slot: str, node_def: dict) -> None:
"""Retrieve and parse the node server definitions."""
try:
self._node_server_node_definitions[slot][node_def[ATTR_ID]] = NodeDef(
**node_def
)
self._node_server_node_definitions[slot][
node_def[ATTR_ID]
] = NodeDef.from_dict(node_def)

except (ValueError, KeyError, NameError) as exc:
_LOGGER.error("Could not parse node server connection: %s", exc)
_LOGGER.error("Could not parse node server definition: %s", exc)
return

def parse_node_server_editor(self, slot: str, editor: dict) -> None:
Expand All @@ -351,7 +390,7 @@ def parse_node_server_editor(self, slot: str, editor: dict) -> None:
ranges = [ranges]
editor_ranges = {}
for rng in ranges:
editor_ranges[rng["uom"]] = EditorRange(**rng)
editor_ranges[rng["uom"]] = EditorRange.from_dict(rng)

self._node_server_node_editors[slot][editor_id] = NodeEditor(
editor_id=editor_id,
Expand All @@ -375,7 +414,7 @@ def parse_nls_info_for_slot(self, slot: str) -> None:
editor.values = {
int(k.replace(f"{index_range.nls}-", "")): v
for k, v in nls.items()
if k.startswith(index_range.nls)
if k.startswith(f"{index_range.nls}-")
}

if not (node_defs := self._node_server_node_definitions.get(slot)):
Expand Down
10 changes: 6 additions & 4 deletions pyisyox/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def parse_folder_entity(self, feature: dict[str, Any]) -> None:
address = feature[TAG_ADDRESS]
name = feature[TAG_NAME]
_LOGGER.log(LOG_VERBOSE, "Parsing %s: %s (%s)", PLATFORM, name, address)
entity = NodeFolder(self, address, name, NodeFolderDetail(**feature))
entity = NodeFolder(
self, address, name, NodeFolderDetail.from_dict(feature)
)
self.add_or_update_entity(address, name, entity)
except (TypeError, KeyError, ValueError) as exc:
_LOGGER.exception("Error loading %s: %s", PLATFORM, exc)
Expand All @@ -133,7 +135,7 @@ def parse_node_entity(self, feature: dict[str, Any]) -> None:
feature["node_server"] = family.get("instance", "")
feature["protocol"] = self.get_protocol_from_family(family)

entity = Node(self, address, name, NodeDetail(**feature))
entity = Node(self, address, name, NodeDetail.from_dict(feature))
self.add_or_update_entity(address, name, entity)
except (TypeError, KeyError, ValueError) as exc:
_LOGGER.exception("Error loading %s: %s", PLATFORM, exc)
Expand All @@ -147,7 +149,7 @@ def parse_group_entity(self, feature: dict[str, Any]) -> None:
if (flag := feature["flag"]) & NodeFlag.ROOT:
_LOGGER.debug("Skipping root group flag=%s %s", flag, address)
return
entity = Group(self, address, name, GroupDetail(**feature))
entity = Group(self, address, name, GroupDetail.from_dict(feature))
self.add_or_update_entity(address, name, entity)
except (TypeError, KeyError, ValueError) as exc:
_LOGGER.exception("Error loading %s: %s", PLATFORM, exc)
Expand Down Expand Up @@ -226,7 +228,7 @@ def parse_node_status(self, status: dict[str, Any]) -> None:

def parse_node_properties(self, prop: dict[str, Any], entity: Node) -> None:
"""Parse the node node property from the ISY."""
result = NodeProperty(**prop)
result = NodeProperty.from_dict(prop)
if result.control == PROP_STATUS:
entity.update_state(result)
if result.control == PROP_BATTERY_LEVEL and not entity.state_set:
Expand Down
6 changes: 3 additions & 3 deletions pyisyox/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class NodeDetail(NodeBaseDetail):
def __post_init__(self) -> None:
"""Post-initialization of Node detail dataclass."""
if self.devtype:
self.zwave_props = ZWaveProperties(**self.devtype)
self.zwave_props = ZWaveProperties.from_dict(self.devtype)


class Node(NodeBase, Entity[NodeDetail, StatusT]):
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
if detail.property and PROP_STATUS in detail.property:
self.state_set = True
self._is_battery_node = False
self.update_state(NodeProperty(**detail.property))
self.update_state(NodeProperty.from_dict(detail.property))

@property
def formatted(self) -> str:
Expand Down Expand Up @@ -332,7 +332,7 @@ async def get_zwave_parameter(self, parameter: int) -> ZWaveParameter | None:
_LOGGER.warning("Error fetching parameter from ISY")
return None

result = ZWaveParameter(**config)
result = ZWaveParameter.from_dict(config)

# Add/update the aux_properties to include the parameter.
node_prop = NodeProperty(
Expand Down
2 changes: 1 addition & 1 deletion pyisyox/nodes/nodebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def get_notes(self) -> None:
if not (notes := notes_dict.get("node_properties")):
return

self.notes = NodeNotes(**cast(dict, notes))
self.notes = NodeNotes.from_dict(cast(dict, notes))

async def send_cmd(
self,
Expand Down
Loading

0 comments on commit f7c65c8

Please sign in to comment.