Pytorch 5 : Distributed Training & Parallelism

导言

  • 内存使用监控与优化
  • 多GPU使用
  • 分布式训练:了解如何使用PyTorch进行多GPU训练或分布式训练。
  • 并行策略实现

在深度学习中,内存格式(memory_format)决定了张量数据在内存中的存储顺序,直接影响计算效率和硬件兼容性。以下是不同内存格式的区别及转换的作用:


显存估计

假设

  • 参数数量用 $N$ 表示(例如 7B = $7{,}000{,}000{,}000$)。

  • 数据类型字节数:

    • bf16/16-bit:2 bytes(每个数占 2 字节)
    • fp32/32-bit:4 bytes(每个数占 4 字节)
  • 梯度(grad)通常以 fp32 存储/累积(即使参数是 bf16,梯度多为 fp32,除非特意用了fp16梯度优化)。

  • AdamW 经典实现为每个参数保留 两个状态张量exp_avg(一阶矩)和 exp_avg_sq(二阶矩),通常以 fp32 存储(各占 N 个元素)。有时也叫m和v。

  • 在 bf16 模型常见做法:在 GPU 上保留 bf16 的模型副本(用于前向/反向计算)+ 保留 fp32 的 master weights(用于更新)。因此会同时存在 bf16 参数与 fp32 主权重(master)。

  • 这里不计算激活(activation)/临时缓冲/optim buffers(如梯度缩放器小额开销)/cuBLAS workspace 等小项。

公式

设 $N$ 为参数总数(scalar 参数个数),则:

  • 参数占用(若模型主存为 bf16):

    $$
    P_{\text{bf16}} = N \times 2\ \text{bytes}
    $$

  • 参数占用(若模型为 fp32):

    $$
    P_{\text{fp32}} = N \times 4\ \text{bytes}
    $$

  • 梯度(通常 fp32):

    $$
    G = N \times 4\ \text{bytes}
    $$

  • AdamW 状态(exp_avg + exp_avg_sq,fp32):

    $$
    A = 2 \times N \times 4 = N \times 8\ \text{bytes}
    $$

  • bf16 情况常见的额外项 — fp32 master weights(如果有):

    $$
    M = N \times 4\ \text{bytes}
    $$

合并常见场景:

  • 场景 A(典型 bf16 + fp32 master + fp32 grads + AdamW)

    $$
    \text{Total}{A} = P{\text{bf16}} + M + G + A
    = N(2 + 4 + 4 + 8)\ \text{bytes} = N\times 18\ \text{bytes}
    $$

    (等价:bf16 参数 2 字节 + master 4 + grad 4 + AdamW 两个状态共 8 = 18 bytes/param)

  • 场景 B(bf16,但没有 master,仍 fp32 grads + AdamW)(不常见,训练稳定性可能受影响):

    $$
    \text{Total}{B} = P{\text{bf16}} + G + A = N(2 + 4 + 8) = N\times 14\ \text{bytes}
    $$

  • 场景 C(全部 fp32:参数、梯度、AdamW)

    $$
    \text{Total}{C} = P{\text{fp32}} + G + A = N(4 + 4 + 8) = N\times 16\ \text{bytes}
    $$

小结:在常见实现下(bf16 + fp32 master + AdamW),约 18 bytes/param;全部 fp32 时约 16 bytes/param。如果去掉 master(不常见),bf16 情况 ~ 14 bytes/param

快速估算

  • 1B 等价于 0.931 GB
  • bf16 + fp32 master + AdamW(典型):大约 $18 \times N$ bytes,总 GB ≈ $18N / 1024^3$.
  • 全 fp32(典型):大约 $16 \times N$ bytes,总 GB ≈ $16N / 1024^3$.
  • bf16 无 master(不常见):大约 $14 \times N$ bytes,总 GB ≈ $14N / 1024^3$.

考虑并行策略

考虑TP和DP的脚本 https://gitcode.com/Ascend/MindSpeed-MM/blob/master/mindspeed_mm/tools/mem_analysis.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
************************************************************
driver mem = free + reserved

allocated = OS + CANN + Driver + GE + PTA
(Currently, Driver consume 3GB)

PTA = fragmentation + Multi-stream overhead + allocated

allocated = static_mem + active_mem + worksapce

In a model,
Optimizer: param_fp32, momentum, variance. All are FP32
Model: model_param. Often is bf16/fp16

In specific module(not precisely),
Linear: B * S * (C_in + C_out)
Conv: B * C_in * H_in * W_in + B * C_out * H_out * W_out
LayerNorm: B * S * H
Residual Connection: B * S * H

************************************************************

