|
| 1 | +from typing import Any, cast |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from pydantic import BaseModel, ConfigDict |
| 6 | +from torch import distributed as dist |
| 7 | +from torch.distributed._functional_collectives import all_reduce |
| 8 | + |
| 9 | +from xtuner.v1.utils import get_torch_device_module |
| 10 | +from xtuner.v1.utils.router_offload import AsyncOffloadedTensor, async_offload_to_cpu, wait_async_offload |
| 11 | + |
| 12 | + |
| 13 | +DEVICE_MODULE = get_torch_device_module() |
| 14 | + |
| 15 | + |
| 16 | +class _AllReduce(torch.autograd.Function): |
| 17 | + @staticmethod |
| 18 | + def forward(ctx, op, group, tensor): |
| 19 | + ctx.group = group |
| 20 | + ctx.op = op |
| 21 | + tensor = tensor.clone(memory_format=torch.contiguous_format) |
| 22 | + tensor = all_reduce(tensor, op, group=group) |
| 23 | + return tensor |
| 24 | + |
| 25 | + @staticmethod |
| 26 | + def backward(ctx, grad_output): |
| 27 | + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) |
| 28 | + |
| 29 | + |
| 30 | +def all_reduce_autograd(tensor, op, group): |
| 31 | + return _AllReduce.apply(op, group, tensor) |
| 32 | + |
| 33 | + |
| 34 | +def select_nonpad(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> tuple[torch.Tensor, torch.Tensor]: |
| 35 | + """Select non-padding positions from tensor. |
| 36 | +
|
| 37 | + Args: |
| 38 | + tensor (torch.Tensor): Input tensor. |
| 39 | + mask (torch.Tensor): Attention mask. |
| 40 | + dim (int): Select dimension. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + tuple[torch.Tensor, torch.Tensor]: Selected tensor and selected indices. |
| 44 | + """ |
| 45 | + indices = torch.nonzero(mask, as_tuple=True)[1] |
| 46 | + selected = torch.index_select(tensor, dim, indices).contiguous().float() |
| 47 | + return selected, indices |
| 48 | + |
| 49 | + |
| 50 | +def maybe_offload_tensor( |
| 51 | + tensor: torch.Tensor, |
| 52 | + split_bal_loss: bool, |
| 53 | + offload_stream: torch.cuda.Stream, |
| 54 | +) -> torch.Tensor | AsyncOffloadedTensor: |
| 55 | + """Offload tensor to CPU only when split_bal_loss is enabled.""" |
| 56 | + if split_bal_loss: |
| 57 | + return async_offload_to_cpu(tensor, offload_stream) |
| 58 | + return tensor |
| 59 | + |
| 60 | + |
| 61 | +def maybe_wait_offload_tensor( |
| 62 | + tensor: torch.Tensor | AsyncOffloadedTensor, |
| 63 | + split_bal_loss: bool, |
| 64 | +) -> torch.Tensor: |
| 65 | + """Wait offloaded tensor only when split_bal_loss is enabled.""" |
| 66 | + if split_bal_loss: |
| 67 | + return wait_async_offload(cast(AsyncOffloadedTensor, tensor)) |
| 68 | + return cast(torch.Tensor, tensor).detach() |
| 69 | + |
| 70 | + |
| 71 | +class LayerBalancingLossConfig(BaseModel): |
| 72 | + """Configuration for layer-wise split balancing loss.""" |
| 73 | + |
| 74 | + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) |
| 75 | + num_layers: int | None = None |
| 76 | + n_routed_experts: int | None = None |
| 77 | + device: torch.device | str | None = None |
| 78 | + |
| 79 | + def build( |
| 80 | + self, |
| 81 | + *, |
| 82 | + num_layers: int | None = None, |
| 83 | + n_routed_experts: int | None = None, |
| 84 | + device: torch.device | str | None = None, |
| 85 | + ) -> "LayerBalancingLoss": |
| 86 | + """Build layer balancing loss context. |
| 87 | +
|
| 88 | + Args: |
| 89 | + num_layers (int | None): Number of layers. Fallback to config value when None. |
| 90 | + n_routed_experts (int | None): Number of routed experts. Fallback to config value when None. |
| 91 | + device (torch.device | str | None): Device used for internal accumulators. |
| 92 | + Fallback order: argument -> config field -> DEVICE_MODULE.current_device(). |
| 93 | +
|
| 94 | + Returns: |
| 95 | + LayerBalancingLoss: Built context. |
| 96 | + """ |
| 97 | + resolved_num_layers = num_layers if num_layers is not None else self.num_layers |
| 98 | + resolved_n_routed_experts = n_routed_experts if n_routed_experts is not None else self.n_routed_experts |
| 99 | + assert resolved_num_layers is not None, "num_layers must be provided either in config or build()." |
| 100 | + assert resolved_n_routed_experts is not None, "n_routed_experts must be provided either in config or build()." |
| 101 | + |
| 102 | + resolved_device = device if device is not None else self.device |
| 103 | + if resolved_device is None: |
| 104 | + resolved_device = DEVICE_MODULE.current_device() |
| 105 | + |
| 106 | + return LayerBalancingLoss( |
| 107 | + num_layers=resolved_num_layers, |
| 108 | + n_routed_experts=resolved_n_routed_experts, |
| 109 | + device=resolved_device, |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +class LayerBalancingLossKwargs(BaseModel): |
| 114 | + """Keyword arguments for layer-wise split balancing loss context.""" |
| 115 | + |
| 116 | + model_config = ConfigDict( |
| 117 | + title="layer balancing loss keyword arguments", extra="forbid", arbitrary_types_allowed=True |
| 118 | + ) |
| 119 | + device: torch.device | str |
| 120 | + |
| 121 | + |
| 122 | +class LayerBalancingLossContext(nn.Module): |
| 123 | + """Layer-wise balancing loss accumulator used by split_bal_loss mode.""" |
| 124 | + |
| 125 | + def __init__(self, loss_cfg: LayerBalancingLossConfig, loss_kwargs: LayerBalancingLossKwargs): |
| 126 | + super().__init__() |
| 127 | + self.loss_cfg = loss_cfg |
| 128 | + self.loss_kwargs = loss_kwargs |
| 129 | + num_layers = self.loss_cfg.num_layers |
| 130 | + n_routed_experts = self.loss_cfg.n_routed_experts |
| 131 | + assert num_layers is not None |
| 132 | + assert n_routed_experts is not None |
| 133 | + |
| 134 | + self.local_load = torch.zeros(num_layers, n_routed_experts, device=loss_kwargs.device) |
| 135 | + self.routing_weights_sum_list: list[torch.Tensor] = [] |
| 136 | + self.local_load_logits = torch.zeros( |
| 137 | + num_layers, |
| 138 | + n_routed_experts, |
| 139 | + dtype=torch.int64, |
| 140 | + device=loss_kwargs.device, |
| 141 | + ) |
| 142 | + |
| 143 | + def update( |
| 144 | + self, |
| 145 | + layer_idx: int, |
| 146 | + router_weights: torch.Tensor, |
| 147 | + num_experts_per_tok: int, |
| 148 | + router_logits: torch.Tensor, |
| 149 | + ) -> None: |
| 150 | + """Update accumulators for one layer. |
| 151 | +
|
| 152 | + Args: |
| 153 | + layer_idx (int): Layer index. |
| 154 | + router_weights (torch.Tensor): Router weights, shape (1, non_pad_seq, n_experts). |
| 155 | + num_experts_per_tok (int): Number of experts selected per token. |
| 156 | + router_logits (torch.Tensor): Router logits, shape (1, non_pad_seq, n_experts). |
| 157 | + """ |
| 158 | + n_routed_experts = self.loss_cfg.n_routed_experts |
| 159 | + assert n_routed_experts is not None |
| 160 | + |
| 161 | + _, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1) |
| 162 | + tokens_per_expert = torch.histc( |
| 163 | + selected_experts.view(-1), |
| 164 | + bins=n_routed_experts, |
| 165 | + min=0, |
| 166 | + max=n_routed_experts, |
| 167 | + ).float() |
| 168 | + self.local_load[layer_idx] = tokens_per_expert |
| 169 | + self.routing_weights_sum_list.append(router_weights.sum(dim=1).squeeze(0)) |
| 170 | + |
| 171 | + _, selected_experts = torch.topk(router_logits, num_experts_per_tok, dim=-1) |
| 172 | + tokens_per_expert_logits = torch.histc( |
| 173 | + selected_experts.view(-1), |
| 174 | + bins=n_routed_experts, |
| 175 | + min=0, |
| 176 | + max=n_routed_experts, |
| 177 | + ).to(torch.long) |
| 178 | + self.local_load_logits[layer_idx] = tokens_per_expert_logits |
| 179 | + |
| 180 | + def finalize( |
| 181 | + self, |
| 182 | + dist_init: bool, |
| 183 | + num_experts_per_tok: int, |
| 184 | + non_pad_token: int, |
| 185 | + balancing_loss_weight: float = 1.0, |
| 186 | + ) -> torch.Tensor: |
| 187 | + """Finalize layer-wise balancing loss. |
| 188 | +
|
| 189 | + Args: |
| 190 | + dist_init (bool): Whether to use distributed global average mode. |
| 191 | + num_experts_per_tok (int): Number of experts selected per token. |
| 192 | + non_pad_token (int): Number of non-padding tokens. |
| 193 | + balancing_loss_weight (float): Balancing loss weight. |
| 194 | +
|
| 195 | + Returns: |
| 196 | + torch.Tensor: Final balancing loss. |
| 197 | + """ |
| 198 | + n_routed_experts = self.loss_cfg.n_routed_experts |
| 199 | + assert n_routed_experts is not None |
| 200 | + local_gating_sum = torch.stack(self.routing_weights_sum_list, dim=0) |
| 201 | + |
| 202 | + if dist_init: |
| 203 | + group = dist.group.WORLD |
| 204 | + assert group is not None |
| 205 | + tokens_per_expert_global = all_reduce(self.local_load, "sum", group) |
| 206 | + tokens_global = tokens_per_expert_global.sum(-1) |
| 207 | + seqlen_global = tokens_global // num_experts_per_tok |
| 208 | + |
| 209 | + routing_weights_sum_global = all_reduce_autograd(local_gating_sum, "sum", group) |
| 210 | + routing_weights_mean_global = routing_weights_sum_global / seqlen_global.unsqueeze(-1) |
| 211 | + scale_global = n_routed_experts / tokens_global |
| 212 | + else: |
| 213 | + tokens_per_expert_global = self.local_load |
| 214 | + valid_tokens = max(non_pad_token, 1) |
| 215 | + scale_global = n_routed_experts / (valid_tokens * num_experts_per_tok) |
| 216 | + routing_weights_mean_global = local_gating_sum / valid_tokens |
| 217 | + |
| 218 | + loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1) |
| 219 | + return loss.sum() * balancing_loss_weight |
| 220 | + |
| 221 | + def cal_tokens_per_expert(self) -> torch.Tensor: |
| 222 | + """Get tokens-per-expert tensor for logging/bias update.""" |
| 223 | + if dist.is_initialized(): |
| 224 | + group = dist.group.WORLD |
| 225 | + assert group is not None |
| 226 | + return all_reduce(self.local_load_logits, "sum", group) |
| 227 | + return self.local_load_logits |
| 228 | + |
| 229 | + |
| 230 | +class LayerBalancingLoss(LayerBalancingLossContext): |
| 231 | + """Backward-compatible wrapper keeping the original constructor usage.""" |
| 232 | + |
| 233 | + def __init__(self, num_layers: int, n_routed_experts: int, device: torch.device | str): |
| 234 | + cfg = LayerBalancingLossConfig(num_layers=num_layers, n_routed_experts=n_routed_experts, device=device) |
| 235 | + kwargs = LayerBalancingLossKwargs(device=device) |
| 236 | + super().__init__(cfg, kwargs) |
| 237 | + |
| 238 | + |
| 239 | +def prepare_layer_balancing_loss( |
| 240 | + layer_balancing_cfg: LayerBalancingLossConfig | None, |
| 241 | + *, |
| 242 | + num_layers: int, |
| 243 | + n_routed_experts: int, |
| 244 | +) -> LayerBalancingLoss | None: |
| 245 | + """Build layer balancing loss object from config. |
| 246 | +
|
| 247 | + Args: |
| 248 | + layer_balancing_cfg (LayerBalancingLossConfig | None): Layer balancing config. |
| 249 | + num_layers (int): Number of decoder layers. |
| 250 | + n_routed_experts (int): Number of routed experts. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + LayerBalancingLoss | None: Built object when enabled, else None. |
| 254 | + """ |
| 255 | + if layer_balancing_cfg is None: |
| 256 | + return None |
| 257 | + |
| 258 | + return layer_balancing_cfg.build( |
| 259 | + num_layers=num_layers, |
| 260 | + n_routed_experts=n_routed_experts, |
| 261 | + ) |
| 262 | + |
| 263 | + |
| 264 | +def accumulate_layer_balancing_loss( |
| 265 | + layer_balancing_loss: LayerBalancingLoss | None, |
| 266 | + *, |
| 267 | + layer_idx: int, |
| 268 | + router_weights: torch.Tensor, |
| 269 | + router_logits: torch.Tensor, |
| 270 | + mask: torch.Tensor, |
| 271 | + dim: int, |
| 272 | + num_experts_per_tok: int, |
| 273 | +) -> None: |
| 274 | + """Accumulate per-layer balancing statistics. |
| 275 | +
|
| 276 | + This is a no-op when layer balancing loss is disabled. |
| 277 | + """ |
| 278 | + if layer_balancing_loss is None: |
| 279 | + return |
| 280 | + |
| 281 | + router_weights_selected, _ = select_nonpad(router_weights, mask, dim=dim) |
| 282 | + router_logits_selected, _ = select_nonpad(router_logits, mask, dim=dim) |
| 283 | + layer_balancing_loss.update( |
| 284 | + layer_idx, |
| 285 | + router_weights_selected, |
| 286 | + num_experts_per_tok, |
| 287 | + router_logits_selected, |
| 288 | + ) |
| 289 | + |
| 290 | + |
| 291 | +def finalize_layer_balancing_loss( |
| 292 | + layer_balancing_loss: LayerBalancingLoss | None, |
| 293 | + *, |
| 294 | + balancing_ctx: Any, |
| 295 | + num_experts_per_tok: int, |
| 296 | + non_pad_token: int, |
| 297 | +) -> tuple[torch.Tensor, torch.Tensor] | None: |
| 298 | + """Finalize balancing loss and tokens-per-expert from accumulated layer |
| 299 | + stats. |
| 300 | +
|
| 301 | + Returns None when layer balancing is disabled or balancing_ctx is None. |
| 302 | + """ |
| 303 | + if layer_balancing_loss is None or balancing_ctx is None: |
| 304 | + return None |
| 305 | + |
| 306 | + dist_init = balancing_ctx.loss_cfg.balancing_loss_global_average and dist.is_initialized() |
| 307 | + balancing_loss = layer_balancing_loss.finalize( |
| 308 | + dist_init, |
| 309 | + num_experts_per_tok, |
| 310 | + non_pad_token=non_pad_token, |
| 311 | + balancing_loss_weight=balancing_ctx.loss_cfg.balancing_loss_alpha, |
| 312 | + ) |
| 313 | + tokens_per_expert_global = layer_balancing_loss.cal_tokens_per_expert() |
| 314 | + return balancing_loss, tokens_per_expert_global |
0 commit comments