Monday, May 4, 2026
banner
Top Selling Multipurpose WP Theme

Collection on distributed AI throughout a number of GPUs:

introduction

In my earlier put up, I defined how distributed information parallelism (DDP) quickens coaching by splitting batches throughout GPUs. Though DDP solves the throughput downside, it introduces new challenges: reminiscence redundancy.

In vanilla DDP, each GPU maintains an entire copy of the mannequin parameters, gradients, and optimizer state. For big fashions like GPT-3 (175B parameters), this redundancy is a large waste of helpful VRAM.

Picture by creator: In common DDP, fashions, gradients, and optimizers are redundant throughout GPUs

ZeRO (Zero Redundancy Optimizer) solves this. There are three ranges:

  • zero one Partition-only optimizer state
  • zero-2 Partition Optimizer State + Gradient
  • ZeRO-3 Partition Optimizer State + Gradient + Mannequin Parameters

ZeRO is just not a parallel processing approach as a result of all GPUs nonetheless carry out the identical ahead and backward passes. it’s Reminiscence optimization A method that eliminates redundancy between GPUs and permits you to prepare bigger fashions on the identical {hardware}.

DDP reminiscence points

Let’s analyze what truly consumes reminiscence throughout coaching. For fashions with parameters:

  • mannequin parameters: worth (neural community weight)
  • gradation: Worth (one gradient per parameter)
  • Optimizer state (Adam): Worth (first and second second of every parameter)
  • activation: intermediate output saved through the ahead cross to be used within the backward cross

The primary three scales are the dimensions of the mannequin; Redundancy throughout GPUs In DDP. Activations scale in line with batch measurement, sequence size, and variety of neurons. Distinctive per GPU It is because every GPU processes totally different information. ZeRO doesn’t contact activation reminiscence.

Let’s calculate the reminiscence utilization of a 7B parameter mannequin utilizing Adam and FP32.

  • Parameter: 7 billion * 4 bytes = 28 GB
  • Gradient: 7 billion * 4 bytes = 28 GB
  • Optimizer state: 7 billion * 2 * 4 bytes = 56 GB
  • Reminiscence per GPU for DDP: 112 GB

Activation provides vital reminiscence on high of this, however since reminiscence is exclusive per GPU, ZeRO can not partition it. methods like Activation checkpoint Discard some activations and recalculate as wanted through the backward cross. However that’s exterior the scope of this text.

Perceive how ZeRO works by implementing ZeRO from scratch, beginning with ZeRO-1 and ending with ZeRO-3.

ZeRO-1: Optimizer state partitioning

In ZeRO-1, Optimizer state Partitioned. Every GPU:

  • nonetheless holding Full mannequin parameters and gradients
  • Retailer solely 1/N of optimizer state (N = variety of GPUs)
  • Replace solely the corresponding ones 1/N of the parameter

This can be a sequence of actions carried out throughout coaching.

  1. Ahead cross: Every GPU processes its personal micro-batches
  2. Backward cross: calculate slope
  3. all-reduce Gradient: All GPUs get all gradients
  4. Optimizer steps: Every GPU updates its parameter partition
  5. all-gather Parameters: Synchronize up to date fashions throughout GPUs
Picture by creator: Zero 1 Animation

Here is a simplified implementation:

import torch
import torch.distributed as dist


