aboutsummaryrefslogtreecommitdiff
path: root/src/std/thread_pool.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/std/thread_pool.c')
-rw-r--r--src/std/thread_pool.c187
1 files changed, 187 insertions, 0 deletions
diff --git a/src/std/thread_pool.c b/src/std/thread_pool.c
new file mode 100644
index 0000000..8f6e180
--- /dev/null
+++ b/src/std/thread_pool.c
@@ -0,0 +1,187 @@
+#include "../../include/std/thread_pool.h"
+#include "../../include/std/alloc.h"
+#include "../../include/std/mem.h"
+#include "../../include/std/log.h"
+
+/* Sets @result to NULL if there are no available tasks */
+static CHECK_RESULT int thread_pool_take_task(amal_thread_pool *self, amal_thread_pool_task **result) {
+ *result = NULL;
+ cleanup_if_error(amal_mutex_lock(&self->task_select_mutex, "thread_pool_take_task"));
+ if(self->num_finished_queued_tasks < (int)buffer_get_size(&self->queued_tasks, amal_thread_pool_task) && !self->dead)
+ *result = buffer_get(&self->queued_tasks, self->num_finished_queued_tasks, sizeof(amal_thread_pool_task));
+ cleanup:
+ amal_mutex_tryunlock(&self->task_select_mutex);
+ return 0;
+}
+
+static void* thread_pool_thread_callback(void *userdata) {
+ amal_thread_pool_callback_data *thread_pool_data = userdata;
+ if(thread_pool_data->thread_pool->dead)
+ goto cleanup;
+
+ if(thread_pool_data->callback(thread_pool_data->userdata) != 0) {
+ thread_pool_mark_dead(thread_pool_data->thread_pool);
+ goto cleanup;
+ }
+
+ for(;;) {
+ amal_thread_pool_task *new_task;
+ cleanup_if_error(thread_pool_take_task(thread_pool_data->thread_pool, &new_task));
+ if(!new_task)
+ break;
+
+ if(new_task->callback(new_task->userdata) != 0) {
+ thread_pool_mark_dead(thread_pool_data->thread_pool);
+ goto cleanup;
+ }
+ }
+
+ cleanup:
+ thread_pool_data->thread_pool_thread->status = THREAD_POOL_THREAD_STATUS_IDLE;
+ am_free(thread_pool_data);
+ return NULL;
+}
+
+static void thread_pool_thread_init(amal_thread_pool_thread *self) {
+ am_memset(&self->thread, 0, sizeof(self->thread));
+ self->status = THREAD_POOL_THREAD_STATUS_NEW;
+}
+
+static void thread_pool_thread_deinit(amal_thread_pool_thread *self) {
+ ignore_result_int(amal_thread_deinit(&self->thread));
+}
+
+static CHECK_RESULT int thread_pool_thread_start(amal_thread_pool_thread *self, amal_thread_pool *thread_pool, amal_thread_job_callback callback, void *userdata) {
+ amal_thread_pool_callback_data *callback_data;
+ return_if_error(am_malloc(sizeof(amal_thread_pool_callback_data), (void**)&callback_data));
+ callback_data->thread_pool = thread_pool;
+ callback_data->thread_pool_thread = self;
+ callback_data->callback = callback;
+ callback_data->userdata = userdata;
+
+ return_if_error(amal_thread_deinit(&self->thread));
+ return_if_error(amal_thread_create(&self->thread, AMAL_THREAD_JOINABLE, "thread_pool_thread_start", thread_pool_thread_callback, callback_data));
+ self->status = THREAD_POOL_THREAD_STATUS_RUNNING;
+ return 0;
+}
+
+static CHECK_RESULT int thread_pool_thread_join(amal_thread_pool_thread *self) {
+ if(self->status == THREAD_POOL_THREAD_STATUS_NEW)
+ return 0;
+ return amal_thread_join(&self->thread, NULL);
+}
+
+int thread_pool_init(amal_thread_pool *self, int num_threads) {
+ int i;
+ self->num_threads = num_threads != 0 ? num_threads : amal_get_usable_thread_count();
+ if(self->num_threads == 0) {
+ amal_log_warning("Unable to get the number of threads available on the system, using 1 thread.");
+ self->num_threads = 1;
+ }
+
+ self->dead = bool_false;
+ self->num_finished_queued_tasks = 0;
+ self->threads = NULL;
+
+ ignore_result_int(buffer_init(&self->queued_tasks, NULL));
+ return_if_error(amal_mutex_init(&self->task_select_mutex));
+ cleanup_if_error(am_malloc(self->num_threads * sizeof(amal_thread_pool_thread), (void**)&self->threads));
+ for(i = 0; i < self->num_threads; ++i)
+ thread_pool_thread_init(&self->threads[i]);
+ return 0;
+
+ cleanup:
+ am_free(self->threads);
+ self->num_threads = 0;
+ return -1;
+}
+
+void thread_pool_deinit(amal_thread_pool *self) {
+ if(self->threads) {
+ int i;
+ for(i = 0; i < self->num_threads; ++i)
+ thread_pool_thread_deinit(&self->threads[i]);
+ am_free(self->threads);
+ }
+ amal_mutex_deinit(&self->task_select_mutex);
+ buffer_deinit(&self->queued_tasks);
+}
+
+int thread_pool_add_task(amal_thread_pool *self, amal_thread_job_callback callback, void *userdata) {
+ int i;
+ bool found_available_thread = bool_false;
+ int result = -1;
+
+ if(self->dead)
+ return result;
+
+ cleanup_if_error(amal_mutex_lock(&self->task_select_mutex, "thread_pool_add_task"));
+ for(i = 0; i < self->num_threads; ++i) {
+ amal_thread_pool_thread *thread = &self->threads[i];
+ if(thread->status != THREAD_POOL_THREAD_STATUS_RUNNING) {
+ cleanup_if_error(thread_pool_thread_start(thread, self, callback, userdata));
+ found_available_thread = bool_true;
+ break;
+ }
+ }
+
+ if(!found_available_thread) {
+ amal_thread_pool_task task;
+ task.callback = callback;
+ task.userdata = userdata;
+ cleanup_if_error(buffer_append(&self->queued_tasks, &task, sizeof(task)));
+ }
+
+ result = 0;
+ cleanup:
+ amal_mutex_tryunlock(&self->task_select_mutex);
+ return result;
+}
+
+bool thread_pool_join_all_tasks(amal_thread_pool *self) {
+ bool died;
+ for(;;) {
+ /*
+ Joining running threads. After checking one running thread another one might start up,
+ so this is mostly to wait for threads to finish and to sleep without doing work.
+ The check after that (thread_pool_check_threads_finished) check that all threads have finished correctly
+ */
+ int i;
+ bool finished = bool_true;
+ for(i = 0; i < self->num_threads; ++i) {
+ amal_thread_pool_thread_status thread_status;
+ if(amal_mutex_lock(&self->task_select_mutex, "thread_pool_join_all_tasks") != 0)
+ thread_pool_mark_dead(self);
+ thread_status = self->threads[i].status;
+ amal_mutex_tryunlock(&self->task_select_mutex);
+ /* TODO: What to do if join fails? */
+ switch(thread_status) {
+ case THREAD_POOL_THREAD_STATUS_NEW:
+ break;
+ case THREAD_POOL_THREAD_STATUS_RUNNING:
+ finished = bool_false;
+ /* fallthrough */
+ case THREAD_POOL_THREAD_STATUS_IDLE:
+ ignore_result_int(thread_pool_thread_join(&self->threads[i]));
+ break;
+ }
+ }
+
+ if(finished)
+ break;
+ }
+
+ died = self->dead;
+ self->dead = bool_false;
+ buffer_clear(&self->queued_tasks);
+ self->num_finished_queued_tasks = 0;
+ return !died;
+}
+
+void thread_pool_mark_dead(amal_thread_pool *self) {
+ self->dead = bool_true;
+}
+
+BufferView thread_pool_get_threads(amal_thread_pool *self) {
+ return create_buffer_view((const char*)self->threads, self->num_threads * sizeof(amal_thread_pool_thread));
+}