Skip to content

Commit 1bd80f0

Browse files
author
wentiange
committed
[Feature] Layer-wise MoE balance loss computation
1 parent 18e5533 commit 1bd80f0

4 files changed

Lines changed: 523 additions & 34 deletions

File tree

xtuner/v1/loss/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs
22
from .ce_loss import CELossConfig, CELossContext, LMHeadLossContext
33
from .chunk_loss import ChunkLoss
4+
from .layer_moe_loss import (
5+
LayerBalancingLoss,
6+
LayerBalancingLossConfig,
7+
LayerBalancingLossContext,
8+
LayerBalancingLossKwargs,
9+
)
410
from .moe_loss import (
511
BalancingLoss,
612
BalancingLossConfig,
@@ -20,6 +26,10 @@
2026
"BalancingLossConfig",
2127
"BalancingLossContext",
2228
"BalancingLossKwargs",
29+
"LayerBalancingLoss",
30+
"LayerBalancingLossConfig",
31+
"LayerBalancingLossContext",
32+
"LayerBalancingLossKwargs",
2333
"ZLoss",
2434
"ZLossConfig",
2535
"ZLossContext",

xtuner/v1/loss/layer_moe_loss.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
from typing import Any, cast
2+
3+
import torch
4+
import torch.nn as nn
5+
from pydantic import BaseModel, ConfigDict
6+
from torch import distributed as dist
7+
from torch.distributed._functional_collectives import all_reduce
8+
9+
from xtuner.v1.utils import get_torch_device_module
10+
from xtuner.v1.utils.router_offload import AsyncOffloadedTensor, async_offload_to_cpu, wait_async_offload
11+
12+
13+
DEVICE_MODULE = get_torch_device_module()
14+
15+
16+
class _AllReduce(torch.autograd.Function):
17+
@staticmethod
18+
def forward(ctx, op, group, tensor):
19+
ctx.group = group
20+
ctx.op = op
21+
tensor = tensor.clone(memory_format=torch.contiguous_format)
22+
tensor = all_reduce(tensor, op, group=group)
23+
return tensor
24+
25+
@staticmethod
26+
def backward(ctx, grad_output):
27+
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
28+
29+
30+
def all_reduce_autograd(tensor, op, group):
31+
return _AllReduce.apply(op, group, tensor)
32+
33+
34+
def select_nonpad(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> tuple[torch.Tensor, torch.Tensor]:
35+
"""Select non-padding positions from tensor.
36+
37+
Args:
38+
tensor (torch.Tensor): Input tensor.
39+
mask (torch.Tensor): Attention mask.
40+
dim (int): Select dimension.
41+
42+
Returns:
43+
tuple[torch.Tensor, torch.Tensor]: Selected tensor and selected indices.
44+
"""
45+
indices = torch.nonzero(mask, as_tuple=True)[1]
46+
selected = torch.index_select(tensor, dim, indices).contiguous().float()
47+
return selected, indices
48+
49+
50+
def maybe_offload_tensor(
51+
tensor: torch.Tensor,
52+
split_bal_loss: bool,
53+
offload_stream: torch.cuda.Stream,
54+
) -> torch.Tensor | AsyncOffloadedTensor:
55+
"""Offload tensor to CPU only when split_bal_loss is enabled."""
56+
if split_bal_loss:
57+
return async_offload_to_cpu(tensor, offload_stream)
58+
return tensor
59+
60+
61+
def maybe_wait_offload_tensor(
62+
tensor: torch.Tensor | AsyncOffloadedTensor,
63+
split_bal_loss: bool,
64+
) -> torch.Tensor:
65+
"""Wait offloaded tensor only when split_bal_loss is enabled."""
66+
if split_bal_loss:
67+
return wait_async_offload(cast(AsyncOffloadedTensor, tensor))
68+
return cast(torch.Tensor, tensor).detach()
69+
70+
71+
class LayerBalancingLossConfig(BaseModel):
72+
"""Configuration for layer-wise split balancing loss."""
73+
74+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
75+
num_layers: int | None = None
76+
n_routed_experts: int | None = None
77+
device: torch.device | str | None = None
78+
79+
def build(
80+
self,
81+
*,
82+
num_layers: int | None = None,
83+
n_routed_experts: int | None = None,
84+
device: torch.device | str | None = None,
85+
) -> "LayerBalancingLoss":
86+
"""Build layer balancing loss context.
87+
88+
Args:
89+
num_layers (int | None): Number of layers. Fallback to config value when None.
90+
n_routed_experts (int | None): Number of routed experts. Fallback to config value when None.
91+
device (torch.device | str | None): Device used for internal accumulators.
92+
Fallback order: argument -> config field -> DEVICE_MODULE.current_device().
93+
94+
Returns:
95+
LayerBalancingLoss: Built context.
96+
"""
97+
resolved_num_layers = num_layers if num_layers is not None else self.num_layers
98+
resolved_n_routed_experts = n_routed_experts if n_routed_experts is not None else self.n_routed_experts
99+
assert resolved_num_layers is not None, "num_layers must be provided either in config or build()."
100+
assert resolved_n_routed_experts is not None, "n_routed_experts must be provided either in config or build()."
101+
102+
resolved_device = device if device is not None else self.device
103+
if resolved_device is None:
104+
resolved_device = DEVICE_MODULE.current_device()
105+
106+
return LayerBalancingLoss(
107+
num_layers=resolved_num_layers,
108+
n_routed_experts=resolved_n_routed_experts,
109+
device=resolved_device,
110+
)
111+
112+
113+
class LayerBalancingLossKwargs(BaseModel):
114+
"""Keyword arguments for layer-wise split balancing loss context."""
115+
116+
model_config = ConfigDict(
117+
title="layer balancing loss keyword arguments", extra="forbid", arbitrary_types_allowed=True
118+
)
119+
device: torch.device | str
120+
121+
122+
class LayerBalancingLossContext(nn.Module):
123+
"""Layer-wise balancing loss accumulator used by split_bal_loss mode."""
124+
125+
def __init__(self, loss_cfg: LayerBalancingLossConfig, loss_kwargs: LayerBalancingLossKwargs):
126+
super().__init__()
127+
self.loss_cfg = loss_cfg
128+
self.loss_kwargs = loss_kwargs
129+
num_layers = self.loss_cfg.num_layers
130+
n_routed_experts = self.loss_cfg.n_routed_experts
131+
assert num_layers is not None
132+
assert n_routed_experts is not None
133+
134+
self.local_load = torch.zeros(num_layers, n_routed_experts, device=loss_kwargs.device)
135+
self.routing_weights_sum_list: list[torch.Tensor] = []
136+
self.local_load_logits = torch.zeros(
137+
num_layers,
138+
n_routed_experts,
139+
dtype=torch.int64,
140+
device=loss_kwargs.device,
141+
)
142+
143+
def update(
144+
self,
145+
layer_idx: int,
146+
router_weights: torch.Tensor,
147+
num_experts_per_tok: int,
148+
router_logits: torch.Tensor,
149+
) -> None:
150+
"""Update accumulators for one layer.
151+
152+
Args:
153+
layer_idx (int): Layer index.
154+
router_weights (torch.Tensor): Router weights, shape (1, non_pad_seq, n_experts).
155+
num_experts_per_tok (int): Number of experts selected per token.
156+
router_logits (torch.Tensor): Router logits, shape (1, non_pad_seq, n_experts).
157+
"""
158+
n_routed_experts = self.loss_cfg.n_routed_experts
159+
assert n_routed_experts is not None
160+
161+
_, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1)
162+
tokens_per_expert = torch.histc(
163+
selected_experts.view(-1),
164+
bins=n_routed_experts,
165+
min=0,
166+
max=n_routed_experts,
167+
).float()
168+
self.local_load[layer_idx] = tokens_per_expert
169+
self.routing_weights_sum_list.append(router_weights.sum(dim=1).squeeze(0))
170+
171+
_, selected_experts = torch.topk(router_logits, num_experts_per_tok, dim=-1)
172+
tokens_per_expert_logits = torch.histc(
173+
selected_experts.view(-1),
174+
bins=n_routed_experts,
175+
min=0,
176+
max=n_routed_experts,
177+
).to(torch.long)
178+
self.local_load_logits[layer_idx] = tokens_per_expert_logits
179+
180+
def finalize(
181+
self,
182+
dist_init: bool,
183+
num_experts_per_tok: int,
184+
non_pad_token: int,
185+
balancing_loss_weight: float = 1.0,
186+
) -> torch.Tensor:
187+
"""Finalize layer-wise balancing loss.
188+
189+
Args:
190+
dist_init (bool): Whether to use distributed global average mode.
191+
num_experts_per_tok (int): Number of experts selected per token.
192+
non_pad_token (int): Number of non-padding tokens.
193+
balancing_loss_weight (float): Balancing loss weight.
194+
195+
Returns:
196+
torch.Tensor: Final balancing loss.
197+
"""
198+
n_routed_experts = self.loss_cfg.n_routed_experts
199+
assert n_routed_experts is not None
200+
local_gating_sum = torch.stack(self.routing_weights_sum_list, dim=0)
201+
202+
if dist_init:
203+
group = dist.group.WORLD
204+
assert group is not None
205+
tokens_per_expert_global = all_reduce(self.local_load, "sum", group)
206+
tokens_global = tokens_per_expert_global.sum(-1)
207+
seqlen_global = tokens_global // num_experts_per_tok
208+
209+
routing_weights_sum_global = all_reduce_autograd(local_gating_sum, "sum", group)
210+
routing_weights_mean_global = routing_weights_sum_global / seqlen_global.unsqueeze(-1)
211+
scale_global = n_routed_experts / tokens_global
212+
else:
213+
tokens_per_expert_global = self.local_load
214+
valid_tokens = max(non_pad_token, 1)
215+
scale_global = n_routed_experts / (valid_tokens * num_experts_per_tok)
216+
routing_weights_mean_global = local_gating_sum / valid_tokens
217+
218+
loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1)
219+
return loss.sum() * balancing_loss_weight
220+
221+
def cal_tokens_per_expert(self) -> torch.Tensor:
222+
"""Get tokens-per-expert tensor for logging/bias update."""
223+
if dist.is_initialized():
224+
group = dist.group.WORLD
225+
assert group is not None
226+
return all_reduce(self.local_load_logits, "sum", group)
227+
return self.local_load_logits
228+
229+
230+
class LayerBalancingLoss(LayerBalancingLossContext):
231+
"""Backward-compatible wrapper keeping the original constructor usage."""
232+
233+
def __init__(self, num_layers: int, n_routed_experts: int, device: torch.device | str):
234+
cfg = LayerBalancingLossConfig(num_layers=num_layers, n_routed_experts=n_routed_experts, device=device)
235+
kwargs = LayerBalancingLossKwargs(device=device)
236+
super().__init__(cfg, kwargs)
237+
238+
239+
def prepare_layer_balancing_loss(
240+
layer_balancing_cfg: LayerBalancingLossConfig | None,
241+
*,
242+
num_layers: int,
243+
n_routed_experts: int,
244+
) -> LayerBalancingLoss | None:
245+
"""Build layer balancing loss object from config.
246+
247+
Args:
248+
layer_balancing_cfg (LayerBalancingLossConfig | None): Layer balancing config.
249+
num_layers (int): Number of decoder layers.
250+
n_routed_experts (int): Number of routed experts.
251+
252+
Returns:
253+
LayerBalancingLoss | None: Built object when enabled, else None.
254+
"""
255+
if layer_balancing_cfg is None:
256+
return None
257+
258+
return layer_balancing_cfg.build(
259+
num_layers=num_layers,
260+
n_routed_experts=n_routed_experts,
261+
)
262+
263+
264+
def accumulate_layer_balancing_loss(
265+
layer_balancing_loss: LayerBalancingLoss | None,
266+
*,
267+
layer_idx: int,
268+
router_weights: torch.Tensor,
269+
router_logits: torch.Tensor,
270+
mask: torch.Tensor,
271+
dim: int,
272+
num_experts_per_tok: int,
273+
) -> None:
274+
"""Accumulate per-layer balancing statistics.
275+
276+
This is a no-op when layer balancing loss is disabled.
277+
"""
278+
if layer_balancing_loss is None:
279+
return
280+
281+
router_weights_selected, _ = select_nonpad(router_weights, mask, dim=dim)
282+
router_logits_selected, _ = select_nonpad(router_logits, mask, dim=dim)
283+
layer_balancing_loss.update(
284+
layer_idx,
285+
router_weights_selected,
286+
num_experts_per_tok,
287+
router_logits_selected,
288+
)
289+
290+
291+
def finalize_layer_balancing_loss(
292+
layer_balancing_loss: LayerBalancingLoss | None,
293+
*,
294+
balancing_ctx: Any,
295+
num_experts_per_tok: int,
296+
non_pad_token: int,
297+
) -> tuple[torch.Tensor, torch.Tensor] | None:
298+
"""Finalize balancing loss and tokens-per-expert from accumulated layer
299+
stats.
300+
301+
Returns None when layer balancing is disabled or balancing_ctx is None.
302+
"""
303+
if layer_balancing_loss is None or balancing_ctx is None:
304+
return None
305+
306+
dist_init = balancing_ctx.loss_cfg.balancing_loss_global_average and dist.is_initialized()
307+
balancing_loss = layer_balancing_loss.finalize(
308+
dist_init,
309+
num_experts_per_tok,
310+
non_pad_token=non_pad_token,
311+
balancing_loss_weight=balancing_ctx.loss_cfg.balancing_loss_alpha,
312+
)
313+
tokens_per_expert_global = layer_balancing_loss.cal_tokens_per_expert()
314+
return balancing_loss, tokens_per_expert_global

0 commit comments

Comments
 (0)