class ZeRO_1:
    def __init__(self, mannequin, optimizer_cls):
        self.mannequin = mannequin
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()

        self.param_shards = record()  # every rank holds solely its shard of the optimizer states
        self.param_metadata = record()  # metadata to reconstruct shards

        for param in self.mannequin.parameters():
            original_shape = param.information.form
            flat = param.information.view(-1)
            numel = flat.numel()

            the rest = numel % self.world_size
            pad_size = (self.world_size - the rest) % self.world_size
            padded_numel = numel + pad_size
            shard_size = padded_numel // self.world_size

            shard_start = self.rank * shard_size
            shard_end = shard_start + shard_size

            self.param_metadata.append(
                {
                    "original_shape": original_shape,
                    "numel": numel,
                    "padded_numel": padded_numel,
                    "shard_size": shard_size,
                    "shard_start": shard_start,
                    "shard_end": shard_end,
                }
            )

            if pad_size > 0:
                flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
            else:
                flat_padded = flat

            shard = flat_padded[shard_start:shard_end].clone()
            shard.requires_grad_(True)
            self.param_shards.append(shard)

        self.optimizer = optimizer_cls(self.param_shards)

    def training_step(self, inputs, targets, loss_fn):
        output = self.mannequin(inputs) # ahead
        loss = loss_fn(output, targets) # compute loss
        loss.backward() # backward

        self._sync_gradients()  # all-reduce gradients throughout GPUs
        self.optimizer.step() # replace native shard of parameters
        self._sync_params() # all collect mannequin params

        # clear gradients for the subsequent step
        for param in self.mannequin.parameters():
            param.grad = None

    def _sync_gradients(self):
        for idx, param in enumerate(self.mannequin.parameters()):
            meta = self.param_metadata[idx]

            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
            param.grad /= self.world_size

            self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]

    def _sync_params(self):
        for idx, param in enumerate(self.mannequin.parameters()):
            meta = self.param_metadata[idx]

            full_flat = torch.empty(meta["padded_numel"], system=param.system, dtype=param.dtype)
            dist.all_gather_into_tensor(
                output_tensor=full_flat,
                input_tensor=self.param_shards[idx].information,
            )
            
            reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
            param.information.copy_(reconstructed)

Please concentrate. all-reduce synchronize all There’s a gradient, however every GPU makes use of the gradient just for its personal parameter partition, resulting in extreme communication. ZeRO-2 fixes this downside by sharding the gradient as effectively.

In actuality, you’ll by no means use ZeRO-1. ZeRO-2 affords higher reminiscence financial savings at just about the identical value. Nonetheless, it’s nonetheless value contemplating for studying functions.

ZeRO-1, 7B mannequin, reminiscence with 8 GPUs:

  • Parameters: 28 GB (totally replicated)
  • Gradient: 28 GB (totally replicated)
  • Optimizer state: 56 GB / 8 = 7 GB
  • Whole per GPU: 63 GB (down from GB)

ZeRO-2: Gradient Partitioning

ZeRO-2 makes use of optimizer state and gradation. Since every GPU solely updates a partition of parameters, solely the corresponding gradient is required.

Functions of ZeRO-1 all-reduceThis offers all gradients to all GPUs. ZeRO-2 will change this. reduce-scatterevery GPU receives solely the gradients it truly wants. This protects each reminiscence and communication bandwidth.

Coaching steps:

  1. Ahead cross: Every GPU processes its personal micro-batches
  2. Backward cross: calculate slope
  3. reduce-scatter Gradient: Every GPU will get solely its partition
  4. Optimizer steps: Every GPU updates its parameter partition
  5. all-gather Parameters: Synchronize up to date fashions throughout GPUs
Picture by creator: Zero 2 Animation

The implementation is similar to ZeRO-1, however within the gradient synchronization step reduce-scatter as a substitute of all-reduce:
However wait. If all GPUs calculate all gradients throughout backprop, how is VRAM truly saved? Here is how:

  • The gradient of the parameter is calculated layer by layer, so instantly reduce-scattered The native copy is then freed (the simplified implementation doesn’t do that).
  • Throughout backprop, all we have to calculate the gradient of the present parameter is the gradient of the subsequent neuron activation. That’s, you do not want the complete gradient graph.
  • This lets you maintain solely the partitions assigned to every GPU whereas liberating up reminiscence for the gradient as you go backwards.

ZeRO-2, 7B mannequin, reminiscence with 8 GPUs:

  • Parameters: 28 GB (totally replicated)
  • Gradient: 28 GB / 8 = 3.5 GB
  • Optimizer state: 56 GB / 8 = 7 GB
  • Whole per GPU: 38.5 GB (Down from 112GB)

