roborock.devices.device_manager

Module for discovering Roborock devices.

  1"""Module for discovering Roborock devices."""
  2
  3import asyncio
  4import enum
  5import logging
  6from collections.abc import Callable, Mapping
  7from dataclasses import dataclass
  8from typing import Any
  9
 10import aiohttp
 11
 12from roborock.data import (
 13    HomeData,
 14    HomeDataDevice,
 15    HomeDataProduct,
 16    UserData,
 17)
 18from roborock.devices.device import DeviceReadyCallback, RoborockDevice
 19from roborock.diagnostics import Diagnostics
 20from roborock.exceptions import RoborockException
 21from roborock.map.map_parser import MapParserConfig
 22from roborock.mqtt.roborock_session import create_lazy_mqtt_session
 23from roborock.mqtt.session import MqttSession, SessionUnauthorizedHook
 24from roborock.protocol import create_mqtt_params
 25from roborock.web_api import RoborockApiClient, UserWebApiClient
 26
 27from .cache import Cache, DeviceCache, NoCache
 28from .channel import Channel
 29from .mqtt_channel import create_mqtt_channel
 30from .traits import Trait, a01, b01, v1
 31from .v1_channel import create_v1_channel
 32
 33_LOGGER = logging.getLogger(__name__)
 34
 35__all__ = [
 36    "create_device_manager",
 37    "UserParams",
 38    "DeviceManager",
 39]
 40
 41
 42DeviceCreator = Callable[[HomeData, HomeDataDevice, HomeDataProduct], RoborockDevice]
 43
 44
 45class DeviceVersion(enum.StrEnum):
 46    """Enum for device versions."""
 47
 48    V1 = "1.0"
 49    A01 = "A01"
 50    B01 = "B01"
 51    UNKNOWN = "unknown"
 52
 53
 54class UnsupportedDeviceError(RoborockException):
 55    """Exception raised when a device is unsupported."""
 56
 57
 58class DeviceManager:
 59    """Central manager for Roborock device discovery and connections."""
 60
 61    def __init__(
 62        self,
 63        web_api: UserWebApiClient,
 64        device_creator: DeviceCreator,
 65        mqtt_session: MqttSession,
 66        cache: Cache,
 67        diagnostics: Diagnostics,
 68    ) -> None:
 69        """Initialize the DeviceManager with user data and optional cache storage.
 70
 71        This takes ownership of the MQTT session and will close it when the manager is closed.
 72        """
 73        self._web_api = web_api
 74        self._cache = cache
 75        self._device_creator = device_creator
 76        self._devices: dict[str, RoborockDevice] = {}
 77        self._mqtt_session = mqtt_session
 78        self._diagnostics = diagnostics
 79
 80    async def discover_devices(self, prefer_cache: bool = True) -> list[RoborockDevice]:
 81        """Discover all devices for the logged-in user."""
 82        self._diagnostics.increment("discover_devices")
 83        cache_data = await self._cache.get()
 84        if not cache_data.home_data or not prefer_cache:
 85            _LOGGER.debug("Fetching home data (prefer_cache=%s)", prefer_cache)
 86            self._diagnostics.increment("fetch_home_data")
 87            try:
 88                cache_data.home_data = await self._web_api.get_home_data()
 89            except RoborockException as ex:
 90                if not cache_data.home_data:
 91                    raise
 92                _LOGGER.debug("Failed to fetch home data, using cached data: %s", ex)
 93            await self._cache.set(cache_data)
 94        home_data = cache_data.home_data
 95
 96        device_products = home_data.device_products
 97        _LOGGER.debug("Discovered %d devices", len(device_products))
 98
 99        # These are connected serially to avoid overwhelming the MQTT broker
