Skip to content

Commit 9383c21

Browse files
shihaobaisufubao
andauthored
Model support: Qwen3-next and Qwen3.5 (#1233)
Co-authored-by: sufubao <sufubao@sensetime.com>
1 parent 56d9a29 commit 9383c21

153 files changed

Lines changed: 9618 additions & 97 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1313
from lightllm.common.basemodel.infer_struct import InferStateInfo
14+
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
1415
from lightllm.common.kv_cache_mem_manager import MemoryManager
1516
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
1617
from lightllm.common.req_manager import ReqManager
@@ -53,6 +54,9 @@ class TpPartBaseModel:
5354
# infer state class
5455
infer_state_class = InferStateInfo
5556

57+
# radix cache class
58+
radix_cache_class = RadixCache
59+
5660
def __init__(self, kvargs):
5761
self.args = get_env_start_args()
5862
self.run_mode = kvargs["run_mode"]

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor
6262
def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
6363
raise Exception("need to impl")
6464

65-
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
66-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
67-
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
68-
input1 = None
65+
def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
66+
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
6967
self._post_cache_kv(cache_kv, infer_state, layer_weight)
70-
7168
o = self._context_attention_wrapper_run(
7269
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
7370
)
74-
7571
q = None
7672
o = self._get_o(o, infer_state, layer_weight)
7773
if self.tp_world_size_ > 1:
7874
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
75+
return o
76+
77+
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
78+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
79+
o = self.context_attention_forward(input1, infer_state, layer_weight)
7980
input_embdings.add_(o.view(-1, self.embed_dim_))
8081
o = None
8182

@@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
8788
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
8889
return input_embdings
8990

90-
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
91-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
92-
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
93-
input1 = None
91+
def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
92+
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
9493
self._post_cache_kv(cache_kv, infer_state, layer_weight)
9594
o = self._token_attention_kernel(q, infer_state, layer_weight)
9695
q = None
9796
o = self._get_o(o, infer_state, layer_weight)
9897
if self.tp_world_size_ > 1:
9998
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
99+
return o
100+
101+
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
102+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
103+
o = self.token_attention_forward(input1, infer_state, layer_weight)
100104
input_embdings.add_(o.view(-1, self.embed_dim_))
101105
o = None
102106

103107
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
104108
ffn_out = self._ffn(input1, infer_state, layer_weight)
105-
input1 = None
106109
if self.tp_world_size_ > 1:
107110
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
108111
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
109112
return input_embdings
110113

111-
def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
112-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
113-
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
114-
input1 = None
114+
def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
115+
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
115116
self._post_cache_kv(cache_kv, infer_state, layer_weight)
116-
117117
o = self._context_attention_wrapper_run(
118118
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
119119
)
120-
121120
q = None
122121
o = self._tpsp_get_o(o, infer_state, layer_weight)
122+
return o
123+
124+
def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
125+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
126+
o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight)
123127
input_embdings.add_(o.view(-1, self.embed_dim_))
124128
o = None
125129

@@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
129133
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
130134
return input_embdings
131135

132-
def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
133-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
134-
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
135-
input1 = None
136+
def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
137+
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
136138
self._post_cache_kv(cache_kv, infer_state, layer_weight)
137139
o = self._token_attention_kernel(q, infer_state, layer_weight)
138140
q = None
139141
o = self._tpsp_get_o(o, infer_state, layer_weight)
142+
return o
143+
144+
def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
145+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
146+
o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight)
140147
input_embdings.add_(o.view(-1, self.embed_dim_))
141148
o = None
142149

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
QKVROWNMMWeight,
88
COLMMWeight,
99
)
10-
from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight
10+
from .norm_weight import (
11+
TpRMSNormWeight,
12+
RMSNormWeight,
13+
GatedRMSNormWeight,
14+
LayerNormWeight,
15+
NoTpGEMMANormWeight,
16+
QKRMSNORMWeight,
17+
QKGEMMANormWeight,
18+
)
1119
from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight
1220
from .att_sink_weight import TpAttSinkWeight
1321
from .fused_moe.fused_moe_weight import FusedMoeWeight
22+
from .parameter_weight import ParameterWeight, TpParameterWeight

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward
66
from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward
77
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward
8+
from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward
89
from .platform_op import PlatformAwareOp
910

