Refactor telegram_bot polling/webhooks platforms and add tests (#66433)

Co-authored-by: Pär Berge <paer.berge@gmail.com>
pull/69171/head
Wictor 2022-04-03 05:39:14 +02:00 committed by GitHub
parent 55c6112a28
commit d7375f1a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 514 additions and 249 deletions

View File

@ -1,20 +1,27 @@
"""Support to send and receive Telegram messages."""
from __future__ import annotations
from functools import partial
import importlib
import io
from ipaddress import ip_network
import logging
from typing import Any
import requests
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
from telegram import (
Bot,
CallbackQuery,
InlineKeyboardButton,
InlineKeyboardMarkup,
Message,
ReplyKeyboardMarkup,
ReplyKeyboardRemove,
Update,
)
from telegram.error import TelegramError
from telegram.ext import CallbackContext, Filters
from telegram.parsemode import ParseMode
from telegram.utils.request import Request
import voluptuous as vol
@ -311,14 +318,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return False
for p_config in config[DOMAIN]:
# Each platform config gets its own bot
bot = initialize_bot(p_config)
p_type = p_config.get(CONF_PLATFORM)
platform = importlib.import_module(f".{p_config[CONF_PLATFORM]}", __name__)
_LOGGER.info("Setting up %s.%s", DOMAIN, p_type)
try:
receiver_service = await platform.async_setup_platform(hass, p_config)
receiver_service = await platform.async_setup_platform(hass, bot, p_config)
if receiver_service is False:
_LOGGER.error("Failed to initialize Telegram bot %s", p_type)
return False
@ -327,7 +335,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
_LOGGER.exception("Error setting up platform %s", p_type)
return False
bot = initialize_bot(p_config)
notify_service = TelegramNotificationService(
hass, bot, p_config.get(CONF_ALLOWED_CHAT_IDS), p_config.get(ATTR_PARSER)
)
@ -416,7 +423,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
def initialize_bot(p_config):
"""Initialize telegram bot with proxy support."""
api_key = p_config.get(CONF_API_KEY)
proxy_url = p_config.get(CONF_PROXY_URL)
proxy_params = p_config.get(CONF_PROXY_PARAMS)
@ -435,7 +441,6 @@ class TelegramNotificationService:
def __init__(self, hass, bot, allowed_chat_ids, parser):
"""Initialize the service."""
self.allowed_chat_ids = allowed_chat_ids
self._default_user = self.allowed_chat_ids[0]
self._last_message_id = {user: None for user in self.allowed_chat_ids}
@ -495,7 +500,6 @@ class TelegramNotificationService:
- a string like: `/cmd1, /cmd2, /cmd3`
- or a string like: `text_b1:/cmd1, text_b2:/cmd2`
"""
buttons = []
if isinstance(row_keyboard, str):
for key in row_keyboard.split(","):
@ -566,7 +570,6 @@ class TelegramNotificationService:
def _send_msg(self, func_send, msg_error, message_tag, *args_msg, **kwargs_msg):
"""Send one message."""
try:
out = func_send(*args_msg, **kwargs_msg)
if not isinstance(out, bool) and hasattr(out, ATTR_MESSAGEID):
@ -857,131 +860,99 @@ class TelegramNotificationService:
class BaseTelegramBotEntity:
"""The base class for the telegram bot."""
def __init__(self, hass, allowed_chat_ids):
def __init__(self, hass, config):
"""Initialize the bot base class."""
self.allowed_chat_ids = allowed_chat_ids
self.allowed_chat_ids = config[CONF_ALLOWED_CHAT_IDS]
self.hass = hass
def _get_message_data(self, msg_data):
"""Return boolean msg_data_is_ok and dict msg_data."""
if not msg_data:
return False, None
bad_fields = (
"text" not in msg_data and "data" not in msg_data and "chat" not in msg_data
)
if bad_fields or "from" not in msg_data:
# Message is not correct.
_LOGGER.error("Incoming message does not have required data (%s)", msg_data)
return False, None
def handle_update(self, update: Update, context: CallbackContext) -> bool:
"""Handle updates from bot dispatcher set up by the respective platform."""
_LOGGER.debug("Handling update %s", update)
if not self.authorize_update(update):
return False
if (
msg_data["from"].get("id") not in self.allowed_chat_ids
and msg_data["message"]["chat"].get("id") not in self.allowed_chat_ids
):
# Neither from id nor chat id was in allowed_chat_ids,
# origin is not allowed.
_LOGGER.error("Incoming message is not allowed (%s)", msg_data)
return True, None
data = {
ATTR_USER_ID: msg_data["from"]["id"],
ATTR_FROM_FIRST: msg_data["from"]["first_name"],
}
if "message_id" in msg_data:
data[ATTR_MSGID] = msg_data["message_id"]
if "last_name" in msg_data["from"]:
data[ATTR_FROM_LAST] = msg_data["from"]["last_name"]
if "chat" in msg_data:
data[ATTR_CHAT_ID] = msg_data["chat"]["id"]
elif ATTR_MESSAGE in msg_data and "chat" in msg_data[ATTR_MESSAGE]:
data[ATTR_CHAT_ID] = msg_data[ATTR_MESSAGE]["chat"]["id"]
return True, data
def _get_channel_post_data(self, msg_data):
"""Return boolean msg_data_is_ok and dict msg_data."""
if not msg_data:
return False, None
if "sender_chat" in msg_data and "chat" in msg_data and "text" in msg_data:
if (
msg_data["sender_chat"].get("id") not in self.allowed_chat_ids
and msg_data["chat"].get("id") not in self.allowed_chat_ids
):
# Neither sender_chat id nor chat id was in allowed_chat_ids,
# origin is not allowed.
_LOGGER.error("Incoming message is not allowed (%s)", msg_data)
return True, None
data = {
ATTR_MSGID: msg_data["message_id"],
ATTR_CHAT_ID: msg_data["chat"]["id"],
ATTR_TEXT: msg_data["text"],
}
return True, data
_LOGGER.error("Incoming message does not have required data (%s)", msg_data)
return False, None
def process_message(self, data):
"""Check for basic message rules and fire an event if message is ok."""
if ATTR_MSG in data or ATTR_EDITED_MSG in data:
event = EVENT_TELEGRAM_COMMAND
if ATTR_MSG in data:
data = data.get(ATTR_MSG)
else:
data = data.get(ATTR_EDITED_MSG)
message_ok, event_data = self._get_message_data(data)
if event_data is None:
return message_ok
if ATTR_MSGID in data:
event_data[ATTR_MSGID] = data[ATTR_MSGID]
if "text" in data:
if data["text"][0] == "/":
pieces = data["text"].split(" ")
event_data[ATTR_COMMAND] = pieces[0]
event_data[ATTR_ARGS] = pieces[1:]
else:
event_data[ATTR_TEXT] = data["text"]
event = EVENT_TELEGRAM_TEXT
else:
_LOGGER.warning("Message without text data received: %s", data)
event_data[ATTR_TEXT] = str(data)
event = EVENT_TELEGRAM_TEXT
self.hass.bus.async_fire(event, event_data)
return True
if ATTR_CALLBACK_QUERY in data:
event = EVENT_TELEGRAM_CALLBACK
data = data.get(ATTR_CALLBACK_QUERY)
message_ok, event_data = self._get_message_data(data)
if event_data is None:
return message_ok
query_data = event_data[ATTR_DATA] = data[ATTR_DATA]
if query_data[0] == "/":
pieces = query_data.split(" ")
event_data[ATTR_COMMAND] = pieces[0]
event_data[ATTR_ARGS] = pieces[1:]
event_data[ATTR_MSG] = data[ATTR_MSG]
event_data[ATTR_CHAT_INSTANCE] = data[ATTR_CHAT_INSTANCE]
event_data[ATTR_MSGID] = data[ATTR_MSGID]
self.hass.bus.async_fire(event, event_data)
return True
if ATTR_CHANNEL_POST in data:
event = EVENT_TELEGRAM_TEXT
data = data.get(ATTR_CHANNEL_POST)
message_ok, event_data = self._get_channel_post_data(data)
if event_data is None:
return message_ok
self.hass.bus.async_fire(event, event_data)
# establish event type: text, command or callback_query
if update.callback_query:
# NOTE: Check for callback query first since effective message will be populated with the message
# in .callback_query (python-telegram-bot docs are wrong)
event_type, event_data = self._get_callback_query_event_data(
update.callback_query
)
elif update.effective_message:
event_type, event_data = self._get_message_event_data(
update.effective_message
)
else:
_LOGGER.warning("Unhandled update: %s", update)
return True
_LOGGER.warning("Message with unknown data received: %s", data)
_LOGGER.debug("Firing event %s: %s", event_type, event_data)
self.hass.bus.fire(event_type, event_data)
return True
@staticmethod
def _get_command_event_data(command_text: str) -> dict[str, str | list]:
if not command_text.startswith("/"):
return {}
command_parts = command_text.split()
command = command_parts[0]
args = command_parts[1:]
return {ATTR_COMMAND: command, ATTR_ARGS: args}
def _get_message_event_data(self, message: Message) -> tuple[str, dict[str, Any]]:
event_data: dict[str, Any] = {
ATTR_MSGID: message.message_id,
ATTR_CHAT_ID: message.chat.id,
}
if Filters.command.filter(message):
# This is a command message - set event type to command and split data into command and args
event_type = EVENT_TELEGRAM_COMMAND
event_data.update(self._get_command_event_data(message.text))
else:
event_type = EVENT_TELEGRAM_TEXT
event_data[ATTR_TEXT] = message.text
if message.from_user:
event_data.update(
{
ATTR_USER_ID: message.from_user.id,
ATTR_FROM_FIRST: message.from_user.first_name,
ATTR_FROM_LAST: message.from_user.last_name,
}
)
return event_type, event_data
def _get_callback_query_event_data(
self, callback_query: CallbackQuery
) -> tuple[str, dict[str, Any]]:
event_type = EVENT_TELEGRAM_CALLBACK
event_data: dict[str, Any] = {
ATTR_MSGID: callback_query.id,
ATTR_CHAT_INSTANCE: callback_query.chat_instance,
ATTR_DATA: callback_query.data,
ATTR_MSG: None,
ATTR_CHAT_ID: None,
}
if callback_query.message:
event_data[ATTR_MSG] = callback_query.message.to_dict()
event_data[ATTR_CHAT_ID] = callback_query.message.chat.id
# Split data into command and args if possible
event_data.update(self._get_command_event_data(callback_query.data))
return event_type, event_data
def authorize_update(self, update: Update) -> bool:
"""Make sure either user or chat is in allowed_chat_ids."""
from_user = update.effective_user.id if update.effective_user else None
from_chat = update.effective_chat.id if update.effective_chat else None
if from_user in self.allowed_chat_ids or from_chat in self.allowed_chat_ids:
return True
_LOGGER.error(
"Unauthorized update - neither user id %s nor chat id %s is in allowed chats: %s",
from_user,
from_chat,
self.allowed_chat_ids,
)
return False

View File

@ -3,31 +3,21 @@ import logging
from telegram import Update
from telegram.error import NetworkError, RetryAfter, TelegramError, TimedOut
from telegram.ext import CallbackContext, Dispatcher, Handler, Updater
from telegram.utils.types import HandlerArg
from telegram.ext import CallbackContext, TypeHandler, Updater
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from . import CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, initialize_bot
from . import BaseTelegramBotEntity
_LOGGER = logging.getLogger(__name__)
async def async_setup_platform(hass, config):
async def async_setup_platform(hass, bot, config):
"""Set up the Telegram polling platform."""
bot = initialize_bot(config)
pol = TelegramPoll(bot, hass, config[CONF_ALLOWED_CHAT_IDS])
pollbot = PollBot(hass, bot, config)
def _start_bot(_event):
"""Start the bot."""
pol.start_polling()
def _stop_bot(_event):
"""Stop the bot."""
pol.stop_polling()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _start_bot)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop_bot)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, pollbot.start_polling)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, pollbot.stop_polling)
return True
@ -43,57 +33,28 @@ def process_error(update: Update, context: CallbackContext):
_LOGGER.error('Update "%s" caused error: "%s"', update, context.error)
def message_handler(handler):
"""Create messages handler."""
class PollBot(BaseTelegramBotEntity):
"""
Controls the Updater object that holds the bot and a dispatcher.
class MessageHandler(Handler):
"""Telegram bot message handler."""
def __init__(self):
"""Initialize the messages handler instance."""
super().__init__(handler)
def check_update(self, update):
"""Check is update valid."""
return isinstance(update, Update)
def handle_update(
self,
update: HandlerArg,
dispatcher: Dispatcher,
check_result: object,
context: CallbackContext = None,
):
"""Handle update."""
optional_args = self.collect_optional_args(dispatcher, update)
context.args = optional_args
return self.callback(update, context)
return MessageHandler()
class TelegramPoll(BaseTelegramBotEntity):
"""Asyncio telegram incoming message handler."""
def __init__(self, bot, hass, allowed_chat_ids):
"""Initialize the polling instance."""
BaseTelegramBotEntity.__init__(self, hass, allowed_chat_ids)
The dispatcher is set up by the super class to pass telegram updates to `self.handle_update`
"""
def __init__(self, hass, bot, config):
"""Create Updater and Dispatcher before calling super()."""
self.bot = bot
self.updater = Updater(bot=bot, workers=4)
self.dispatcher = self.updater.dispatcher
self.dispatcher.add_handler(message_handler(self.process_update))
self.dispatcher.add_handler(TypeHandler(Update, self.handle_update))
self.dispatcher.add_error_handler(process_error)
super().__init__(hass, config)
def start_polling(self):
def start_polling(self, event=None):
"""Start the polling task."""
_LOGGER.debug("Starting polling")
self.updater.start_polling()
def stop_polling(self):
def stop_polling(self, event=None):
"""Stop the polling task."""
_LOGGER.debug("Stopping polling")
self.updater.stop()
def process_update(self, update: HandlerArg, context: CallbackContext):
"""Process incoming message."""
self.process_message(update.to_dict())

View File

@ -4,90 +4,115 @@ from http import HTTPStatus
from ipaddress import ip_address
import logging
from telegram import Update
from telegram.error import TimedOut
from telegram.ext import Dispatcher, TypeHandler
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.helpers.network import get_url
from . import (
CONF_ALLOWED_CHAT_IDS,
CONF_TRUSTED_NETWORKS,
CONF_URL,
BaseTelegramBotEntity,
initialize_bot,
)
from . import CONF_TRUSTED_NETWORKS, CONF_URL, BaseTelegramBotEntity
_LOGGER = logging.getLogger(__name__)
TELEGRAM_HANDLER_URL = "/api/telegram_webhooks"
REMOVE_HANDLER_URL = ""
TELEGRAM_WEBHOOK_URL = "/api/telegram_webhooks"
REMOVE_WEBHOOK_URL = ""
async def async_setup_platform(hass, config):
async def async_setup_platform(hass, bot, config):
"""Set up the Telegram webhooks platform."""
pushbot = PushBot(hass, bot, config)
bot = initialize_bot(config)
current_status = await hass.async_add_executor_job(bot.getWebhookInfo)
if not (base_url := config.get(CONF_URL)):
base_url = get_url(hass, require_ssl=True, allow_internal=False)
# Some logging of Bot current status:
last_error_date = getattr(current_status, "last_error_date", None)
if (last_error_date is not None) and (isinstance(last_error_date, int)):
last_error_date = dt.datetime.fromtimestamp(last_error_date)
_LOGGER.info(
"Telegram webhook last_error_date: %s. Status: %s",
last_error_date,
current_status,
)
else:
_LOGGER.debug("telegram webhook Status: %s", current_status)
handler_url = f"{base_url}{TELEGRAM_HANDLER_URL}"
if not handler_url.startswith("https"):
_LOGGER.error("Invalid telegram webhook %s must be https", handler_url)
if not pushbot.webhook_url.startswith("https"):
_LOGGER.error("Invalid telegram webhook %s must be https", pushbot.webhook_url)
return False
def _try_to_set_webhook():
retry_num = 0
while retry_num < 3:
try:
return bot.setWebhook(handler_url, timeout=5)
except TimedOut:
retry_num += 1
_LOGGER.warning("Timeout trying to set webhook (retry #%d)", retry_num)
webhook_registered = await pushbot.register_webhook()
if not webhook_registered:
return False
if current_status and current_status["url"] != handler_url:
result = await hass.async_add_executor_job(_try_to_set_webhook)
if result:
_LOGGER.info("Set new telegram webhook %s", handler_url)
else:
_LOGGER.error("Set telegram webhook failed %s", handler_url)
return False
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, lambda event: bot.setWebhook(REMOVE_HANDLER_URL)
)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, pushbot.deregister_webhook)
hass.http.register_view(
BotPushReceiver(
hass, config[CONF_ALLOWED_CHAT_IDS], config[CONF_TRUSTED_NETWORKS]
)
PushBotView(hass, bot, pushbot.dispatcher, config[CONF_TRUSTED_NETWORKS])
)
return True
class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
"""Handle pushes from Telegram."""
class PushBot(BaseTelegramBotEntity):
"""Handles all the push/webhook logic and passes telegram updates to `self.handle_update`."""
def __init__(self, hass, bot, config):
"""Create Dispatcher before calling super()."""
self.bot = bot
self.trusted_networks = config[CONF_TRUSTED_NETWORKS]
# Dumb dispatcher that just gets our updates to our handler callback (self.handle_update)
self.dispatcher = Dispatcher(bot, None)
self.dispatcher.add_handler(TypeHandler(Update, self.handle_update))
super().__init__(hass, config)
self.base_url = config.get(CONF_URL) or get_url(
hass, require_ssl=True, allow_internal=False
)
self.webhook_url = f"{self.base_url}{TELEGRAM_WEBHOOK_URL}"
def _try_to_set_webhook(self):
_LOGGER.debug("Registering webhook URL: %s", self.webhook_url)
retry_num = 0
while retry_num < 3:
try:
return self.bot.set_webhook(self.webhook_url, timeout=5)
except TimedOut:
retry_num += 1
_LOGGER.warning("Timeout trying to set webhook (retry #%d)", retry_num)
return False
async def register_webhook(self):
"""Query telegram and register the URL for our webhook."""
current_status = await self.hass.async_add_executor_job(
self.bot.get_webhook_info
)
# Some logging of Bot current status:
last_error_date = getattr(current_status, "last_error_date", None)
if (last_error_date is not None) and (isinstance(last_error_date, int)):
last_error_date = dt.datetime.fromtimestamp(last_error_date)
_LOGGER.debug(
"Telegram webhook last_error_date: %s. Status: %s",
last_error_date,
current_status,
)
else:
_LOGGER.debug("telegram webhook status: %s", current_status)
if current_status and current_status["url"] != self.webhook_url:
result = await self.hass.async_add_executor_job(self._try_to_set_webhook)
if result:
_LOGGER.info("Set new telegram webhook %s", self.webhook_url)
else:
_LOGGER.error("Set telegram webhook failed %s", self.webhook_url)
return False
return True
def deregister_webhook(self, event=None):
"""Query telegram and deregister the URL for our webhook."""
_LOGGER.debug("Deregistering webhook URL")
return self.bot.delete_webhook()
class PushBotView(HomeAssistantView):
"""View for handling webhook calls from Telegram."""
requires_auth = False
url = TELEGRAM_HANDLER_URL
url = TELEGRAM_WEBHOOK_URL
name = "telegram_webhooks"
def __init__(self, hass, allowed_chat_ids, trusted_networks):
"""Initialize the class."""
BaseTelegramBotEntity.__init__(self, hass, allowed_chat_ids)
def __init__(self, hass, bot, dispatcher, trusted_networks):
"""Initialize by storing stuff needed for setting up our webhook endpoint."""
self.hass = hass
self.bot = bot
self.dispatcher = dispatcher
self.trusted_networks = trusted_networks
async def post(self, request):
@ -98,10 +123,12 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED)
try:
data = await request.json()
update_data = await request.json()
except ValueError:
return self.json_message("Invalid JSON", HTTPStatus.BAD_REQUEST)
if not self.process_message(data):
return self.json_message("Invalid message", HTTPStatus.BAD_REQUEST)
update = Update.de_json(update_data, self.bot)
_LOGGER.debug("Received Update on %s: %s", self.url, update)
await self.hass.async_add_executor_job(self.dispatcher.process_update, update)
return None

View File

@ -26,6 +26,9 @@ PyQRCode==1.2.1
# homeassistant.components.rmvtransport
PyRMVtransport==0.3.3
# homeassistant.components.telegram_bot
PySocks==1.7.1
# homeassistant.components.switchbot
# PySwitchbot==0.13.3
@ -1259,6 +1262,9 @@ python-songpal==0.14.1
# homeassistant.components.tado
python-tado==0.12.0
# homeassistant.components.telegram_bot
python-telegram-bot==13.1
# homeassistant.components.awair
python_awair==0.2.3

View File

@ -0,0 +1 @@
"""Tests for telegram_bot integration."""

View File

@ -0,0 +1,187 @@
"""Tests for the telegram_bot integration."""
from unittest.mock import patch
import pytest
from telegram.ext.dispatcher import Dispatcher
from homeassistant.components.telegram_bot import (
CONF_ALLOWED_CHAT_IDS,
CONF_TRUSTED_NETWORKS,
DOMAIN,
)
from homeassistant.const import CONF_API_KEY, CONF_PLATFORM, CONF_URL
from homeassistant.setup import async_setup_component
@pytest.fixture
def config_webhooks():
"""Fixture for a webhooks platform configuration."""
return {
DOMAIN: [
{
CONF_PLATFORM: "webhooks",
CONF_URL: "https://test",
CONF_TRUSTED_NETWORKS: ["127.0.0.1"],
CONF_API_KEY: "1234567890:ABC",
CONF_ALLOWED_CHAT_IDS: [
# "me"
12345678,
# Some chat
-123456789,
],
}
]
}
@pytest.fixture
def config_polling():
"""Fixture for a polling platform configuration."""
return {
DOMAIN: [
{
CONF_PLATFORM: "polling",
CONF_API_KEY: "1234567890:ABC",
CONF_ALLOWED_CHAT_IDS: [
# "me"
12345678,
# Some chat
-123456789,
],
}
]
}
@pytest.fixture
def mock_register_webhook():
"""Mock calls made by telegram_bot when (de)registering webhook."""
with patch(
"homeassistant.components.telegram_bot.webhooks.PushBot.register_webhook",
return_value=True,
), patch(
"homeassistant.components.telegram_bot.webhooks.PushBot.deregister_webhook",
return_value=True,
):
yield
@pytest.fixture
def update_message_command():
"""Fixture for mocking an incoming update of type message/command."""
return {
"update_id": 1,
"message": {
"message_id": 1,
"from": {
"id": 12345678,
"is_bot": False,
"first_name": "Firstname",
"username": "some_username",
"language_code": "en",
},
"chat": {
"id": -123456789,
"title": "SomeChat",
"type": "group",
"all_members_are_administrators": True,
},
"date": 1644518189,
"text": "/command",
"entities": [
{
"type": "bot_command",
"offset": 0,
"length": 7,
}
],
},
}
@pytest.fixture
def update_message_text():
"""Fixture for mocking an incoming update of type message/text."""
return {
"update_id": 1,
"message": {
"message_id": 1,
"date": 1441645532,
"from": {
"id": 12345678,
"is_bot": False,
"last_name": "Test Lastname",
"first_name": "Test Firstname",
"username": "Testusername",
},
"chat": {
"last_name": "Test Lastname",
"id": 1111111,
"type": "private",
"first_name": "Test Firstname",
"username": "Testusername",
},
"text": "HELLO",
},
}
@pytest.fixture
def unauthorized_update_message_text(update_message_text):
"""Fixture for mocking an incoming update of type message/text that is not in our `allowed_chat_ids`."""
update_message_text["message"]["from"]["id"] = 1234
update_message_text["message"]["chat"]["id"] = 1234
return update_message_text
@pytest.fixture
def update_callback_query():
"""Fixture for mocking an incoming update of type callback_query."""
return {
"update_id": 1,
"callback_query": {
"id": "4382bfdwdsb323b2d9",
"from": {
"id": 12345678,
"type": "private",
"is_bot": False,
"last_name": "Test Lastname",
"first_name": "Test Firstname",
"username": "Testusername",
},
"chat_instance": "aaa111",
"data": "Data from button callback",
"inline_message_id": "1234csdbsk4839",
},
}
@pytest.fixture
async def webhook_platform(hass, config_webhooks, mock_register_webhook):
"""Fixture for setting up the webhooks platform using appropriate config and mocks."""
await async_setup_component(
hass,
DOMAIN,
config_webhooks,
)
await hass.async_block_till_done()
@pytest.fixture
async def polling_platform(hass, config_polling):
"""Fixture for setting up the polling platform using appropriate config and mocks."""
await async_setup_component(
hass,
DOMAIN,
config_polling,
)
await hass.async_block_till_done()
@pytest.fixture(autouse=True)
def clear_dispatcher():
"""Clear the singleton that telegram.ext.dispatcher.Dispatcher sets on itself."""
yield
Dispatcher._set_singleton(None)
# This is how python-telegram-bot resets the dispatcher in their test suite
Dispatcher._Dispatcher__singleton_semaphore.release()

View File

@ -0,0 +1,112 @@
"""Tests for the telegram_bot component."""
from telegram import Update
from telegram.ext.dispatcher import Dispatcher
from homeassistant.components.telegram_bot import DOMAIN, SERVICE_SEND_MESSAGE
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
from tests.common import async_capture_events
async def test_webhook_platform_init(hass, webhook_platform):
"""Test initialization of the webhooks platform."""
assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True
async def test_polling_platform_init(hass, polling_platform):
"""Test initialization of the polling platform."""
assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True
async def test_webhook_endpoint_generates_telegram_text_event(
hass, webhook_platform, hass_client, update_message_text
):
"""POST to the configured webhook endpoint and assert fired `telegram_text` event."""
client = await hass_client()
events = async_capture_events(hass, "telegram_text")
response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_message_text)
assert response.status == 200
assert (await response.read()).decode("utf-8") == ""
# Make sure event has fired
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].data["text"] == update_message_text["message"]["text"]
async def test_webhook_endpoint_generates_telegram_command_event(
hass, webhook_platform, hass_client, update_message_command
):
"""POST to the configured webhook endpoint and assert fired `telegram_command` event."""
client = await hass_client()
events = async_capture_events(hass, "telegram_command")
response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_message_command)
assert response.status == 200
assert (await response.read()).decode("utf-8") == ""
# Make sure event has fired
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].data["command"] == update_message_command["message"]["text"]
async def test_webhook_endpoint_generates_telegram_callback_event(
hass, webhook_platform, hass_client, update_callback_query
):
"""POST to the configured webhook endpoint and assert fired `telegram_callback` event."""
client = await hass_client()
events = async_capture_events(hass, "telegram_callback")
response = await client.post(TELEGRAM_WEBHOOK_URL, json=update_callback_query)
assert response.status == 200
assert (await response.read()).decode("utf-8") == ""
# Make sure event has fired
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].data["data"] == update_callback_query["callback_query"]["data"]
async def test_polling_platform_message_text_update(
hass, polling_platform, update_message_text
):
"""Provide the `PollBot`s `Dispatcher` with an `Update` and assert fired `telegram_text` event."""
events = async_capture_events(hass, "telegram_text")
def telegram_dispatcher_callback():
dispatcher = Dispatcher.get_instance()
update = Update.de_json(update_message_text, dispatcher.bot)
dispatcher.process_update(update)
# python-telegram-bots `Updater` uses threading, so we need to schedule its callback in a sync context.
await hass.async_add_executor_job(telegram_dispatcher_callback)
# Make sure event has fired
await hass.async_block_till_done()
assert len(events) == 1
assert events[0].data["text"] == update_message_text["message"]["text"]
async def test_webhook_endpoint_unauthorized_update_doesnt_generate_telegram_text_event(
hass, webhook_platform, hass_client, unauthorized_update_message_text
):
"""Update with unauthorized user/chat should not trigger event."""
client = await hass_client()
events = async_capture_events(hass, "telegram_text")
response = await client.post(
TELEGRAM_WEBHOOK_URL, json=unauthorized_update_message_text
)
assert response.status == 200
assert (await response.read()).decode("utf-8") == ""
# Make sure any events would have fired
await hass.async_block_till_done()
assert len(events) == 0