Source code for ssds.core.data_parallel

from torch.nn.parallel import DataParallel
import torch
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.parallel_apply import parallel_apply


def scatter(inputs, target_gpus, chunk_sizes, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """

    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            try:
                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
            except:
                print("obj", obj.size())
                print("dim", dim)
                print("chunk_sizes", chunk_sizes)
                quit()
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
    r"""Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs


[docs]class BalancedDataParallel(DataParallel): """ This class is used to replace the original pytorch DataParallel and balance the first GPU memory usage. The original script is from: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/utils/data_parallel.py """ def __init__(self, gpu0_bsz, *args, **kwargs): self.gpu0_bsz = gpu0_bsz super().__init__(*args, **kwargs) def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) if self.gpu0_bsz == 0: device_ids = self.device_ids[1:] else: device_ids = self.device_ids inputs, kwargs = self.scatter(inputs, kwargs, device_ids) # print('len(inputs)1: ', str(len(inputs))) # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) if self.gpu0_bsz == 0: replicas = replicas[1:] outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) return self.gather(outputs, self.output_device) def parallel_apply(self, replicas, device_ids, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, device_ids[: len(inputs)]) def scatter(self, inputs, kwargs, device_ids): bsz = inputs[0].size(self.dim) num_dev = len(self.device_ids) gpu0_bsz = self.gpu0_bsz bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) if gpu0_bsz < bsz_unit: chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) delta = bsz - sum(chunk_sizes) for i in range(delta): chunk_sizes[i + 1] += 1 if gpu0_bsz == 0: chunk_sizes = chunk_sizes[1:] else: return super().scatter(inputs, kwargs, device_ids) # print('bsz: ', bsz) # print('num_dev: ', num_dev) # print('gpu0_bsz: ', gpu0_bsz) # print('bsz_unit: ', bsz_unit) # print('chunk_sizes: ', chunk_sizes) return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)