V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
wisefree
V2EX  ›  C++

请教大家一个 C++线程池的问题

  •  
  •   wisefree · 9 天前 · 1063 次点击

    最近在找一个简单的 C++11 线程池实现,发现网上有很多相关的代码,在 CSDN 网上看到一个比较简洁的。但是总感觉是不是实现错了。

    1. Any 类 noncopyable 的,仅仅支持移动语义,
    2. Result 类使用了 Any 实例作为成员变量,那么 Result 类应该也是 noncopyable 的,
    3. Result SubmitTask(std::shared_ptr<Task> taskPtr);直接使用了复制语义,应该是有问题吧,可是代码能够被 vs2022 正常编译。

    threadpool.h

    #pragma once
    #include <vector>
    #include <cstdint>
    #include <queue>
    #include <memory>
    #include <atomic>
    #include <mutex>
    #include <thread>
    #include <condition_variable>
    #include <functional>
    #include <sstream>
    #include <unordered_map>
    
    
    // Any 类型:可以接收任意数据的类型
    // 任意其他类型 template
    // 能让一个类型指向其他类型,基类指针可以指向子类
    class Any
    {
    public:
    	Any() = default;
    	~Any() = default;
    	Any(const Any&) = delete;
    	Any& operator=(const Any&) = delete;
    	Any(Any&&) = default;
    	Any& operator=(Any&&) = default;
    
    	template<typename T>
    	Any(T data) : m_base(std::make_unique<Derive<T>>(data)) {}
    
    	template<typename T>
    	T cast_()
    	{
    		Derive<T>* pd = dynamic_cast<Derive<T>*>(m_base.get());
    
    		if (pd == nullptr) {
    			throw "type is unmath!!";
    		}
    
    		return pd->m_data;
    	}
    
    private:
    	// 基类
    	class Base
    	{
    	public:
    		virtual ~Base() = default;
    	};
    
    	// 派生类
    	template<typename T>
    	class Derive : public Base
    	{
    	public:
    		Derive(T data) : m_data(data) {}
    	public:
    		T m_data;
    	};
    
    private:
    	std::unique_ptr<Base> m_base;
    };
    
    
    // 实现一个信号量类
    class Semaphore
    {
    public:
    	Semaphore(int limit = 0) : m_resLimit(limit)
    	{}
    
    	~Semaphore() = default;
    
    	// 获取一个信号量资源
    	void wait()
    	{
    		std::unique_lock<std::mutex> lock(m_mtx);
    		// 如果没有资源,阻塞线程
    		while (m_resLimit < 1) {
    			m_cond.wait(lock);
    		}
    
    		m_resLimit--;
    	}
    
    	// 增加一个信号量资源
    	void post()
    	{
    		std::unique_lock<std::mutex> lock(m_mtx);
    		m_resLimit++;
    		m_cond.notify_all();
    
    	}
    private:
    	int m_resLimit;  // 资源量
    	std::mutex m_mtx;
    	std::condition_variable m_cond;
    };
    
    
    // Task 类型前置声明
    class Task;
    
    // 实现接收提交到线程池的 task 任务执行完成后的返回值类型
    class Result
    {
    public:
    	Result(std::shared_ptr<Task> task, bool isValid = true);
    	~Result() = default;
    
    	// setVal
    	void setVal(Any result);
    
    	// get 方法,用户调用这个方法获取 task 的返回值
    	Any get();
    private:
    	Any m_any;
    	Semaphore m_sem;
    	std::shared_ptr<Task> m_task;
    	std::atomic_bool m_isValid;
    };
    
    
    // 任务抽象基类
    class Task
    {
    public:
    	void exec();
    	void setResult(Result* res);
    	virtual Any run() = 0;
    
    private:
    	Result* m_result{ nullptr };  // 不要用智能指针,task 含有 Result  Result 含有 task ,可能导致问题
    };
    
    class MyTask : public Task
    {
    public:
    	MyTask(int start, int end) : m_start(start), m_end(end) {}
    
    	Any run()
    	{
    		std::ostringstream ostr;
    		ostr << std::this_thread::get_id();
    		printf("thead %s, task start \n", ostr.str().c_str());
    
    		uint64_t sum = 0;
    
    		for (int i = m_start; i <= m_end; i++) {
    			sum += i;
    		}
    
    		printf("sum %llu\n", sum);
    		std::this_thread::sleep_for(std::chrono::seconds(2));
    		printf("thread %s, task finish \n", ostr.str().c_str());
    
    		return sum;
    	}
    
    private:
    	int m_start;
    	int m_end;
    };
    
    enum ThreadPoolMode
    {
    	MODE_FIXED,  // 固定数量的线程
    	MODE_CACHED,  // 线程数量可以动态增长
    };
    
    class Thread
    {
    public:
    	using ThreadFunc = std::function<void(int)>;
    
    	Thread(ThreadFunc func);
    	~Thread();
    
    	void Start();
    	int GetId() { return m_threadId; }
    private:
    	ThreadFunc m_func;
    	static int generateId;
    	int m_threadId;
    };
    
    
    class ThreadPool
    {
    public:
    	ThreadPool();
    	~ThreadPool();
    
    	// 设置线程池工作模式
    	void SetMode(ThreadPoolMode mode);
    
    	// 设置任务数量上限
    	void SetTaskQueMaxThreshold(int value);
    
    	// 给线程池提交任务
    	Result SubmitTask(std::shared_ptr<Task> taskPtr);
    
    	// 开启线程池
    	void Start(int initThreadSize = std::thread::hardware_concurrency());
    
    private:
    	ThreadPool(const ThreadPool&) = delete;
    	ThreadPool& operator=(const ThreadPool&) = delete;
    
    	// 定义线程函数
    	void ThreadFunc(int threadId);
    	bool CheckRunningState() const;
    
    private:
    	std::unordered_map<int, std::unique_ptr<Thread>> m_threadMap;  // 线程列表
    	int m_initThreadSize;  // 初始的线程数量
    	std::atomic_int m_curThreadSize;  // 当前线程数量
    
    	std::queue<std::shared_ptr<Task>> m_taskQue;  // 任务队列
    	std::atomic_int m_taskSize;  // 任务的数量
    	int m_taskQueMaxThreshold;  // 任务队列的数量上限
    
    	std::mutex m_taskQueMtx;  // 保证任务队列的线程安全
    	std::condition_variable m_taskQueNotFullCv;  // 表示任务队列不满
    	std::condition_variable m_taskQueNotEmptyCv;  // 表示任务队列不空
    	std::condition_variable m_exitCv;  // 退出线程池
    
    	ThreadPoolMode m_poolMode;  // 当前线程池的工作模式
    	std::atomic_bool m_isPoolRuning;  // 当前线程工作状态
    };
    

    threadpool.cpp

    #include "threadpool.h"
    #include <functional>
    #include <iostream>
    
    constexpr int TASK_MAX_THRESHOLD = 1024;
    
    ThreadPool::ThreadPool() : m_initThreadSize(4), m_taskSize(0),
    m_taskQueMaxThreshold(TASK_MAX_THRESHOLD),
    m_poolMode(ThreadPoolMode::MODE_FIXED)
    {
    }
    
    ThreadPool::~ThreadPool()
    {
    	m_isPoolRuning = false;
    	std::unique_lock<std::mutex> lock(m_taskQueMtx);
    
    	// 线程 要么在阻塞中 要么在工作中
    	while (m_threadMap.size() > 0) {
    		m_taskQueNotEmptyCv.notify_all();  // 唤醒等待的工作线程
    		m_exitCv.wait(lock);
    	}
    }
    
    void ThreadPool::SetMode(ThreadPoolMode mode)
    {
    	if (m_isPoolRuning) { return; }  // 线程池启动后,不允许设置线程池一些参数
    
    	m_poolMode = mode;
    }
    
    void ThreadPool::SetTaskQueMaxThreshold(int value)
    {
    	if (m_isPoolRuning) { return; }
    
    	m_taskQueMaxThreshold = value;
    }
    
    Result ThreadPool::SubmitTask(std::shared_ptr<Task> taskPtr)
    {
    	// 获取锁
    	std::unique_lock<std::mutex> lock(m_taskQueMtx);
    
    	// 线程通信,检查任务队列是否有空余
    	while (m_taskQue.size() >= m_taskQueMaxThreshold) {
    
    		// 用于提交任务,不能阻塞太长时间,如果超过 1s ,给用户返回提交失败
    		if (m_taskQueNotFullCv.wait_for(lock, std::chrono::seconds(1)) == std::cv_status::timeout) {
    			return Result(taskPtr, false);
    		}
    	}
    
    	// 如果有空余,把任务提交到任务队列中
    	m_taskQue.emplace(taskPtr);
    	m_taskSize++;
    
    	// 因为新放了任务,任务队列肯定不为空了,在 m_taskQueNotEmptyCv 进行通知,赶快分配线程执行这个任务
    	m_taskQueNotEmptyCv.notify_all();
    
    	return Result(taskPtr);
    }
    
    void ThreadPool::Start(int initThreadSize)
    {
    	m_initThreadSize = initThreadSize;
    	m_curThreadSize = initThreadSize;
        m_isPoolRuning = true;
    
    	// 创建线程对象
    	for (int i = 0; i < m_initThreadSize; i++) {
    		auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1));
    		int threadId = ptr->GetId();
    		m_threadMap.emplace(threadId, std::move(ptr));
    	}
    
    	// 启动所有线程
    	for (auto iter = m_threadMap.cbegin(); iter != m_threadMap.end(); iter++) {
    		iter->second->Start();
    	}
    }
    
    void ThreadPool::ThreadFunc(int threadId)
    {
    	while (true) {
    
    		// 获取锁
    		std::unique_lock<std::mutex> lock(m_taskQueMtx);
    
    		std::ostringstream ostr;
    		ostr << std::this_thread::get_id();
    		printf("thead %s, To Get task \n", ostr.str().c_str());
    
    		// 判断任务队列是否为空
    		while (m_taskQue.empty()) {
    			if (!m_isPoolRuning) {
    				m_threadMap.erase(threadId);
    				m_exitCv.notify_all();
    
    				printf("deconstructor thread exit, id = %d\n", threadId);
    				return;
    			}
                
    			m_taskQueNotEmptyCv.wait(lock);
    
    		}
    
    		printf("thead %s, Getted task \n", ostr.str().c_str());
    		// 不为空,获取任务
    		auto taskPtr = m_taskQue.front();  // front()返回引用,auto 忽略引用属性,正好满足需要
    		m_taskQue.pop();
    		m_taskSize--;
    
    		lock.unlock();  // 释放锁;
    
    		// 如果任务队列还有任务,通知其他线程执行任务
    		if (m_taskQue.size() > 0) {
    			m_taskQueNotEmptyCv.notify_all();
    		}
    
    		// 通知队列已经不满
    		m_taskQueNotFullCv.notify_all();
    
    		taskPtr->exec();
    
    		if (!m_isPoolRuning) {
    			m_threadMap.erase(threadId);
    			m_exitCv.notify_all();
    
    			printf("deconstructor thread exit, id = %d\n", threadId);
    			return;
    		}
    
    	}
    }
    
    bool ThreadPool::CheckRunningState() const
    {
    	if (m_isPoolRuning) {
    		return true;
    	}
    
    	return false;
    }
    
    // 线程方法
    int Thread::generateId = 0;
    
    Thread::Thread(ThreadFunc func) : m_func(func),
    								m_threadId(generateId++)
    {
    }
    
    Thread::~Thread()
    {
    }
    
    void Thread::Start()
    {
    	std::thread t(m_func, m_threadId);
    	t.detach();
    }
    
    Result::Result(std::shared_ptr<Task> task, bool isValid) : m_task(task), m_isValid(isValid)
    {
    	m_task->setResult(this);
    }
    
    void Result::setVal(Any result)
    {
    	m_any = std::move(result);
    	m_sem.post();  // 通知已经获得结果
    }
    
    Any Result::get()
    {
    	if (!m_isValid) {
    		return "";
    	}
    
    	m_sem.wait();  // 等待结果
    	return std::move(m_any);
    }
    
    
    void Task::exec()
    {
    	if (m_result != nullptr) {
    		Any result = run();  // 这里发生多态调用
    
    		m_result->setVal(std::move(result));
    	}
    }
    
    void Task::setResult(Result* res)
    {
    	m_result = res;
    }
    
    

    main.cpp

    #include "threadpool.h"
    
    #include <chrono>
    #include <iostream>
    
    using std::cout;
    using std::endl;
    
    
    int main(int argc, char* argv[])
    {
    	{
    		ThreadPool pool;
    		pool.Start(4);
    
    		Result res1 = pool.SubmitTask(std::make_shared<MyTask>(1, 100000000));
    		Result res2 = pool.SubmitTask(std::make_shared<MyTask>(100000001, 200000000));
    		Result res3 = pool.SubmitTask(std::make_shared<MyTask>(200000001, 300000000));
    
    		//uint64_t sum1 = res1.get().cast_<uint64_t>();
    		//uint64_t sum2 = res2.get().cast_<uint64_t>();
    		//uint64_t sum3 = res3.get().cast_<uint64_t>();
    
    		//cout << (sum1 + sum2 + sum3) << endl;
    	}
    
    	cout << "main over" << endl;
    
    	getchar();
    	return 0;
    }
    
    10 条回复    2024-06-18 08:56:06 +08:00
    donaldturinglee
        1
    donaldturinglee  
       9 天前   ❤️ 1
    你在 Github 按照 star 挑几个高星的看看
    ysc3839
        2
    ysc3839  
       9 天前   ❤️ 1
    shared_ptr 复制只是增加引用计数吧?底层对象没复制。
    zhaoloving
        3
    zhaoloving  
       9 天前 via Android   ❤️ 1
    有右值构造函数,函数返回一个右值就好了
    wisefree
        4
    wisefree  
    OP
       9 天前
    @donaldturinglee 好的,随手搜了一个,看到挺清晰的,就看了一下。。。
    wisefree
        5
    wisefree  
    OP
       9 天前
    @ysc3839 SubmitTask 函数返回了 Result 这个动作,是有 copyable 语义的,我也不太懂这个例子
    wisefree
        6
    wisefree  
    OP
       9 天前
    @zhaoloving 嗯嗯,我也想这么干,只是这个例子能编译通过,我没想明白
    ysc3839
        7
    ysc3839  
       9 天前   ❤️ 1
    @wisefree 看错了,如果说的是返回值,那大概是 RVO 消除了复制。
    leonshaw
        8
    leonshaw  
       9 天前 via Android
    Any(Any&&) = default;
    wisefree
        9
    wisefree  
    OP
       9 天前
    @ysc3839 是的,我也想过 RVO 最有可能
    xyz1001
        10
    xyz1001  
       8 天前   ❤️ 1
    返回的 Result 是临时变量,属于将亡值,也是右值的一种,走的右值拷贝构造
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   2358 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 25ms · UTC 15:55 · PVG 23:55 · LAX 08:55 · JFK 11:55
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.