发布于 

triton-01-softmax

介绍

Triton是OpenAI开发的一种编程语言,帮助没有CUDA经验的开发者快速编写高性能GPU算子,实现加速。

作为上手的第一次尝试,本文以Softmax为例,体验一下triton的编写流程。Softmax是深度学习中一个基础算子,可用来将原始分数(Logits)转化为概率分布,常用于分类网络。

一维场景:$X,Y \in \mathbb{R}^{N}$

二维场景:$X,Y \in \mathbb{R}^{M \times N}$

这里,$x_i$是一个长度为$N$的向量。

本文使用triton来实现一个二维softmax算子,并与torch softmax对比精度,确认实现的正确性。

Softmax实现

说明:
实验环境:AutoDL 容器 RTX 2080 Ti(11GB)
PyTorch 2.5.1
Python 3.12(ubuntu22.04)
CUDA 12.4

安装triton:

pip install triton

triton softmax实现(参考openai triton官网实现):

softmax计算流程图

代码:

import triton
import triton.language as tl

# 对二维Tensor计算Softmax
@triton.jit
def triton_softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
# 取某一行的索引
m = tl.program_id(0)

# 取某一列的索引
BLOCK_SIZE : tl.constexpr = 1024
n = tl.arange(0, BLOCK_SIZE)

# X取某一行的向量
X = X + m * stride_xm + n * stride_xn
# 加载行向量,如果超过N,就用inf填充
x = tl.load(X, mask=n < N, other=-float('inf'))

# 计算softmax
# 首先,x - max(x);减去最大值,提高数值稳定性
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)

y = num / denom # 分子/分母

# 写回到Y
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask= n < N)

# 调用Triton Softmax算子,与torch实现对比精度。
import torch
torch.manual_seed(42)
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)

# 设置SPMD的网格
grid = (X.shape[0], )
triton_softmax[grid](Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0], X.shape[1])

print(f'Y_triton = {Y}')

Y_torch = torch.nn.Softmax(dim=1)(X)
print(f'Y_torch = {Y_torch}')

ok = torch.allclose(Y, Y_torch)
print(f'equal ? {ok}')

输出结果:

Y_triton = tensor([[7.4673e-04, 5.3404e-03, 5.1782e-04,  ..., 3.1986e-03, 2.1141e-04,
5.1145e-04],
[1.9864e-03, 3.3511e-05, 2.1382e-03, ..., 7.2760e-04, 4.7654e-04,
3.8883e-03],
[1.8407e-03, 2.7075e-04, 7.4302e-04, ..., 8.6954e-04, 2.7257e-03,
2.6177e-03],
...,
[2.9115e-03, 1.3667e-04, 1.6691e-03, ..., 1.6625e-03, 2.2443e-03,
2.6430e-03],
[5.0886e-04, 2.1393e-04, 4.7878e-04, ..., 1.7784e-03, 1.0859e-04,
5.6565e-04],
[1.3451e-04, 2.1544e-03, 1.6709e-03, ..., 4.5203e-04, 9.8369e-05,
7.7892e-04]], device='cuda:0')
Y_torch = tensor([[7.4673e-04, 5.3404e-03, 5.1782e-04, ..., 3.1986e-03, 2.1140e-04,
5.1145e-04],
[1.9864e-03, 3.3511e-05, 2.1382e-03, ..., 7.2760e-04, 4.7654e-04,
3.8883e-03],
[1.8407e-03, 2.7075e-04, 7.4302e-04, ..., 8.6954e-04, 2.7257e-03,
2.6177e-03],
...,
[2.9115e-03, 1.3667e-04, 1.6691e-03, ..., 1.6625e-03, 2.2443e-03,
2.6430e-03],
[5.0886e-04, 2.1393e-04, 4.7879e-04, ..., 1.7784e-03, 1.0859e-04,
5.6565e-04],
[1.3451e-04, 2.1544e-03, 1.6709e-03, ..., 4.5203e-04, 9.8369e-05,
7.7892e-04]], device='cuda:0')
equal ? True

验证了triton和torch softmax输出结果一致。

补充:softmax数值计算技巧

当输入x值的范围很大时,exp(x)的数值也很大,会上溢为nan;当所有的x值都接近-inf时,sum(exp(x))就接近0,作为分母容易导致除法不稳定。

这里可以通过数学上的变换小技巧,通过减去x的最大值,可使exp()的值小于1,另外由于$x{max} - x{max} = 0$,由于0的存在,exp(0)=1,不会使得所有分母都接近0,推导如下:

参考

Introducing Triton: Open-source GPU programming for neural networks