100        new_devices = {}
101        start_tasks = []
102        supported_devices_counter = self._diagnostics.subkey("supported_devices")
103        unsupported_devices_counter = self._diagnostics.subkey("unsupported_devices")
104        for duid, (device, product) in device_products.items():
105            _LOGGER.debug("[%s] Discovered device %s %s", duid, product.summary_info(), device.summary_info())
106            if duid in self._devices:
107                continue
108            try:
109                new_device = self._device_creator(home_data, device, product)
110            except UnsupportedDeviceError:
111                _LOGGER.info("Skipping unsupported device %s %s", product.summary_info(), device.summary_info())
112                unsupported_devices_counter.increment(device.pv or "unknown")
113                continue
114            supported_devices_counter.increment(device.pv or "unknown")
115            start_tasks.append(new_device.start_connect())
116            new_devices[duid] = new_device
117
118        self._devices.update(new_devices)
119        await asyncio.gather(*start_tasks)
120        return list(self._devices.values())
121
122    async def get_device(self, duid: str) -> RoborockDevice | None:
123        """Get a specific device by DUID."""
124        return self._devices.get(duid)
125
126    async def get_devices(self) -> list[RoborockDevice]:
127        """Get all discovered devices."""
128        return list(self._devices.values())
129
130    async def close(self) -> None:
131        """Close all MQTT connections and clean up resources."""
132        tasks = [device.close() for device in self._devices.values()]
133        self._devices.clear()
134        tasks.append(self._mqtt_session.close())
135        await asyncio.gather(*tasks)
136
137    def diagnostic_data(self) -> Mapping[str, Any]:
138        """Return diagnostics information about the device manager."""
139        return self._diagnostics.as_dict()
140
141
142@dataclass
143class UserParams:
144    """Parameters for creating a new session with Roborock devices.
145
146    These parameters include the username, user data for authentication,
147    and an optional base URL for the Roborock API. The `user_data` and `base_url`
148    parameters are obtained from `RoborockApiClient` during the login process.
149    """
150
151    username: str
152    """The username (email) used for logging in."""
153
154    user_data: UserData
155    """This is the user data containing authentication information."""
156
157    base_url: str | None = None
158    """Optional base URL for the Roborock API.
159
160    This is used to speed up connection times by avoiding the need to
161    discover the API base URL each time. If not provided, the API client
162    will attempt to discover it automatically which may take multiple requests.
163    """
164
165
166def create_web_api_wrapper(
167    user_params: UserParams,
168    *,
169    cache: Cache | None = None,
170    session: aiohttp.ClientSession | None = None,
171) -> UserWebApiClient:
172    """Create a home data API wrapper from an existing API client."""
173
174    # Note: This will auto discover the API base URL. This can be improved
175    # by caching this next to `UserData` if needed to avoid unnecessary API calls.
176    client = RoborockApiClient(username=user_params.username, base_url=user_params.base_url, session=session)
177
178    return UserWebApiClient(client, user_params.user_data)
179
180
181async def create_device_manager(
182    user_params: UserParams,
183    *,
184    cache: Cache | None = None,
185    map_parser_config: MapParserConfig | None = None,
186    session: aiohttp.ClientSession | None = None,
187    ready_callback: DeviceReadyCallback | None = None,
188    mqtt_session_unauthorized_hook: SessionUnauthorizedHook | None = None,
189) -> DeviceManager:
190    """Convenience function to create and initialize a DeviceManager.
191
192    Args:
193        user_params: Parameters for creating the user session.
194        cache: Optional cache implementation to use for caching device data.
195        map_parser_config: Optional configuration for parsing maps.
196        session: Optional aiohttp ClientSession to use for HTTP requests.
197        ready_callback: Optional callback to be notified when a device is ready.
198        mqtt_session_unauthorized_hook: Optional hook for MQTT session unauthorized
199          events which may indicate rate limiting or revoked credentials. The
200          caller may use this to refresh authentication tokens as needed.
201
202    Returns:
203        An initialized DeviceManager with discovered devices.
204    """
205    if cache is None:
206        cache = NoCache()
207
208    web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
209    user_data = user_params.user_data
210
211    diagnostics = Diagnostics()
212
213    mqtt_params = create_mqtt_params(user_data.rriot)
214    mqtt_params.diagnostics = diagnostics.subkey("mqtt_session")
215    mqtt_params.unauthorized_hook = mqtt_session_unauthorized_hook
216    mqtt_session = await create_lazy_mqtt_session(mqtt_params)
217
218    def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
219        channel: Channel
220        trait: Trait
221        device_cache: DeviceCache = DeviceCache(device.duid, cache)
222        match device.pv:
223            case DeviceVersion.V1:
224                channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, device_cache)
225                trait = v1.create(
226                    device.duid,
227                    product,
228                    home_data,
229                    channel.rpc_channel,
230                    channel.mqtt_rpc_channel,
231                    channel.map_rpc_channel,
232                    web_api,
233                    device_cache=device_cache,
234                    map_parser_config=map_parser_config,
235                )
236            case DeviceVersion.A01:
237                channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
238                trait = a01.create(product, channel)
239            case DeviceVersion.B01:
240                channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
241                model_part = product.model.split(".")[-1]
242                if "ss" in model_part:
243                    raise UnsupportedDeviceError(
244                        f"Device {device.name} has unsupported version B01 product model {product.model}"
245                    )
246                elif "sc" in model_part:
247                    # Q7 devices start with 'sc' in their model naming.
248                    trait = b01.q7.create(channel)
249                else:
250                    raise UnsupportedDeviceError(f"Device {device.name} has unsupported B01 model: {product.model}")
251            case _:
252                raise UnsupportedDeviceError(
253                    f"Device {device.name} has unsupported version {device.pv} {product.model}"
254                )
255
256        dev = RoborockDevice(device, product, channel, trait)
257        if ready_callback:
258            dev.add_ready_callback(ready_callback)
259        return dev
260
261    manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache, diagnostics=diagnostics)
262    await manager.discover_devices()
263    return manager
async def create_device_manager( user_params: UserParams, *, cache: roborock.devices.cache.Cache | None = None, map_parser_config: roborock.map.MapParserConfig | None = None, session: aiohttp.client.ClientSession | None = None, ready_callback: Callable[roborock.devices.device.RoborockDevice, None] | None = None, mqtt_session_unauthorized_hook: Callable[[], None] | None = None) -> DeviceManager:
182async def create_device_manager(
183    user_params: UserParams,
184    *,
185    cache: Cache | None = None,
186    map_parser_config: MapParserConfig | None = None,
187    session: aiohttp.ClientSession | None = None,
188    ready_callback: DeviceReadyCallback | None = None,
189    mqtt_session_unauthorized_hook: SessionUnauthorizedHook | None = None,
190) -> DeviceManager:
191    """Convenience function to create and initialize a DeviceManager.
192
193    Args:
194        user_params: Parameters for creating the user session.
195        cache: Optional cache implementation to use for caching device data.
196        map_parser_config: Optional configuration for parsing maps.
197        session: Optional aiohttp ClientSession to use for HTTP requests.
198        ready_callback: Optional callback to be notified when a device is ready.
199        mqtt_session_unauthorized_hook: Optional hook for MQTT session unauthorized
200          events which may indicate rate limiting or revoked credentials. The
201          caller may use this to refresh authentication tokens as needed.
202
203    Returns:
204        An initialized DeviceManager with discovered devices.
205    """
206    if cache is None:
207        cache = NoCache()
208
209    web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
210    user_data = user_params.user_data
211
212    diagnostics = Diagnostics()
213
214    mqtt_params = create_mqtt_params(user_data.rriot)
215    mqtt_params.diagnostics = diagnostics.subkey("mqtt_session")
216    mqtt_params.unauthorized_hook = mqtt_session_unauthorized_hook
217    mqtt_session = await create_lazy_mqtt_session(mqtt_params)
218
219    def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
220        channel: Channel
221        trait: Trait
222        device_cache: DeviceCache = DeviceCache(device.duid, cache)
223        match device.pv:
224            case DeviceVersion.V1:
225                channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, device_cache)
226                trait = v1.create(
227                    device.duid,
228                    product,
229                    home_data,
230                    channel.rpc_channel,
231                    channel.mqtt_rpc_channel,
232                    channel.map_rpc_channel,
233                    web_api,
234                    device_cache=device_cache,
235                    map_parser_config=map_parser_config,
236                )
237            case DeviceVersion.A01:
238                channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
239                trait = a01.create(product, channel)
240            case DeviceVersion.B01:
241                channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
242                model_part = product.model.split(".")[-1]
243                if "ss" in model_part:
244                    raise UnsupportedDeviceError(
245                        f"Device {device.name} has unsupported version B01 product model {product.model}"
246                    )
247                elif "sc" in model_part:
248                    # Q7 devices start with 'sc' in their model naming.
249                    trait = b01.q7.create(channel)
250                else:
251                    raise UnsupportedDeviceError(f"Device {device.name} has unsupported B01 model: {product.model}")
252            case _:
253                raise UnsupportedDeviceError(
254                    f"Device {device.name} has unsupported version {device.pv} {product.model}"
255                )
256
257        dev = RoborockDevice(device, product, channel, trait)
258        if ready_callback:
259            dev.add_ready_callback(ready_callback)
260        return dev
261
262    manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache, diagnostics=diagnostics)
263    await manager.discover_devices()
264    return manager

