"""Aiohttp test utils.""" import asyncio from contextlib import contextmanager from http import HTTPStatus import json as _json import re from unittest import mock from urllib.parse import parse_qs from aiohttp import ClientSession from aiohttp.client_exceptions import ClientError, ClientResponseError from aiohttp.streams import StreamReader from multidict import CIMultiDict from yarl import URL from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE RETYPE = type(re.compile("")) def mock_stream(data): """Mock a stream with data.""" protocol = mock.Mock(_reading_paused=False) stream = StreamReader(protocol, limit=2 ** 16) stream.feed_data(data) stream.feed_eof() return stream class AiohttpClientMocker: """Mock Aiohttp client requests.""" def __init__(self): """Initialize the request mocker.""" self._mocks = [] self._cookies = {} self.mock_calls = [] def request( self, method, url, *, auth=None, status=HTTPStatus.OK, text=None, data=None, content=None, json=None, params=None, headers={}, exc=None, cookies=None, side_effect=None, ): """Mock a request.""" if not isinstance(url, RETYPE): url = URL(url) if params: url = url.with_query(params) self._mocks.append( AiohttpClientMockResponse( method=method, url=url, status=status, response=content, json=json, text=text, cookies=cookies, exc=exc, headers=headers, side_effect=side_effect, ) ) def get(self, *args, **kwargs): """Register a mock get request.""" self.request("get", *args, **kwargs) def put(self, *args, **kwargs): """Register a mock put request.""" self.request("put", *args, **kwargs) def post(self, *args, **kwargs): """Register a mock post request.""" self.request("post", *args, **kwargs) def delete(self, *args, **kwargs): """Register a mock delete request.""" self.request("delete", *args, **kwargs) def options(self, *args, **kwargs): """Register a mock options request.""" self.request("options", *args, **kwargs) def patch(self, *args, **kwargs): """Register a mock patch request.""" self.request("patch", *args, **kwargs) @property def call_count(self): """Return the number of requests made.""" return len(self.mock_calls) def clear_requests(self): """Reset mock calls.""" self._mocks.clear() self._cookies.clear() self.mock_calls.clear() def create_session(self, loop): """Create a ClientSession that is bound to this mocker.""" session = ClientSession(loop=loop) # Setting directly on `session` will raise deprecation warning object.__setattr__(session, "_request", self.match_request) return session async def match_request( self, method, url, *, data=None, auth=None, params=None, headers=None, allow_redirects=None, timeout=None, json=None, cookies=None, **kwargs, ): """Match a request against pre-registered requests.""" data = data or json url = URL(url) if params: url = url.with_query(params) for response in self._mocks: if response.match_request(method, url, params): self.mock_calls.append((method, url, data, headers)) if response.side_effect: response = await response.side_effect(method, url, data) if response.exc: raise response.exc return response assert False, "No mock registered for {} {} {}".format( method.upper(), url, params ) class AiohttpClientMockResponse: """Mock Aiohttp client response.""" def __init__( self, method, url, status=HTTPStatus.OK, response=None, json=None, text=None, cookies=None, exc=None, headers=None, side_effect=None, ): """Initialize a fake response.""" if json is not None: text = _json.dumps(json) if text is not None: response = text.encode("utf-8") if response is None: response = b"" self.method = method self._url = url self.status = status self.response = response self.exc = exc self.side_effect = side_effect self._headers = CIMultiDict(headers or {}) self._cookies = {} if cookies: for name, data in cookies.items(): cookie = mock.MagicMock() cookie.value = data self._cookies[name] = cookie def match_request(self, method, url, params=None): """Test if response answers request.""" if method.lower() != self.method.lower(): return False # regular expression matching if isinstance(self._url, RETYPE): return self._url.search(str(url)) is not None if ( self._url.scheme != url.scheme or self._url.host != url.host or self._url.path != url.path ): return False # Ensure all query components in matcher are present in the request request_qs = parse_qs(url.query_string) matcher_qs = parse_qs(self._url.query_string) for key, vals in matcher_qs.items(): for val in vals: try: request_qs.get(key, []).remove(val) except ValueError: return False return True @property def headers(self): """Return content_type.""" return self._headers @property def cookies(self): """Return dict of cookies.""" return self._cookies @property def url(self): """Return yarl of URL.""" return self._url @property def content_type(self): """Return yarl of URL.""" return self._headers.get("content-type") @property def content(self): """Return content.""" return mock_stream(self.response) async def read(self): """Return mock response.""" return self.response async def text(self, encoding="utf-8", errors="strict"): """Return mock response as a string.""" return self.response.decode(encoding, errors=errors) async def json(self, encoding="utf-8", content_type=None): """Return mock response as a json.""" return _json.loads(self.response.decode(encoding)) def release(self): """Mock release.""" def raise_for_status(self): """Raise error if status is 400 or higher.""" if self.status >= 400: request_info = mock.Mock(real_url="http://example.com") raise ClientResponseError( request_info=request_info, history=None, code=self.status, headers=self.headers, ) def close(self): """Mock close.""" @contextmanager def mock_aiohttp_client(): """Context manager to mock aiohttp client.""" mocker = AiohttpClientMocker() def create_session(hass, *args, **kwargs): session = mocker.create_session(hass.loop) async def close_session(event): """Close session.""" await session.close() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, close_session) return session with mock.patch( "homeassistant.helpers.aiohttp_client._async_create_clientsession", side_effect=create_session, ): yield mocker class MockLongPollSideEffect: """Imitate a long_poll request. It should be created and used as a side effect for a GET/PUT/etc. request. Once created, actual responses are queued with queue_response If queue is empty, will await until done. """ def __init__(self): """Initialize the queue.""" self.semaphore = asyncio.Semaphore(0) self.response_list = [] self.stopping = False async def __call__(self, method, url, data): """Fetch the next response from the queue or wait until the queue has items.""" if self.stopping: raise ClientError() await self.semaphore.acquire() kwargs = self.response_list.pop(0) return AiohttpClientMockResponse(method=method, url=url, **kwargs) def queue_response(self, **kwargs): """Add a response to the long_poll queue.""" self.response_list.append(kwargs) self.semaphore.release() def stop(self): """Stop the current request and future ones. This avoids an exception if there is someone waiting when exiting test. """ self.stopping = True self.queue_response(exc=ClientError())