1011

@@ -71,6 +72,55 @@ def __call__(
7172
return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func)
7273

7374

75+
class GatedRMSNormWeight(RMSNormWeight):
76+
def _triton_forward(
77+
self,
78+
input: torch.Tensor,
79+
gate_value: torch.Tensor,
80+
eps: float,
81+
out: Optional[torch.Tensor] = None,
82+
alloc_func=torch.empty,
83+
) -> torch.Tensor:
84+
assert (
85+
input.ndim in [2, 3] and self.weight.ndim == 1
86+
), f"input.ndim: {input.ndim} != 2 or weight.ndim: {self.weight.ndim} != 1"
87+
if out is None:
88+
out = alloc_func(input.shape, dtype=input.dtype, device=input.device)
89+
return gated_rmsnorm_forward(x=input, weight=self.weight, bias=None, eps=eps, z=gate_value, out=out)
90+
91+
def _cuda_forward(
92+
self,
93+
input: torch.Tensor,
94+
gate_value: torch.Tensor,
95+
eps: float,
96+
out: Optional[torch.Tensor] = None,
97+
alloc_func=torch.empty,
98+
) -> torch.Tensor:
99+
# only triton implementation is supported for rmsnorm on cuda platform
100+
return self._triton_forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func)
101+
102+
def _musa_forward(
103+
self,
104+
input: torch.Tensor,
105+
gate_value: torch.Tensor,
106+
eps: float,
107+
out: Optional[torch.Tensor] = None,
108+
alloc_func=torch.empty,
109+
) -> torch.Tensor:
110+
# triton implementation is supported by musa.
111+
return self._triton_forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func)
112+
113+
def __call__(
114+
self,
115+
input: torch.Tensor,
116+
gate_value: torch.Tensor,
117+
eps: float,
118+
out: Optional[torch.Tensor] = None,
119+
alloc_func=torch.empty,
120+
) -> torch.Tensor:
121+
return self._forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func)
122+
123+
74124
class LayerNormWeight(BaseWeightTpl, PlatformAwareOp):
75125
def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None):
76126
super().__init__(tp_rank=0, tp_world_size=1)
@@ -193,6 +243,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
193243
if self.weight_name in weights:
194244
self.weight.copy_(weights[self.weight_name])
195245
self.weight += 1
246+
self.weight.load_ok = True
196247

197248

