Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel
from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel
from lightllm.models.gemma3.model import Gemma3TpPartModel
from lightllm.models.glm4v.model import GLM4VTpPartModel
from lightllm.models.tarsier2.model import (
Tarsier2Qwen2TpPartModel,
Tarsier2Qwen2VLTpPartModel,
Expand Down
Empty file.
437 changes: 437 additions & 0 deletions lightllm/models/glm4v/glm4v_visual.py

Large diffs are not rendered by default.

Empty file.
104 changes: 104 additions & 0 deletions lightllm/models/glm4v/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import torch.functional as F
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import torch.functional as F is unused in this file and can be removed. Note that torch.functional is also a deprecated alias for torch.nn.functional.

import torch.distributed as dist
import numpy as np
from typing import Tuple
from functools import partial

from lightllm.distributed import all_reduce
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight


class Glm4VTransformerLayerInfer(LlamaTransformerLayerInfer):
def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
mrope_section = network_config["rope_parameters"]["mrope_section"]
self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda")
self.partial_rotary_factor = network_config["rope_parameters"]["partial_rotary_factor"]

def _post_self_att_norm(
self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight
) -> torch.Tensor:
out = self.alloc_tensor(input.shape, input.dtype)
rmsnorm_forward(input, weight=layer_weight._post_self_att_norm_weight_.weight, eps=self.eps_, out=out)
return out

def _post_mlp_norm(
self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight
) -> torch.Tensor:
out = self.alloc_tensor(input.shape, input.dtype)
rmsnorm_forward(input, weight=layer_weight._post_mlp_norm_weight_.weight, eps=self.eps_, out=out)
return out

def _get_qkv(self, input, infer_state, layer_weight):
q = layer_weight.q_proj.mm(input)
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
mrope_triton_fused(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
self.mrope_section,
partial_rotary_factor=self.partial_rotary_factor,
is_interleaved=False,
is_glm4v=True,
)
return q, cache_kv

def context_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self._TransformerLayerInferTpl__context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def token_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
raise Exception("not impl")
Empty file.
14 changes: 14 additions & 0 deletions lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import rename_weight_keys


class Glm4VPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This return statement is redundant at the end of an __init__ method, which implicitly returns None. It can be removed for cleaner code.


def load_hf_weights(self, weights):
rename_weight_keys(weights)
super().load_hf_weights(weights)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This return statement is redundant at the end of a method that doesn't explicitly return a value. It can be removed for cleaner code.

35 changes: 35 additions & 0 deletions lightllm/models/glm4v/layer_weight/transformer_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight


class Glm4VTransformerLayerWeight(Qwen2TransformerLayerWeight):
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)

def _init_weight_names(self):
self._post_self_att_norm_weight_name = f"model.layers.{self.layer_num_}.post_self_attn_layernorm.weight"
self._post_self_att_norm_bias_name = None
self._post_mlp_norm_weight_name = f"model.layers.{self.layer_num_}.post_mlp_layernorm.weight"
self._post_mlp_norm_bias_name = None
super()._init_weight_names()

def load_hf_weights(self, weights):
gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight"
if gate_up_weight_name in weights:
intermediate_size = self.network_config_["intermediate_size"]
gate_up_proj = weights[gate_up_weight_name]
gate_weight_ = gate_up_proj[0:intermediate_size, :]
up_weight_ = gate_up_proj[intermediate_size:, :]
weights[self._gate_weight_name] = gate_weight_
weights[self._up_weight_name] = up_weight_
del weights[gate_up_weight_name]
super().load_hf_weights(weights)

def _init_norm(self):
self._post_self_att_norm_weight_ = NormWeight(
self._post_self_att_norm_weight_name, self.data_type_, bias_name=self._post_self_att_norm_bias_name
)
self._post_mlp_norm_weight_ = NormWeight(
self._post_mlp_norm_weight_name, self.data_type_, bias_name=self._post_mlp_norm_bias_name
)
super()._init_norm()
87 changes: 87 additions & 0 deletions lightllm/models/glm4v/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import json
import numpy as np
from lightllm.common.build_utils import repair_config
from lightllm.models.registry import ModelRegistry
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
from lightllm.models.glm4v.layer_infer.transformer_layer_infer import Glm4VTransformerLayerInfer
from lightllm.models.glm4v.layer_weight.pre_and_post_layer_weight import Glm4VPreAndPostLayerWeight
from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight
from lightllm.server.multimodal_params import MultimodalParams
from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
from lightllm.models.qwen2.model import Qwen2TpPartModel


class GLM4VTokenizer(QWen2VLTokenizer):
def __init__(self, tokenizer=None, image_processor=None, **kwargs):
self.tokenizer = tokenizer
self.image_processor = image_processor
self.min_pixel = self.image_processor.size["shortest_edge"]
self.max_pixel = self.image_processor.size["longest_edge"]
self.patch_size = self.image_processor.patch_size
self.merge_size = self.image_processor.merge_size
self.image_start_id = kwargs["model_cfg"]["image_start_token_id"]
self.image_end_id = kwargs["model_cfg"]["image_end_token_id"]
self.image_token_id = kwargs["model_cfg"]["image_token_id"]

def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
origin_ids = self.tokenizer.encode(prompt)