Convenience function to create and initialize a DeviceManager.

Args: user_params: Parameters for creating the user session. cache: Optional cache implementation to use for caching device data. map_parser_config: Optional configuration for parsing maps. session: Optional aiohttp ClientSession to use for HTTP requests. ready_callback: Optional callback to be notified when a device is ready. mqtt_session_unauthorized_hook: Optional hook for MQTT session unauthorized events which may indicate rate limiting or revoked credentials. The caller may use this to refresh authentication tokens as needed.

Returns: An initialized DeviceManager with discovered devices.

@dataclass
class UserParams:
143@dataclass
144class UserParams:
145    """Parameters for creating a new session with Roborock devices.
146
147    These parameters include the username, user data for authentication,
148    and an optional base URL for the Roborock API. The `user_data` and `base_url`
149    parameters are obtained from `RoborockApiClient` during the login process.
150    """
151
152    username: str
153    """The username (email) used for logging in."""
154
155    user_data: UserData
156    """This is the user data containing authentication information."""
157
158    base_url: str | None = None
159    """Optional base URL for the Roborock API.
160
161    This is used to speed up connection times by avoiding the need to
162    discover the API base URL each time. If not provided, the API client
163    will attempt to discover it automatically which may take multiple requests.
164    """

Parameters for creating a new session with Roborock devices.