逐行解释(按脚本顺序)

  1. # optimizer
    m, v = fp32 * model_size, fp32 * model_size     # self.grad_data 
    fp32_param = fp32 * model_size
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228

    * `m` 与 `v`:Adam/AdamW 的两个状态向量(first moment `m`,second moment `v`),每个元素以 fp32 存储,所以大小 = `4 bytes * model_size`。脚本把它们列为 `m, v`。
    * `fp32_param`:如果把模型参数以 fp32 表示(即 master weights 的总大小),则大小为 `4 * model_size` bytes。
    这三项都表示**全模型维度**的字节数(还没除以任何并行切分因子)。

    2. `grad_data = fp32 * model_size / tp # megatron/core/distributed/param_and_grad_buffer.py:366`
    这里把**梯度缓冲**按 `tp` 切分:`grad_data` 表示当前设备需要存的梯度大小。原因是 **tensor parallel (TP)** 会把参数切分成 `tp` 份,所以每个 TP shard 只保存对应分片的梯度 → 所以除以 `tp`。
    注意脚本**没有**再除以 `dp`,因为在传统的 Data Parallel(没有 optimizer state sharding 的情况下),数据并行组内各副本最终都会拥有同一份“本地参数分片的梯度”,即梯度在 DP 维度上是\*\*被 replicate(通过 all-reduce 后每个 DP 副本都保存)\*\*的,所以不再继续除 `dp`。

    3. `main_params_shard = fp32_param / tp / dp # fp32_params`
    这是重点(`main_params_shard`):它表示**当前设备上存放的 fp32 主权重(master fp32 params) 大小**,脚本把总的 `fp32_param` 同时除以 `tp` 和 `dp`,含义如下:

    * `/ tp`:按 tensor-parallel 把参数切分,单个 TP 分片大小是 `1/tp`。
    * `/ dp`:把 master weights 也按 data-parallel 再切分(即主权重在 DP 维度上也被**分片**而不是 replicated)。这意味着脚本假设**optimizer/主权重做了跨 DP 的分片**(类似 ZeRO 的思想),所以每个 DP 副本不再保留全量的 master 权重,而只保留 `1/dp` 的份额。

    重要的区分(常见场景):

    * **没有 optimizer-state sharding(常规 Megatron)**:master weights 通常在 DP 维度是**replicated**的,只需要 `/tp`,**不应该**再 `/dp`。那时 `main_params_shard = fp32_param / tp`。
    * **启用了 optimizer-state sharding / ZeRO**:master/optimizer states 在 DP 维度会被分片,这时 `main_params_shard = fp32_param / tp / dp` 是正确的(每个物理 GPU 只保留一小块 master)。

    所以脚本里 `/dp` 的存在意味着它在估算时**把主权重也按 DP 分片**(即在做 ZeRO-like 的分布式优化器分片)。

    4. `optimizer = grad_data + main_params_shard + (m + v) / tp / dp`
    `optimizer` 总量是当前设备为了优化器需要常驻的内存和梯度:

    * `grad_data`:TP 分片的梯度(见上)
    * `main_params_shard`:当前设备的 fp32 master 权重分片(见上)
    * `(m + v) / tp / dp`:Adam 的两份状态 `m`、`v`(总和 `m+v`),脚本把它们也按 `tp` 和 `dp` 分片,表示 `m`、`v` 也被在 TP 和 DP 上分散(即 optimizer-state 被分片)。
    → 综上:这行计算的是 **每个设备上(静态/长期存在的)optimizer & grad 的峰值存储**(假设 optimizer 状态被分片到 DP)。

    再强调:如果没有做 DP 级别的 optimizer sharding,那么 `(m+v)` 只需 `/(tp)`,而不是再 `/dp`。


    ---

    核心点(对 `main_params_shard` 的直观理解)

    * `main_params_shard` = “**本卡要保存的 fp32 主权重(master weights)**”。
    * 是否要 `/dp`,取决于你是否在 DP 维度把主权重/optimizer states 做了 **分片(ZeRO)**:

    * **若没有 ZeRO**(常规 DP):主权重在 DP 上是**replicated**,应只 `/tp`,**不 `/dp`**。
    * **若有 ZeRO / optimizer-state sharding**:主权重在 DP 上也被切分,每张卡只保存 `1/dp` → 则应 `/tp/dp`。
    * 结论:脚本里写 `/tp / dp` 表示 “脚本假设 optimizer state(包括 master weights)在 DP 维度也被分片(节省每卡内存)”。

    ---

    ### 推理显存

    * 不需要 **optimizer**
    * 不需要 **grad**
    * 不需要 **fp32 master**
    * 只保留 **模型权重本身**

    常见精度下的参数显存

    | 精度 | bytes / param | 说明 |
    | ----------- | ------------- | ---------- |
    | fp32 | 4 | 几乎不用 |
    | bf16 / fp16 | 2 | 主流 |
    | int8 | 1 | 量化 |
    | int4 | 0.5 | GPTQ / AWQ |


    > **推理参数显存 ≈ N × 2 bytes(bf16/fp16)**

    👉 相比训练:**18 → 2 bytes / param(↓ 9×)**


    ---

    ### vLLM显存

    vLLM 的核心显存消耗分两块:

    #### (1) 模型参数(和普通推理一样)

    * bf16 / fp16
    * TP shard
    * **无 optimizer / grad**

    > 👉 这部分 **几乎不变**

    ---

    #### (2) KV Cache(重点‼️)

    vLLM **显存大头 = KV cache**,而不是参数。

    KV cache 结构:

    对 **每一层、每一个 token**:

    * Key: `[n_heads, head_dim]`
    * Value: `[n_heads, head_dim]`

    单 token KV cache 显存

    $$
    \text{KV/token} = 2 \times L \times H \times D \times \text{dtype\_bytes}
    $$

    其中:

    * 2:K + V
    * L:layer 数
    * H:heads
    * D:head_dim
    * dtype_bytes:通常 **2 bytes(fp16/bf16)**

    换成常见近似写法:

    $$
    \text{KV/token} \approx 2 \times L \times d_{\text{model}} \times 2
    = 4L d_{\text{model}} \text{ bytes}
    $$

    ---

    #### TP 对 KV cache 的影响(很关键)

    > **KV cache 也会按 TP shard!**

    * Attention heads 按 TP 切
    * 每张卡只保存 `H / TP` 的 KV

    $$
    \text{KV}*{\text{per GPU}} = \frac{\text{KV}*{\text{total}}}{\text{TP}}
    $$

    👉 **TP 对 KV cache 同样是线性下降**



    #### 每张 GPU 的显存组成

    $$
    \text{Total}*{\text{infer}} =
    \underbrace{\frac{N \times b_p}{TP}}*{\text{params}}
    +
    \underbrace{\frac{\text{KV}(B, S)}{TP}}_{\text{KV cache}}
    +
    \text{overhead}
    $$

    其中:

    * (b_p = 2) bytes(bf16)
    * (B):并发 request 数
    * (S):上下文长度
    * overhead ≈ 1–3 GB

    > **vLLM + TP 后,参数显存几乎“不是问题”,真正决定能跑多大 batch / 多长 context 的是 KV cache。**


    ## 内存优化

    ### 思路

    显存占用一般由:

    1. ​​参数显存​​: 模型权重,model init后并move2npu加载完;
    1. 问题是显存占用往往几倍大于 权重tensor size * dtype大小
    2. dtype一般是BF16 2字节
    2. 激活显存:forward时,传递的中间变量大小。
    1. forward时存在
    2. 一般使用重计算消除
    3. 梯度显存:用于反向的梯度值,一般认为和权重一样大(float32,4字节)。
    1. 分配时机:理应是backward时存在,但是megatron会在model init处提前声明buffer,
    2. 占比极大:FP32梯度一直是BF16参数显存的两倍,无论采取什么并行策略。
    3. 下降现象:megatron会先声明个大buffer,然后规整释放tensor出现下降现象
    4. 优化器(如AdaW每个参数占8字节,动量+方差):
    1. 组成:分成FP32参数副本和FP32的动量组成两部分;
    2. 分配时机:megatron会在梯度buffer声明后,forward前声明,FP32参数副本。
    3. 占比极大:两部分,每部分都有梯度显存那么大,但可以通过分布式优化器分配到每张卡上。
    4. 调用 optimizer.zero_grad(set_to_none=True) 会释放梯度显存,但​​优化器状态显存会保留​​直到优化器被销毁。

    ### **1. 内存格式类型及区别**

    #### (1) `contiguous_format`

    (默认连续格式,如 `NCHW`)**

    * **存储顺序**:数据按维度顺序连续存储。
    * 示例:对于形状为 `(N, C, H, W)` 的张量(批次、通道、高度、宽度),内存中按 `N→C→H→W` 顺序排列。
    * 内存布局:`NCHW` → `[N][C][H][W]`。
    * **适用场景**:
    * 大多数深度学习框架(如PyTorch)的默认格式。
    * 昇腾NPU、部分CPU场景下性能最佳。
    * **优点**:
    * 内存连续,访问效率高。
    * 兼容性强,支持所有硬件。

    #### (2) `channels_last`

    (通道最后格式,如 `NHWC`)

    * **存储顺序**:通道维度放在最后。
    * 示例:形状为 `(N, H, W, C)`,内存中按 `N→H→W→C` 顺序排列。
    * 内存布局:`NHWC` → `[N][H][W][C]`。
    * **适用场景**:
    * GPU上使用Tensor Core加速时(如混合精度训练)。
    * 某些卷积操作(如Depthwise Conv)在`NHWC`下效率更高。
    * **优点**:
    * 更适合并行计算,减少内存访问跳跃。
    * 在GPU上通常比`NCHW`快10%~30%。

    #### (3) `preserve_format`

    (保持原格式)

    * **作用**:保留输入张量的现有内存格式(不主动修改)。
    * 例如:若输入是`NHWC`,输出仍为`NHWC`;若输入是`NCHW`,输出仍为`NCHW`。
    * **适用场景**:
    * 当需要保持与输入一致的内存格式时(如模型中间层)。

    ---

    ### **2. 内存格式转换的作用**

    #### **(1) 优化计算效率**
    * **硬件适配**:
    * **GPU**:`NHWC`格式更适合利用Tensor Core加速(尤其是FP16计算)。
    * **NPU/ASIC**:可能仅支持`NCHW`(如昇腾NPU),需强制使用连续格式。
    * **示例**:
    ```python
    # 在GPU上使用NHWC加速卷积
    x = x.to(memory_format=torch.channels_last) # 转换为NHWC
    conv_output = conv(x) # 速度更快

(2) 减少内存碎片

  • 连续性保证
    • contiguous()to(memory_format=...) 可确保张量在内存中连续,避免因维度操作(如permute)导致内存不连续,从而减少计算错误或性能下降。

(3) 兼容性处理

  • 昇腾NPU限制
    • 昇腾NPU的torch_npu库目前仅支持contiguous_formatpreserve_format,强行使用channels_last会报错。
    • 需通过其他方式优化性能(如调整数据预处理或使用NPU专用算子)。

3. 昇腾NPU的特殊性

(1) 为何不支持 channels_last

  • 硬件设计差异
    • NPU的计算单元和内存控制器针对NCHW格式优化,强行使用NHWC可能导致内存访问冲突或计算效率下降。
  • 软件栈限制
    • 昇腾AI软件栈(CANN)可能未实现NHWC格式的底层算子支持。

(2) 替代优化方案

  1. 使用默认连续格式
    1
    x = x.permute(0,3,1,2).contiguous()  # 确保NCHW且内存连续
  2. 调整模型结构
    • 在数据输入前转换为NCHW,避免运行时转换。
  3. 使用NPU专用优化
    • 调用昇腾提供的高性能算子(如torch_npu.npu_format_cast)。

4. 代码示例对比

(1) GPU优化

(使用channels_last

1
2
3
4
# 在GPU上启用NHWC加速
x = torch.randn(1, 3, 224, 224).to("cuda").to(memory_format=torch.channels_last)
conv = nn.Conv2d(3, 64, kernel_size=3).to("cuda").to(memory_format=torch.channels_last)
output = conv(x) # 速度更快

(2) 昇腾NPU适配

(强制连续格式)

1
2
3
4
5
# 在NPU上使用默认NCHW格式
x = torch.randn(1, 3, 224, 224).npu() # 数据加载到NPU
x = x.permute(0,3,1,2).contiguous() # 确保连续内存
conv = nn.Conv2d(3, 64, kernel_size=3).npu()
output = conv(x) # 兼容NPU

总结

  • 区别contiguous_format(NCHW)是通用格式,channels_last(NHWC)适合GPU加速,preserve_format保持原格式。
  • 作用:转换内存格式可优化计算效率,但需适配硬件限制。
  • 昇腾NPU:强制使用NCHW,需通过contiguous()或调整数据预处理保证兼容性。

怎么使用GPU,怎么多GPU

在GPU上训练 就像你怎么把一个张量转移到GPU上一样,你要将神经网络转到GPU上。 如果CUDA可以用,让我们首先定义下我们的设备为第一个可见的cuda设备。

1
2
3
4
5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assume that we are on a CUDA machine, then this should print a CUDA device:

print(device) # cuda:0
1
2
3
net=Net()
net.to(device)
outputs = net(inputs)
1
2
3
input = torch.randn(1, 1, 32, 32)
inputs, labels = inputs.to(device), labels.to(device)
out = net(input)

多GPU

如果你想要来看到大规模加速,使用你的所有GPU,请查看:数据并行性(https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)。PyTorch 60 分钟入门教程:数据并行处理

http://pytorchchina.com/2018/12/11/optional-data-parallelism/

Author

Shaojie Tan

Posted on

2023-06-13

Updated on

2026-02-03

Licensed under