Fix and add test for ConditionVariable

Fix the circular linked list handling in ConditionVariable and add a
test to validate the linked list implementation.
pull/5508/head
Russ Butler 2017-11-15 16:15:27 -06:00
parent 41eb565d9c
commit 93cf15d57c
3 changed files with 120 additions and 37 deletions

View File

@ -92,6 +92,80 @@ void test_notify_all()
t2.join(); t2.join();
} }
class TestConditionVariable : public ConditionVariable {
public:
static void test_linked_list(void)
{
Waiter *list = NULL;
Waiter w1;
Waiter w2;
Waiter w3;
Waiter w4;
TEST_ASSERT_EQUAL(0, validate_and_get_size(&list));
// Add 4 nodes
_add_wait_list(&list, &w1);
TEST_ASSERT_EQUAL(1, validate_and_get_size(&list));
_add_wait_list(&list, &w2);
TEST_ASSERT_EQUAL(2, validate_and_get_size(&list));
_add_wait_list(&list, &w3);
TEST_ASSERT_EQUAL(3, validate_and_get_size(&list));
_add_wait_list(&list, &w4);
TEST_ASSERT_EQUAL(4, validate_and_get_size(&list));
// Remove a middle node
_remove_wait_list(&list, &w2);
TEST_ASSERT_EQUAL(3, validate_and_get_size(&list));
// Remove front node
_remove_wait_list(&list, &w1);
TEST_ASSERT_EQUAL(2, validate_and_get_size(&list));
// remove back node
_remove_wait_list(&list, &w4);
TEST_ASSERT_EQUAL(1, validate_and_get_size(&list));
// remove last node
_remove_wait_list(&list, &w3);
TEST_ASSERT_EQUAL(0, validate_and_get_size(&list));
TEST_ASSERT_EQUAL_PTR(NULL, list);
}
/**
* Validate the linked list an return the number of elements
*
* If this list is invalid then this function asserts and does not
* return.
*
* Every node in a valid linked list has the properties:
* 1. node->prev->next == node
* 2. node->next->prev == node
*/
static int validate_and_get_size(Waiter **list)
{
Waiter *first = *list;
if (NULL == first) {
// List is empty
return 0;
}
int size = 0;
Waiter *current = first;
do {
TEST_ASSERT_EQUAL_PTR(current, current->prev->next);
TEST_ASSERT_EQUAL_PTR(current, current->next->prev);
current = current->next;
size++;
} while (current != first);
return size;
}
};
utest::v1::status_t test_setup(const size_t number_of_cases) utest::v1::status_t test_setup(const size_t number_of_cases)
{ {
GREENTEA_SETUP(10, "default_auto"); GREENTEA_SETUP(10, "default_auto");
@ -101,6 +175,7 @@ utest::v1::status_t test_setup(const size_t number_of_cases)
Case cases[] = { Case cases[] = {
Case("Test notify one", test_notify_one), Case("Test notify one", test_notify_one),
Case("Test notify all", test_notify_all), Case("Test notify all", test_notify_all),
Case("Test linked list", TestConditionVariable::test_linked_list),
}; };
Specification specification(test_setup, cases); Specification specification(test_setup, cases);

View File

@ -20,7 +20,6 @@
* SOFTWARE. * SOFTWARE.
*/ */
#include "rtos/ConditionVariable.h" #include "rtos/ConditionVariable.h"
#include "rtos/Semaphore.h"
#include "rtos/Thread.h" #include "rtos/Thread.h"
#include "mbed_error.h" #include "mbed_error.h"
@ -28,17 +27,8 @@
namespace rtos { namespace rtos {
#define RESUME_SIGNAL (1 << 15)
struct Waiter { ConditionVariable::Waiter::Waiter(): sem(0), prev(NULL), next(NULL), in_list(false)
Waiter();
Semaphore sem;
Waiter *prev;
Waiter *next;
bool in_list;
};
Waiter::Waiter(): sem(0), prev(NULL), next(NULL), in_list(false)
{ {
// No initialization to do // No initialization to do
} }
@ -58,7 +48,7 @@ bool ConditionVariable::wait_for(uint32_t millisec)
Waiter current_thread; Waiter current_thread;
MBED_ASSERT(_mutex.get_owner() == Thread::gettid()); MBED_ASSERT(_mutex.get_owner() == Thread::gettid());
MBED_ASSERT(_mutex._count == 1); MBED_ASSERT(_mutex._count == 1);
_add_wait_list(&current_thread); _add_wait_list(&_wait_list, &current_thread);
_mutex.unlock(); _mutex.unlock();
@ -68,7 +58,7 @@ bool ConditionVariable::wait_for(uint32_t millisec)
_mutex.lock(); _mutex.lock();
if (current_thread.in_list) { if (current_thread.in_list) {
_remove_wait_list(&current_thread); _remove_wait_list(&_wait_list, &current_thread);
} }
return timeout; return timeout;
@ -79,7 +69,7 @@ void ConditionVariable::notify_one()
MBED_ASSERT(_mutex.get_owner() == Thread::gettid()); MBED_ASSERT(_mutex.get_owner() == Thread::gettid());
if (_wait_list != NULL) { if (_wait_list != NULL) {
_wait_list->sem.release(); _wait_list->sem.release();
_remove_wait_list(_wait_list); _remove_wait_list(&_wait_list, _wait_list);
} }
} }
@ -88,41 +78,50 @@ void ConditionVariable::notify_all()
MBED_ASSERT(_mutex.get_owner() == Thread::gettid()); MBED_ASSERT(_mutex.get_owner() == Thread::gettid());
while (_wait_list != NULL) { while (_wait_list != NULL) {
_wait_list->sem.release(); _wait_list->sem.release();
_remove_wait_list(_wait_list); _remove_wait_list(&_wait_list, _wait_list);
} }
} }
void ConditionVariable::_add_wait_list(Waiter * waiter) void ConditionVariable::_add_wait_list(Waiter **wait_list, Waiter *waiter)
{ {
if (NULL == _wait_list) { if (NULL == *wait_list) {
// Nothing in the list so add it directly. // Nothing in the list so add it directly.
// Update prev pointer to reference self // Update prev and next pointer to reference self
_wait_list = waiter; *wait_list = waiter;
waiter->next = waiter;
waiter->prev = waiter; waiter->prev = waiter;
} else { } else {
// Add after the last element // Add after the last element
Waiter *last = _wait_list->prev; Waiter *first = *wait_list;
last->next = waiter; Waiter *last = (*wait_list)->prev;
// Update new entry
waiter->next = first;
waiter->prev = last; waiter->prev = last;
_wait_list->prev = waiter;
// Insert into the list
first->prev = waiter;
last->next = waiter;
} }
waiter->in_list = true; waiter->in_list = true;
} }
void ConditionVariable::_remove_wait_list(Waiter * waiter) void ConditionVariable::_remove_wait_list(Waiter **wait_list, Waiter *waiter)
{ {
// Remove this element from the start of the list
Waiter * next = waiter->next;
if (waiter == _wait_list) {
_wait_list = next;
}
if (next != NULL) {
next = waiter->prev;
}
Waiter *prev = waiter->prev; Waiter *prev = waiter->prev;
if (prev != NULL) { Waiter *next = waiter->next;
prev = waiter->next;
// Remove from list
prev->next = waiter->next;
next->prev = waiter->prev;
*wait_list = waiter->next;
if (*wait_list == waiter) {
// This was the last element in the list
*wait_list = NULL;
} }
// Invalidate pointers
waiter->next = NULL; waiter->next = NULL;
waiter->prev = NULL; waiter->prev = NULL;
waiter->in_list = false; waiter->in_list = false;

View File

@ -25,6 +25,7 @@
#include <stdint.h> #include <stdint.h>
#include "cmsis_os.h" #include "cmsis_os.h"
#include "rtos/Mutex.h" #include "rtos/Mutex.h"
#include "rtos/Semaphore.h"
#include "platform/NonCopyable.h" #include "platform/NonCopyable.h"
@ -192,9 +193,17 @@ public:
~ConditionVariable(); ~ConditionVariable();
private: protected:
void _add_wait_list(Waiter * waiter); struct Waiter {
void _remove_wait_list(Waiter * waiter); Waiter();
Semaphore sem;
Waiter *prev;
Waiter *next;
bool in_list;
};
static void _add_wait_list(Waiter **wait_list, Waiter *waiter);
static void _remove_wait_list(Waiter **wait_list, Waiter *waiter);
Mutex &_mutex; Mutex &_mutex;
Waiter *_wait_list; Waiter *_wait_list;
}; };