These parameters include the username, user data for authentication, and an optional base URL for the Roborock API. The user_data and base_url parameters are obtained from RoborockApiClient during the login process.

UserParams( username: str, user_data: roborock.data.containers.UserData, base_url: str | None = None)
username: str

The username (email) used for logging in.

This is the user data containing authentication information.

base_url: str | None = None

Optional base URL for the Roborock API.

This is used to speed up connection times by avoiding the need to discover the API base URL each time. If not provided, the API client will attempt to discover it automatically which may take multiple requests.

class DeviceManager:
 59class DeviceManager:
 60    """Central manager for Roborock device discovery and connections."""
 61
 62    def __init__(
 63        self,
 64        web_api: UserWebApiClient,
 65        device_creator: DeviceCreator,
 66        mqtt_session: MqttSession,
 67        cache: Cache,
 68        diagnostics: Diagnostics,
 69    ) -> None:
 70        """Initialize the DeviceManager with user data and optional cache storage.
 71
 72        This takes ownership of the MQTT session and will close it when the manager is closed.
 73        """
 74        self._web_api = web_api
 75        self._cache = cache
 76        self._device_creator = device_creator
 77        self._devices: dict[str, RoborockDevice] = {}
 78        self._mqtt_session = mqtt_session
 79        self._diagnostics = diagnostics
 80
 81    async def discover_devices(self, prefer_cache: bool = True) -> list[RoborockDevice]:
 82        """Discover all devices for the logged-in user."""
 83        self._diagnostics.increment("discover_devices")
 84        cache_data = await self._cache.get()
 85        if not cache_data.home_data or not prefer_cache:
 86            _LOGGER.debug("Fetching home data (prefer_cache=%s)", prefer_cache)
 87            self._diagnostics.increment("fetch_home_data")
 88            try:
 89                cache_data.home_data = await self._web_api.get_home_data()
 90            except RoborockException as ex:
 91                if not cache_data.home_data:
 92                    raise
 93                _LOGGER.debug("Failed to fetch home data, using cached data: %s", ex)
 94            await self._cache.set(cache_data)
 95        home_data = cache_data.home_data
 96
 97        device_products = home_data.device_products
 98        _LOGGER.debug("Discovered %d devices", len(device_products))
 99
100        # These are connected serially to avoid overwhelming the MQTT broker
101        new_devices = {}
102        start_tasks = []
103        supported_devices_counter = self._diagnostics.subkey("supported_devices")
104        unsupported_devices_counter = self._diagnostics.subkey("unsupported_devices")
105        for duid, (device, product) in device_products.items():
106            _LOGGER.debug("[%s] Discovered device %s %s", duid, product.summary_info(), device.summary_info())
107            if duid in self._devices:
108                continue
109            try:
110                new_device = self._device_creator(home_data, device, product)
111            except UnsupportedDeviceError:
112                _LOGGER.info("Skipping unsupported device %s %s", product.summary_info(), device.summary_info())
113                unsupported_devices_counter.increment(device.pv or "unknown")
114                continue
115            supported_devices_counter.increment(device.pv or "unknown")
116            start_tasks.append(new_device.start_connect())
117            new_devices[duid] = new_device
118
119        self._devices.update(new_devices)
120        await asyncio.gather(*start_tasks)
121        return list(self._devices.values())
122
123    async def get_device(self, duid: str) -> RoborockDevice | None:
124        """Get a specific device by DUID."""
125        return self._devices.get(duid)
126
127    async def get_devices(self) -> list[RoborockDevice]:
128        """Get all discovered devices."""
129        return list(self._devices.values())
130
131    async def close(self) -> None:
132        """Close all MQTT connections and clean up resources."""
133        tasks = [device.close() for device in self._devices.values()]
134        self._devices.clear()
135        tasks.append(self._mqtt_session.close())
136        await asyncio.gather(*tasks)
137
138    def diagnostic_data(self) -> Mapping[str, Any]:
139        """Return diagnostics information about the device manager."""
140        return self._diagnostics.as_dict()

