Skip to content

Commit

Permalink
Add support for additional cloud credentials based on region (#181)
Browse files Browse the repository at this point in the history
* Add additional cloud credentials and default to US

* Update unit tests and use a constant for default region
  • Loading branch information
mill1000 authored Dec 11, 2024
1 parent 45828d2 commit 7d7aaba
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 49 deletions.
35 changes: 21 additions & 14 deletions msmart/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from msmart import __version__
from msmart.cloud import Cloud, CloudError
from msmart.const import OPEN_MIDEA_APP_ACCOUNT, OPEN_MIDEA_APP_PASSWORD
from msmart.const import CLOUD_CREDENTIALS, DEFAULT_CLOUD_REGION
from msmart.device import AirConditioner as AC
from msmart.discover import Discover
from msmart.lan import AuthenticationError
Expand All @@ -15,16 +15,19 @@
_LOGGER = logging.getLogger(__name__)


DEFAULT_CLOUD_ACCOUNT, DEFAULT_CLOUD_PASSWORD = CLOUD_CREDENTIALS[DEFAULT_CLOUD_REGION]


async def _discover(args) -> None:
"""Discover Midea devices and print configuration information."""

devices = []
if args.host is None:
_LOGGER.info("Discovering all devices on local network.")
devices = await Discover.discover(account=args.account, password=args.password, discovery_packets=args.count)
devices = await Discover.discover(region=args.region, account=args.account, password=args.password, discovery_packets=args.count)
else:
_LOGGER.info("Discovering %s on local network.", args.host)
dev = await Discover.discover_single(args.host, account=args.account, password=args.password, discovery_packets=args.count)
dev = await Discover.discover_single(args.host, region=args.region, account=args.account, password=args.password, discovery_packets=args.count)
if dev:
devices.append(dev)

Expand Down Expand Up @@ -52,7 +55,7 @@ async def _connect(args) -> AC:
if args.auto:
# Use discovery to automatically connect and authenticate with device
_LOGGER.info("Discovering %s on local network.", args.host)
device = await Discover.discover_single(args.host, account=args.account, password=args.password)
device = await Discover.discover_single(args.host, region=args.region, account=args.account, password=args.password)

if device is None:
_LOGGER.error("Device not found.")
Expand Down Expand Up @@ -216,7 +219,7 @@ async def _download(args) -> None:

# Use discovery to to find device information
_LOGGER.info("Discovering %s on local network.", args.host)
device = await Discover.discover_single(args.host, account=args.account, password=args.password, auto_connect=False)
device = await Discover.discover_single(args.host, region=args.region, account=args.account, password=args.password, auto_connect=False)

if device is None:
_LOGGER.error("Device not found.")
Expand All @@ -232,7 +235,7 @@ async def _download(args) -> None:
exit(1)

# Get cloud connection
cloud = Cloud(args.account, args.password)
cloud = Cloud(args.region, account=args.account, password=args.password)
try:
await cloud.login()
except CloudError as e:
Expand Down Expand Up @@ -270,7 +273,7 @@ def _run(args) -> NoReturn:
logging.getLogger("httpcore").setLevel(logging.WARNING)

# Validate common arguments
if args.china and (args.account == OPEN_MIDEA_APP_ACCOUNT or args.password == OPEN_MIDEA_APP_PASSWORD):
if args.china and (args.account is None or args.password is None):
_LOGGER.error(
"Account (phone number) and password of 美的美居 is required to use --china option.")
exit(1)
Expand Down Expand Up @@ -299,14 +302,18 @@ def main() -> NoReturn:
common_parser = argparse.ArgumentParser(add_help=False)
common_parser.add_argument("-d", "--debug",
help="Enable debug logging.", action="store_true")
common_parser.add_argument("--region",
help="Country/region for built-in cloud credential selection.",
choices=CLOUD_CREDENTIALS.keys(),
default=DEFAULT_CLOUD_REGION)
common_parser.add_argument("--account",
help="MSmartHome or 美的美居 username for discovery and automatic authentication",
default=OPEN_MIDEA_APP_ACCOUNT)
help="Manually specify a MSmart username for cloud authentication.",
default=None)
common_parser.add_argument("--password",
help="MSmartHome or 美的美居 password for discovery and automatic authentication.",
default=OPEN_MIDEA_APP_PASSWORD)
help="Manually specify a MSmart password for cloud authentication.",
default=None)
common_parser.add_argument("--china",
help="Use China server for discovery and automatic authentication.",
help="Use China server for discovery and authentication. Username and password must be specified.",
action="store_true")