ZeRO-3: Parameter Partitioning

ZeRO-3 Partition Optimizer states, gradients, parameters. Every GPU shops just one/N of the complete mannequin’s state.

Throughout ahead and backward passes, every layer requires its full parameters, however every GPU shops solely a portion. So we Acquire all parameters simply in timeuse it, after which discard it instantly.

Coaching steps:

  • ahead cross:
    • Acquire all layer parameters from all GPUs
    • Performs a ahead cross of the layer utilizing the earlier layer’s activations as enter.
    • Discard collected parameters (maintain native partition solely)
    • Repeat these steps till all layers are accomplished
  • Reverse cross (layer by layer, backwards):
    • Acquire all of the layer parameters once more
    • Compute the gradient of the present layer utilizing the activation gradient from the subsequent layer
    • Cut back gradient and distribute (every GPU holds its shard)
    • Discard collected parameters (maintain native partition solely)
    • Repeat these steps till all layers are accomplished
  • Every GPU runs an optimizer step on its partition
  • Parameters are collected layer by layer through the ahead cross, so a ultimate allgather is just not required.
Picture by creator: Zero 3 Animation

Here is a simplified implementation:

class ZeRO_3(ZeRO_2):
    """
    ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + mannequin parameters (stage 3).

    At relaxation, every rank holds solely param_shards[idx] — a 1/world_size slice
    of every parameter. Full parameters are materialised quickly throughout
    the ahead and backward passes by way of all_gather, then instantly freed.
    """

    def __init__(self, mannequin, optimizer_cls):
        self.mannequin = mannequin
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()

        self.param_metadata = []
        shard_list = []

        self._param_to_idx = {}

        for idx, param in enumerate(self.mannequin.parameters()):
            original_shape = param.information.form
            flat = param.information.view(-1)
            numel = flat.numel()

            the rest = numel % self.world_size
            pad_size = (self.world_size - the rest) % self.world_size
            padded_numel = numel + pad_size
            shard_size = padded_numel // self.world_size

            shard_start = self.rank * shard_size
            shard_end = shard_start + shard_size

            self.param_metadata.append(
                {
                    "original_shape": original_shape,
                    "numel": numel,
                    "padded_numel": padded_numel,
                    "shard_size": shard_size,
                    "shard_start": shard_start,
                    "shard_end": shard_end,
                }
            )

            if pad_size > 0:
                flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
            else:
                flat_padded = flat

            shard = flat_padded[shard_start:shard_end].clone()
            shard_list.append(shard)

            # Exchange the total tensor with solely this rank's shard.
            # The mannequin's param.information now factors to a tiny slice; the total
            # weight might be reconstructed on demand throughout ahead/backward.
            param.information = shard.detach()
            self._param_to_idx[param] = idx

        self.param_shards = [s.requires_grad_(True) for s in shard_list]
        self.optimizer = optimizer_cls(self.param_shards)

        self._register_hooks()

    def _gather_param(self, idx, system, dtype):
        """All-gather the total parameter tensor for parameter `idx`."""
        meta = self.param_metadata[idx]
        full_flat = torch.empty(meta["padded_numel"], system=system, dtype=dtype)
        dist.all_gather_into_tensor(
            output_tensor=full_flat,
            input_tensor=self.param_shards[idx].information,
        )
        return full_flat[: meta["numel"]].view(meta["original_shape"])

    def _gather_module_params(self, module):
        """Collect full params for each parameter that belongs to this module solely (not kids)."""
        for param in module.parameters(recurse=False):
            idx = self._param_to_idx[param]
            param.information = self._gather_param(idx, param.system, param.dtype)

    def _reshard_module_params(self, module):
        """Reshard params again to native shard for each direct param of this module."""
        for param in module.parameters(recurse=False):
            idx = self._param_to_idx[param]
            param.information = self.param_shards[idx].information

    def _register_hooks(self):
        self._hooks = []
        for module in self.mannequin.modules():
            # Skip container modules that don't have any direct parameters
            if not record(module.parameters(recurse=False)):
                proceed

            # Ahead: collect -> run -> reshard
            h1 = module.register_forward_pre_hook(
                lambda mod, _inputs: self._gather_module_params(mod)
            )
            h2 = module.register_forward_hook(
                lambda mod, _inputs, _output: self._reshard_module_params(mod)
            )

            # Backward: collect earlier than grad computation → reshard after
            h3 = module.register_full_backward_pre_hook(
                lambda mod, _grad_output: self._gather_module_params(mod)
            )
            h4 = module.register_full_backward_hook(
                lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
            )

            self._hooks.lengthen([h1, h2, h3, h4])

    def training_step(self, inputs, targets, loss_fn):
        # Hooks deal with all collect/reshard round every module routinely
        output = self.mannequin(inputs)
        loss = loss_fn(output, targets)
        loss.backward()

        self._sync_gradients()

        # Every rank updates solely its native shard
        self.optimizer.step()

        for param in self.mannequin.parameters():
            param.grad = None

