小天管理 发表于 2024年6月17日 发表于 2024年6月17日 最近在找一个简单的 C++11 线程池实现,发现网上有很多相关的代码,在 CSDN 网上看到一个比较简洁的。但是总感觉是不是实现错了。 Any 类 noncopyable 的,仅仅支持移动语义, Result 类使用了 Any 实例作为成员变量,那么 Result 类应该也是 noncopyable 的, Result SubmitTask(std::shared_ptr<Task> taskPtr);直接使用了复制语义,应该是有问题吧,可是代码能够被 vs2022 正常编译。 threadpool.h #pragma once #include <vector> #include <cstdint> #include <queue> #include <memory> #include <atomic> #include <mutex> #include <thread> #include <condition_variable> #include <functional> #include <sstream> #include <unordered_map> // Any 类型:可以接收任意数据的类型 // 任意其他类型 template // 能让一个类型指向其他类型,基类指针可以指向子类 class Any { public: Any() = default; ~Any() = default; Any(const Any&) = delete; Any& operator=(const Any&) = delete; Any(Any&&) = default; Any& operator=(Any&&) = default; template<typename T> Any(T data) : m_base(std::make_unique<Derive<T>>(data)) {} template<typename T> T cast_() { Derive<T>* pd = dynamic_cast<Derive<T>*>(m_base.get()); if (pd == nullptr) { throw "type is unmath!!"; } return pd->m_data; } private: // 基类 class Base { public: virtual ~Base() = default; }; // 派生类 template<typename T> class Derive : public Base { public: Derive(T data) : m_data(data) {} public: T m_data; }; private: std::unique_ptr<Base> m_base; }; // 实现一个信号量类 class Semaphore { public: Semaphore(int limit = 0) : m_resLimit(limit) {} ~Semaphore() = default; // 获取一个信号量资源 void wait() { std::unique_lock<std::mutex> lock(m_mtx); // 如果没有资源,阻塞线程 while (m_resLimit < 1) { m_cond.wait(lock); } m_resLimit--; } // 增加一个信号量资源 void post() { std::unique_lock<std::mutex> lock(m_mtx); m_resLimit++; m_cond.notify_all(); } private: int m_resLimit; // 资源量 std::mutex m_mtx; std::condition_variable m_cond; }; // Task 类型前置声明 class Task; // 实现接收提交到线程池的 task 任务执行完成后的返回值类型 class Result { public: Result(std::shared_ptr<Task> task, bool isValid = true); ~Result() = default; // setVal void setVal(Any result); // get 方法,用户调用这个方法获取 task 的返回值 Any get(); private: Any m_any; Semaphore m_sem; std::shared_ptr<Task> m_task; std::atomic_bool m_isValid; }; // 任务抽象基类 class Task { public: void exec(); void setResult(Result* res); virtual Any run() = 0; private: Result* m_result{ nullptr }; // 不要用智能指针,task 含有 Result Result 含有 task ,可能导致问题 }; class MyTask : public Task { public: MyTask(int start, int end) : m_start(start), m_end(end) {} Any run() { std::ostringstream ostr; ostr << std::this_thread::get_id(); printf("thead %s, task start \n", ostr.str().c_str()); uint64_t sum = 0; for (int i = m_start; i <= m_end; i++) { sum += i; } printf("sum %llu\n", sum); std::this_thread::sleep_for(std::chrono::seconds(2)); printf("thread %s, task finish \n", ostr.str().c_str()); return sum; } private: int m_start; int m_end; }; enum ThreadPoolMode { MODE_FIXED, // 固定数量的线程 MODE_CACHED, // 线程数量可以动态增长 }; class Thread { public: using ThreadFunc = std::function<void(int)>; Thread(ThreadFunc func); ~Thread(); void Start(); int GetId() { return m_threadId; } private: ThreadFunc m_func; static int generateId; int m_threadId; }; class ThreadPool { public: ThreadPool(); ~ThreadPool(); // 设置线程池工作模式 void SetMode(ThreadPoolMode mode); // 设置任务数量上限 void SetTaskQueMaxThreshold(int value); // 给线程池提交任务 Result SubmitTask(std::shared_ptr<Task> taskPtr); // 开启线程池 void Start(int initThreadSize = std::thread::hardware_concurrency()); private: ThreadPool(const ThreadPool&) = delete; ThreadPool& operator=(const ThreadPool&) = delete; // 定义线程函数 void ThreadFunc(int threadId); bool CheckRunningState() const; private: std::unordered_map<int, std::unique_ptr<Thread>> m_threadMap; // 线程列表 int m_initThreadSize; // 初始的线程数量 std::atomic_int m_curThreadSize; // 当前线程数量 std::queue<std::shared_ptr<Task>> m_taskQue; // 任务队列 std::atomic_int m_taskSize; // 任务的数量 int m_taskQueMaxThreshold; // 任务队列的数量上限 std::mutex m_taskQueMtx; // 保证任务队列的线程安全 std::condition_variable m_taskQueNotFullCv; // 表示任务队列不满 std::condition_variable m_taskQueNotEmptyCv; // 表示任务队列不空 std::condition_variable m_exitCv; // 退出线程池 ThreadPoolMode m_poolMode; // 当前线程池的工作模式 std::atomic_bool m_isPoolRuning; // 当前线程工作状态 }; threadpool.cpp #include "threadpool.h" #include <functional> #include <iostream> constexpr int TASK_MAX_THRESHOLD = 1024; ThreadPool::ThreadPool() : m_initThreadSize(4), m_taskSize(0), m_taskQueMaxThreshold(TASK_MAX_THRESHOLD), m_poolMode(ThreadPoolMode::MODE_FIXED) { } ThreadPool::~ThreadPool() { m_isPoolRuning = false; std::unique_lock<std::mutex> lock(m_taskQueMtx); // 线程 要么在阻塞中 要么在工作中 while (m_threadMap.size() > 0) { m_taskQueNotEmptyCv.notify_all(); // 唤醒等待的工作线程 m_exitCv.wait(lock); } } void ThreadPool::SetMode(ThreadPoolMode mode) { if (m_isPoolRuning) { return; } // 线程池启动后,不允许设置线程池一些参数 m_poolMode = mode; } void ThreadPool::SetTaskQueMaxThreshold(int value) { if (m_isPoolRuning) { return; } m_taskQueMaxThreshold = value; } Result ThreadPool::SubmitTask(std::shared_ptr<Task> taskPtr) { // 获取锁 std::unique_lock<std::mutex> lock(m_taskQueMtx); // 线程通信,检查任务队列是否有空余 while (m_taskQue.size() >= m_taskQueMaxThreshold) { // 用于提交任务,不能阻塞太长时间,如果超过 1s ,给用户返回提交失败 if (m_taskQueNotFullCv.wait_for(lock, std::chrono::seconds(1)) == std::cv_status::timeout) { return Result(taskPtr, false); } } // 如果有空余,把任务提交到任务队列中 m_taskQue.emplace(taskPtr); m_taskSize++; // 因为新放了任务,任务队列肯定不为空了,在 m_taskQueNotEmptyCv 进行通知,赶快分配线程执行这个任务 m_taskQueNotEmptyCv.notify_all(); return Result(taskPtr); } void ThreadPool::Start(int initThreadSize) { m_initThreadSize = initThreadSize; m_curThreadSize = initThreadSize; m_isPoolRuning = true; // 创建线程对象 for (int i = 0; i < m_initThreadSize; i++) { auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1)); int threadId = ptr->GetId(); m_threadMap.emplace(threadId, std::move(ptr)); } // 启动所有线程 for (auto iter = m_threadMap.cbegin(); iter != m_threadMap.end(); iter++) { iter->second->Start(); } } void ThreadPool::ThreadFunc(int threadId) { while (true) { // 获取锁 std::unique_lock<std::mutex> lock(m_taskQueMtx); std::ostringstream ostr; ostr << std::this_thread::get_id(); printf("thead %s, To Get task \n", ostr.str().c_str()); // 判断任务队列是否为空 while (m_taskQue.empty()) { if (!m_isPoolRuning) { m_threadMap.erase(threadId); m_exitCv.notify_all(); printf("deconstructor thread exit, id = %d\n", threadId); return; } m_taskQueNotEmptyCv.wait(lock); } printf("thead %s, Getted task \n", ostr.str().c_str()); // 不为空,获取任务 auto taskPtr = m_taskQue.front(); // front()返回引用,auto 忽略引用属性,正好满足需要 m_taskQue.pop(); m_taskSize--; lock.unlock(); // 释放锁; // 如果任务队列还有任务,通知其他线程执行任务 if (m_taskQue.size() > 0) { m_taskQueNotEmptyCv.notify_all(); } // 通知队列已经不满 m_taskQueNotFullCv.notify_all(); taskPtr->exec(); if (!m_isPoolRuning) { m_threadMap.erase(threadId); m_exitCv.notify_all(); printf("deconstructor thread exit, id = %d\n", threadId); return; } } } bool ThreadPool::CheckRunningState() const { if (m_isPoolRuning) { return true; } return false; } // 线程方法 int Thread::generateId = 0; Thread::Thread(ThreadFunc func) : m_func(func), m_threadId(generateId++) { } Thread::~Thread() { } void Thread::Start() { std::thread t(m_func, m_threadId); t.detach(); } Result::Result(std::shared_ptr<Task> task, bool isValid) : m_task(task), m_isValid(isValid) { m_task->setResult(this); } void Result::setVal(Any result) { m_any = std::move(result); m_sem.post(); // 通知已经获得结果 } Any Result::get() { if (!m_isValid) { return ""; } m_sem.wait(); // 等待结果 return std::move(m_any); } void Task::exec() { if (m_result != nullptr) { Any result = run(); // 这里发生多态调用 m_result->setVal(std::move(result)); } } void Task::setResult(Result* res) { m_result = res; } main.cpp #include "threadpool.h" #include <chrono> #include <iostream> using std::cout; using std::endl; int main(int argc, char* argv[]) { { ThreadPool pool; pool.Start(4); Result res1 = pool.SubmitTask(std::make_shared<MyTask>(1, 100000000)); Result res2 = pool.SubmitTask(std::make_shared<MyTask>(100000001, 200000000)); Result res3 = pool.SubmitTask(std::make_shared<MyTask>(200000001, 300000000)); //uint64_t sum1 = res1.get().cast_<uint64_t>(); //uint64_t sum2 = res2.get().cast_<uint64_t>(); //uint64_t sum3 = res3.get().cast_<uint64_t>(); //cout << (sum1 + sum2 + sum3) << endl; } cout << "main over" << endl; getchar(); return 0; }
已推荐帖子