这两天在看PyTorch DistributedDataParallel(DDP)相关文章,发现有个系列写的还不错。
- https://zhuanlan.zhihu.com/p/178402798
- https://zhuanlan.zhihu.com/p/187610959
- https://zhuanlan.zhihu.com/p/250471767
虽然讲的是torch.distributed.launch(快被torchrun替代),但是整个思路应该还是有参考意义的。
看的过程中遇到一些问题,顺便补几个知识点。
- contextmanager decorator
- SyncBN 这篇写的很细,推荐精读
补充SyncBN里的一个问题:2.1.5 eval部分,在torch 1.13版本里,只要满足eval模式或track_running_stats=True,就会使用统计量(
running_mean, running_var
)进行计算了。源码如下:# torch.nn.modules.batchnorm return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated self.running_mean if not self.training or self.track_running_stats else None, self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, bn_training, exponential_average_factor, self.eps,)