Skip to content

池化异步执行任务

1. 异步线程池

在异步编程中,有时需要在后台执行一些阻塞的 I/O 操作或 CPU 密集型任务。使用线程池可以避免阻塞事件循环,同时限制并发线程数量。

1.1 基本使用

python
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor

async def main():
    start = time.perf_counter()
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor(max_workers=4) as executor:
        tasks_list = [loop.run_in_executor(executor, time.sleep, 2) for _ in range(16)]
        await asyncio.gather(*tasks_list)
        end = time.perf_counter()
        print(f"Time elapsed: {end - start}")

asyncio.run(main())

在这个例子中,16 个任务在 4 个线程中执行,总时间约为 8 秒(16 / 4 × 2)。

1.2 执行自定义函数

python
import asyncio
from concurrent.futures import ThreadPoolExecutor

def blocking_io_task(name):
    print(f"Starting {name}")
    import time
    time.sleep(1)  # 模拟阻塞 I/O
    print(f"Finished {name}")
    return f"Result from {name}"

async def main():
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor(max_workers=3) as executor:
        # 方法 1:使用 loop.run_in_executor
        futures = [
            loop.run_in_executor(executor, blocking_io_task, f"task-{i}")
            for i in range(5)
        ]
        results = await asyncio.gather(*futures)
        for result in results:
            print(result)

asyncio.run(main())

1.3 使用 asyncio.to_thread (Python 3.9+)

Python 3.9+ 提供了更简洁的 asyncio.to_thread() 函数:

python
import asyncio
import time

def blocking_function(name):
    print(f"Starting {name}")
    time.sleep(1)
    return f"Completed {name}"

async def main():
    # 使用 asyncio.to_thread,无需手动创建线程池
    tasks = [
        asyncio.to_thread(blocking_function, f"task-{i}")
        for i in range(5)
    ]
    results = await asyncio.gather(*tasks)
    print(results)

asyncio.run(main())

1.4 传递多个参数

python
import asyncio
from concurrent.futures import ThreadPoolExecutor

def calculate(x, y, operation):
    import time
    time.sleep(0.5)
    if operation == "add":
        return x + y
    elif operation == "multiply":
        return x * y
    return None

async def main():
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor() as executor:
        # 使用 functools.partial 或 lambda
        from functools import partial
        
        tasks = [
            loop.run_in_executor(executor, partial(calculate, 10, 20, "add")),
            loop.run_in_executor(executor, partial(calculate, 10, 20, "multiply")),
        ]
        
        results = await asyncio.gather(*tasks)
        print(results)  # [30, 200]

asyncio.run(main())

2. 异步进程池

对于 CPU 密集型任务,由于 Python 的 GIL(全局解释器锁),使用多进程比多线程更有效。

2.1 基本使用

python
from multiprocessing import Pool
import time

def say_hello(name) -> str:
    time.sleep(1)
    return f"hello, {name}"

if __name__ == "__main__":
    time_start = time.perf_counter()
    with Pool() as pool:
        hi1_async = pool.apply_async(say_hello, args=("satori",))
        hi2_async = pool.apply_async(say_hello, args=("koishi",))
        print(hi1_async.get())
        print(hi2_async.get())
    time_end = time.perf_counter()
    print(f"Time elapsed: {time_end - time_start:.2f}s")

2.2 使用 map 批量处理

python
from multiprocessing import Pool

def square(x):
    return x * x

if __name__ == "__main__":
    with Pool(processes=4) as pool:
        # map 会阻塞直到所有结果返回
        results = pool.map(square, range(10))
        print(results)  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

2.3 使用 imap 进行迭代处理

imap 返回一个迭代器,可以逐个获取结果:

python
from multiprocessing import Pool
import time

def process_item(x):
    time.sleep(0.5)
    return x * x

if __name__ == "__main__":
    with Pool(processes=4) as pool:
        # imap 返回迭代器,可以逐个获取结果
        for result in pool.imap(process_item, range(10)):
            print(f"Got result: {result}")

2.4 使用 starmap 处理多参数

python
from multiprocessing import Pool

def multiply(x, y):
    return x * y

if __name__ == "__main__":
    data = [(1, 2), (3, 4), (5, 6)]
    with Pool() as pool:
        results = pool.starmap(multiply, data)
        print(results)  # [2, 12, 30]

2.5 异步版本的 map

python
from multiprocessing import Pool
import time

def slow_square(x):
    time.sleep(1)
    return x * x

if __name__ == "__main__":
    start = time.perf_counter()
    with Pool(processes=4) as pool:
        # map_async 立即返回,不会阻塞
        result = pool.map_async(slow_square, range(8))
        
        # 可以在等待时做其他事情
        print("Doing other work...")
        
        # 获取结果,这里会阻塞
        results = result.get()
        print(results)
    
    end = time.perf_counter()
    print(f"Time elapsed: {end - start:.2f}s")

3. 在 asyncio 中使用进程池

可以将进程池与 asyncio 结合使用:

python
import asyncio
from concurrent.futures import ProcessPoolExecutor

