SIGN IN SIGN UP
deepspeedai / DeepSpeed UNCLAIMED

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

0 0 0 Python
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
from typing import Any
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
Enable torch.autocast with ZeRO (#6993) DeepSpeed supports mixed precision training, but the behavior is different from `torch.autocast`. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while `torch.autocast` maintains parameters in FP32 but computes only certain operators in the lower precision. This leads to differences in: - performance: `torch.autocast` needs downcast in forward/backward - memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision - accuracy: `torch.autocast` has a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32. To align DeepSpeed's behavior with `torch.autocast` when necessary, this PR adds the integration with `torch.autocast` with ZeRO. Here is an examples of the configuration. ```json "torch_autocast": { "enabled": true, "dtype": "bfloat16", "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"] } ``` Each configuration works as follows: - `enabled`: Enable the integration with `torch.autocast` if this is set to `True`. You don't need to call `torch.autocast` in your code. The grad scaler is also applied in the DeepSpeed optimizer. - `dtype`: lower precision dtype passed to `torch.autocast`. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) of `lower_precision_safe_modules` are also downcasted to this dtype. - `lower_precision_safe_modules`: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward follows `torch.autocast`'s policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list: `[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]`. Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable `fp16` or `bf16` in DeepSpeed config. --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Fabien Dupont <fdupont@redhat.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: inkcherry <mingzhi.liu@intel.com> Signed-off-by: Omar Elayan <oelayan@habana.ai> Signed-off-by: Roman Fitzjalen <romaactor@gmail.com> Signed-off-by: Hongwei <hongweichen@microsoft.com> Signed-off-by: shaomin <wukon1992@gmail.com> Signed-off-by: Stas Bekman <stas@stason.org> Signed-off-by: siqi <siqi@tecorigin.com> Signed-off-by: Wei Wu <wuwei211x@gmail.com> Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il> Signed-off-by: Lai, Yejing <yejing.lai@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr> Co-authored-by: Liangliang Ma <1906710196@qq.com> Co-authored-by: inkcherry <mingzhi.liu@intel.com> Co-authored-by: Omar Elayan <142979319+oelayan7@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Roman Fitzjalen <romaactor@gmail.com> Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com> Co-authored-by: Guanhua Wang <alexwgh333@gmail.com> Co-authored-by: root <root@ftqtmec25000000.taxzvufipdhelhupulxcbvr15f.ux.internal.cloudapp.net> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com> Co-authored-by: shaomin <wukon1992@gmail.com> Co-authored-by: loadams <loadams@users.noreply.github.com> Co-authored-by: siqi654321 <siqi202311@163.com> Co-authored-by: siqi <siqi@tecorigin.com> Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com> Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com> Co-authored-by: snahir <snahir@habana.ai> Co-authored-by: Yejing-Lai <yejing.lai@intel.com> Co-authored-by: Siddharth Singh <siddharth9820@gmail.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
2025-06-19 14:36:03 -07:00
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage
from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
from deepspeed.runtime.utils import maybe_loss_for_backward
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
class DeepSpeedOptimizer(object):
pass
class BackwardHookStateManager:
"""Manages backward pass state for ZeRO optimizers.
This class handles the complex state management needed for gradient accumulation hooks
to work correctly with:
1. **Reentrant Gradient Checkpointing** (use_reentrant=True):
With reentrant checkpointing, gradient hooks fire in multiple phases within a
single backward() call. For example, with model: linear1 (checkpointed) -> linear2:
- Phase 1: Hooks for linear2 fire (non-checkpointed params)
- Checkpoint recomputes linear1's forward
- Phase 2: Hooks for linear1 fire (checkpointed params)
The challenge is that `count_used_parameters_in_backward()` only sees params
currently in the backward graph. During Phase 1, it returns 2 (linear2's params),
but after checkpoint recomputation, it returns 4 (all params). We must NOT run
the epilogue prematurely after Phase 1.
Solution: Queue a post-backward callback on the autograd engine at the start of
backward and run the epilogue when the graph task completes. This avoids premature
epilogues across reentrant phases. The `_max_expected_hooks_seen` counter remains
as a fallback when the callback API is unavailable.
2. **TiledFusedLogitsLoss and Similar Custom Autograd Functions**:
Some custom autograd functions call `torch.autograd.backward()` from their
forward pass BEFORE the user calls `engine.backward(loss)`. These internal
backward calls trigger ZeRO's gradient hooks, but we must NOT run the epilogue
until the user's actual backward pass.
Solution: Track `_backward_active_depth` which is only incremented when
`enter_backward()` is called (from engine.backward or user code). Hooks check
this depth before running the epilogue.
3. **Multiple Backward Phases with Exit/Re-entry**:
When the epilogue runs after Phase 1 (with reentrant checkpointing), it calls
`exit_backward()`, setting `_backward_active_depth` to 0. When Phase 2's hooks
fire, we need to re-enter the backward context.
Solution: `_backward_seen_this_step` flag tracks if backward was ever active
this step. Combined with `_backward_active_depth == 0`, this detects Phase 2
and calls `enter_backward()` again.
Attributes:
remaining_grad_acc_hooks: Count of hooks remaining before epilogue should run
backward_active_depth: Nesting depth of backward() calls (0 = not in backward)
backward_seen_this_step: True if enter_backward() was called this step
epilogue_ran_this_backward: True if epilogue ran (for micro_step_id management)
hooks_fired_this_backward: Count of gradient hooks that have fired
max_expected_hooks_seen: Maximum expected hook count seen (grows with reentrant)
post_backward_callback_queued: True if a post-backward callback is queued
post_backward_callback_graph_task_id: Graph task id for the queued callback
"""
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
def __init__(self):
self.remaining_grad_acc_hooks = 0
self._grad_acc_post_hooks = []
self.backward_active_depth = 0
self.backward_seen_this_step = False
self.epilogue_ran_this_backward = False
self.hooks_fired_this_backward = 0
self.max_expected_hooks_seen = 0
self.post_backward_callback_queued = False
self.post_backward_callback_graph_task_id = None
def register_grad_acc_post_hook(self, hook):
"""Register a callback to run when all gradient hooks have fired."""
self._grad_acc_post_hooks.append(hook)
def unregister_grad_acc_post_hooks(self):
"""Remove all registered gradient accumulation post hooks."""
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
self._grad_acc_post_hooks = []
def run_grad_acc_post_hooks(self):
"""Run all registered post hooks if backward is active.
Custom autograd Functions (e.g., TiledFusedLogitsLoss) can invoke
`torch.autograd.backward()` from their *forward* pass before the user
ever calls `engine.backward(loss)`. Those early backward calls still
trigger ZeRO's grad hooks, but we must not run the engine's
post-backward logic (which reduces/clears grads) until the outer/user
backward is active. The depth guard filters out only those pre-user
invocations while still allowing backward calls that happen during
the real user backward.
"""
if self.backward_active_depth == 0:
return
for hook in self._grad_acc_post_hooks:
hook()
def enter_backward(self):
"""Enter backward context. Call at the start of backward pass."""
# On first real backward entry of a step, reset counters that may have been
# polluted by pre-user-backward hooks (e.g. TiledFusedLogitsLoss calling
# torch.autograd.backward() from forward). Do NOT reset on reentrant
# phase re-entry (backward_seen_this_step == True) so phase-to-phase
# state remains intact.
if self.backward_active_depth == 0 and not self.backward_seen_this_step:
self.hooks_fired_this_backward = 0
self.max_expected_hooks_seen = 0
self.remaining_grad_acc_hooks = 0
self.post_backward_callback_queued = False
self.post_backward_callback_graph_task_id = None
self.backward_active_depth += 1
# Track that backward has been active at some point in this step.
# This is used to detect subsequent gradient hook phases with reentrant checkpointing.
self.backward_seen_this_step = True
def exit_backward(self):
"""Exit backward context. Call at the end of backward pass."""
if self.backward_active_depth > 0:
self.backward_active_depth -= 1
def reset_for_new_step(self):
"""Reset state at the start of each forward/backward step."""
self.backward_seen_this_step = False
self.hooks_fired_this_backward = 0
self.max_expected_hooks_seen = 0
self.epilogue_ran_this_backward = False
self.post_backward_callback_queued = False
self.post_backward_callback_graph_task_id = None
def should_refresh_expected_hook_count(self):
"""Return True when count_used_parameters_in_backward() should be re-evaluated.
Refresh is needed in two cases:
1. First hook of a backward (or backward phase): hooks_fired == 0.
2. A new reentrant phase started: remaining hooks exhausted, we exited
backward, but backward was active earlier this step.
The predicate must be evaluated BEFORE reenter_backward_if_needed()
because re-entering changes backward_active_depth and hides the
phase-boundary signal.
"""
return (self.hooks_fired_this_backward == 0
or (self.remaining_grad_acc_hooks == 0 and self.backward_active_depth == 0
and self.backward_seen_this_step))
def reenter_backward_if_needed(self):
"""Re-enter backward context for subsequent phases in reentrant checkpointing.
With reentrant gradient checkpointing, gradient hooks can fire in multiple phases
within a single backward call. When the epilogue runs after a phase, it calls
exit_backward(), setting backward_active_depth to 0. When the next phase starts,
we need to re-enter backward.
We detect subsequent phases by checking:
1. remaining_grad_acc_hooks == 0 (epilogue ran or new backward)
2. backward_active_depth == 0 (we've exited from previous phase)
3. backward_seen_this_step == True (backward was active earlier)
This distinguishes from TiledFusedLogitsLoss which calls backward() during forward -
in that case backward_seen_this_step is False because enter_backward() was never called.
"""
if self.remaining_grad_acc_hooks == 0:
if self.backward_active_depth == 0 and self.backward_seen_this_step:
self.enter_backward()
def queue_post_backward_callback(self):
"""Queue post-backward hooks to run after the current graph finishes."""
if self.post_backward_callback_queued:
return True
if self.backward_active_depth == 0:
return False
engine = getattr(torch.autograd.Variable, "_execution_engine", None)
if engine is None or not hasattr(engine, "queue_callback"):
return False
if not hasattr(torch._C, "_current_graph_task_id"):
return False
graph_task_id = torch._C._current_graph_task_id()
if graph_task_id == -1:
return False
def _run_post_backward():
self.run_grad_acc_post_hooks()
engine.queue_callback(_run_post_backward)
self.post_backward_callback_queued = True
self.post_backward_callback_graph_task_id = graph_task_id
return True
def update_hook_state_and_maybe_run_epilogue(self, current_expected_count):
"""Update hook state after a gradient hook fires and run epilogue if all hooks have fired.
With reentrant gradient checkpointing, count_used_parameters_in_backward() returns the
count of params that will execute in the current backward graph. This count grows as
checkpointed regions are recomputed. We track the MAXIMUM count seen to ensure we don't
run the epilogue until all params that will ever participate have been processed.
Counters are reset at forward() time via reset_for_new_step().
Args:
current_expected_count: The current expected number of hooks, from
count_used_parameters_in_backward() plus any leaf modules.
"""
self.hooks_fired_this_backward += 1
self.max_expected_hooks_seen = max(self.max_expected_hooks_seen, current_expected_count)
# Prefer running post-backward hooks via autograd engine callback when available.
# This avoids premature epilogues with reentrant checkpointing.
if self.queue_post_backward_callback():
self.remaining_grad_acc_hooks = max(self.max_expected_hooks_seen - self.hooks_fired_this_backward, 0)
return
# Fallback: Run epilogue only when we've processed ALL params that will participate.
# This is the maximum count we've seen (accounts for late-joining params
# from reentrant checkpointing) and also excludes unused params.
if self.hooks_fired_this_backward >= self.max_expected_hooks_seen:
self.remaining_grad_acc_hooks = 0
self.run_grad_acc_post_hooks()
else:
self.remaining_grad_acc_hooks = self.max_expected_hooks_seen - self.hooks_fired_this_backward
class ZeROOptimizer(DeepSpeedOptimizer):
"""Base class for ZeRO optimizer implementations (stages 1, 2, and 3)."""
def __init__(self):
self._backward_hook_state = BackwardHookStateManager()
# Delegate backward hook state management to the manager.
# These properties provide backward compatibility with code that accesses
# these attributes directly (e.g., in stage3.py and stage_1_and_2.py).
@property
def _remaining_grad_acc_hooks(self):
return self._backward_hook_state.remaining_grad_acc_hooks
@_remaining_grad_acc_hooks.setter
def _remaining_grad_acc_hooks(self, value):
self._backward_hook_state.remaining_grad_acc_hooks = value
@property
def _backward_active_depth(self):
return self._backward_hook_state.backward_active_depth
@_backward_active_depth.setter
def _backward_active_depth(self, value):
self._backward_hook_state.backward_active_depth = value
@property
def _backward_seen_this_step(self):
return self._backward_hook_state.backward_seen_this_step
@_backward_seen_this_step.setter
def _backward_seen_this_step(self, value):
self._backward_hook_state.backward_seen_this_step = value
@property
def _epilogue_ran_this_backward(self):
return self._backward_hook_state.epilogue_ran_this_backward
@_epilogue_ran_this_backward.setter
def _epilogue_ran_this_backward(self, value):
self._backward_hook_state.epilogue_ran_this_backward = value
@property
def _hooks_fired_this_backward(self):
return self._backward_hook_state.hooks_fired_this_backward
@_hooks_fired_this_backward.setter
def _hooks_fired_this_backward(self, value):
self._backward_hook_state.hooks_fired_this_backward = value
@property
def _max_expected_hooks_seen(self):
return self._backward_hook_state.max_expected_hooks_seen
@_max_expected_hooks_seen.setter
def _max_expected_hooks_seen(self, value):
self._backward_hook_state.max_expected_hooks_seen = value
@property
def _grad_acc_post_hooks(self):
return self._backward_hook_state._grad_acc_post_hooks
@_grad_acc_post_hooks.setter
def _grad_acc_post_hooks(self, value):
self._backward_hook_state._grad_acc_post_hooks = value
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path, weights_only=False)
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
self._load_global_state(optim_sd)
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.")
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
2024-03-28 01:12:57 -07:00
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()
for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()
steps = []
lp_groups = getattr(self, lp_groups_name)
for lp in lp_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
steps.append(step)
hp_param = param_group['params'][0]
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
if steps[0] is not None:
self.optimizer.state[hp_param]['step'] = steps[0]
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)
for key, value in loaded_param_group.items():
if key == 'params':
continue
param_group[key] = value
Enable torch.autocast with ZeRO (#6993) DeepSpeed supports mixed precision training, but the behavior is different from `torch.autocast`. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while `torch.autocast` maintains parameters in FP32 but computes only certain operators in the lower precision. This leads to differences in: - performance: `torch.autocast` needs downcast in forward/backward - memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision - accuracy: `torch.autocast` has a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32. To align DeepSpeed's behavior with `torch.autocast` when necessary, this PR adds the integration with `torch.autocast` with ZeRO. Here is an examples of the configuration. ```json "torch_autocast": { "enabled": true, "dtype": "bfloat16", "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"] } ``` Each configuration works as follows: - `enabled`: Enable the integration with `torch.autocast` if this is set to `True`. You don't need to call `torch.autocast` in your code. The grad scaler is also applied in the DeepSpeed optimizer. - `dtype`: lower precision dtype passed to `torch.autocast`. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) of `lower_precision_safe_modules` are also downcasted to this dtype. - `lower_precision_safe_modules`: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward follows `torch.autocast`'s policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list: `[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]`. Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable `fp16` or `bf16` in DeepSpeed config. --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Fabien Dupont <fdupont@redhat.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: inkcherry <mingzhi.liu@intel.com> Signed-off-by: Omar Elayan <oelayan@habana.ai> Signed-off-by: Roman Fitzjalen <romaactor@gmail.com> Signed-off-by: Hongwei <hongweichen@microsoft.com> Signed-off-by: shaomin <wukon1992@gmail.com> Signed-off-by: Stas Bekman <stas@stason.org> Signed-off-by: siqi <siqi@tecorigin.com> Signed-off-by: Wei Wu <wuwei211x@gmail.com> Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il> Signed-off-by: Lai, Yejing <yejing.lai@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr> Co-authored-by: Liangliang Ma <1906710196@qq.com> Co-authored-by: inkcherry <mingzhi.liu@intel.com> Co-authored-by: Omar Elayan <142979319+oelayan7@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Roman Fitzjalen <romaactor@gmail.com> Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com> Co-authored-by: Guanhua Wang <alexwgh333@gmail.com> Co-authored-by: root <root@ftqtmec25000000.taxzvufipdhelhupulxcbvr15f.ux.internal.cloudapp.net> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com> Co-authored-by: shaomin <wukon1992@gmail.com> Co-authored-by: loadams <loadams@users.noreply.github.com> Co-authored-by: siqi654321 <siqi202311@163.com> Co-authored-by: siqi <siqi@tecorigin.com> Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com> Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com> Co-authored-by: snahir <snahir@habana.ai> Co-authored-by: Yejing-Lai <yejing.lai@intel.com> Co-authored-by: Siddharth Singh <siddharth9820@gmail.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
2025-06-19 14:36:03 -07:00
def report_ipg_memory_usage(self, tag, param_elems, dtype=None):
dtypes = self.ipg_buckets.keys() if dtype is None else [dtype]
for dt in dtypes:
bucket = self.ipg_buckets[dt]
elem_count = bucket.elements + param_elems
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
see_memory_usage(
f"{tag}: elems in_bucket {dt} {bucket.elements} param {param_elems} max_percent {percent_of_bucket_size}"
)
def get_param_comm_dtype(self, param):
if is_autocast_initialized():
return get_comm_dtype(param)
Enable torch.autocast with ZeRO (#6993) DeepSpeed supports mixed precision training, but the behavior is different from `torch.autocast`. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while `torch.autocast` maintains parameters in FP32 but computes only certain operators in the lower precision. This leads to differences in: - performance: `torch.autocast` needs downcast in forward/backward - memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision - accuracy: `torch.autocast` has a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32. To align DeepSpeed's behavior with `torch.autocast` when necessary, this PR adds the integration with `torch.autocast` with ZeRO. Here is an examples of the configuration. ```json "torch_autocast": { "enabled": true, "dtype": "bfloat16", "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"] } ``` Each configuration works as follows: - `enabled`: Enable the integration with `torch.autocast` if this is set to `True`. You don't need to call `torch.autocast` in your code. The grad scaler is also applied in the DeepSpeed optimizer. - `dtype`: lower precision dtype passed to `torch.autocast`. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) of `lower_precision_safe_modules` are also downcasted to this dtype. - `lower_precision_safe_modules`: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward follows `torch.autocast`'s policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list: `[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]`. Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable `fp16` or `bf16` in DeepSpeed config. --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Fabien Dupont <fdupont@redhat.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: inkcherry <mingzhi.liu@intel.com> Signed-off-by: Omar Elayan <oelayan@habana.ai> Signed-off-by: Roman Fitzjalen <romaactor@gmail.com> Signed-off-by: Hongwei <hongweichen@microsoft.com> Signed-off-by: shaomin <wukon1992@gmail.com> Signed-off-by: Stas Bekman <stas@stason.org> Signed-off-by: siqi <siqi@tecorigin.com> Signed-off-by: Wei Wu <wuwei211x@gmail.com> Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il> Signed-off-by: Lai, Yejing <yejing.lai@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr> Co-authored-by: Liangliang Ma <1906710196@qq.com> Co-authored-by: inkcherry <mingzhi.liu@intel.com> Co-authored-by: Omar Elayan <142979319+oelayan7@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Roman Fitzjalen <romaactor@gmail.com> Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com> Co-authored-by: Guanhua Wang <alexwgh333@gmail.com> Co-authored-by: root <root@ftqtmec25000000.taxzvufipdhelhupulxcbvr15f.ux.internal.cloudapp.net> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com> Co-authored-by: shaomin <wukon1992@gmail.com> Co-authored-by: loadams <loadams@users.noreply.github.com> Co-authored-by: siqi654321 <siqi202311@163.com> Co-authored-by: siqi <siqi@tecorigin.com> Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com> Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com> Co-authored-by: snahir <snahir@habana.ai> Co-authored-by: Yejing-Lai <yejing.lai@intel.com> Co-authored-by: Siddharth Singh <siddharth9820@gmail.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
2025-06-19 14:36:03 -07:00
else:
return self.communication_data_type
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
def needs_scaler(self) -> bool:
"""
Check if this optimizer requires loss scaling for correct backward pass.
Returns True if any of the following conditions are met:
- Custom loss scaler is enabled
- torch.autocast gradient scaler is active (fp16 only)
- Dynamic loss scaling is enabled (fp16 with DeepSpeed's loss scaler)
Returns False for bf16 or fp32, which don't require gradient scaling.
"""
return (self.custom_loss_scaler or self.torch_autocast_gradscaler is not None
or (hasattr(self, 'dynamic_loss_scale') and self.dynamic_loss_scale))
def scale_if_loss(self, value: Any) -> Any:
"""
Applies loss scaling to the input value if it is a loss tensor.
"""
if maybe_loss_for_backward(value):
if self.custom_loss_scaler:
return self.external_loss_scale * value
if self.torch_autocast_gradscaler:
return self.torch_autocast_gradscaler.scale(value)
# Only call loss_scaler if it exists (not present in BF16_Optimizer)
if hasattr(self, 'loss_scaler') and self.loss_scaler is not None:
return self.loss_scaler.scale_loss(value)
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
return value
def backward_prologue(self):
pass
def backward_epilogue(self, **kwargs):
pass
def backward(self, loss, **kwargs):
assert maybe_loss_for_backward(loss), "Optimizer's backward() only accepts a scalar tensor"
scaled_loss = self.backward_prologue(loss)
retain_graph = kwargs.pop('retain_graph', False)
self.enter_backward()
scaled_loss.backward(retain_graph=retain_graph)
self.backward_epilogue()
self.exit_backward()
def register_grad_acc_post_hook(self, hook):
"""Register a callback to run when all gradient hooks have fired."""
self._backward_hook_state.register_grad_acc_post_hook(hook)
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
def unregister_grad_acc_post_hooks(self):
"""Remove all registered gradient accumulation post hooks."""
self._backward_hook_state.unregister_grad_acc_post_hooks()
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
def run_grad_acc_post_hooks(self):
"""Run all registered post hooks if backward is active."""
self._backward_hook_state.run_grad_acc_post_hooks()
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
def enter_backward(self):
"""Enter backward context. Call at the start of backward pass."""
self._backward_hook_state.enter_backward()
PyTorch-compatible backward API (#7665) Currently DeepSpeed's backward API has more constraints compared to PyTorch's normal backward API. Here is the usage as described in the documentation: ```python loss = model_engine(batch) model_engine.backward(loss) ``` In this example, 1. Only accepts a (scalar) loss value 1. Need to call engine's backward API In contrast, in standard PyTorch, you can do: ```python output = model(batch) output.backward(out_grad) ``` There are several use cases that rely on this flexibility. For example, combining multiple models or using loss functions defined separately from the main model. If you attempt the same pattern with a DeepSpeed engine, some preprocessing and postprocessing steps will be silently skipped, which can lead to incorrect results. The [document](https://deepspeed.readthedocs.io/en/latest/training.html#jointly-training-models-with-shared-loss) explains we can call `_backward_epilogue` manually (possibly `backward_prologue` as well). However, it's easy for users to miss these calls, and passing a non-scalar gradient is still not supported. This PR introduces the same `.backward()` behavior as PyTorch, allowing .backward() to be called directly on tensors and supporting non-scalar outputs. To implement post-backward hooks, we had to use some torch internal APIs. See [comments](https://github.com/deepspeedai/DeepSpeed/blob/73f7ff1aab9d1387eb7dd4eca7453a25024533f4/deepspeed/runtime/engine.py#L424) for more details. When the internal APIs are not available, DeepSpeed engine only accepts the traditional way `model_engine.backward(loss)`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2025-11-19 09:26:19 +09:00
def exit_backward(self):
"""Exit backward context. Call at the end of backward pass."""
self._backward_hook_state.exit_backward()
def clear_backward_seen_flag(self):
"""Clear the backward seen flag and reset hook counters at the start of each step."""
self._backward_hook_state.reset_for_new_step()
def should_refresh_expected_hook_count(self):
"""Return True when count_used_parameters_in_backward() should be re-evaluated."""
return self._backward_hook_state.should_refresh_expected_hook_count()
def reenter_backward_if_needed(self):
"""Re-enter backward context for subsequent phases in reentrant checkpointing."""
self._backward_hook_state.reenter_backward_if_needed()
def update_hook_state_and_maybe_run_epilogue(self, current_expected_count):
"""Update hook state after a gradient hook fires and run epilogue if all hooks have fired."""
self._backward_hook_state.update_hook_state_and_maybe_run_epilogue(current_expected_count)
def queue_post_backward_callback(self):
"""Queue post-backward hooks to run after autograd completes."""
return self._backward_hook_state.queue_post_backward_callback()
Low-precision master params/grads/optimizer states (#7700) DeepSpeed optimizer always creates fp32 master params/gradients/optimizer states. However, we sometimes want to keep them lower precision given [torch.autocast support](https://deepspeed.readthedocs.io/en/latest/training.html#mixed-precision-training). This PR allows lower precision master params/grads/optimizer states with bf16/fp16 enabled. DeepSpeed currently accepts `fp16_master_weights_and_gradients` option under `fp16` section (not documented) with ZeRO1/2. This PR extends this for bf16 and also ZeRO3. In `bf16` section, we can have new items `bf16_master_weights_and_grads` and `bf16_optimizer_states`. Similary to `fp16_master_weights_and_grads`, `bf16_master_weights_and_grads` keeps master parameters in bf16. `bf16_optimizer_states` keeps optimizer states also in bf16. Here is an example configuration: ```json "bf16": { "enabled": true, "bf16_master_weights_and_grads": true, "bf16_optimizer_states": true } ``` Note that `bf16_master_weights_and_grads==True` and `bf16_optimizer_states==False` is supported only with cpu offloading. Also, we don't have `fp16_optimizer_states` as it won't be practical. More details are described in [`config-json.md`](https://github.com/tohtana/DeepSpeed/blob/88e0bbdfba89c4712d815980ddb28353d6da5b2e/docs/_pages/config-json.md) Previously, `torch.autocast` support (`torch_autocast` section in config) was not compatible with `bf16` `fp16` enabled, but we now accept the combination. This PR also adds some test cases for the configurations as well as the combination with `torch.autocast`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
2025-12-04 12:53:37 +09:00
def _configure_master_weights(self,
fp16_master_weights_and_gradients=False,
bf16_master_weights_and_gradients=False,
bf16_optimizer_states=False,
fp16_offload_validator=None,
bf16_fp32_offload_validator=None):
"""
Common validation and dtype selection for ZeRO optimizer master-weight settings.
Optionally accepts callables that enforce backend-specific offload requirements.
"""
self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
self.bf16_master_weights_and_gradients = bf16_master_weights_and_gradients
assert not (self.fp16_master_weights_and_gradients and self.bf16_master_weights_and_gradients), \
"fp16_master_weights_and_gradients and bf16_master_weights_and_gradients are mutually exclusive."
self.bf16_optimizer_states = bf16_optimizer_states
if self.bf16_optimizer_states:
assert self.bf16_master_weights_and_gradients, \
"bf16_optimizer_states requires bf16_master_weights_and_gradients."
if (self.bf16_master_weights_and_gradients and not self.bf16_optimizer_states
and bf16_fp32_offload_validator is not None):
bf16_fp32_offload_validator()
if self.fp16_master_weights_and_gradients and fp16_offload_validator is not None:
fp16_offload_validator()
if self.fp16_master_weights_and_gradients:
return torch.float16
elif self.bf16_master_weights_and_gradients:
return torch.bfloat16
else:
return torch.float32