最近看到一个写的挺好的多任务框架(https://github.com/SwinTransformer/AiT,参考了detectron2),分享一下~
多任务训练一般分为两种:data mixing 和 batch mixing。简单来说,对于前者,一个batch中的样本可以来自不同任务,而后者一个batch中任务都是一样的。两者相比,后者实现更加容易,效率更高,并且做数据增强也更方便一点(由Pix2Seq提出,但其实数据增强我认为并不是两者的主要差异)。
先给出我写的一个batch mixing的例子(来自JJJYmmm/Pix2SeqV2-Pytorch):
def get_multi_task_loaders(tokenizer,tasks):
assert set(tasks) <= set(['detection', 'keypoint', 'segmentation', 'captioning'])
train_loaders = {}
valid_loaders = {}
if 'detection' in tasks:
detection_train_loader, detection_valid_loader = detection_loaders(
CFG.dir_root, tokenizer, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
train_loaders['detection'] = detection_train_loader
valid_loaders['detection'] = detection_valid_loader
if 'keypoint' in tasks:
keypoint_train_loader, keypoint_valid_loader = keypoint_loaders(
CFG.dir_root, tokenizer,person_kps_info, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
train_loaders['keypoint'] = keypoint_train_loader
valid_loaders['keypoint'] = keypoint_valid_loader
if 'segmentation' in tasks:
segmentation_train_loader, segmentation_valid_loader = segmentation_loaders(
CFG.dir_root, tokenizer,person_kps_info, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
train_loaders['segmentation'] = segmentation_train_loader
valid_loaders['segmentation'] = segmentation_valid_loader
if 'captioning' in tasks:
img_caption_train_loader, img_caption_valid_loader = img_caption_loaders(
CFG.dir_root, tokenizer,vocab, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
train_loaders['captioning'] = img_caption_train_loader
valid_loaders['captioning'] = img_caption_valid_loader
return train_loaders, valid_loaders
以上代码首先维护一个多任务的dataloader字典,在这里之前(即创建dataset)就可以针对不同任务做对应的数据增强。
# get longest dataloader
epoch_size = 0
longest_loader = None
for name, loader in train_loaders.items():
if len(loader) > epoch_size:
epoch_size = len(loader)
longest_loader = name
# create iter for dataloaders
loader_iters = dict()
for k, v in train_loaders.items():
if k != longest_loader:
loader_iters[k] = iter(v)
# iter longest dataloader
tqdm_object = tqdm(train_loaders[longest_loader], total=len(train_loaders[longest_loader]))
for iteration,(x, y, init_lens) in enumerate(tqdm_object):
optimizer.zero_grad()
total_loss = torch.zeros(1, requires_grad=False, device=CFG.device)
total_batch = x.size(0)
# loss_1
loss = cal_loss_multi_task(model, criterion, x, y, init_lens, task_id = task_ids[longest_loader])
total_loss = total_loss + loss.item() * task_weights[longest_loader]
loss *= task_weights[longest_loader] # mul weight
loss.backward()
# calculate other tasks' loss
for k, v in loader_iters.items():
try:
(x, y, init_lens) = next(v)
except StopIteration: # recover other tasks' iter
loader_iters[k] = iter(train_loaders[k])
(x, y, init_lens) = next(loader_iters[k])
total_batch += x.size(0)
# loss_i
loss = cal_loss_multi_task(model, criterion, x, y, init_lens, task_id=task_ids[k])
total_loss = total_loss + loss.item() * task_weights[k]
loss *= task_weights[k]
loss.backward()
# total_loss.backward()
optimizer.step()
训练时,首先确定batch数最多的任务(数据集A),把它作为训练的最外层,这一部分和单任务训练一致,对于其他任务,则分别创建一个迭代器iterator负责取数据。之后在循环数据集A的时候,每次算出loss_A后,会从其他迭代器中取出对应的数据并计算loss_B/loss_C...,之后根据任务权重对loss进行加权平均,并进行反向传播。
在上述代码中,为了节省显存,每次计算完loss,我都直接乘上权重反向传播了,这个好处在于,每个loss计算完后对应的计算图会被自动释放,如果显式显出加权平均,那么所有任务的计算图都会被保留~
可以看出,batch mixing其实只是在训练过程中加入多个数据集的batch,然后分别算出loss并反向传播罢了。(data mixing也差不多,只是粒度更细一点)
关于data mixing,其实就比batch mixing多了一步操作,就是把所有的数据集拼起来,然后对于每个样本都添加一个字段表示任务。取数据的时候直接从大数据集里面取就可以了。
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
到这里也就知道,data mixing的数据增强也很方便,在拼接数据集之前,各个数据集定义自己的增强方式即可。但是data mixing的最大问题在于:训练过程中,需要循环batch里的每个样本,根据任务的不同分配给不同的Heads处理loss,并行度其实很差。当然,对于大模型训练来说,有时候单张GPU就放一个样本,那这个劣势就相当于没有了~
最后总结就是,data mixing相比于batch mixing,任务粒度更细,并且对于多任务的支持更好(不像batch mixing每加一个任务就要改train代码,data mixing只需要写好数据集和对应处理的Head即可),但是并行度相比于batch mixing较差。