File size: 8,518 Bytes
7ff2ba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from collections import defaultdict

import torch
import intel_extension_for_pytorch as ipex  # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core  # pylint: disable=import-error, unused-import

# pylint: disable=protected-access, missing-function-docstring, line-too-long

OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = (
    ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
)


def _unscale_grads_(
    self, optimizer, inv_scale, found_inf, allow_fp16
):  # pylint: disable=unused-argument
    per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
    per_device_found_inf = _MultiDeviceReplicator(found_inf)

    # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
    # There could be hundreds of grads, so we'd like to iterate through them just once.
    # However, we don't know their devices or dtypes in advance.

    # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
    # Google says mypy struggles with defaultdicts type annotations.
    per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
    # sync grad to master weight
    if hasattr(optimizer, "sync_grad"):
        optimizer.sync_grad()
    with torch.no_grad():
        for group in optimizer.param_groups:
            for param in group["params"]:
                if param.grad is None:
                    continue
                if (not allow_fp16) and param.grad.dtype == torch.float16:
                    raise ValueError("Attempting to unscale FP16 gradients.")
                if param.grad.is_sparse:
                    # is_coalesced() == False means the sparse grad has values with duplicate indices.
                    # coalesce() deduplicates indices and adds all values that have the same index.
                    # For scaled fp16 values, there's a good chance coalescing will cause overflow,
                    # so we should check the coalesced _values().
                    if param.grad.dtype is torch.float16:
                        param.grad = param.grad.coalesce()
                    to_unscale = param.grad._values()
                else:
                    to_unscale = param.grad

                # -: is there a way to split by device and dtype without appending in the inner loop?
                to_unscale = to_unscale.to("cpu")
                per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
                    to_unscale
                )

        for _, per_dtype_grads in per_device_and_dtype_grads.items():
            for grads in per_dtype_grads.values():
                core._amp_foreach_non_finite_check_and_unscale_(
                    grads,
                    per_device_found_inf.get("cpu"),
                    per_device_inv_scale.get("cpu"),
                )

    return per_device_found_inf._per_device_tensors


def unscale_(self, optimizer):
    """
    Divides ("unscales") the optimizer's gradient tensors by the scale factor.
    :meth:`unscale_` is optional, serving cases where you need to
    :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
    between the backward pass(es) and :meth:`step`.
    If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.
    Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
        ...
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        scaler.step(optimizer)
        scaler.update()
    Args:
        optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.
    .. warning::
        :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
        and only after all gradients for that optimizer's assigned parameters have been accumulated.
        Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
    .. warning::
        :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
    """
    if not self._enabled:
        return

    self._check_scale_growth_tracker("unscale_")

    optimizer_state = self._per_optimizer_states[id(optimizer)]

    if optimizer_state["stage"] is OptState.UNSCALED:  # pylint: disable=no-else-raise
        raise RuntimeError(
            "unscale_() has already been called on this optimizer since the last update()."
        )
    elif optimizer_state["stage"] is OptState.STEPPED:
        raise RuntimeError("unscale_() is being called after step().")

    # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
    assert self._scale is not None
    inv_scale = (
        self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
    )
    found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)

    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
        optimizer, inv_scale, found_inf, False
    )
    optimizer_state["stage"] = OptState.UNSCALED


def update(self, new_scale=None):
    """
    Updates the scale factor.
    If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
    to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
    the scale is multiplied by ``growth_factor`` to increase it.
    Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
    used directly, it's used to fill GradScaler's internal scale tensor. So if
    ``new_scale`` was a tensor, later in-place changes to that tensor will not further
    affect the scale GradScaler uses internally.)
    Args:
        new_scale (float or :class:`torch.FloatTensor`, optional, default=None):  New scale factor.
    .. warning::
        :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
        been invoked for all optimizers used this iteration.
    """
    if not self._enabled:
        return

    _scale, _growth_tracker = self._check_scale_growth_tracker("update")

    if new_scale is not None:
        # Accept a new user-defined scale.
        if isinstance(new_scale, float):
            self._scale.fill_(new_scale)  # type: ignore[union-attr]
        else:
            reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
            assert isinstance(new_scale, torch.FloatTensor), reason  # type: ignore[attr-defined]
            assert new_scale.numel() == 1, reason
            assert new_scale.requires_grad is False, reason
            self._scale.copy_(new_scale)  # type: ignore[union-attr]
    else:
        # Consume shared inf/nan data collected from optimizers to update the scale.
        # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
        found_infs = [
            found_inf.to(device="cpu", non_blocking=True)
            for state in self._per_optimizer_states.values()
            for found_inf in state["found_inf_per_device"].values()
        ]

        assert len(found_infs) > 0, "No inf checks were recorded prior to update."

        found_inf_combined = found_infs[0]
        if len(found_infs) > 1:
            for i in range(1, len(found_infs)):
                found_inf_combined += found_infs[i]

        to_device = _scale.device
        _scale = _scale.to("cpu")
        _growth_tracker = _growth_tracker.to("cpu")

        core._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

        _scale = _scale.to(to_device)
        _growth_tracker = _growth_tracker.to(to_device)
    # To prepare for next iteration, clear the data collected from optimizers this iteration.
    self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)


def gradscaler_init():
    torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
    torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
    torch.xpu.amp.GradScaler.unscale_ = unscale_
    torch.xpu.amp.GradScaler.update = update
    return torch.xpu.amp.GradScaler