diff --git a/ha_test/test_component_chromecast.py b/ha_test/test_component_chromecast.py index c653c84b74f..75ac9765c63 100644 --- a/ha_test/test_component_chromecast.py +++ b/ha_test/test_component_chromecast.py @@ -63,7 +63,7 @@ class TestChromecast(unittest.TestCase): calls = mock_service(self.hass, chromecast.DOMAIN, service_name) service_method(self.hass) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(calls)) call = calls[-1] @@ -71,7 +71,7 @@ class TestChromecast(unittest.TestCase): self.assertEqual(service_name, call.service) service_method(self.hass, self.test_entity) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(2, len(calls)) call = calls[-1] diff --git a/ha_test/test_component_core.py b/ha_test/test_component_core.py index 927bfb98a7c..2c53d578277 100644 --- a/ha_test/test_component_core.py +++ b/ha_test/test_component_core.py @@ -44,7 +44,7 @@ class TestComponentsCore(unittest.TestCase): comps.turn_on(self.hass, 'light.Ceiling') - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(runs)) @@ -56,6 +56,6 @@ class TestComponentsCore(unittest.TestCase): comps.turn_off(self.hass, 'light.Bowl') - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(runs)) diff --git a/ha_test/test_component_demo.py b/ha_test/test_component_demo.py index 72be2d12525..b687653a0bc 100644 --- a/ha_test/test_component_demo.py +++ b/ha_test/test_component_demo.py @@ -35,21 +35,21 @@ class TestDemo(unittest.TestCase): self.hass.services.call( domain, SERVICE_TURN_ON, {ATTR_ENTITY_ID: entity_id}) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(STATE_ON, self.hass.states.get(entity_id).state) self.hass.services.call( domain, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: entity_id}) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(STATE_OFF, self.hass.states.get(entity_id).state) # Act on all self.hass.services.call(domain, SERVICE_TURN_ON) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() for entity_id in self.hass.states.entity_ids(domain): self.assertEqual( @@ -57,7 +57,7 @@ class TestDemo(unittest.TestCase): self.hass.services.call(domain, SERVICE_TURN_OFF) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() for entity_id in self.hass.states.entity_ids(domain): self.assertEqual( diff --git a/ha_test/test_component_device_scanner.py b/ha_test/test_component_device_scanner.py index 09900951eeb..3c6385bc42f 100644 --- a/ha_test/test_component_device_scanner.py +++ b/ha_test/test_component_device_scanner.py @@ -110,7 +110,7 @@ class TestComponentsDeviceTracker(unittest.TestCase): device_tracker.DOMAIN, device_tracker.SERVICE_DEVICE_TRACKER_RELOAD) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() dev1 = device_tracker.ENTITY_ID_FORMAT.format('Device_1') dev2 = device_tracker.ENTITY_ID_FORMAT.format('Device_2') @@ -154,7 +154,7 @@ class TestComponentsDeviceTracker(unittest.TestCase): device_tracker.DOMAIN, device_tracker.SERVICE_DEVICE_TRACKER_RELOAD) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() # Test what happens if a device comes home and another leaves self.assertTrue(device_tracker.is_on(self.hass)) @@ -171,7 +171,7 @@ class TestComponentsDeviceTracker(unittest.TestCase): self.hass.bus.fire( ha.EVENT_TIME_CHANGED, {ha.ATTR_NOW: nowAlmostMinGone}) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertTrue(device_tracker.is_on(self.hass)) self.assertTrue(device_tracker.is_on(self.hass, dev1)) @@ -182,7 +182,7 @@ class TestComponentsDeviceTracker(unittest.TestCase): # Now test if gone for longer then error margin self.hass.bus.fire(ha.EVENT_TIME_CHANGED, {ha.ATTR_NOW: nowMinGone}) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertTrue(device_tracker.is_on(self.hass)) self.assertTrue(device_tracker.is_on(self.hass, dev1)) diff --git a/ha_test/test_component_group.py b/ha_test/test_component_group.py index 204494b02e8..4e6307aa2b5 100644 --- a/ha_test/test_component_group.py +++ b/ha_test/test_component_group.py @@ -86,7 +86,7 @@ class TestComponentsGroup(unittest.TestCase): # Turn the Bowl off and see if group turns off self.hass.states.set('light.Bowl', STATE_OFF) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() group_state = self.hass.states.get(self.group_name) self.assertEqual(STATE_OFF, group_state.state) @@ -94,7 +94,7 @@ class TestComponentsGroup(unittest.TestCase): # Turn the Ceiling on and see if group turns on self.hass.states.set('light.Ceiling', STATE_ON) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() group_state = self.hass.states.get(self.group_name) self.assertEqual(STATE_ON, group_state.state) @@ -103,7 +103,7 @@ class TestComponentsGroup(unittest.TestCase): """ Test is_on method. """ self.assertTrue(group.is_on(self.hass, self.group_name)) self.hass.states.set('light.Bowl', STATE_OFF) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertFalse(group.is_on(self.hass, self.group_name)) # Try on non existing state diff --git a/ha_test/test_component_http.py b/ha_test/test_component_http.py index 15c2966292a..98b976cf099 100644 --- a/ha_test/test_component_http.py +++ b/ha_test/test_component_http.py @@ -199,7 +199,7 @@ class TestHTTP(unittest.TestCase): _url(remote.URL_API_EVENTS_EVENT.format("test.event_no_data")), headers=HA_HEADERS) - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) @@ -221,7 +221,7 @@ class TestHTTP(unittest.TestCase): data=json.dumps({"test": 1}), headers=HA_HEADERS) - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) @@ -241,7 +241,7 @@ class TestHTTP(unittest.TestCase): data=json.dumps('not an object'), headers=HA_HEADERS) - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(422, req.status_code) self.assertEqual(0, len(test_value)) @@ -252,7 +252,7 @@ class TestHTTP(unittest.TestCase): data=json.dumps([1, 2, 3]), headers=HA_HEADERS) - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(422, req.status_code) self.assertEqual(0, len(test_value)) @@ -297,7 +297,7 @@ class TestHTTP(unittest.TestCase): "test_domain", "test_service")), headers=HA_HEADERS) - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) @@ -319,7 +319,7 @@ class TestHTTP(unittest.TestCase): data=json.dumps({"test": 1}), headers=HA_HEADERS) - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) diff --git a/ha_test/test_component_light.py b/ha_test/test_component_light.py index 4d781c70132..e9cb219d07b 100644 --- a/ha_test/test_component_light.py +++ b/ha_test/test_component_light.py @@ -63,7 +63,7 @@ class TestLight(unittest.TestCase): xy_color='xy_color_val', profile='profile_val') - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(turn_on_calls)) call = turn_on_calls[-1] @@ -86,7 +86,7 @@ class TestLight(unittest.TestCase): light.turn_off( self.hass, entity_id='entity_id_val', transition='transition_val') - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(turn_off_calls)) call = turn_off_calls[-1] @@ -115,7 +115,7 @@ class TestLight(unittest.TestCase): light.turn_off(self.hass, entity_id=dev1.entity_id) light.turn_on(self.hass, entity_id=dev2.entity_id) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertFalse(light.is_on(self.hass, dev1.entity_id)) self.assertTrue(light.is_on(self.hass, dev2.entity_id)) @@ -123,7 +123,7 @@ class TestLight(unittest.TestCase): # turn on all lights light.turn_on(self.hass) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertTrue(light.is_on(self.hass, dev1.entity_id)) self.assertTrue(light.is_on(self.hass, dev2.entity_id)) @@ -132,7 +132,7 @@ class TestLight(unittest.TestCase): # turn off all lights light.turn_off(self.hass) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertFalse(light.is_on(self.hass, dev1.entity_id)) self.assertFalse(light.is_on(self.hass, dev2.entity_id)) @@ -145,7 +145,7 @@ class TestLight(unittest.TestCase): self.hass, dev2.entity_id, rgb_color=[255, 255, 255]) light.turn_on(self.hass, dev3.entity_id, xy_color=[.4, .6]) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() method, data = dev1.last_call('turn_on') self.assertEqual( @@ -171,7 +171,7 @@ class TestLight(unittest.TestCase): self.hass, dev2.entity_id, profile=prof_name, brightness=100, xy_color=[.4, .6]) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() method, data = dev1.last_call('turn_on') self.assertEqual( @@ -190,7 +190,7 @@ class TestLight(unittest.TestCase): light.turn_on(self.hass, dev2.entity_id, xy_color=["bla-di-bla", 5]) light.turn_on(self.hass, dev3.entity_id, rgb_color=[255, None, 2]) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() method, data = dev1.last_call('turn_on') self.assertEqual({}, data) @@ -206,7 +206,7 @@ class TestLight(unittest.TestCase): self.hass, dev1.entity_id, profile=prof_name, brightness='bright', rgb_color='yellowish') - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() method, data = dev1.last_call('turn_on') self.assertEqual( @@ -267,7 +267,7 @@ class TestLight(unittest.TestCase): light.turn_on(self.hass, dev1.entity_id, profile='test') - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() method, data = dev1.last_call('turn_on') diff --git a/ha_test/test_component_sun.py b/ha_test/test_component_sun.py index daf8970f406..a587f60bff5 100644 --- a/ha_test/test_component_sun.py +++ b/ha_test/test_component_sun.py @@ -93,7 +93,7 @@ class TestSun(unittest.TestCase): self.hass.bus.fire(ha.EVENT_TIME_CHANGED, {ha.ATTR_NOW: test_time + dt.timedelta(seconds=5)}) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(test_state, self.hass.states.get(sun.ENTITY_ID).state) diff --git a/ha_test/test_component_switch.py b/ha_test/test_component_switch.py index daab9cde4d1..687df62ed5f 100644 --- a/ha_test/test_component_switch.py +++ b/ha_test/test_component_switch.py @@ -50,7 +50,7 @@ class TestSwitch(unittest.TestCase): switch.turn_off(self.hass, self.switch_1.entity_id) switch.turn_on(self.hass, self.switch_2.entity_id) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertTrue(switch.is_on(self.hass)) self.assertFalse(switch.is_on(self.hass, self.switch_1.entity_id)) @@ -59,7 +59,7 @@ class TestSwitch(unittest.TestCase): # Turn all off switch.turn_off(self.hass) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertFalse(switch.is_on(self.hass)) self.assertEqual( @@ -72,7 +72,7 @@ class TestSwitch(unittest.TestCase): # Turn all on switch.turn_on(self.hass) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertTrue(switch.is_on(self.hass)) self.assertEqual( diff --git a/ha_test/test_core.py b/ha_test/test_core.py index 2da8151261b..73513eee502 100644 --- a/ha_test/test_core.py +++ b/ha_test/test_core.py @@ -53,7 +53,7 @@ class TestHomeAssistant(unittest.TestCase): self.assertTrue(blocking_thread.is_alive()) self.hass.services.call(ha.DOMAIN, ha.SERVICE_HOMEASSISTANT_STOP) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() # hass.block_till_stopped checks every second if it should quit # we have to wait worst case 1 second @@ -76,23 +76,23 @@ class TestHomeAssistant(unittest.TestCase): lambda x: runs.append(1), birthday_paulus) self._send_time_changed(before_birthday) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(0, len(runs)) self._send_time_changed(birthday_paulus) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(runs)) # A point in time tracker will only fire once, this should do nothing self._send_time_changed(birthday_paulus) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(runs)) self.hass.track_point_in_time( lambda x: runs.append(1), birthday_paulus) self._send_time_changed(after_birthday) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(2, len(runs)) def test_track_time_change(self): @@ -105,17 +105,17 @@ class TestHomeAssistant(unittest.TestCase): lambda x: specific_runs.append(1), second=[0, 30]) self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0)) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(specific_runs)) self.assertEqual(1, len(wildcard_runs)) self._send_time_changed(datetime(2014, 5, 24, 12, 0, 15)) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(1, len(specific_runs)) self.assertEqual(2, len(wildcard_runs)) self._send_time_changed(datetime(2014, 5, 24, 12, 0, 30)) - self.hass._pool.block_till_done() + self.hass.pool.block_till_done() self.assertEqual(2, len(specific_runs)) self.assertEqual(3, len(wildcard_runs)) diff --git a/ha_test/test_remote.py b/ha_test/test_remote.py index cc317d63960..f6de538e54a 100644 --- a/ha_test/test_remote.py +++ b/ha_test/test_remote.py @@ -100,7 +100,7 @@ class TestRemoteMethods(unittest.TestCase): remote.fire_event(master_api, "test.event_no_data") - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) @@ -165,7 +165,7 @@ class TestRemoteMethods(unittest.TestCase): remote.call_service(master_api, "test_domain", "test_service") - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) @@ -204,7 +204,7 @@ class TestRemoteClasses(unittest.TestCase): # Wait till slave tells master slave._pool.block_till_done() # Wait till master gives updated state - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual("remote.statemachine test", slave.states.get("remote.test").state) @@ -224,7 +224,7 @@ class TestRemoteClasses(unittest.TestCase): # Wait till slave tells master slave._pool.block_till_done() # Wait till master gives updated event - hass._pool.block_till_done() + hass.pool.block_till_done() self.assertEqual(1, len(test_value)) diff --git a/homeassistant/__init__.py b/homeassistant/__init__.py index 887ec74c68e..336fa6b433d 100644 --- a/homeassistant/__init__.py +++ b/homeassistant/__init__.py @@ -30,22 +30,10 @@ TIMER_INTERVAL = 10 # seconds # How long we wait for the result of a service call SERVICE_CALL_LIMIT = 10 # seconds -# Define number of worker threads -# -# There are two categories of Home Assistant jobs: -# - jobs that poll external components that are mostly waiting for IO -# - jobs that respond to events that happen inside HA (state_changed, etc) -# -# Based on different setups I see 3 times as many events responding to events -# then that there are ones that poll components. We therefore want to set the -# number of threads to 1.25 of the CPU count, we will round it up so the -# minimum number of threads is 2. -# -# We want to have atleast 2 threads because a call to the homeassistant.turn_on -# will wait till the service is executed which is in a different thread. -# -# If os.cpu_count() cannot determine the cpu_count, we will assume there is 1. -POOL_NUM_THREAD = int((os.cpu_count() or 1) * 1.25) + 1 +# Define number of MINIMUM worker threads. +# During bootstrap of HA (see bootstrap.from_config_dict()) worker threads +# will be added for each component that polls devices. +MIN_WORKER_THREAD = 2 # Pattern for validating entity IDs (format: .) ENTITY_ID_PATTERN = re.compile(r"^(?P\w+)\.(?P\w+)$") @@ -57,8 +45,7 @@ class HomeAssistant(object): """ Core class to route all communication to right components. """ def __init__(self): - self._pool = pool = create_worker_pool() - + self.pool = pool = create_worker_pool() self.bus = EventBus(pool) self.services = ServiceRegistry(self.bus, pool) self.states = StateMachine(self.bus) @@ -71,6 +58,9 @@ class HomeAssistant(object): def start(self): """ Start home assistant. """ + _LOGGER.info( + "Starting Home Assistant (%d threads)", self.pool.worker_count) + Timer(self) self.bus.fire(EVENT_HOMEASSISTANT_START) @@ -165,9 +155,9 @@ class HomeAssistant(object): self.bus.fire(EVENT_HOMEASSISTANT_STOP) # Wait till all responses to homeassistant_stop are done - self._pool.block_till_done() + self.pool.block_till_done() - self._pool.stop() + self.pool.stop() def get_entity_ids(self, domain_filter=None): """ @@ -266,7 +256,7 @@ class JobPriority(util.OrderedEnum): return JobPriority.EVENT_DEFAULT -def create_worker_pool(thread_count=POOL_NUM_THREAD): +def create_worker_pool(): """ Creates a worker pool to be used. """ def job_handler(job): @@ -279,18 +269,18 @@ def create_worker_pool(thread_count=POOL_NUM_THREAD): # We do not want to crash our ThreadPool _LOGGER.exception("BusHandler:Exception doing job") - def busy_callback(current_jobs, pending_jobs_count): + def busy_callback(worker_count, current_jobs, pending_jobs_count): """ Callback to be called when the pool queue gets too big. """ _LOGGER.warning( "WorkerPool:All %d threads are busy and %d jobs pending", - thread_count, pending_jobs_count) + worker_count, pending_jobs_count) for start, job in current_jobs: _LOGGER.warning("WorkerPool:Current job from %s: %s", util.datetime_to_str(start), job) - return util.ThreadPool(thread_count, job_handler, busy_callback) + return util.ThreadPool(job_handler, MIN_WORKER_THREAD, busy_callback) class EventOrigin(enum.Enum): diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 112a8018d63..1b2a8ee7312 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -41,24 +41,39 @@ def from_config_dict(config, hass=None): components = (key for key in config.keys() if ' ' not in key and key != homeassistant.DOMAIN) - # Setup the components - if core_components.setup(hass, config): - logger.info("Home Assistant core initialized") - - for domain in loader.load_order_components(components): - try: - if loader.get_component(domain).setup(hass, config): - logger.info("component %s initialized", domain) - else: - logger.error("component %s failed to initialize", domain) - - except Exception: # pylint: disable=broad-except - logger.exception("Error during setup of component %s", domain) - - else: + if not core_components.setup(hass, config): logger.error(("Home Assistant core failed to initialize. " "Further initialization aborted.")) + return hass + + logger.info("Home Assistant core initialized") + + # Setup the components + + # We assume that all components that load before the group component loads + # are components that poll devices. As their tasks are IO based, we will + # add an extra worker for each of them. + add_worker = True + + for domain in loader.load_order_components(components): + component = loader.get_component(domain) + + try: + if component.setup(hass, config): + logger.info("component %s initialized", domain) + + add_worker = add_worker and domain != "group" + + if add_worker: + hass.pool.add_worker() + + else: + logger.error("component %s failed to initialize", domain) + + except Exception: # pylint: disable=broad-except + logger.exception("Error during setup of component %s", domain) + return hass diff --git a/homeassistant/util.py b/homeassistant/util.py index fbf686c39ab..1932d9ada63 100644 --- a/homeassistant/util.py +++ b/homeassistant/util.py @@ -308,67 +308,79 @@ class Throttle(object): return wrapper -# Reason why I decided to roll my own ThreadPool instead of using -# multiprocessing.dummy.pool or even better, use multiprocessing.pool and -# not be hurt by the GIL in the cpython interpreter: -# 1. The built in threadpool does not allow me to create custom workers and so -# I would have to wrap every listener that I passed into it with code to log -# the exceptions. Saving a reference to the logger in the worker seemed -# like a more sane thing to do. -# 2. Most event listeners are simple checks if attributes match. If the method -# that they will call takes a long time to complete it might be better to -# put that request in a seperate thread. This is for every component to -# decide on its own instead of enforcing it for everyone. class ThreadPool(object): - """ A simple queue-based thread pool. - - Will initiate it's workers using worker(queue).start() """ + """ A priority queue-based thread pool. """ # pylint: disable=too-many-instance-attributes - def __init__(self, worker_count, job_handler, busy_callback=None): + def __init__(self, job_handler, worker_count=0, busy_callback=None): """ - worker_count: number of threads to run that handle jobs job_handler: method to be called from worker thread to handle job + worker_count: number of threads to run that handle jobs busy_callback: method to be called when queue gets too big. - Parameters: list_of_current_jobs, number_pending_jobs + Parameters: worker_count, list of current_jobs, + pending_jobs_count """ - self.work_queue = work_queue = queue.PriorityQueue() - self.current_jobs = current_jobs = [] - self.worker_count = worker_count - self.busy_callback = busy_callback - self.busy_warning_limit = worker_count**2 + self._job_handler = job_handler + self._busy_callback = busy_callback + + self.worker_count = 0 + self.busy_warning_limit = 0 + self._work_queue = queue.PriorityQueue() + self.current_jobs = [] self._lock = threading.RLock() self._quit_task = object() - for _ in range(worker_count): - worker = threading.Thread(target=_threadpool_worker, - args=(work_queue, current_jobs, - job_handler, self._quit_task)) - worker.daemon = True - worker.start() - self.running = True - def add_job(self, priority, job): - """ Add a job to be sent to the workers. """ + for _ in range(worker_count): + self.add_worker() + + def add_worker(self): + """ Adds a worker to the thread pool. Resets warning limit. """ with self._lock: if not self.running: raise RuntimeError("ThreadPool not running") - self.work_queue.put(PriorityQueueItem(priority, job)) + worker = threading.Thread(target=self._worker) + worker.daemon = True + worker.start() + + self.worker_count += 1 + self.busy_warning_limit = self.worker_count * 3 + + def remove_worker(self): + """ Removes a worker from the thread pool. Resets warning limit. """ + with self._lock: + if not self.running: + raise RuntimeError("ThreadPool not running") + + self._work_queue.put(PriorityQueueItem(0, self._quit_task)) + + self.worker_count -= 1 + self.busy_warning_limit = self.worker_count * 3 + + def add_job(self, priority, job): + """ Add a job to the queue. """ + with self._lock: + if not self.running: + raise RuntimeError("ThreadPool not running") + + self._work_queue.put(PriorityQueueItem(priority, job)) # check if our queue is getting too big - if self.work_queue.qsize() > self.busy_warning_limit \ - and self.busy_callback is not None: + if self._work_queue.qsize() > self.busy_warning_limit \ + and self._busy_callback is not None: # Increase limit we will issue next warning self.busy_warning_limit *= 2 - self.busy_callback(self.current_jobs, self.work_queue.qsize()) + self._busy_callback( + self.worker_count, self.current_jobs, + self._work_queue.qsize()) def block_till_done(self): """ Blocks till all work is done. """ - self.work_queue.join() + self._work_queue.join() def stop(self): """ Stops all the threads. """ @@ -376,19 +388,41 @@ class ThreadPool(object): if not self.running: return - # Clear the queue - while self.work_queue.qsize() > 0: - self.work_queue.get() - self.work_queue.task_done() + # Ensure all current jobs finish + self.block_till_done() # Tell the workers to quit for _ in range(self.worker_count): - self.add_job(1000, self._quit_task) + self.remove_worker() self.running = False + # Wait till all workers have quit self.block_till_done() + def _worker(self): + """ Handles jobs for the thread pool. """ + while True: + # Get new item from work_queue + job = self._work_queue.get().item + + if job == self._quit_task: + self._work_queue.task_done() + return + + # Add to current running jobs + job_log = (datetime.now(), job) + self.current_jobs.append(job_log) + + # Do the job + self._job_handler(job) + + # Remove from current running job + self.current_jobs.remove(job_log) + + # Tell work_queue the task is done + self._work_queue.task_done() + class PriorityQueueItem(object): """ Holds a priority and a value. Used within PriorityQueue. """ @@ -400,27 +434,3 @@ class PriorityQueueItem(object): def __lt__(self, other): return self.priority < other.priority - - -def _threadpool_worker(work_queue, current_jobs, job_handler, quit_task): - """ Provides the base functionality of a worker for the thread pool. """ - while True: - # Get new item from work_queue - job = work_queue.get().item - - if job == quit_task: - work_queue.task_done() - return - - # Add to current running jobs - job_log = (datetime.now(), job) - current_jobs.append(job_log) - - # Do the job - job_handler(job) - - # Remove from current running job - current_jobs.remove(job_log) - - # Tell work_queue a task is done - work_queue.task_done()