Central manager for Roborock device discovery and connections.

DeviceManager( web_api: roborock.web_api.UserWebApiClient, device_creator: Callable[[roborock.data.containers.HomeData, roborock.data.containers.HomeDataDevice, roborock.data.containers.HomeDataProduct], roborock.devices.device.RoborockDevice], mqtt_session: roborock.mqtt.session.MqttSession, cache: roborock.devices.cache.Cache, diagnostics: roborock.diagnostics.Diagnostics)
62    def __init__(
63        self,
64        web_api: UserWebApiClient,
65        device_creator: DeviceCreator,
66        mqtt_session: MqttSession,
67        cache: Cache,
68        diagnostics: Diagnostics,
69    ) -> None:
70        """Initialize the DeviceManager with user data and optional cache storage.
71
72        This takes ownership of the MQTT session and will close it when the manager is closed.
73        """
74        self._web_api = web_api
75        self._cache = cache
76        self._device_creator = device_creator
77        self._devices: dict[str, RoborockDevice] = {}
78        self._mqtt_session = mqtt_session
79        self._diagnostics = diagnostics

Initialize the DeviceManager with user data and optional cache storage.

This takes ownership of the MQTT session and will close it when the manager is closed.

async def discover_devices( self, prefer_cache: bool = True) -> list[roborock.devices.device.RoborockDevice]:
 81    async def discover_devices(self, prefer_cache: bool = True) -> list[RoborockDevice]:
 82        """Discover all devices for the logged-in user."""
 83        self._diagnostics.increment("discover_devices")
 84        cache_data = await self._cache.get()
 85        if not cache_data.home_data or not prefer_cache:
 86            _LOGGER.debug("Fetching home data (prefer_cache=%s)", prefer_cache)
 87            self._diagnostics.increment("fetch_home_data")
 88            try:
 89                cache_data.home_data = await self._web_api.get_home_data()
 90            except RoborockException as ex:
 91                if not cache_data.home_data:
 92                    raise
 93                _LOGGER.debug("Failed to fetch home data, using cached data: %s", ex)
 94            await self._cache.set(cache_data)
 95        home_data = cache_data.home_data
 96
 97        device_products = home_data.device_products
 98        _LOGGER.debug("Discovered %d devices", len(device_products))
 99
100        # These are connected serially to avoid overwhelming the MQTT broker
101        new_devices = {}
102        start_tasks = []
103        supported_devices_counter = self._diagnostics.subkey("supported_devices")
104        unsupported_devices_counter = self._diagnostics.subkey("unsupported_devices")
105        for duid, (device, product) in device_products.items():
106            _LOGGER.debug("[%s] Discovered device %s %s", duid, product.summary_info(), device.summary_info())
107            if duid in self._devices:
108                continue
109            try:
110                new_device = self._device_creator(home_data, device, product)
111            except UnsupportedDeviceError:
112                _LOGGER.info("Skipping unsupported device %s %s", product.summary_info(), device.summary_info())
113                unsupported_devices_counter.increment(device.pv or "unknown")
114                continue
115            supported_devices_counter.increment(device.pv or "unknown")
116            start_tasks.append(new_device.start_connect())
117            new_devices[duid] = new_device
118
119        self._devices.update(new_devices)
120        await asyncio.gather(*start_tasks)
121        return list(self._devices.values())

Discover all devices for the logged-in user.

async def get_device(self, duid: str) -> roborock.devices.device.RoborockDevice | None:
123    async def get_device(self, duid: str) -> RoborockDevice | None:
124        """Get a specific device by DUID."""
125        return self._devices.get(duid)

Get a specific device by DUID.

async def get_devices(self) -> list[roborock.devices.device.RoborockDevice]:
127    async def get_devices(self) -> list[RoborockDevice]:
128        """Get all discovered devices."""
129        return list(self._devices.values())

Get all discovered devices.

async def close(self) -> None:
131    async def close(self) -> None:
132        """Close all MQTT connections and clean up resources."""
133        tasks = [device.close() for device in self._devices.values()]
134        self._devices.clear()
135        tasks.append(self._mqtt_session.close())
136        await asyncio.gather(*tasks)

Close all MQTT connections and clean up resources.

def diagnostic_data(self) -> Mapping[str, typing.Any]:
138    def diagnostic_data(self) -> Mapping[str, Any]:
139        """Return diagnostics information about the device manager."""
140        return self._diagnostics.as_dict()

Return diagnostics information about the device manager.