# Setup discover parser
Expand Down Expand Up @@ -404,9 +411,9 @@ async def _wrap_discover(args) -> None:
parser.add_argument(
"-d", "--debug", help="Enable debug logging.", action="store_true")
parser.add_argument(
"-a", "--account", help="MSmartHome or 美的美居 account username.", default=OPEN_MIDEA_APP_ACCOUNT)
"-a", "--account", help="MSmartHome or 美的美居 account username.", default=DEFAULT_CLOUD_ACCOUNT)
parser.add_argument(
"-p", "--password", help="MSmartHome or 美的美居 account password.", default=OPEN_MIDEA_APP_PASSWORD)
"-p", "--password", help="MSmartHome or 美的美居 account password.", default=DEFAULT_CLOUD_PASSWORD)
parser.add_argument(
"-i", "--ip", help="IP address of a device. Useful if broadcasts don't work, or to query a single device.")
parser.add_argument(
Expand Down
24 changes: 19 additions & 5 deletions msmart/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from Crypto.Cipher import AES
from Crypto.Util import Padding

from msmart.const import DeviceType
from msmart.const import CLOUD_CREDENTIALS, DEFAULT_CLOUD_REGION, DeviceType

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,14 +54,28 @@ class Cloud:
# Default number of request retries
RETRIES = 3

def __init__(self, account: str, password: str,
use_china_server: bool = False) -> None:
def __init__(self,
region: str = DEFAULT_CLOUD_REGION,
*,
account: Optional[str] = None,
password: Optional[str] = None,
use_china_server: bool = False
) -> None:
# Allow override Chia server from environment
if os.getenv("MIDEA_CHINA_SERVER", "0") == "1":
use_china_server = True

self._account = account
self._password = password
# Validate incoming credentials and region
if account and password:
self._account = account
self._password = password
elif account or password:
raise ValueError("Account and password must be specified.")
else:
try:
self._account, self._password = CLOUD_CREDENTIALS[region]
except KeyError:
raise ValueError(f"Unknown cloud region '{region}'.")

# Attributes that holds the login information of the current user
self._login_id = None
Expand Down
9 changes: 6 additions & 3 deletions msmart/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
0xb7, 0xe4, 0x2d, 0x53, 0x49, 0x47, 0x62, 0xbe
])


OPEN_MIDEA_APP_ACCOUNT = "[email protected]"
OPEN_MIDEA_APP_PASSWORD = "this_is_a_password1"
DEFAULT_CLOUD_REGION = "US"
CLOUD_CREDENTIALS = {
"DE": ("[email protected]", "das_ist_passwort1"),
"KR": ("[email protected]", "password_for_sea1"),
"US": ("[email protected]", "this_is_a_password1")
}


class DeviceType(IntEnum):
Expand Down
36 changes: 15 additions & 21 deletions msmart/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import Any, Optional, Type, cast

