March 26, 2025
最近对之前写的 TinyTorch 的内存管理(Allocator)进行了升级,主要是 Copy 了 PyTorch 的实现 CUDACachingAllocator.cpp,去除了其中 stream、cuda graph、expandable segments 以及数据统计相关的逻辑,仅保留最基础的 DeviceCachingAllocator
的核心实现,对第一次接触这块的同学来说会更容易看懂,完整代码详见:Allocator.cpp
基本思路
首先建议读下这篇文章:
该文章对 PyTorch CUDA 内存管理做了宏观介绍,并给出了伪代码,这里总结一下:
- 对要分配的 size 进行向上取整,比如 512 的倍数,从而提高 cache 复用概率进而减少碎片;
- 实际分配的时候,分配更大尺寸的内存块,这一块剩下的部分可以供之后的分配使用,从而减少调用系统
cudaMalloc
的次数; - 区别对待小内存分配和大内存分配,使用不同的内存池管理,减少碎片产生;
当然还有其它的一些设计,比如每个 stream 使用独立内存池、通过 Events 跟踪对象生命周期、OOM 处理等,在我的项目里目前都不涉及,先忽略。
数据结构
首先是 Block
,字段如下:
struct Block {
size_t size; // block size in bytes
BlockPool* pool; // owning memory pool
void* ptr; // memory address
bool allocated; // in-use flag
Block* prev; // prev block if split from a larger allocation
Block* next; // next block if split from a larger allocation
};
Block 是内存块的抽象,调用系统 cudaMalloc
分配出来的一个大块内存上可以创建多个 Block, 通过 prev
、next
指针形成一个双链表结构,存在两种情况:
- 分配内存的时候,如果大块内存比需求的更大,可以在后面
split
出一个空闲 block,供之后的分配使用; - 释放内存的时候,如果双链表前后两个 block 都是空闲的,就可以
merge
成一个(更大的) block,供之后的分配使用;
然后是缓存池 BlockPool
的定义:
typedef bool (*Comparison)(const Block*, const Block*);
struct BlockPool {
std::set<Block*, Comparison> blocks;
const bool isSmall;
};
BlockPool 管理的都是空闲的 block,通过 std::set 来存储(自动排序可以提高查找性能),如前面所说区别对待大、小内存分配,所以用个 isSmall 字段来标识。
最后是 Allocator :
class CachedAllocatorImpl {
...
private:
BlockPool largeBlocks;
BlockPool smallBlocks;
std::unordered_map<void*, Block*> activeBlocks;
};
Allocator 的 activeBlocks 存储所有已对外分配(使用中)的 blocks,smallBlocks 和 largeBlocks 则是两个空闲的缓存池。
实现细节
从上面的的数据结构来看整体是比较简单直接的,接下来看看具体的实现,主要包括分配、释放、缓存清理三部分:
分配
1、对要分配的 size 进行向上取整:
constexpr size_t kMinBlockSize = 512; // all sizes are rounded to at least 512 bytes
static size_t roundSize(size_t size) {
if (size < kMinBlockSize) {
return kMinBlockSize;
}
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
}
即取整为 512 的倍数,这里做了简化,实际上 PyTorch 还有个 roundup_power2_divisions() 的逻辑,用来提高大内存分配的复用概率。
2、根据要分配的 size 获取对应的缓存池,以 1MB 为分界线
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
BlockPool& getPool(size_t size) {
if (size <= kSmallSize) {
return smallBlocks;
}
return largeBlocks;
}
3、计算要分配的大块内存的大小:
constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks
constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
static size_t getAllocationSize(size_t size) {
if (size <= kSmallSize) {
return kSmallBuffer;
}
if (size < kMinLargeAlloc) {
return kLargeBuffer;
}
return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
}
- <= 1MB:分配 2MB
- > 1MB && < 10MB:分配 20MB
- > 10MB:向上取整到 2MB 的倍数
4、从步骤2的缓存池中查找满足大小需求的空闲块:
static bool getFreeBlock(AllocParams& p) {
BlockPool& pool = *p.pool;
const auto it = pool.blocks.lower_bound(&p.searchKey);
if (it == pool.blocks.end()) {
return false;
}
p.retBlock = *it;
pool.blocks.erase(it);
return true;
}
这里的 searchKey 就是步骤3返回的要分配的大小,通过 lower_bound
获取最小的 >= searchKey 的 block
5、如果步骤4没有找到空闲块,则调用系统接口进行分配
bool allocBlock(AllocParams& p) {
size_t size = p.allocSize;
void* ptr = nullptr;
base_->allocate(&ptr, size);
if (!ptr) {
return false;
}
totalAllocatedSize_ += size;
p.retBlock = new Block(size, p.pool, ptr);
assert(p.retBlock != nullptr && p.retBlock->ptr != nullptr);
return true;
}
这里的 base_->allocate() 是一个封装,会进一步调用到 cudaMalloc
,分配成功后会 new Block 创建一个新的 block
6、经过步骤4、5 得到一个 block(可能来自缓存也可能来自系统分配),接下来判断是否可以 split
static bool shouldSplit(const Block* block, size_t size) {
size_t remaining = block->size - size;
if (block->pool->isSmall) {
return remaining >= kMinBlockSize;
}
return remaining > kSmallSize;
}
即剩余的大小是否大于最小分配大小
7、如果可以 split
,则 split 出一个新的 block
if (shouldSplit(block, size)) {
Block* remaining = block;
block = new Block(size, &pool, block->ptr);
block->prev = remaining->prev;
if (block->prev) {
block->prev->next = block;
}
block->next = remaining;
remaining->prev = block;
remaining->ptr = static_cast<char*>(remaining->ptr) + size;
remaining->size -= size;
pool.blocks.insert(remaining);
}
这里主要是双链表的指针操作,然后将新的 block(remaining)添加到缓存池
8、到这里,整个分配过程已经完成,设置标记,将 block 存到已分配列表 activeBlocks,并将指针返回给调用方
block->allocated = true;
activeBlocks[block->ptr] = block;
*ptr = block->ptr;
整个分配过程的流程串联可以查看 mallocImpl
函数。
释放
1、从已分配列表 activeBlocks 中查找输入指针对应的 block,如果存在,调用 freeImpl()
进行 block 的释放,并从 activeBlocks 中移除
auto it = activeBlocks.find(ptr);
if (it != activeBlocks.end()) {
freeImpl(it->second);
activeBlocks.erase(it);
} else {
LOGE("deallocate error, ptr not valid: %p", ptr);
}
2、清除标记
block->allocated = false;
3、尝试双链表相邻 block 之间的 merge
const std::array<Block*, 2> mergeCandidates = {block->prev, block->next};
for (Block* candidate : mergeCandidates) {
tryMergeBlocks(block, candidate, pool);
}
static size_t tryMergeBlocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated) {
return 0;
}
assert(dst->isSplit() && src->isSplit());
if (dst->prev == src) {
// [src dst]
dst->ptr = src->ptr;
dst->prev = src->prev;
if (dst->prev) {
dst->prev->next = dst;
}
} else {
// [dest src]
dst->next = src->next;
if (dst->next) {
dst->next->prev = dst;
}
}
const size_t subsumedSize = src->size;
dst->size += subsumedSize;
pool.blocks.erase(src);
delete src;
return subsumedSize;
}
首先 merge 只会发生在当前 block 的 prev 和 next 两个点,然后只有空闲块之间才可以 merge,merge 本身就是双链表节点的删除,然后更新 dst block 的 size
4、将 block 添加到缓存池
pool.blocks.insert(block);
至此,整个释放流程完成,可以看到主要是先尝试 merge,然后归还缓存池
缓存清理
前面的 mallocImpl 中有一段 OOM 处理的逻辑:
blockFound = allocBlock(params);
if (!blockFound) {
// retry after release caches
releaseCachedBlocks();
blockFound = allocBlock(params);
}
if (!blockFound) {
LOGE("Out of memory. failed to allocate size: %lld", allocSize);
return nullptr;
}
即在 allocBlock 失败后,会 releaseCachedBlocks
清理缓存,然后重试分配,现在来看下清理缓存的实现:
1、释放 small、large 两个缓存池中的完整(未分割过)空闲 block :
void releaseCachedBlocks() {
releaseBlocks(largeBlocks);
releaseBlocks(smallBlocks);
}
void releaseBlocks(BlockPool& pool) {
// Frees all non-split blocks
auto it = pool.blocks.begin();
while (it != pool.blocks.end()) {
Block* block = *it;
++it;
if (!block->prev && !block->next) {
releaseBlock(block);
}
}
}
根据 prev、next 指针来判断是否完整,从前面的释放逻辑分析可知,如果存在分割,则必然有部分 block 在 activeBlock 中,所以不能释放
2、对完整的空闲 block 调用系统接口进行释放
void releaseBlock(Block* block) {
base_->deallocate(block->ptr);
totalAllocatedSize_ -= block->size;
auto* pool = block->pool;
pool->blocks.erase(block);
delete block;
}
这里 base_->deallocate 是对系统内存释放接口的封装,会进一步调到 cudaFree
效果实测
引入上述 Allocator 机制后,TinyTorch 项目中的 MNIST demo 完整跑一次,cudaMalloc
调用次数从之前的 283 次下降到了 18 次,优化效果明显。