198249
class QKRMSNORMWeight(BaseWeightTpl, PlatformAwareOp):
@@ -276,3 +327,23 @@ def __call__(
276327
eps: float,
277328
) -> None:
278329
return self._forward(q=q, k=k, eps=eps)
330+
331+
332+
class QKGEMMANormWeight(QKRMSNORMWeight):
333+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
334+
if self.q_weight_name in weights:
335+
self.q_weight.copy_(weights[self.q_weight_name])
336+
self.q_weight += 1
337+
self.q_weight.load_ok = True
338+
if self.k_weight_name in weights:
339+
self.k_weight.copy_(weights[self.k_weight_name])
340+
self.k_weight += 1
341+
self.k_weight.load_ok = True
342+
343+
def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
344+
assert q.ndim == 2 and self.q_weight.ndim == 1
345+
assert k.ndim == 2 and self.k_weight.ndim == 1
346+
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
347+
# See https://github.com/huggingface/transformers/pull/29402
348+
# So we need to set fp32_multiply to True here.
349+
return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps, fp32_multiply=True)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
from typing import Dict, Optional, Tuple
3+
from .base_weight import BaseWeightTpl
4+
from lightllm.utils.dist_utils import get_dp_world_size
5+
6+
7+
class ParameterWeight(BaseWeightTpl):
8+
def __init__(
9+
self,
10+
weight_name: str,
11+
data_type: torch.dtype,
12+
weight_shape: Optional[Tuple[int, ...]],
13+
bias_name: Optional[str] = None,
14+
bias_shape: Optional[Tuple[int, ...]] = None,
15+
):
16+
super().__init__()
17+
self.weight_name = weight_name
18+
self.bias_name = bias_name
19+
self.data_type_ = data_type
20+
self.weight_shape = weight_shape
21+
self.bias_shape = bias_shape
22+
self.weight: Optional[torch.Tensor] = None
23+
self.bias: Optional[torch.Tensor] = None
24+
if weight_shape is not None:
25+
self._create_weight()
26+
27+
def _create_weight(self):
28+
if self.weight_shape is not None:
29+
self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_)
30+
self.weight.load_ok = False
31+
if self.bias_name is not None and self.bias_shape is not None:
32+
self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_)
33+
self.bias.load_ok = False
34+
35+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
36+
if self.weight_name in weights:
37+
t_weight = weights[self.weight_name]
38+
self.weight.copy_(t_weight.to(self.data_type_))
39+
self.weight.load_ok = True
40+
if self.bias_name is not None and self.bias_name in weights:
41+
t_bias = weights[self.bias_name]
42+
self.bias.copy_(t_bias.to(self.data_type_))
43+
self.bias.load_ok = True
44+
45+
def verify_load(self) -> bool:
46+
if self.weight is not None and not getattr(self.weight, "load_ok", False):
47+
return False
48+
if self.bias is not None and not getattr(self.bias, "load_ok", False):
49+
return False
50+
return True
51+
52+
53+
class TpParameterWeight(ParameterWeight):
54+
def __init__(
55+
self,
56+
weight_name: str,
57+
data_type: torch.dtype,
58+
bias_name: Optional[str] = None,
59+
weight_shape: Optional[Tuple[int, ...]] = None,
60+
bias_shape: Optional[Tuple[int, ...]] = None,
61+
dim: int = 0, # the default split dimension is 0
62+
):
63+
64+
assert (
65+
0 <= dim < len(weight_shape)
66+
), f"split dimension: {dim} must be less than the length of weight_shape: {weight_shape}"
67+
n_embed = weight_shape[dim]
68+
tp_world_size = get_dp_world_size()
69+
assert (
70+
n_embed % tp_world_size == 0
71+
), f"weight_shape[{dim}]={weight_shape[dim]} must be divisible by tp_world_size_: {tp_world_size}"
72+
self.dim = dim
73+
self.split_n_embed = n_embed // tp_world_size
74+
tp_weight_shape = None
75+
tp_bias_shape = None
76+
if weight_shape is not None:
77+
tp_weight_shape = weight_shape[:dim] + (self.split_n_embed,) + weight_shape[dim + 1 :]
78+
if bias_shape is not None:
79+
tp_bias_shape = bias_shape[:dim] + (self.split_n_embed,) + bias_shape[dim + 1 :]
80+
super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape)
81+
82+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
83+
start = self.split_n_embed * self.tp_rank_
84+
end = self.split_n_embed * (self.tp_rank_ + 1)
85+
86+
if self.weight_name in weights:
87+
t_weight = weights[self.weight_name].narrow(self.dim, start, end - start)
88+
self.weight.copy_(t_weight.to(self.data_type_))
89+
self.weight.load_ok = True
90+
if self.bias_name is not None and self.bias_name in weights:
91+
t_bias = weights[self.bias_name].narrow(self.dim, start, end - start)
92+
self.bias.copy_(t_bias.to(self.data_type_))
93+
self.bias.load_ok = True

0 commit comments

Comments
 (0)