Pytorch 多任务训练框架

深度学习 · 2023-11-15 · 487 人浏览

最近看到一个写的挺好的多任务框架(https://github.com/SwinTransformer/AiT,参考了detectron2),分享一下~

多任务训练一般分为两种:data mixing 和 batch mixing。简单来说,对于前者,一个batch中的样本可以来自不同任务,而后者一个batch中任务都是一样的。两者相比,后者实现更加容易,效率更高,并且做数据增强也更方便一点(由Pix2Seq提出,但其实数据增强我认为并不是两者的主要差异)。

先给出我写的一个batch mixing的例子(来自JJJYmmm/Pix2SeqV2-Pytorch):

def get_multi_task_loaders(tokenizer,tasks):

    assert set(tasks) <= set(['detection', 'keypoint', 'segmentation', 'captioning'])

    train_loaders = {}
    valid_loaders = {}

    if 'detection' in tasks:
        detection_train_loader, detection_valid_loader = detection_loaders(
        CFG.dir_root, tokenizer, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['detection'] = detection_train_loader
        valid_loaders['detection'] = detection_valid_loader

    if 'keypoint' in tasks: 
        keypoint_train_loader, keypoint_valid_loader = keypoint_loaders(
        CFG.dir_root, tokenizer,person_kps_info, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['keypoint'] = keypoint_train_loader
        valid_loaders['keypoint'] = keypoint_valid_loader

    if 'segmentation' in tasks: 
        segmentation_train_loader, segmentation_valid_loader = segmentation_loaders(
        CFG.dir_root, tokenizer,person_kps_info, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['segmentation'] = segmentation_train_loader
        valid_loaders['segmentation'] = segmentation_valid_loader

    if 'captioning' in tasks:
        img_caption_train_loader, img_caption_valid_loader = img_caption_loaders(
        CFG.dir_root, tokenizer,vocab, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
        train_loaders['captioning'] = img_caption_train_loader
        valid_loaders['captioning'] = img_caption_valid_loader

    return train_loaders, valid_loaders

以上代码首先维护一个多任务的dataloader字典,在这里之前(即创建dataset)就可以针对不同任务做对应的数据增强。

# get longest dataloader
epoch_size = 0
longest_loader = None
for name, loader in train_loaders.items():
    if len(loader) > epoch_size:
        epoch_size = len(loader)
        longest_loader = name

# create iter for dataloaders
loader_iters = dict()
for k, v in train_loaders.items():
    if k != longest_loader:
        loader_iters[k] = iter(v)

# iter longest dataloader
tqdm_object = tqdm(train_loaders[longest_loader], total=len(train_loaders[longest_loader]))
for iteration,(x, y, init_lens) in enumerate(tqdm_object):

    optimizer.zero_grad()

    total_loss = torch.zeros(1, requires_grad=False, device=CFG.device)
    total_batch = x.size(0)
    # loss_1
    loss = cal_loss_multi_task(model, criterion, x, y, init_lens, task_id = task_ids[longest_loader])
    total_loss = total_loss + loss.item() * task_weights[longest_loader]
    loss *= task_weights[longest_loader] # mul weight
    loss.backward()
    # calculate other tasks' loss
    for k, v in loader_iters.items():
        try:
            (x, y, init_lens) = next(v)
        except StopIteration: # recover other tasks' iter
            loader_iters[k] = iter(train_loaders[k])
            (x, y, init_lens) = next(loader_iters[k])
        total_batch += x.size(0)
        # loss_i
        loss = cal_loss_multi_task(model, criterion, x, y, init_lens, task_id=task_ids[k])
        total_loss = total_loss + loss.item() * task_weights[k]
        loss *= task_weights[k]
        loss.backward()

    # total_loss.backward()
    optimizer.step()

训练时,首先确定batch数最多的任务(数据集A),把它作为训练的最外层,这一部分和单任务训练一致,对于其他任务,则分别创建一个迭代器iterator负责取数据。之后在循环数据集A的时候,每次算出loss_A后,会从其他迭代器中取出对应的数据并计算loss_B/loss_C...,之后根据任务权重对loss进行加权平均,并进行反向传播。

在上述代码中,为了节省显存,每次计算完loss,我都直接乘上权重反向传播了,这个好处在于,每个loss计算完后对应的计算图会被自动释放,如果显式显出加权平均,那么所有任务的计算图都会被保留~

可以看出,batch mixing其实只是在训练过程中加入多个数据集的batch,然后分别算出loss并反向传播罢了。(data mixing也差不多,只是粒度更细一点)

关于data mixing,其实就比batch mixing多了一步操作,就是把所有的数据集拼起来,然后对于每个样本都添加一个字段表示任务。取数据的时候直接从大数据集里面取就可以了。

class ConcatDataset(Dataset[T_co]):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        self.datasets = list(datasets)
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

到这里也就知道,data mixing的数据增强也很方便,在拼接数据集之前,各个数据集定义自己的增强方式即可。但是data mixing的最大问题在于:训练过程中,需要循环batch里的每个样本,根据任务的不同分配给不同的Heads处理loss,并行度其实很差。当然,对于大模型训练来说,有时候单张GPU就放一个样本,那这个劣势就相当于没有了~

最后总结就是,data mixing相比于batch mixing,任务粒度更细,并且对于多任务的支持更好(不像batch mixing每加一个任务就要改train代码,data mixing只需要写好数据集和对应处理的Head即可),但是并行度相比于batch mixing较差。

Pytorch
Theme Jasmine by Kent Liao