# <img><image_pad></img> -> <img></img>
origin_ids = [token for token in origin_ids if token != self.image_token_id]
# <img></img> --> <img>id,id+1...id+num</img>
input_ids = []
image_id = 0
while True:
try:
start_idx = origin_ids.index(self.image_start_id)
if start_idx + 1 >= len(origin_ids):
break
if origin_ids[start_idx + 1] == self.image_end_id:
input_ids.extend(origin_ids[: start_idx + 1])
token_id = multimodal_params.images[image_id].token_id
token_num = multimodal_params.images[image_id].token_num
multimodal_params.images[image_id].start_idx = len(input_ids)
input_ids.extend(range(token_id, token_id + token_num))
input_ids.append(self.image_end_id)
origin_ids = origin_ids[start_idx + 2 :]
image_id += 1
else:
raise ValueError("image token error")
except ValueError:
break
input_ids.extend(origin_ids)
return input_ids


@ModelRegistry(["glm4v"], is_multimodal=True)
class GLM4VTpPartModel(Qwen2TpPartModel):

pre_layer_infer_class = LlamaMultimodalPreLayerInfer
transformer_layer_infer_class = Glm4VTransformerLayerInfer

pre_and_post_weight_class = Glm4VPreAndPostLayerWeight
transformer_weight_class = Glm4VTransformerLayerWeight

infer_state_class = Qwen2VLInferStateInfo

def __init__(self, kvargs):
super().__init__(kvargs)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This return statement is redundant at the end of an __init__ method, which implicitly returns None. It can be removed for cleaner code.


def _init_inferstate_cls(self):
pass

def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
all_config = json.load(json_file)
self.config = all_config["text_config"]
# rename keys
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
if self.finetune_config:
self.config["vocab_size"] = self.finetune_config.vocab_size
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This return statement is redundant at the end of a method that doesn't explicitly return a value. It can be removed for cleaner code.

15 changes: 12 additions & 3 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def _init_custom(self):
模型特殊的一些初始化
"""
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
rope_scaling = self.config.get("rope_parameters", None)
if rope_scaling is None:
self._init_to_get_rotary()
return
Expand Down Expand Up @@ -171,14 +173,21 @@ def _init_weights(self):
return

def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
rope_params = self.config.get("rope_parameters")
if rope_params is not None:
partial_rotary_factor = rope_params.get("partial_rotary_factor", 1)
base = rope_params.get("rope_theta", float(default_base))
else:
partial_rotary_factor = self.config.get("partial_rotary_factor", 1)
base = self.config.get("rope_theta", float(default_base))

partial_head_dim = int(partial_rotary_factor * self.head_dim_)

if self.config.get("rope_scaling", {}) is None:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)

base = self.config.get("rope_theta", float(default_base))

if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
Expand Down
12 changes: 10 additions & 2 deletions lightllm/models/qwen2_vl/triton_kernel/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _mrope_triton_fused_kernel(
stride_kh,
stride_kd,
is_interleaved: tl.constexpr,
is_glm4v: tl.constexpr,
HEAD_Q: tl.constexpr,
HEAD_K: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
Expand All @@ -95,6 +96,10 @@ def _mrope_triton_fused_kernel(
dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)
dim_range1 = dim_range0 + BLOCK_DMODEL // 2

if is_glm4v:
dim_range0 = dim_range0 * 2
dim_range1 = dim_range0 + 1

t_cos = Cos + seq_index * stride_cosd
h_cos = Cos + stride_cosld + seq_index * stride_cosd
w_cos = Cos + 2 * stride_cosld + seq_index * stride_cosd
Expand Down Expand Up @@ -192,11 +197,13 @@ def mrope_triton_fused(
cos: torch.Tensor,
sin: torch.Tensor,
mrope_section: torch.Tensor,
is_interleaved: bool,
partial_rotary_factor: float = 1.0,
is_interleaved: bool = False,
is_glm4v: bool = False,
run_config: Optional[dict] = None,
):
head_num_q, head_num_k = q.shape[1], k.shape[1]
head_dim = int(q.shape[2])
head_dim = int(q.shape[2] * partial_rotary_factor)
num_tokens = q.shape[0]

if not run_config:
Expand Down Expand Up @@ -228,6 +235,7 @@ def mrope_triton_fused(
stride_kh=k.stride(1),
stride_kd=k.stride(2),
is_interleaved=is_interleaved,
is_glm4v=is_glm4v,
HEAD_Q=head_num_q,
HEAD_K=head_num_k,
BLOCK_DMODEL=head_dim,
Expand Down
6 changes: 6 additions & 0 deletions lightllm/server/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..models.qwen_vl.model import QWenVLTokenizer
from ..models.qwen2_vl.model import QWen2VLTokenizer
from ..models.qwen3_vl.model import QWen3VLTokenizer
from ..models.glm4v.model import GLM4VTokenizer
from ..models.internvl.model import InternvlTokenizer
from ..models.gemma3.model import Gemma3Tokenizer

Expand Down Expand Up @@ -104,5 +105,10 @@ def get_tokenizer(
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
tokenizer = Gemma3Tokenizer(tokenizer, model_cfg)
elif model_type == "glm4v":
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(tokenizer_name)
tokenizer = GLM4VTokenizer(tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg)

return tokenizer
5 changes: 5 additions & 0 deletions lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
from lightllm.models.glm4v.glm4v_visual import Glm4vVisionTransformerPretrainedModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
Expand Down Expand Up @@ -78,6 +79,10 @@ def exposed_init_model(self, kvargs):
# self.model = InternVLVisionModel()
elif self.model_type == "gemma3":
self.model = Gemma3VisionModel()
elif self.model_type == "glm4v":
self.model = (
Glm4vVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
)
else:
raise Exception(f"can not support {self.model_type} now")

Expand Down