def cpu_intensive_task(n):
    """CPU 密集型任务"""
    total = 0
    for i in range(n):
        total += i ** 2
    return total

async def main():
    loop = asyncio.get_event_loop()
    with ProcessPoolExecutor(max_workers=4) as executor:
        tasks = [
            loop.run_in_executor(executor, cpu_intensive_task, 1000000)
            for _ in range(8)
        ]
        results = await asyncio.gather(*tasks)
        print(f"Results: {results}")

if __name__ == "__main__":
    asyncio.run(main())

4. 线程池 vs 进程池

4.1 性能对比

python
import time
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor

def cpu_bound(n):
    """CPU 密集型任务"""
    return sum(i * i for i in range(n))

def io_bound(n):
    """I/O 密集型任务"""
    time.sleep(0.1)
    return n

def benchmark(executor_class, task, task_name, data):
    start = time.perf_counter()
    with executor_class(max_workers=4) as executor:
        results = list(executor.map(task, data))
    end = time.perf_counter()
    print(f"{executor_class.__name__} with {task_name}: {end - start:.2f}s")

if __name__ == "__main__":
    cpu_data = [1000000] * 8
    io_data = [1] * 20
    
    print("CPU 密集型任务:")
    benchmark(ThreadPoolExecutor, cpu_bound, "CPU-bound", cpu_data)
    benchmark(ProcessPoolExecutor, cpu_bound, "CPU-bound", cpu_data)
    
    print("\nI/O 密集型任务:")
    benchmark(ThreadPoolExecutor, io_bound, "IO-bound", io_data)
    benchmark(ProcessPoolExecutor, io_bound, "IO-bound", io_data)

4.2 选择指南

使用线程池的场景

  • I/O 密集型任务(网络请求、文件读写)
  • 需要共享内存的任务
  • 任务启动开销小的场景

使用进程池的场景

  • CPU 密集型任务(数值计算、图像处理)
  • 需要绕过 GIL 限制
  • 任务相互独立,不需要共享状态

5. 错误处理

5.1 线程池错误处理

python
import asyncio
from concurrent.futures import ThreadPoolExecutor

def may_fail(x):
    if x == 5:
        raise ValueError(f"Cannot process {x}")
    return x * 2

async def main():
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor() as executor:
        tasks = [
            loop.run_in_executor(executor, may_fail, i)
            for i in range(10)
        ]
        
        # gather 会在遇到第一个异常时抛出
        try:
            results = await asyncio.gather(*tasks)
        except ValueError as e:
            print(f"Caught error: {e}")
        
        # 使用 return_exceptions=True 来收集所有结果和异常
        results = await asyncio.gather(*tasks, return_exceptions=True)
        for i, result in enumerate(results):
            if isinstance(result, Exception):
                print(f"Task {i} failed: {result}")
            else:
                print(f"Task {i} succeeded: {result}")

asyncio.run(main())

5.2 进程池错误处理

python
from multiprocessing import Pool

def risky_operation(x):
    if x == 5:
        raise ValueError(f"Cannot process {x}")
    return x * 2

if __name__ == "__main__":
    with Pool() as pool:
        try:
            result = pool.apply_async(risky_operation, (5,))
            print(result.get())  # 这里会抛出异常
        except ValueError as e:
            print(f"Caught error: {e}")
        
        # 使用 error_callback
        def error_callback(e):
            print(f"Error occurred: {e}")
        
        result = pool.apply_async(
            risky_operation,
            (5,),
            error_callback=error_callback
        )
        
        # 等待所有任务完成
        pool.close()
        pool.join()

6. 最佳实践

6.1 正确设置工作进程/线程数

python
import os
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor

# 对于 I/O 密集型任务,可以使用更多线程
io_workers = min(32, (os.cpu_count() or 1) * 5)

# 对于 CPU 密集型任务,通常使用 CPU 核心数
cpu_workers = os.cpu_count() or 1

print(f"I/O workers: {io_workers}")
print(f"CPU workers: {cpu_workers}")

6.2 使用上下文管理器

始终使用上下文管理器(with 语句)来确保资源正确清理:

python
# 推荐
with ThreadPoolExecutor() as executor:
    # 使用 executor
    pass
# executor 会自动关闭

# 不推荐
executor = ThreadPoolExecutor()
# 使用 executor
executor.shutdown(wait=True)  # 需要手动关闭

6.3 避免在进程池中使用不可序列化的对象

进程间通信需要序列化对象,某些对象无法序列化:

python
from multiprocessing import Pool

# 错误示例:lambda 函数无法被序列化
def wrong_example():
    with Pool() as pool:
        # 这会失败!
        results = pool.map(lambda x: x * 2, range(10))

# 正确示例:使用普通函数
def double(x):
    return x * 2

def correct_example():
    with Pool() as pool:
        results = pool.map(double, range(10))
        print(results)

if __name__ == "__main__":
    correct_example()

7. 参考资料

  1. Python 官方文档 - concurrent.futures
  2. Python 官方文档 - multiprocessing
  3. Python 官方文档 - asyncio