diff --git a/TESTS/mbedmicro-rtos-mbed/condition_variable/main.cpp b/TESTS/mbedmicro-rtos-mbed/condition_variable/main.cpp index ad235854de..d33c9c7bac 100644 --- a/TESTS/mbedmicro-rtos-mbed/condition_variable/main.cpp +++ b/TESTS/mbedmicro-rtos-mbed/condition_variable/main.cpp @@ -92,6 +92,80 @@ void test_notify_all() t2.join(); } + +class TestConditionVariable : public ConditionVariable { + +public: + static void test_linked_list(void) + { + Waiter *list = NULL; + Waiter w1; + Waiter w2; + Waiter w3; + Waiter w4; + + TEST_ASSERT_EQUAL(0, validate_and_get_size(&list)); + + // Add 4 nodes + _add_wait_list(&list, &w1); + TEST_ASSERT_EQUAL(1, validate_and_get_size(&list)); + _add_wait_list(&list, &w2); + TEST_ASSERT_EQUAL(2, validate_and_get_size(&list)); + _add_wait_list(&list, &w3); + TEST_ASSERT_EQUAL(3, validate_and_get_size(&list)); + _add_wait_list(&list, &w4); + TEST_ASSERT_EQUAL(4, validate_and_get_size(&list)); + + // Remove a middle node + _remove_wait_list(&list, &w2); + TEST_ASSERT_EQUAL(3, validate_and_get_size(&list)); + + // Remove front node + _remove_wait_list(&list, &w1); + TEST_ASSERT_EQUAL(2, validate_and_get_size(&list)); + + // remove back node + _remove_wait_list(&list, &w4); + TEST_ASSERT_EQUAL(1, validate_and_get_size(&list)); + + // remove last node + _remove_wait_list(&list, &w3); + TEST_ASSERT_EQUAL(0, validate_and_get_size(&list)); + + TEST_ASSERT_EQUAL_PTR(NULL, list); + } + + /** + * Validate the linked list an return the number of elements + * + * If this list is invalid then this function asserts and does not + * return. + * + * Every node in a valid linked list has the properties: + * 1. node->prev->next == node + * 2. node->next->prev == node + */ + static int validate_and_get_size(Waiter **list) + { + Waiter *first = *list; + if (NULL == first) { + // List is empty + return 0; + } + + int size = 0; + Waiter *current = first; + do { + TEST_ASSERT_EQUAL_PTR(current, current->prev->next); + TEST_ASSERT_EQUAL_PTR(current, current->next->prev); + current = current->next; + size++; + } while (current != first); + return size; + } + +}; + utest::v1::status_t test_setup(const size_t number_of_cases) { GREENTEA_SETUP(10, "default_auto"); @@ -101,6 +175,7 @@ utest::v1::status_t test_setup(const size_t number_of_cases) Case cases[] = { Case("Test notify one", test_notify_one), Case("Test notify all", test_notify_all), + Case("Test linked list", TestConditionVariable::test_linked_list), }; Specification specification(test_setup, cases); diff --git a/rtos/ConditionVariable.cpp b/rtos/ConditionVariable.cpp index 418ec71bf7..69be38b936 100644 --- a/rtos/ConditionVariable.cpp +++ b/rtos/ConditionVariable.cpp @@ -20,7 +20,6 @@ * SOFTWARE. */ #include "rtos/ConditionVariable.h" -#include "rtos/Semaphore.h" #include "rtos/Thread.h" #include "mbed_error.h" @@ -28,17 +27,8 @@ namespace rtos { -#define RESUME_SIGNAL (1 << 15) -struct Waiter { - Waiter(); - Semaphore sem; - Waiter *prev; - Waiter *next; - bool in_list; -}; - -Waiter::Waiter(): sem(0), prev(NULL), next(NULL), in_list(false) +ConditionVariable::Waiter::Waiter(): sem(0), prev(NULL), next(NULL), in_list(false) { // No initialization to do } @@ -58,7 +48,7 @@ bool ConditionVariable::wait_for(uint32_t millisec) Waiter current_thread; MBED_ASSERT(_mutex.get_owner() == Thread::gettid()); MBED_ASSERT(_mutex._count == 1); - _add_wait_list(¤t_thread); + _add_wait_list(&_wait_list, ¤t_thread); _mutex.unlock(); @@ -68,7 +58,7 @@ bool ConditionVariable::wait_for(uint32_t millisec) _mutex.lock(); if (current_thread.in_list) { - _remove_wait_list(¤t_thread); + _remove_wait_list(&_wait_list, ¤t_thread); } return timeout; @@ -79,7 +69,7 @@ void ConditionVariable::notify_one() MBED_ASSERT(_mutex.get_owner() == Thread::gettid()); if (_wait_list != NULL) { _wait_list->sem.release(); - _remove_wait_list(_wait_list); + _remove_wait_list(&_wait_list, _wait_list); } } @@ -88,41 +78,50 @@ void ConditionVariable::notify_all() MBED_ASSERT(_mutex.get_owner() == Thread::gettid()); while (_wait_list != NULL) { _wait_list->sem.release(); - _remove_wait_list(_wait_list); + _remove_wait_list(&_wait_list, _wait_list); } } -void ConditionVariable::_add_wait_list(Waiter * waiter) +void ConditionVariable::_add_wait_list(Waiter **wait_list, Waiter *waiter) { - if (NULL == _wait_list) { + if (NULL == *wait_list) { // Nothing in the list so add it directly. - // Update prev pointer to reference self - _wait_list = waiter; + // Update prev and next pointer to reference self + *wait_list = waiter; + waiter->next = waiter; waiter->prev = waiter; } else { // Add after the last element - Waiter *last = _wait_list->prev; - last->next = waiter; + Waiter *first = *wait_list; + Waiter *last = (*wait_list)->prev; + + // Update new entry + waiter->next = first; waiter->prev = last; - _wait_list->prev = waiter; + + // Insert into the list + first->prev = waiter; + last->next = waiter; } waiter->in_list = true; } -void ConditionVariable::_remove_wait_list(Waiter * waiter) +void ConditionVariable::_remove_wait_list(Waiter **wait_list, Waiter *waiter) { - // Remove this element from the start of the list - Waiter * next = waiter->next; - if (waiter == _wait_list) { - _wait_list = next; - } - if (next != NULL) { - next = waiter->prev; - } - Waiter * prev = waiter->prev; - if (prev != NULL) { - prev = waiter->next; + Waiter *prev = waiter->prev; + Waiter *next = waiter->next; + + // Remove from list + prev->next = waiter->next; + next->prev = waiter->prev; + *wait_list = waiter->next; + + if (*wait_list == waiter) { + // This was the last element in the list + *wait_list = NULL; } + + // Invalidate pointers waiter->next = NULL; waiter->prev = NULL; waiter->in_list = false; diff --git a/rtos/ConditionVariable.h b/rtos/ConditionVariable.h index b2f51f8a58..dd091ec3d4 100644 --- a/rtos/ConditionVariable.h +++ b/rtos/ConditionVariable.h @@ -25,6 +25,7 @@ #include #include "cmsis_os.h" #include "rtos/Mutex.h" +#include "rtos/Semaphore.h" #include "platform/NonCopyable.h" @@ -192,9 +193,17 @@ public: ~ConditionVariable(); -private: - void _add_wait_list(Waiter * waiter); - void _remove_wait_list(Waiter * waiter); +protected: + struct Waiter { + Waiter(); + Semaphore sem; + Waiter *prev; + Waiter *next; + bool in_list; + }; + + static void _add_wait_list(Waiter **wait_list, Waiter *waiter); + static void _remove_wait_list(Waiter **wait_list, Waiter *waiter); Mutex &_mutex; Waiter *_wait_list; };