From 595c9a2e014683f06eeb5b9f575e95b47d0ff9ac Mon Sep 17 00:00:00 2001
From: Matrix <justin@yosmart.com>
Date: Mon, 3 Jun 2024 21:56:42 +0800
Subject: [PATCH] Fixing device model compatibility issues. (#118686)

---
 homeassistant/components/yolink/const.py  |  1 +
 homeassistant/components/yolink/switch.py | 36 +++++++++++++++--------
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/homeassistant/components/yolink/const.py b/homeassistant/components/yolink/const.py
index 110b9cb9810..e829fe08d32 100644
--- a/homeassistant/components/yolink/const.py
+++ b/homeassistant/components/yolink/const.py
@@ -16,3 +16,4 @@ YOLINK_EVENT = f"{DOMAIN}_event"
 YOLINK_OFFLINE_TIME = 32400
 
 DEV_MODEL_WATER_METER_YS5007 = "YS5007"
+DEV_MODEL_MULTI_OUTLET_YS6801 = "YS6801"
diff --git a/homeassistant/components/yolink/switch.py b/homeassistant/components/yolink/switch.py
index 7a24ec1bd13..2e31100bf3c 100644
--- a/homeassistant/components/yolink/switch.py
+++ b/homeassistant/components/yolink/switch.py
@@ -25,7 +25,7 @@ from homeassistant.config_entries import ConfigEntry
 from homeassistant.core import HomeAssistant, callback
 from homeassistant.helpers.entity_platform import AddEntitiesCallback
 
-from .const import DOMAIN
+from .const import DEV_MODEL_MULTI_OUTLET_YS6801, DOMAIN
 from .coordinator import YoLinkCoordinator
 from .entity import YoLinkEntity
 
@@ -35,7 +35,7 @@ class YoLinkSwitchEntityDescription(SwitchEntityDescription):
     """YoLink SwitchEntityDescription."""
 
     exists_fn: Callable[[YoLinkDevice], bool] = lambda _: True
-    plug_index: int | None = None
+    plug_index_fn: Callable[[YoLinkDevice], int | None] = lambda _: None
 
 
 DEVICE_TYPES: tuple[YoLinkSwitchEntityDescription, ...] = (
@@ -61,36 +61,43 @@ DEVICE_TYPES: tuple[YoLinkSwitchEntityDescription, ...] = (
         key="multi_outlet_usb_ports",
         translation_key="usb_ports",
         device_class=SwitchDeviceClass.OUTLET,
-        exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET,
-        plug_index=0,
+        exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET
+        and device.device_model_name.startswith(DEV_MODEL_MULTI_OUTLET_YS6801),
+        plug_index_fn=lambda _: 0,
     ),
     YoLinkSwitchEntityDescription(
         key="multi_outlet_plug_1",
         translation_key="plug_1",
         device_class=SwitchDeviceClass.OUTLET,
         exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET,
-        plug_index=1,
+        plug_index_fn=lambda device: 1
+        if device.device_model_name.startswith(DEV_MODEL_MULTI_OUTLET_YS6801)
+        else 0,
     ),
     YoLinkSwitchEntityDescription(
         key="multi_outlet_plug_2",
         translation_key="plug_2",
         device_class=SwitchDeviceClass.OUTLET,
         exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET,
-        plug_index=2,
+        plug_index_fn=lambda device: 2
+        if device.device_model_name.startswith(DEV_MODEL_MULTI_OUTLET_YS6801)
+        else 1,
     ),
     YoLinkSwitchEntityDescription(
         key="multi_outlet_plug_3",
         translation_key="plug_3",
         device_class=SwitchDeviceClass.OUTLET,
-        exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET,
-        plug_index=3,
+        exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET
+        and device.device_model_name.startswith(DEV_MODEL_MULTI_OUTLET_YS6801),
+        plug_index_fn=lambda _: 3,
     ),
     YoLinkSwitchEntityDescription(
         key="multi_outlet_plug_4",
         translation_key="plug_4",
         device_class=SwitchDeviceClass.OUTLET,
-        exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET,
-        plug_index=4,
+        exists_fn=lambda device: device.device_type == ATTR_DEVICE_MULTI_OUTLET
+        and device.device_model_name.startswith(DEV_MODEL_MULTI_OUTLET_YS6801),
+        plug_index_fn=lambda _: 4,
     ),
 )
 
@@ -152,7 +159,8 @@ class YoLinkSwitchEntity(YoLinkEntity, SwitchEntity):
     def update_entity_state(self, state: dict[str, str | list[str]]) -> None:
         """Update HA Entity State."""
         self._attr_is_on = self._get_state(
-            state.get("state"), self.entity_description.plug_index
+            state.get("state"),
+            self.entity_description.plug_index_fn(self.coordinator.device),
         )
         self.async_write_ha_state()
 
@@ -164,12 +172,14 @@ class YoLinkSwitchEntity(YoLinkEntity, SwitchEntity):
             ATTR_DEVICE_MULTI_OUTLET,
         ]:
             client_request = OutletRequestBuilder.set_state_request(
-                state, self.entity_description.plug_index
+                state, self.entity_description.plug_index_fn(self.coordinator.device)
             )
         else:
             client_request = ClientRequest("setState", {"state": state})
         await self.call_device(client_request)
-        self._attr_is_on = self._get_state(state, self.entity_description.plug_index)
+        self._attr_is_on = self._get_state(
+            state, self.entity_description.plug_index_fn(self.coordinator.device)
+        )
         self.async_write_ha_state()
 
     async def async_turn_on(self, **kwargs: Any) -> None: