#include "pthrpoll.h"

// 初始化任务队列
static int task_queue_init(TaskQueue *queue) {
    queue->head = NULL;
    queue->tail = NULL;
    queue->size = 0;
    if (pthread_mutex_init(&queue->mutex, NULL) != 0) {
        return -1;
    }
    if (pthread_cond_init(&queue->cond, NULL) != 0) {
        pthread_mutex_destroy(&queue->mutex);
        return -1;
    }
    return 0;
}

// 向任务队列中添加任务
static int task_queue_add(TaskQueue *queue, void (*function)(void *), void *argument) {
    Task *task = (Task *)malloc(sizeof(Task));
    if (task == NULL) {
        return -1;
    }
    
    task->function = function;
    task->argument = argument;
    task->next = NULL;

    pthread_mutex_lock(&queue->mutex);
    
    if (queue->tail == NULL) {
        queue->head = queue->tail = task;
    } else {
        queue->tail->next = task;
        queue->tail = task;
    }
    queue->size++;
    
    pthread_cond_signal(&queue->cond);
    pthread_mutex_unlock(&queue->mutex);
    return 0;
}

// 从任务队列中取出任务
static Task *task_queue_remove(TaskQueue *queue, ThreadPool *pool) {
    pthread_mutex_lock(&queue->mutex);
    
    while (queue->size == 0 && !pool->shutdown) {
        pthread_cond_wait(&queue->cond, &queue->mutex);
    }

    if (pool->shutdown && queue->size == 0) {
        pthread_mutex_unlock(&queue->mutex);
        return NULL;
    }

    Task *task = queue->head;
    queue->head = queue->head->next;
    
    if (queue->head == NULL) {
        queue->tail = NULL;
    }
    queue->size--;
    
    pthread_mutex_unlock(&queue->mutex);
    return task;
}

// 工作线程函数
static void *worker_thread(void *arg) {
    ThreadPool *pool = (ThreadPool *)arg;
    
    pthread_mutex_lock(&pool->mutex);
    pool->active_threads++;
    pthread_mutex_unlock(&pool->mutex);
    
    while (1) {
        Task *task = task_queue_remove(&pool->task_queue, pool);
        if (task == NULL) {
            break;  // 收到关闭信号且队列为空
        }
        
        task->function(task->argument);
        free(task->argument);  // 释放任务参数
        free(task);            // 释放任务本身
        
        // 通知回收线程可能有空闲线程
        pthread_cond_signal(&pool->reaper_cond);
    }
    
    pthread_mutex_lock(&pool->mutex);
    pool->active_threads--;
    pthread_mutex_unlock(&pool->mutex);
    
    return NULL;
}

// 回收线程函数
static void *reaper_thread(void *arg) {
    ThreadPool *pool = (ThreadPool *)arg;
    
    while (!pool->shutdown) {
        pthread_mutex_lock(&pool->mutex);
        
        // 检查是否需要回收线程
        while (pool->active_threads <= pool->min_threads || 
               (pool->task_queue.size > 0 && pool->active_threads <= pool->task_queue.size)) {
            pthread_cond_wait(&pool->reaper_cond, &pool->mutex);
            if (pool->shutdown) {
                pthread_mutex_unlock(&pool->mutex);
                return NULL;
            }
        }
        
        // 计算可以回收的线程数量
        int excess_threads = pool->active_threads - pool->min_threads;
        if (excess_threads > 0 && pool->active_threads > pool->min_threads) {
            // 通过添加空任务来让工作线程退出
            for (int i = 0; i < excess_threads; i++) {
                task_queue_add(&pool->task_queue, NULL, NULL);
            }
        }
        
        pthread_mutex_unlock(&pool->mutex);
        sleep(1);  // 避免过于频繁检查
    }
    
    return NULL;
}