from msmart.cloud import Cloud, CloudError
from msmart.const import (DEVICE_INFO_MSG, DISCOVERY_MSG,
OPEN_MIDEA_APP_ACCOUNT, OPEN_MIDEA_APP_PASSWORD,
from msmart.const import (DEFAULT_CLOUD_REGION, DEVICE_INFO_MSG, DISCOVERY_MSG,
DeviceType)
from msmart.device import AirConditioner, Device
from msmart.lan import AuthenticationError, Security
Expand Down Expand Up @@ -135,8 +134,9 @@ def connection_lost(self, exc) -> None:
class Discover:
"""Discover Midea smart devices on the local network."""

_account = OPEN_MIDEA_APP_ACCOUNT
_password = OPEN_MIDEA_APP_PASSWORD
_region = DEFAULT_CLOUD_REGION
_account = None
_password = None
_lock = None
_cloud = None
_auto_connect = False
Expand All @@ -147,11 +147,12 @@ async def discover(
*,
target=_IPV4_BROADCAST,
timeout=5,
discovery_packets=3,
discovery_packets: int = 3,
interface=None,
account=None,
password=None,
auto_connect=True
region: str = DEFAULT_CLOUD_REGION,
account: Optional[str] = None,
password: Optional[str] = None,
auto_connect: bool = True
) -> list[Device]:
"""Discover devices via broadcast."""

Expand All @@ -162,8 +163,10 @@ async def discover(
# Always use a new cloud connection
cls._cloud = None

# Save cloud credentials
Discover._set_cloud_credentials(account, password)
# Save cloud region and credentials
cls._region = region
cls._account = account
cls._password = password

# Save auto connect arg
cls._auto_connect = auto_connect
Expand Down Expand Up @@ -212,16 +215,6 @@ async def discover_single(

return None

@classmethod
def _set_cloud_credentials(cls, account, password) -> None:
"""Set credentials for cloud access."""

if account and password:
cls._account = account
cls._password = password
elif account or password:
raise ValueError("Both account and password must be specified.")

@classmethod
async def _get_cloud(cls) -> Optional[Cloud]:
"""Return a cloud connection, creating it if necessary."""
Expand All @@ -232,7 +225,8 @@ async def _get_cloud(cls) -> Optional[Cloud]:
async with cls._lock:
# Create cloud connection if nonexistent
if cls._cloud is None:
cloud = Cloud(cls._account, cls._password)
cloud = Cloud(cls._region, account=cls._account,
password=cls._password)
try:
await cloud.login()
cls._cloud = cloud
Expand Down
33 changes: 27 additions & 6 deletions msmart/tests/test_cloud.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import unittest
from typing import Any, Optional

from msmart.cloud import ApiError, Cloud, CloudError
from msmart.const import OPEN_MIDEA_APP_ACCOUNT, OPEN_MIDEA_APP_PASSWORD
from msmart.const import DEFAULT_CLOUD_REGION


class TestCloud(unittest.IsolatedAsyncioTestCase):
# pylint: disable=protected-access

async def _login(self, account: str = OPEN_MIDEA_APP_ACCOUNT,
password: str = OPEN_MIDEA_APP_PASSWORD) -> Cloud:
client = Cloud(account, password)
async def _login(self,
region: str = DEFAULT_CLOUD_REGION,
*,
account: Optional[str] = None,
password: Optional[str] = None
) -> Cloud:
client = Cloud(region, account=account, password=password)
await client.login()

return client
Expand All @@ -23,11 +28,27 @@ async def test_login(self) -> None:
self.assertIsNotNone(client._access_token)

async def test_login_exception(self) -> None:
"""Test that we can login to the cloud."""
"""Test that bad credentials raise an exception."""

with self.assertRaises(ApiError):
await self._login(account="[email protected]", password="not_a_password")

async def test_invalid_region(self) -> None:
"""Test that an invalid region raise an exception."""

with self.assertRaises(ValueError):
await self._login("NOT_A_REGION")

async def test_invalid_credentials(self) -> None:
"""Test that invalid credentials raise an exception."""

# Check that specifying only an account or password raises an error
with self.assertRaises(ValueError):
await self._login(account=None, password="some_password")

with self.assertRaises(ValueError):
await self._login(account="some_account", password=None)

async def test_get_token(self) -> None:
"""Test that a token and key can be obtained from the cloud."""

Expand All @@ -53,7 +74,7 @@ async def test_get_token_exception(self) -> None:
async def test_connect_exception(self) -> None:
"""Test that an exception is thrown when the cloud connection fails."""

client = Cloud(OPEN_MIDEA_APP_ACCOUNT, OPEN_MIDEA_APP_PASSWORD)
client = Cloud(DEFAULT_CLOUD_REGION)

# Override URL to an invalid domain
client._base_url = "https://fake_server.invalid."
Expand Down

0 comments on commit 7d7aaba

Please sign in to comment.