Parameters for every layer are collected simply earlier than they’re wanted and launched instantly after they’re wanted. This will increase communication however minimizes reminiscence throughout peak intervals. The truth is, to cover the latency, the implementation overlaps the allgather at layer N+1 with the ahead at layer N.

ZeRO-3, 7B mannequin, reminiscence with 8 GPUs:

  • Parameter: 28 GB / 8 = 3.5 GB
  • Gradient: 28 GB / 8 = 3.5 GB
  • Optimizer state: 56 GB / 8 = 7 GB
  • Whole per GPU: 14 GB (Down from 112GB)

it’s 8x discount That is as anticipated from partitioning throughout 8 GPUs.

Utilizing ZeRO with PyTorch

PyTorch ships with two ZeRO-3 implementations: FSDP1 (older, much less optimized) and FSDP2 (New, really helpful). At all times use FSDP2.

FSDP (Absolutely Sharded Information Parallel) routinely handles parameter assortment, gradient scattering, communication overlap, and reminiscence administration.

from torch.distributed.fsdp import fully_shard

mannequin = Transformer()
for layer in mannequin.layers:
    fully_shard(layer)
fully_shard(mannequin)

you should apply fully_shard Wrap the complete mannequin layer by layer.

conclusion

ZeRO is changing reminiscence for communication, so it isn’t a free launch. It is usually not value it for smaller fashions (like BERT), however it’s a recreation changer for bigger fashions.

Congratulations on finishing the race to the tip! On this put up, you discovered:

  • Commonplace DDP reminiscence redundancy points
  • How ZeRO partitions optimizer states, gradients, and parameters throughout GPUs
  • Three phases of ZeRO and their reminiscence/communication tradeoffs
  • How you can use ZeRO-3 by way of FSDP in PyTorch

The next article describes Tensor Parallelism, a mannequin parallelism approach that quickens layer computations by distributing work throughout GPUs.

References

  1. ZeRO: Memory optimization for training trillion parameter models (base paper)
  2. PyTorch FSDP Tutorial
  3. FSDP API reference
  4. A super-scale playbook that hugs your face
banner
Top Selling Multipurpose WP Theme

Converter

Top Selling Multipurpose WP Theme

Newsletter

Subscribe my Newsletter for new blog posts, tips & new photos. Let's stay updated!

banner
Top Selling Multipurpose WP Theme

Leave a Comment

banner
Top Selling Multipurpose WP Theme

Latest

Best selling

22000,00 $
16000,00 $
6500,00 $

Top rated

6500,00 $
22000,00 $
900000,00 $

Products

Knowledge Unleashed
Knowledge Unleashed

Welcome to Ivugangingo!

At Ivugangingo, we're passionate about delivering insightful content that empowers and informs our readers across a spectrum of crucial topics. Whether you're delving into the world of insurance, navigating the complexities of cryptocurrency, or seeking wellness tips in health and fitness, we've got you covered.