// 初始化线程池
ThreadPool *thread_pool_init(int min_threads, int max_threads) {
    if (min_threads <= 0 || max_threads <= 0 || min_threads > max_threads) {
        errno = EINVAL;
        return NULL;
    }

    ThreadPool *pool = (ThreadPool *)malloc(sizeof(ThreadPool));
    if (pool == NULL) {
        return NULL;
    }

    pool->threads = (pthread_t *)malloc(max_threads * sizeof(pthread_t));
    if (pool->threads == NULL) {
        free(pool);
        return NULL;
    }

    if (task_queue_init(&pool->task_queue) != 0) {
        free(pool->threads);
        free(pool);
        return NULL;
    }

    pool->min_threads = min_threads;
    pool->max_threads = max_threads;
    pool->thread_count = min_threads;
    pool->active_threads = 0;
    pool->shutdown = 0;
    
    if (pthread_mutex_init(&pool->mutex, NULL) != 0 ||
        pthread_cond_init(&pool->reaper_cond, NULL) != 0) {
        pthread_mutex_destroy(&pool->task_queue.mutex);
        pthread_cond_destroy(&pool->task_queue.cond);
        free(pool->threads);
        free(pool);
        return NULL;
    }

    // 创建工作线程
    for (int i = 0; i < min_threads; i++) {
        if (pthread_create(&pool->threads[i], NULL, worker_thread, pool) != 0) {
            // 创建线程失败，关闭已创建的线程
            pool->shutdown = 1;
            pthread_cond_broadcast(&pool->task_queue.cond);
            
            for (int j = 0; j < i; j++) {
                pthread_join(pool->threads[j], NULL);
            }
            
            pthread_mutex_destroy(&pool->mutex);
            pthread_cond_destroy(&pool->reaper_cond);
            pthread_mutex_destroy(&pool->task_queue.mutex);
            pthread_cond_destroy(&pool->task_queue.cond);
            free(pool->threads);
            free(pool);
            return NULL;
        }
    }
    
    // 创建回收线程
    if (pthread_create(&pool->reaper_thread, NULL, reaper_thread, pool) != 0) {
        pool->shutdown = 1;
        pthread_cond_broadcast(&pool->task_queue.cond);
        
        for (int i = 0; i < min_threads; i++) {
            pthread_join(pool->threads[i], NULL);
        }
        
        pthread_mutex_destroy(&pool->mutex);
        pthread_cond_destroy(&pool->reaper_cond);
        pthread_mutex_destroy(&pool->task_queue.mutex);
        pthread_cond_destroy(&pool->task_queue.cond);
        free(pool->threads);
        free(pool);
        return NULL;
    }
    
    return pool;
}

// 向线程池添加任务
int thread_pool_add_task(ThreadPool *pool, void (*function)(void *), void *argument) {
    if (pool == NULL || function == NULL) {
        return -1;
    }

    pthread_mutex_lock(&pool->mutex);
    if (pool->shutdown) {
        pthread_mutex_unlock(&pool->mutex);
        return -1;
    }
    
    // 如果任务队列过长且可以创建更多线程，则创建新线程
    if (pool->task_queue.size > pool->active_threads && 
        pool->thread_count < pool->max_threads) {
        if (pthread_create(&pool->threads[pool->thread_count], NULL, worker_thread, pool) == 0) {
            pool->thread_count++;
        }
    }
    pthread_mutex_unlock(&pool->mutex);

    // 复制参数，确保生命周期
    void *arg_copy = malloc(sizeof(*argument));
    if (arg_copy == NULL) {
        return -1;
    }
    *(int *)arg_copy = *(int *)argument;

    return task_queue_add(&pool->task_queue, function, arg_copy);
}

// 销毁线程池
void thread_pool_destroy(ThreadPool *pool) {
    if (pool == NULL) {
        return;
    }

    pthread_mutex_lock(&pool->mutex);
    pool->shutdown = 1;
    pthread_mutex_unlock(&pool->mutex);

    // 唤醒所有线程
    pthread_cond_broadcast(&pool->task_queue.cond);
    pthread_cond_signal(&pool->reaper_cond);

    // 等待所有线程退出
    for (int i = 0; i < pool->thread_count; i++) {
        pthread_join(pool->threads[i], NULL);
    }
    pthread_join(pool->reaper_thread, NULL);

    // 清理剩余任务
    Task *task = pool->task_queue.head;
    while (task != NULL) {
        Task *next = task->next;
        if (task->argument) free(task->argument);
        free(task);
        task = next;
    }

    // 释放资源
    free(pool->threads);
    pthread_mutex_destroy(&pool->task_queue.mutex);
    pthread_cond_destroy(&pool->task_queue.cond);
    pthread_mutex_destroy(&pool->mutex);
    pthread_cond_destroy(&pool->reaper_cond);
    free(pool);
}