March 21, 2025
年前实现了一个小巧的神经网络训练框架 TinyTorch,使用 C++ 模拟 PyTorch 的 API,支持 MNIST 训练,不过当时只实现了 CPU 版本,即使依赖第三方 BLAS 加速 gemm,跑起来还是很慢的(更多的是因为没做什么优化),那为什么不把 CUDA 给支持上呢,说干就干,过完年就开搞,到现在基本完成了,代码已提交:
https://github.com/keith2018/TinyTorch
算子流程
这里还是 follow PyTorch 的 API,首先定义了 Device 枚举:
enum class Device { CPU, CUDA };
然后 Tensor 的实现类 TensorImpl 会带有一个 device_
成员,以及一个 TensorOperations 类型的成员 ops_
class TensorImpl {
...
protected:
...
Device device_ = Device::CPU;
TensorOperations *ops_ = nullptr;
};
TensorOperations 抽象了所有的 Tensor 操作,比如加减乘除、维度变换等等,TensorImpl 在创建的时候,根据 device_
类型创建对应的 CPU 或 CUDA 版本的 ops_
,运行时再调用 ops_
对应的操作,即实现了算子的转发,举个最基本的加法例子:
TensorImpl TensorImpl::operator+(const float &other) const {
TENSOR_CHECK_EMPTY_RET(*this, {});
return ops_->add(*this, other);
}
然后是 CUDA 版本的算子实现:
TensorImpl TensorOpsCUDA::add(const TensorImpl& a, const float& b) {
return opPair<OpCudaAdd>(a, b);
}
进一步 opPair<OpCudaAdd>()
的实现中会调用 CUDA Kernel 函数:
template <typename OP>
TensorImpl TensorOpsCUDA::opPair(const TensorImpl& a, float b) const {
auto result = TensorImpl::shape(a.shape(), a.device_);
kPairScalarSecondOp<OP><<<getGridSize(a.elemCount_), getBlockSize()>>>(
result.data_, a.data_, b, a.elemCount_);
CUDA_KERNEL_CHECK();
return result;
}
Kernel 的实现如下:
struct OpCudaAdd {
__device__ float operator()(const float a, const float b) const {
return a + b;
}
};
template <typename OP>
__global__ void kPairScalarSecondOp(float* c, const float* a, const float b,
const int32_t n) {
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
const OP opFunc;
if (index < n) {
c[index] = opFunc(a[index], b);
}
}
由于只是学习项目,这里还是和之前一样只支持 float 数据类型,一些特殊算子如 argmin、argmax 会在内部进行类型转换。
性能
目前共实现 CUDA Kernel 约 40个,主要包括 element-wise、reduce、index 等(矩阵乘法直接用了 cublas 的 gemm),大多只做了基础的优化,以及只实现了 Tensor 级别的 Kernel,还没有针对算子、算子融合做优化,也没有用到如 stream 并行、cuda graphs 等技术,优化的水很深,同时也需要结合实际需求来做。
内存分配上实现了个简单的缓存池,首先所有的分配 size 都 padding 到 512 的倍数,提高缓存复用概率:
#define ALLOC_ROUND(x) (((x) + 511) & ~511) // 512 bytes
void CachedAllocator::allocate(void** ptr, size_t size) {
size = ALLOC_ROUND(size);
...
}
然后维护两个列表:allocatedList_
和 freedList_
,分配的时候先尝试从 freedList_
中找,释放的时候不执行真正的释放,而是添加到 freedList_
,详见 Allocator.cpp 中的实现,这样下来如果缓存上限比实际内存峰值更高,就只会在第一个 batch 时真正调用到系统的 malloc,当然更进一步的优化是提前分配更大块的内存,从而减少系统 malloc 的次数,以及更精细的管理策略。
简单测试了下 MNIST 的训练耗时,在 batchSize 等训练参数都一致的情况下,CUDA 版本目前和 PyTorch 的耗时差异不大,在我的 T4 开发机上每个 Epoch 6s 左右,当然目前的优化大多都只针对 demo 的 MNIST 训练,项目本身也只是为了学习一下 CUDA 相关的东西。