Refactor telegram_bot polling/webhooks platforms and add tests (#66433)
Co-authored-by: Pär Berge <paer.berge@gmail.com>pull/69171/head
parent
55c6112a28
commit
d7375f1a9c
|
@ -1,20 +1,27 @@
|
|||
"""Support to send and receive Telegram messages."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import importlib
|
||||
import io
|
||||
from ipaddress import ip_network
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
|
||||
from telegram import (
|
||||
Bot,
|
||||
CallbackQuery,
|
||||
InlineKeyboardButton,
|
||||
InlineKeyboardMarkup,
|
||||
Message,
|
||||
ReplyKeyboardMarkup,
|
||||
ReplyKeyboardRemove,
|
||||
Update,
|
||||
)
|
||||
from telegram.error import TelegramError
|
||||
from telegram.ext import CallbackContext, Filters
|
||||
from telegram.parsemode import ParseMode
|
||||
from telegram.utils.request import Request
|
||||
import voluptuous as vol
|
||||
|
@ -311,14 +318,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return False
|
||||
|
||||
for p_config in config[DOMAIN]:
|
||||
|
||||
# Each platform config gets its own bot
|
||||
bot = initialize_bot(p_config)
|
||||
p_type = p_config.get(CONF_PLATFORM)
|
||||
|
||||
platform = importlib.import_module(f".{p_config[CONF_PLATFORM]}", __name__)
|
||||
|
||||
_LOGGER.info("Setting up %s.%s", DOMAIN, p_type)
|
||||
try:
|
||||
receiver_service = await platform.async_setup_platform(hass, p_config)
|
||||
receiver_service = await platform.async_setup_platform(hass, bot, p_config)
|
||||
if receiver_service is False:
|
||||
_LOGGER.error("Failed to initialize Telegram bot %s", p_type)
|
||||
return False
|
||||
|
@ -327,7 +335,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
_LOGGER.exception("Error setting up platform %s", p_type)
|
||||
return False
|
||||
|
||||
bot = initialize_bot(p_config)
|
||||
notify_service = TelegramNotificationService(
|
||||
hass, bot, p_config.get(CONF_ALLOWED_CHAT_IDS), p_config.get(ATTR_PARSER)
|
||||
)
|
||||
|
@ -416,7 +423,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
def initialize_bot(p_config):
|
||||
"""Initialize telegram bot with proxy support."""
|
||||
|
||||
api_key = p_config.get(CONF_API_KEY)
|
||||
proxy_url = p_config.get(CONF_PROXY_URL)
|
||||
proxy_params = p_config.get(CONF_PROXY_PARAMS)
|
||||
|
@ -435,7 +441,6 @@ class TelegramNotificationService:
|
|||
|
||||
def __init__(self, hass, bot, allowed_chat_ids, parser):
|
||||
"""Initialize the service."""
|
||||
|
||||
self.allowed_chat_ids = allowed_chat_ids
|
||||
self._default_user = self.allowed_chat_ids[0]
|
||||
self._last_message_id = {user: None for user in self.allowed_chat_ids}
|
||||
|
@ -495,7 +500,6 @@ class TelegramNotificationService:
|
|||
- a string like: `/cmd1, /cmd2, /cmd3`
|
||||
- or a string like: `text_b1:/cmd1, text_b2:/cmd2`
|
||||
"""
|
||||
|
||||
buttons = []
|
||||
if isinstance(row_keyboard, str):
|
||||
for key in row_keyboard.split(","):
|
||||
|
@ -566,7 +570,6 @@ class TelegramNotificationService:
|
|||
|
||||
def _send_msg(self, func_send, msg_error, message_tag, *args_msg, **kwargs_msg):
|
||||
"""Send one message."""
|
||||
|
||||
try:
|
||||
out = func_send(*args_msg, **kwargs_msg)
|
||||
if not isinstance(out, bool) and hasattr(out, ATTR_MESSAGEID):
|
||||
|
@ -857,131 +860,99 @@ class TelegramNotificationService:
|
|||
class BaseTelegramBotEntity:
|
||||
"""The base class for the telegram bot."""
|
||||
|
||||
def __init__(self, hass, allowed_chat_ids):
|
||||
def __init__(self, hass, config):
|
||||
"""Initialize the bot base class."""
|
||||
self.allowed_chat_ids = allowed_chat_ids
|
||||
self.allowed_chat_ids = config[CONF_ALLOWED_CHAT_IDS]
|
||||
self.hass = hass
|
||||
|
||||
def _get_message_data(self, msg_data):
|
||||
"""Return boolean msg_data_is_ok and dict msg_data."""
|
||||
if not msg_data:
|
||||
return False, None
|
||||
bad_fields = (
|
||||
"text" not in msg_data and "data" not in msg_data and "chat" not in msg_data
|
||||
)
|
||||
if bad_fields or "from" not in msg_data:
|
||||
# Message is not correct.
|
||||
_LOGGER.error("Incoming message does not have required data (%s)", msg_data)
|
||||
return False, None
|
||||
def handle_update(self, update: Update, context: CallbackContext) -> bool:
|
||||
"""Handle updates from bot dispatcher set up by the respective platform."""
|
||||
_LOGGER.debug("Handling update %s", update)
|
||||
if not self.authorize_update(update):
|
||||
return False
|
||||
|
||||
if (
|
||||
msg_data["from"].get("id") not in self.allowed_chat_ids
|
||||
and msg_data["message"]["chat"].get("id") not in self.allowed_chat_ids
|
||||
):
|
||||
# Neither from id nor chat id was in allowed_chat_ids,
|
||||
# origin is not allowed.
|
||||
_LOGGER.error("Incoming message is not allowed (%s)", msg_data)
|
||||
return True, None
|
||||
|
||||
data = {
|
||||
ATTR_USER_ID: msg_data["from"]["id"],
|
||||
ATTR_FROM_FIRST: msg_data["from"]["first_name"],
|
||||
}
|
||||
if "message_id" in msg_data:
|
||||
data[ATTR_MSGID] = msg_data["message_id"]
|
||||
if "last_name" in msg_data["from"]:
|
||||
data[ATTR_FROM_LAST] = msg_data["from"]["last_name"]
|
||||
if "chat" in msg_data:
|
||||
data[ATTR_CHAT_ID] = msg_data["chat"]["id"]
|
||||
elif ATTR_MESSAGE in msg_data and "chat" in msg_data[ATTR_MESSAGE]:
|
||||
data[ATTR_CHAT_ID] = msg_data[ATTR_MESSAGE]["chat"]["id"]
|
||||
|
||||
return True, data
|
||||
|
||||
def _get_channel_post_data(self, msg_data):
|
||||
"""Return boolean msg_data_is_ok and dict msg_data."""
|
||||
if not msg_data:
|
||||
return False, None
|
||||
|
||||
if "sender_chat" in msg_data and "chat" in msg_data and "text" in msg_data:
|
||||
if (
|
||||
msg_data["sender_chat"].get("id") not in self.allowed_chat_ids
|
||||
and msg_data["chat"].get("id") not in self.allowed_chat_ids
|
||||
):
|
||||
# Neither sender_chat id nor chat id was in allowed_chat_ids,
|
||||
# origin is not allowed.
|
||||
_LOGGER.error("Incoming message is not allowed (%s)", msg_data)
|
||||
return True, None
|
||||
|
||||
data = {
|
||||
ATTR_MSGID: msg_data["message_id"],
|
||||
ATTR_CHAT_ID: msg_data["chat"]["id"],
|
||||
ATTR_TEXT: msg_data["text"],
|
||||
}
|
||||
return True, data
|
||||
|
||||
_LOGGER.error("Incoming message does not have required data (%s)", msg_data)
|
||||
return False, None
|
||||
|
||||
def process_message(self, data):
|
||||
"""Check for basic message rules and fire an event if message is ok."""
|
||||
if ATTR_MSG in data or ATTR_EDITED_MSG in data:
|
||||
event = EVENT_TELEGRAM_COMMAND
|
||||
if ATTR_MSG in data:
|
||||
data = data.get(ATTR_MSG)
|
||||
else:
|
||||
data = data.get(ATTR_EDITED_MSG)
|
||||
message_ok, event_data = self._get_message_data(data)
|
||||
if event_data is None:
|
||||
return message_ok
|
||||
|
||||
if ATTR_MSGID in data:
|
||||
event_data[ATTR_MSGID] = data[ATTR_MSGID]
|
||||
|
||||
if "text" in data:
|
||||
if data["text"][0] == "/":
|
||||
pieces = data["text"].split(" ")
|
||||
event_data[ATTR_COMMAND] = pieces[0]
|
||||
event_data[ATTR_ARGS] = pieces[1:]
|
||||
else:
|
||||
event_data[ATTR_TEXT] = data["text"]
|
||||
event = EVENT_TELEGRAM_TEXT
|
||||
else:
|
||||
_LOGGER.warning("Message without text data received: %s", data)
|
||||
event_data[ATTR_TEXT] = str(data)
|
||||
event = EVENT_TELEGRAM_TEXT
|
||||
|
||||
self.hass.bus.async_fire(event, event_data)
|
||||
return True
|
||||
if ATTR_CALLBACK_QUERY in data:
|
||||
event = EVENT_TELEGRAM_CALLBACK
|
||||
data = data.get(ATTR_CALLBACK_QUERY)
|
||||
message_ok, event_data = self._get_message_data(data)
|
||||
if event_data is None:
|
||||
return message_ok
|
||||
|
||||
query_data = event_data[ATTR_DATA] = data[ATTR_DATA]
|
||||
|
||||
if query_data[0] == "/":
|
||||
pieces = query_data.split(" ")
|
||||
event_data[ATTR_COMMAND] = pieces[0]
|
||||
event_data[ATTR_ARGS] = pieces[1:]
|
||||
|
||||
event_data[ATTR_MSG] = data[ATTR_MSG]
|
||||
event_data[ATTR_CHAT_INSTANCE] = data[ATTR_CHAT_INSTANCE]
|
||||
event_data[ATTR_MSGID] = data[ATTR_MSGID]
|
||||
|
||||
self.hass.bus.async_fire(event, event_data)
|
||||
return True
|
||||
if ATTR_CHANNEL_POST in data:
|
||||
event = EVENT_TELEGRAM_TEXT
|
||||
data = data.get(ATTR_CHANNEL_POST)
|
||||
message_ok, event_data = self._get_channel_post_data(data)
|
||||
if event_data is None:
|
||||
return message_ok
|
||||
|
||||
self.hass.bus.async_fire(event, event_data)
|
||||
# establish event type: text, command or callback_query
|
||||
if update.callback_query:
|
||||
# NOTE: Check for callback query first since effective message will be populated with the message
|
||||
# in .callback_query (python-telegram-bot docs are wrong)
|
||||
event_type, event_data = self._get_callback_query_event_data(
|
||||
update.callback_query
|
||||
)
|
||||
elif update.effective_message:
|
||||
event_type, event_data = self._get_message_event_data(
|
||||
update.effective_message
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning("Unhandled update: %s", update)
|
||||
return True
|
||||
|
||||
_LOGGER.warning("Message with unknown data received: %s", data)
|
||||
_LOGGER.debug("Firing event %s: %s", event_type, event_data)
|
||||
self.hass.bus.fire(event_type, event_data)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _get_command_event_data(command_text: str) -> dict[str, str | list]:
|
||||
if not command_text.startswith("/"):
|
||||
return {}
|
||||
command_parts = command_text.split()
|
||||
command = command_parts[0]
|
||||
args = command_parts[1:]
|
||||
return {ATTR_COMMAND: command, ATTR_ARGS: args}
|
||||
|
||||
def _get_message_event_data(self, message: Message) -> tuple[str, dict[str, Any]]:
|
||||
event_data: dict[str, Any] = {
|
||||
ATTR_MSGID: message.message_id,
|
||||
ATTR_CHAT_ID: message.chat.id,
|
||||
}
|
||||
if Filters.command.filter(message):
|
||||
# This is a command message - set event type to command and split data into command and args
|
||||
event_type = EVENT_TELEGRAM_COMMAND
|
||||
event_data.update(self._get_command_event_data(message.text))
|
||||
else:
|
||||
event_type = EVENT_TELEGRAM_TEXT
|
||||
event_data[ATTR_TEXT] = message.text
|
||||
|
||||
if message.from_user:
|
||||
event_data.update(
|
||||
{
|
||||
ATTR_USER_ID: message.from_user.id,
|
||||
ATTR_FROM_FIRST: message.from_user.first_name,
|
||||
ATTR_FROM_LAST: message.from_user.last_name,
|
||||
}
|
||||
)
|
||||
|
||||
return event_type, event_data
|
||||
|
||||
def _get_callback_query_event_data(
|
||||
self, callback_query: CallbackQuery
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
event_type = EVENT_TELEGRAM_CALLBACK
|
||||
event_data: dict[str, Any] = {
|
||||
ATTR_MSGID: callback_query.id,
|
||||
ATTR_CHAT_INSTANCE: callback_query.chat_instance,
|
||||
ATTR_DATA: callback_query.data,
|
||||
ATTR_MSG: None,
|
||||
ATTR_CHAT_ID: None,
|
||||
}
|
||||
if callback_query.message:
|
||||
event_data[ATTR_MSG] = callback_query.message.to_dict()
|
||||
event_data[ATTR_CHAT_ID] = callback_query.message.chat.id
|
||||
|
||||
# Split data into command and args if possible
|
||||
event_data.update(self._get_command_event_data(callback_query.data))
|
||||
|
||||
return event_type, event_data
|
||||
|
||||
def authorize_update(self, update: Update) -> bool:
|
||||
"""Make sure either user or chat is in allowed_chat_ids."""
|
||||
from_user = update.effective_user.id if update.effective_user else None
|
||||
from_chat = update.effective_chat.id if update.effective_chat else None
|
||||
if from_user in self.allowed_chat_ids or from_chat in self.allowed_chat_ids:
|
||||
return True
|
||||
_LOGGER.error(
|
||||
"Unauthorized update - neither user id %s nor chat id %s is in allowed chats: %s",
|
||||
from_user,
|
||||
from_chat,
|
||||
self.allowed_chat_ids,
|
||||
)
|
||||
return False
|
||||
|
|
|
@ -3,31 +3,21 @@ import logging
|
|||
|
||||
from telegram import Update
|
||||
from telegram.error import NetworkError, RetryAfter, TelegramError, TimedOut
|
||||
from telegram.ext import CallbackContext, Dispatcher, Handler, Updater
|
||||
from telegram.utils.types import HandlerArg
|
||||
from telegram.ext import CallbackContext, TypeHandler, Updater
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
|
||||
|
||||
from . import CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, initialize_bot
|
||||
from . import BaseTelegramBotEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_platform(hass, config):
|
||||
async def async_setup_platform(hass, bot, config):
|
||||
"""Set up the Telegram polling platform."""
|
||||
bot = initialize_bot(config)
|
||||
pol = TelegramPoll(bot, hass, config[CONF_ALLOWED_CHAT_IDS])
|
||||
pollbot = PollBot(hass, bot, config)
|
||||
|
||||
def _start_bot(_event):
|
||||
"""Start the bot."""
|
||||
pol.start_polling()
|
||||
|
||||
def _stop_bot(_event):
|
||||
"""Stop the bot."""
|
||||
pol.stop_polling()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _start_bot)
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop_bot)
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, pollbot.start_polling)
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, pollbot.stop_polling)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -43,57 +33,28 @@ def process_error(update: Update, context: CallbackContext):
|
|||
_LOGGER.error('Update "%s" caused error: "%s"', update, context.error)
|
||||
|
||||
|
||||
def message_handler(handler):
|
||||
"""Create messages handler."""
|
||||
class PollBot(BaseTelegramBotEntity):
|
||||
"""
|
||||
Controls the Updater object that holds the bot and a dispatcher.
|
||||
|
||||
class MessageHandler(Handler):
|
||||
"""Telegram bot message handler."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the messages handler instance."""
|
||||
super().__init__(handler)
|
||||
|
||||
def check_update(self, update):
|
||||
"""Check is update valid."""
|
||||
return isinstance(update, Update)
|
||||
|
||||
def handle_update(
|
||||
self,
|
||||
update: HandlerArg,
|
||||
dispatcher: Dispatcher,
|
||||
check_result: object,
|
||||
context: CallbackContext = None,
|
||||
):
|
||||
"""Handle update."""
|
||||
optional_args = self.collect_optional_args(dispatcher, update)
|
||||
context.args = optional_args
|
||||
return self.callback(update, context)
|
||||
|
||||
return MessageHandler()
|
||||
|
||||
|
||||
class TelegramPoll(BaseTelegramBotEntity):
|
||||
"""Asyncio telegram incoming message handler."""
|
||||
|
||||
def __init__(self, bot, hass, allowed_chat_ids):
|
||||
"""Initialize the polling instance."""
|
||||
|
||||
BaseTelegramBotEntity.__init__(self, hass, allowed_chat_ids)
|
||||
The dispatcher is set up by the super class to pass telegram updates to `self.handle_update`
|
||||
"""
|
||||
|
||||
def __init__(self, hass, bot, config):
|
||||
"""Create Updater and Dispatcher before calling super()."""
|
||||
self.bot = bot
|
||||
self.updater = Updater(bot=bot, workers=4)
|
||||
self.dispatcher = self.updater.dispatcher
|
||||
|
||||
self.dispatcher.add_handler(message_handler(self.process_update))
|
||||
self.dispatcher.add_handler(TypeHandler(Update, self.handle_update))
|
||||
self.dispatcher.add_error_handler(process_error)
|
||||
super().__init__(hass, config)
|
||||
|
||||
def start_polling(self):
|
||||
def start_polling(self, event=None):
|
||||
"""Start the polling task."""
|
||||
_LOGGER.debug("Starting polling")
|
||||
self.updater.start_polling()
|
||||
|
||||
def stop_polling(self):
|
||||
def stop_polling(self, event=None):
|
||||
"""Stop the polling task."""
|
||||
_LOGGER.debug("Stopping polling")
|
||||
self.updater.stop()
|
||||
|
||||
def process_update(self, update: HandlerArg, context: CallbackContext):
|
||||
"""Process incoming message."""
|
||||
self.process_message(update.to_dict())
|
||||
|
|
|
@ -4,90 +4,115 @@ from http import HTTPStatus
|
|||
from ipaddress import ip_address
|
||||
import logging
|
||||
|
||||
from telegram import Update
|
||||
from telegram.error import TimedOut
|
||||
from telegram.ext import Dispatcher, TypeHandler
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.helpers.network import get_url
|
||||
|
||||
from . import (
|
||||
CONF_ALLOWED_CHAT_IDS,
|
||||
CONF_TRUSTED_NETWORKS,
|
||||
CONF_URL,
|
||||
BaseTelegramBotEntity,
|
||||
initialize_bot,
|
||||
)
|
||||
from . import CONF_TRUSTED_NETWORKS, CONF_URL, BaseTelegramBotEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TELEGRAM_HANDLER_URL = "/api/telegram_webhooks"
|
||||
REMOVE_HANDLER_URL = ""
|
||||
TELEGRAM_WEBHOOK_URL = "/api/telegram_webhooks"
|
||||
REMOVE_WEBHOOK_URL = ""
|
||||
|
||||
|
||||
async def async_setup_platform(hass, config):
|
||||
async def async_setup_platform(hass, bot, config):
|
||||
"""Set up the Telegram webhooks platform."""
|
||||
pushbot = PushBot(hass, bot, config)
|
||||
|
||||
bot = initialize_bot(config)
|
||||
|
||||
current_status = await hass.async_add_executor_job(bot.getWebhookInfo)
|
||||
if not (base_url := config.get(CONF_URL)):
|
||||
base_url = get_url(hass, require_ssl=True, allow_internal=False)
|
||||
|
||||
# Some logging of Bot current status:
|
||||
last_error_date = getattr(current_status, "last_error_date", None)
|
||||
if (last_error_date is not None) and (isinstance(last_error_date, int)):
|
||||
last_error_date = dt.datetime.fromtimestamp(last_error_date)
|
||||
_LOGGER.info(
|
||||
"Telegram webhook last_error_date: %s. Status: %s",
|
||||
last_error_date,
|
||||
current_status,
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("telegram webhook Status: %s", current_status)
|
||||
|
||||
handler_url = f"{base_url}{TELEGRAM_HANDLER_URL}"
|
||||
if not handler_url.startswith("https"):
|
||||
_LOGGER.error("Invalid telegram webhook %s must be https", handler_url)
|
||||
if not pushbot.webhook_url.startswith("https"):
|
||||
_LOGGER.error("Invalid telegram webhook %s must be https", pushbot.webhook_url)
|
||||
return False
|
||||
|
||||
def _try_to_set_webhook():
|
||||
retry_num = 0
|
||||
while retry_num < 3:
|
||||
try:
|
||||
return bot.setWebhook(handler_url, timeout=5)
|
||||
except TimedOut:
|
||||
retry_num += 1
|
||||
_LOGGER.warning("Timeout trying to set webhook (retry #%d)", retry_num)
|
||||
webhook_registered = await pushbot.register_webhook()
|
||||
if not webhook_registered:
|
||||
return False
|
||||
|
||||
if current_status and current_status["url"] != handler_url:
|
||||
result = await hass.async_add_executor_job(_try_to_set_webhook)
|
||||
if result:
|
||||
_LOGGER.info("Set new telegram webhook %s", handler_url)
|
||||
else:
|
||||
_LOGGER.error("Set telegram webhook failed %s", handler_url)
|
||||
return False
|
||||
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, lambda event: bot.setWebhook(REMOVE_HANDLER_URL)
|
||||
)
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, pushbot.deregister_webhook)
|
||||
hass.http.register_view(
|
||||
BotPushReceiver(
|
||||
hass, config[CONF_ALLOWED_CHAT_IDS], config[CONF_TRUSTED_NETWORKS]
|
||||
)
|
||||
PushBotView(hass, bot, pushbot.dispatcher, config[CONF_TRUSTED_NETWORKS])
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
|
||||
"""Handle pushes from Telegram."""
|
||||
class PushBot(BaseTelegramBotEntity):
|
||||
"""Handles all the push/webhook logic and passes telegram updates to `self.handle_update`."""
|
||||
|
||||
def __init__(self, hass, bot, config):
|
||||
"""Create Dispatcher before calling super()."""
|
||||
self.bot = bot
|
||||
self.trusted_networks = config[CONF_TRUSTED_NETWORKS]
|
||||
# Dumb dispatcher that just gets our updates to our handler callback (self.handle_update)
|
||||
self.dispatcher = Dispatcher(bot, None)
|
||||
self.dispatcher.add_handler(TypeHandler(Update, self.handle_update))
|
||||
super().__init__(hass, config)
|
||||
|
||||
self.base_url = config.get(CONF_URL) or get_url(
|
||||
hass, require_ssl=True, allow_internal=False
|
||||
)
|
||||
self.webhook_url = f"{self.base_url}{TELEGRAM_WEBHOOK_URL}"
|
||||
|
||||
def _try_to_set_webhook(self):
|
||||
_LOGGER.debug("Registering webhook URL: %s", self.webhook_url)
|
||||
retry_num = 0
|
||||
while retry_num < 3:
|
||||
try:
|
||||
return self.bot.set_webhook(self.webhook_url, timeout=5)
|
||||
except TimedOut:
|
||||
retry_num += 1
|
||||
_LOGGER.warning("Timeout trying to set webhook (retry #%d)", retry_num)
|
||||
|
||||
return False
|
||||
|
||||
async def register_webhook(self):
|
||||
"""Query telegram and register the URL for our webhook."""
|
||||
current_status = await self.hass.async_add_executor_job(
|
||||
self.bot.get_webhook_info
|
||||
)
|
||||
# Some logging of Bot current status:
|
||||
last_error_date = getattr(current_status, "last_error_date", None)
|
||||
if (last_error_date is not None) and (isinstance(last_error_date, int)):
|
||||
last_error_date = dt.datetime.fromtimestamp(last_error_date)
|
||||
_LOGGER.debug(
|
||||
"Telegram webhook last_error_date: %s. Status: %s",
|
||||
last_error_date,
|
||||
current_status,
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("telegram webhook status: %s", current_status)
|
||||
|
||||
if current_status and current_status["url"] != self.webhook_url:
|
||||
result = await self.hass.async_add_executor_job(self._try_to_set_webhook)
|
||||
if result:
|
||||
_LOGGER.info("Set new telegram webhook %s", self.webhook_url)
|
||||
else:
|
||||
_LOGGER.error("Set telegram webhook failed %s", self.webhook_url)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def deregister_webhook(self, event=None):
|
||||
"""Query telegram and deregister the URL for our webhook."""
|
||||
_LOGGER.debug("Deregistering webhook URL")
|
||||
return self.bot.delete_webhook()
|
||||
|
||||
|
||||
class PushBotView(HomeAssistantView):
|
||||
"""View for handling webhook calls from Telegram."""
|
||||
|
||||
requires_auth = False
|
||||
url = TELEGRAM_HANDLER_URL
|
||||
url = TELEGRAM_WEBHOOK_URL
|
||||
name = "telegram_webhooks"
|
||||
|
||||
def __init__(self, hass, allowed_chat_ids, trusted_networks):
|
||||
"""Initialize the class."""
|
||||
BaseTelegramBotEntity.__init__(self, hass, allowed_chat_ids)
|
||||
def __init__(self, hass, bot, dispatcher, trusted_networks):
|
||||
"""Initialize by storing stuff needed for setting up our webhook endpoint."""
|
||||
self.hass = hass
|
||||
self.bot = bot
|
||||
self.dispatcher = dispatcher
|
||||
self.trusted_networks = trusted_networks
|
||||
|
||||
async def post(self, request):
|
||||
|
@ -98,10 +123,12 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
|
|||
return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
update_data = await request.json()
|
||||
except ValueError:
|
||||
return self.json_message("Invalid JSON", HTTPStatus.BAD_REQUEST)
|
||||
|
||||
if not self.process_message(data):
|
||||
return self.json_message("Invalid message", HTTPStatus.BAD_REQUEST)
|
||||
update = Update.de_json(update_data, self.bot)
|
||||
_LOGGER.debug("Received Update on %s: %s", self.url, update)
|
||||
await self.hass.async_add_executor_job(self.dispatcher.process_update, update)
|
||||
|
||||
return None
|
||||
|
|
|
@ -26,6 +26,9 @@ PyQRCode==1.2.1
|
|||
# homeassistant.components.rmvtransport
|
||||
PyRMVtransport==0.3.3
|
||||
|
||||
# homeassistant.components.telegram_bot
|
||||
PySocks==1.7.1
|
||||
|
||||
# homeassistant.components.switchbot
|
||||
# PySwitchbot==0.13.3
|
||||
|
||||
|
@ -1259,6 +1262,9 @@ python-songpal==0.14.1
|
|||
# homeassistant.components.tado
|
||||
python-tado==0.12.0
|
||||
|
||||
# homeassistant.components.telegram_bot
|
||||
python-telegram-bot==13.1
|
||||
|
||||
# homeassistant.components.awair
|
||||
python_awair==0.2.3
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
"""Tests for telegram_bot integration."""
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for the telegram_bot integration."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from telegram.ext.dispatcher import Dispatcher
|
||||
|
||||
from homeassistant.components.telegram_bot import (
|
||||
CONF_ALLOWED_CHAT_IDS,
|
||||
CONF_TRUSTED_NETWORKS,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_PLATFORM, CONF_URL
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_webhooks():
|
||||
"""Fixture for a webhooks platform configuration."""
|
||||
return {
|
||||
DOMAIN: [
|
||||
{
|
||||
CONF_PLATFORM: "webhooks",
|
||||
CONF_URL: "https://test",
|
||||
CONF_TRUSTED_NETWORKS: ["127.0.0.1"],
|
||||
CONF_API_KEY: "1234567890:ABC",
|
||||
CONF_ALLOWED_CHAT_IDS: [
|
||||
# "me"
|
||||
12345678,
|
||||
# Some chat
|
||||
-123456789,
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_polling():
|
||||
"""Fixture for a polling platform configuration."""
|
||||
return {
|
||||
DOMAIN: [
|
||||
{
|
||||
CONF_PLATFORM: "polling",
|
||||
CONF_API_KEY: "1234567890:ABC",
|
||||
CONF_ALLOWED_CHAT_IDS: [
|
||||
# "me"
|
||||
12345678,
|
||||
# Some chat
|
||||
-123456789,
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_register_webhook():
|
||||
"""Mock calls made by telegram_bot when (de)registering webhook."""
|
||||
with patch(
|
||||
"homeassistant.components.telegram_bot.webhooks.PushBot.register_webhook",
|
||||
return_value=True,
|
||||
), patch(
|
||||
"homeassistant.components.telegram_bot.webhooks.PushBot.deregister_webhook",
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_message_command():
|
||||
"""Fixture for mocking an incoming update of type message/command."""
|
||||
return {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 1,
|
||||
"from": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "Firstname",
|
||||
"username": "some_username",
|
||||
"language_code": "en",
|
||||
},
|
||||
"chat": {
|
||||
"id": -123456789,
|
||||
"title": "SomeChat",
|
||||
"type": "group",
|
||||
"all_members_are_administrators": True,
|
||||
},
|
||||
"date": 1644518189,
|
||||
"text": "/command",
|
||||
"entities": [
|
||||
{
|
||||
"type": "bot_command",
|
||||
"offset": 0,
|
||||
"length": 7,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_message_text():
|
||||
"""Fixture for mocking an incoming update of type message/text."""
|
||||
return {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 1,
|
||||
"date": 1441645532,
|
||||
"from": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"last_name": "Test Lastname",
|
||||
"first_name": "Test Firstname",
|
||||
"username": "Testusername",
|
||||
},
|
||||
"chat": {
|
||||
"last_name": "Test Lastname",
|
||||
"id": 1111111,
|
||||
"type": "private",
|
||||
"first_name": "Test Firstname",
|
||||
"username": "Testusername",
|
||||
},
|
||||
"text": "HELLO",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthorized_update_message_text(update_message_text):
|
||||
"""Fixture for mocking an incoming update of type message/text that is not in our `allowed_chat_ids`."""
|
||||
update_message_text["message"]["from"]["id"] = 1234
|
||||
update_message_text["message"]["chat"]["id"] = 1234
|
||||
return update_message_text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_callback_query():
|
||||
"""Fixture for mocking an incoming update of type callback_query."""
|
||||
return {
|
||||
"update_id": 1,
|
||||
"callback_query": {
|
||||
"id": "4382bfdwdsb323b2d9",
|
||||
"from": {
|
||||
"id": 12345678,
|
||||
"type": "private",
|
||||
"is_bot": False,
|
||||
"last_name": "Test Lastname",
|
||||
"first_name": "Test Firstname",
|
||||
"username": "Testusername",
|
||||
},
|
||||
"chat_instance": "aaa111",
|
||||
"data": "Data from button callback",
|
||||
"inline_message_id": "1234csdbsk4839",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def webhook_platform(hass, config_webhooks, mock_register_webhook):
|
||||
"""Fixture for setting up the webhooks platform using appropriate config and mocks."""
|
||||
await async_setup_component(
|
||||
hass,
|
||||
DOMAIN,
|
||||
config_webhooks,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def polling_platform(hass, config_polling):
|
||||
"""Fixture for setting up the polling platform using appropriate config and mocks."""
|
||||
await async_setup_component(
|
||||
hass,
|
||||
DOMAIN,
|
||||
config_polling,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_dispatcher():
|
||||
"""Clear the singleton that telegram.ext.dispatcher.Dispatcher sets on itself."""
|
||||
yield
|
||||
Dispatcher._set_singleton(None)
|
||||
# This is how python-telegram-bot resets the dispatcher in their test suite
|
||||
Dispatcher._Dispatcher__singleton_semaphore.release()
|
|
@ -0,0 +1,112 @@
|
|||
"""Tests for the telegram_bot component."""
|
||||
from telegram import Update
|
||||
from telegram.ext.dispatcher import Dispatcher
|
||||
|
||||
from homeassistant.components.telegram_bot import DOMAIN, SERVICE_SEND_MESSAGE
|
||||
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
|
||||
|
||||
from tests.common import async_capture_events
|
||||
|
||||
|
||||
async def test_webhook_platform_init(hass, webhook_platform):
|
||||
"""Test initialization of the webhooks platform."""
|
||||
assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True
|
||||
|
||||
|
||||
async def test_polling_platform_init(hass, polling_platform):
|
||||
"""Test initialization of the polling platform."""
|
||||
assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True
|
||||
|
||||
|
||||
async def test_webhook_endpoint_generates_telegram_text_event(
|
||||
hass, webhook_platform, hass_client, update_message_text
|
||||
):
|
||||
"""POST to the configured webhook endpoint and assert fired `telegram_text` event."""
|
||||
client = await hass_client()
|
||||
events = async_capture_events(hass, "telegram_text")
|
||||
|
||||
response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_message_text)
|
||||
assert response.status == 200
|
||||
assert (await response.read()).decode("utf-8") == ""
|
||||
|
||||
# Make sure event has fired
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].data["text"] == update_message_text["message"]["text"]
|
||||
|
||||
|
||||
async def test_webhook_endpoint_generates_telegram_command_event(
|
||||
hass, webhook_platform, hass_client, update_message_command
|
||||
):
|
||||
"""POST to the configured webhook endpoint and assert fired `telegram_command` event."""
|
||||
client = await hass_client()
|
||||
events = async_capture_events(hass, "telegram_command")
|
||||
|
||||
response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_message_command)
|
||||
assert response.status == 200
|
||||
assert (await response.read()).decode("utf-8") == ""
|
||||
|
||||
# Make sure event has fired
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].data["command"] == update_message_command["message"]["text"]
|
||||
|
||||
|
||||
async def test_webhook_endpoint_generates_telegram_callback_event(
|
||||
hass, webhook_platform, hass_client, update_callback_query
|
||||
):
|
||||
"""POST to the configured webhook endpoint and assert fired `telegram_callback` event."""
|
||||
client = await hass_client()
|
||||
events = async_capture_events(hass, "telegram_callback")
|
||||
|
||||
response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_callback_query)
|
||||
assert response.status == 200
|
||||
assert (await response.read()).decode("utf-8") == ""
|
||||
|
||||
# Make sure event has fired
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].data["data"] == update_callback_query["callback_query"]["data"]
|
||||
|
||||
|
||||
async def test_polling_platform_message_text_update(
|
||||
hass, polling_platform, update_message_text
|
||||
):
|
||||
"""Provide the `PollBot`s `Dispatcher` with an `Update` and assert fired `telegram_text` event."""
|
||||
events = async_capture_events(hass, "telegram_text")
|
||||
|
||||
def telegram_dispatcher_callback():
|
||||
dispatcher = Dispatcher.get_instance()
|
||||
update = Update.de_json(update_message_text, dispatcher.bot)
|
||||
dispatcher.process_update(update)
|
||||
|
||||
# python-telegram-bots `Updater` uses threading, so we need to schedule its callback in a sync context.
|
||||
await hass.async_add_executor_job(telegram_dispatcher_callback)
|
||||
|
||||
# Make sure event has fired
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].data["text"] == update_message_text["message"]["text"]
|
||||
|
||||
|
||||
async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event(
|
||||
hass, webhook_platform, hass_client, unauthorized_update_message_text
|
||||
):
|
||||
"""Update with unauthorized user/chat should not trigger event."""
|
||||
client = await hass_client()
|
||||
events = async_capture_events(hass, "telegram_text")
|
||||
|
||||
response = await client.post(
|
||||
TELEGRAM_WEBHOOK_URL, json=unauthorized_update_message_text
|
||||
)
|
||||
assert response.status == 200
|
||||
assert (await response.read()).decode("utf-8") == ""
|
||||
|
||||
# Make sure any events would have fired
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 0
|
Loading…
Reference in New Issue