Skip to content

Commit ecf34ae

Browse files
author
钮圣虓
committed
feat: vit seperation
1 parent a087340 commit ecf34ae

File tree

26 files changed

+1461
-178
lines changed

26 files changed

+1461
-178
lines changed
Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import torch
33
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
4+
from lightllm.utils.sgl_utils import flash_attn_varlen_func
45

56

67
class Fa3VitAttBackend(BaseVitAttBackend):
@@ -17,42 +18,18 @@ def _vit_att_fwd(
1718
head_dim = q.shape[-1]
1819
softmax_scale = head_dim ** -0.5
1920
window_size = (-1, -1)
20-
torch.ops.sgl_kernel.fwd.default(
21+
o = flash_attn_varlen_func(
2122
q,
2223
k,
2324
v,
24-
None, # k_new
25-
None, # v_new
26-
None, # qv
27-
o, # out
28-
cu_seqlens,
29-
cu_seqlens,
30-
None, # cu_seqlens_k_new
31-
None,
32-
None,
33-
max_seqlen,
34-
max_seqlen,
35-
None, # page_table,
36-
None, # kv_batch_idx
37-
None, # leftpad_k
38-
None, # rotary cos
39-
None, # rotary sin
40-
None, # seqlens_rotary
41-
None,
42-
None,
43-
None,
44-
softmax_scale,
45-
False,
46-
window_size[0],
47-
window_size[1],
25+
cu_seqlens_q=cu_seqlens,
26+
cu_seqlens_k=cu_seqlens,
27+
max_seqlen_q=max_seqlen,
28+
max_seqlen_k=max_seqlen,
29+
softmax_scale=softmax_scale,
30+
causal=False,
31+
window_size=window_size,
4832
attention_chunk=0,
4933
softcap=0.0,
50-
is_rotary_interleaved=False,
51-
scheduler_metadata=None,
52-
num_splits=1,
53-
pack_gqa=None,
54-
sm_margin=0,
55-
sinks=None,
5634
)
57-
5835
return o

lightllm/models/internvl/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,22 @@ def init_imageitem_extral_params(
6868
img.extra_params["image_patch_max_num"] = 6
6969
elif num_images > 6:
7070
img.extra_params["image_patch_max_num"] = 0
71+
img.patch_num = self.get_image_patch(img)
7172
return
7273

7374
def init_audioitem_extral_params(
7475
self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
7576
):
7677
return
7778

78-
def get_image_token_length(self, img: ImageItem):
79-
return (
80-
self.get_image_patch_func(
81-
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
82-
)
83-
* self.image_length
79+
def get_image_patch(self, img: ImageItem):
80+
return self.get_image_patch_func(
81+
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
8482
)
8583

84+
def get_image_token_length(self, img: ImageItem):
85+
return self.get_image_patch(img) * self.image_length
86+
8687
def get_audio_token_length(self, audio: AudioItem):
8788
L = audio.audio_length
8889
audio_token_num = 0

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import rpyc
2+
import socket
13
import torch
24
import torch.distributed as dist
35

46
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
57
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
68
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
9+
from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs
710
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
811
from lightllm.distributed.communication_op import all_reduce
12+
from lightllm.utils.envs_utils import get_env_start_args
913

1014

1115
"""
@@ -26,24 +30,41 @@
2630
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
2731
def __init__(self, network_config):
2832
super().__init__(network_config)
33+
self.args = get_env_start_args()
34+
self.cache_client = None
35+
if self.args.enable_remote_vit:
36+
self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True})
37+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
38+
return
39+
40+
def _copy_loaded_embed_to_cache(
41+
self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int
42+
):
43+
if embed_tensor.ndim == 2:
44+
embed_tensor = embed_tensor.unsqueeze(1)
45+
46+
token_num, layer_num, hidden_size = embed_tensor.shape
47+
cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor)
2948
return
3049

3150
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
3251
img_start_token_ids = []
3352
img_token_lens = []
3453
img_start_locs_in_cache = []
54+
unique_uids = []
3555
device = layer_weight.wte_weight_.weight.device
3656
dtype = layer_weight.wte_weight_.weight.dtype
3757
hidden_size = layer_weight.wte_weight_.weight.shape[1]
3858

39-
for batch_id, p in enumerate(infer_state.multimodal_params):
59+
for _, p in enumerate(infer_state.multimodal_params):
4060
for img in p["images"] + p["audios"]:
4161
# skip the same image
4262
if img["token_id"] in img_start_token_ids:
4363
continue
4464
img_start_token_ids.append(img["token_id"])
4565
img_token_lens.append(img["token_num"])
4666
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
67+
unique_uids.append(img["uuid"])
4768
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
4869

4970
from lightllm.server.router.model_infer.infer_batch import g_infer_context
@@ -55,6 +76,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5576
else cpu_embed_cache_client.cpu_embed_cache_tensor
5677
)
5778

79+
if self.args.enable_remote_vit:
80+
release_ids = []
81+
for _, p in enumerate(infer_state.multimodal_params):
82+
for img in p["images"] + p["audios"]:
83+
release_ids.append(img["uuid"])
84+
85+
for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache):
86+
embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir)
87+
self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache)
88+
89+
if release_ids:
90+
self.cache_client.root.release(release_ids)
91+
5892
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
5993
f"Dimension mismatch: text weight dimension is {hidden_size}, "
6094
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -167,44 +167,20 @@ def flash_attention_v3_fwd(
167167
head_dim = q.shape[-1]
168168
softmax_scale = head_dim ** -0.5
169169
window_size = (-1, -1)
170-
torch.ops.sgl_kernel.fwd.default(
170+
o = flash_attn_varlen_func(
171171
q,
172172
k,
173173
v,
174-
None, # k_new
175-
None, # v_new
176-
None, # qv
177-
o, # out
178-
cu_seqlens,
179-
cu_seqlens,
180-
None, # cu_seqlens_k_new
181-
None,
182-
None,
183-
max_seqlen,
184-
max_seqlen,
185-
None, # page_table,
186-
None, # kv_batch_idx
187-
None, # leftpad_k
188-
None, # rotary cos
189-
None, # rotary sin
190-
None, # seqlens_rotary
191-
None,
192-
None,
193-
None,
194-
softmax_scale,
195-
False,
196-
window_size[0],
197-
window_size[1],
198-
0.0,
199-
is_rotary_interleaved=False,
200-
scheduler_metadata=None,
201-
num_splits=1,
202-
pack_gqa=None,
203-
sm_margin=0,
204-
sinks=None,
174+
cu_seqlens_q=cu_seqlens,
175+
cu_seqlens_k=cu_seqlens,
176+
max_seqlen_q=max_seqlen,
177+
max_seqlen_k=max_seqlen,
178+
softmax_scale=softmax_scale,
179+
causal=False,
180+
window_size=window_size,
181+
softcap=0.0,
205182
)
206-
207-
return
183+
return o
208184

209185
except ImportError:
210186
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")

lightllm/server/api_cli.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@ def make_argument_parser() -> argparse.ArgumentParser:
77
parser.add_argument(
88
"--run_mode",
99
type=str,
10-
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"],
10+
choices=[
11+
"normal",
12+
"prefill",
13+
"decode",
14+
"nixl_prefill",
15+
"nixl_decode",
16+
"pd_master",
17+
"config_server",
18+
"visual",
19+
"visual_only",
20+
],
1121
default="normal",
1222
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
1323
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
@@ -605,6 +615,40 @@ def make_argument_parser() -> argparse.ArgumentParser:
605615
default=0.03,
606616
help="""The interval of the schedule time, default is 30ms.""",
607617
)
618+
parser.add_argument(
619+
"--image_embed_dir",
620+
type=str,
621+
default=None,
622+
help="path for vit embed",
623+
)
624+
parser.add_argument(
625+
"--enable_remote_vit",
626+
action="store_true",
627+
help="Whether to enable remote vit for multimodal service.",
628+
)
629+
parser.add_argument(
630+
"--remote_vit_port",
631+
type=int,
632+
default=12346,
633+
help="The port number for the remote vit service.",
634+
)
635+
parser.add_argument(
636+
"--redis_port",
637+
type=int,
638+
default=6379,
639+
help="The port number for the redis service in config_server mode.",
640+
)
641+
parser.add_argument(
642+
"--redis_evict_fraction",
643+
type=float,
644+
default=0.3,
645+
help="The evict fraction for the redis service in config_server mode.",
646+
)
647+
parser.add_argument(
648+
"--start_redis",
649+
action="store_true",
650+
help="Whether to start the redis service in config_server mode.",
651+
)
608652
parser.add_argument(
609653
"--enable_cpu_cache",
610654
action="store_true",

lightllm/server/api_http.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from .multimodal_params import MultimodalParams
4444
from .httpserver.manager import HttpServerManager
4545
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
46-
from .api_lightllm import lightllm_get_score
46+
from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding
4747
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
4848
from lightllm.utils.log_utils import init_logger
4949
from lightllm.utils.error_utils import ServerBusyError
@@ -92,6 +92,8 @@ def set_args(self, args: StartArgs):
9292
self.httpserver_manager = HttpServerManagerForPDMaster(
9393
args=args,
9494
)
95+
elif args.run_mode == "visual":
96+
self.metric_client = MetricClient(args.metric_port)
9597
else:
9698
init_tokenizer(args) # for openai api
9799
SamplingParams.load_generation_cfg(args.model_dir)
@@ -136,7 +138,7 @@ def get_model_name():
136138
@app.get("/health", summary="Check server health")
137139
@app.head("/health", summary="Check server health")
138140
async def healthcheck(request: Request):
139-
if g_objs.args.run_mode == "pd_master":
141+
if g_objs.args.run_mode in ["pd_master", "visual"]:
140142
return JSONResponse({"message": "Ok"}, status_code=200)
141143

142144
if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
@@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response:
221223
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
222224

223225

226+
@app.post("/get_image_embedding")
227+
async def get_image_embed(request: Request) -> Response:
228+
try:
229+
return await lightllm_get_image_embedding(request, g_objs.httpserver_manager)
230+
except ServerBusyError as e:
231+
logger.error("%s", str(e), exc_info=True)
232+
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
233+
except Exception as e:
234+
logger.error("An error occurred: %s", str(e), exc_info=True)
235+
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
236+
237+
224238
@app.post("/")
225239
async def compat_generate(request: Request) -> Response:
226240
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
@@ -359,6 +373,8 @@ async def startup_event():
359373
logger.info("server start up")
360374
loop = asyncio.get_event_loop()
361375
g_objs.set_args(get_env_start_args())
376+
if g_objs.httpserver_manager is None:
377+
return
362378
loop.create_task(g_objs.httpserver_manager.handle_loop())
363379
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
364380
return

lightllm/server/api_lightllm.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections
22
from typing import AsyncGenerator
33
from fastapi import BackgroundTasks, Request
4-
from fastapi.responses import Response, StreamingResponse
4+
from fastapi.responses import Response, StreamingResponse, JSONResponse
55
from lightllm.server.core.objs.sampling_params import SamplingParams
66
from .multimodal_params import MultimodalParams
77
from .httpserver.manager import HttpServerManager
@@ -150,3 +150,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
150150

151151
background_tasks = BackgroundTasks()
152152
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
153+
154+
155+
async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response:
156+
request_dict = await request.json()
157+
# request_dict: {'parameters': {'max_new_tokens': 128},
158+
# 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}}
159+
sample_params_dict = request_dict["parameters"]
160+
sampling_params = SamplingParams()
161+
sampling_params.init(tokenizer=None, **sample_params_dict)
162+
sampling_params.verify()
163+
multimodal_params_dict = request_dict.get("multimodal_params", {})
164+
multimodal_params = MultimodalParams(**multimodal_params_dict)
165+
166+
await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request)
167+
168+
return JSONResponse({"message": "OK"}, status_code=200)

lightllm/server/api_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
66
parser = make_argument_parser()
77
args = parser.parse_args()
8-
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
8+
from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start
99

1010
if args.run_mode == "pd_master":
1111
pd_master_start(args)
1212
elif args.run_mode == "config_server":
1313
config_server_start(args)
14+
elif args.run_mode == "visual":
15+
visual_start(args)
1416
else:
1517
normal_or_p_d_start(args)

0 commit comments

Comments
 (0)