1+ import rpyc
2+ import socket
13import torch
24import torch .distributed as dist
35
46from lightllm .models .llama .layer_weights .pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
57from lightllm .models .llama .infer_struct import LlamaInferStateInfo
68from lightllm .models .llama .layer_infer .pre_layer_infer import LlamaPreLayerInfer
9+ from lightllm .server .embed_cache .utils import get_shm_name_embed , load_tensor_afs
710from lightllm .common .basemodel .triton_kernel .multimodal_emb import multimodal_emb
811from lightllm .distributed .communication_op import all_reduce
12+ from lightllm .utils .envs_utils import get_env_start_args
913
1014
1115"""
2630class LlamaMultimodalPreLayerInfer (LlamaPreLayerInfer ):
2731 def __init__ (self , network_config ):
2832 super ().__init__ (network_config )
33+ self .args = get_env_start_args ()
34+ self .cache_client = None
35+ if self .args .enable_remote_vit :
36+ self .cache_client = rpyc .connect ("localhost" , self .args .cache_port , config = {"allow_pickle" : True })
37+ self .cache_client ._channel .stream .sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
38+ return
39+
40+ def _copy_loaded_embed_to_cache (
41+ self , embed_tensor : torch .Tensor , cpu_embed_cache_tensor : torch .Tensor , start_index : int
42+ ):
43+ if embed_tensor .ndim == 2 :
44+ embed_tensor = embed_tensor .unsqueeze (1 )
45+
46+ token_num , layer_num , hidden_size = embed_tensor .shape
47+ cpu_embed_cache_tensor [start_index : start_index + token_num , :layer_num , :hidden_size ].copy_ (embed_tensor )
2948 return
3049
3150 def context_forward (self , input_ids , infer_state : LlamaInferStateInfo , layer_weight : LlamaPreAndPostLayerWeight ):
3251 img_start_token_ids = []
3352 img_token_lens = []
3453 img_start_locs_in_cache = []
54+ unique_uids = []
3555 device = layer_weight .wte_weight_ .weight .device
3656 dtype = layer_weight .wte_weight_ .weight .dtype
3757 hidden_size = layer_weight .wte_weight_ .weight .shape [1 ]
3858
39- for batch_id , p in enumerate (infer_state .multimodal_params ):
59+ for _ , p in enumerate (infer_state .multimodal_params ):
4060 for img in p ["images" ] + p ["audios" ]:
4161 # skip the same image
4262 if img ["token_id" ] in img_start_token_ids :
4363 continue
4464 img_start_token_ids .append (img ["token_id" ])
4565 img_token_lens .append (img ["token_num" ])
4666 img_start_locs_in_cache .append (img ["start_index_in_embed_cache" ])
67+ unique_uids .append (img ["uuid" ])
4768 out = torch .zeros ((len (input_ids ), hidden_size ), dtype = dtype , device = device )
4869
4970 from lightllm .server .router .model_infer .infer_batch import g_infer_context
@@ -55,6 +76,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5576 else cpu_embed_cache_client .cpu_embed_cache_tensor
5677 )
5778
79+ if self .args .enable_remote_vit :
80+ release_ids = []
81+ for _ , p in enumerate (infer_state .multimodal_params ):
82+ for img in p ["images" ] + p ["audios" ]:
83+ release_ids .append (img ["uuid" ])
84+
85+ for uid , start_index_in_embed_cache in zip (unique_uids , img_start_locs_in_cache ):
86+ embed_tensor = load_tensor_afs (get_shm_name_embed (uid ), self .args .image_embed_dir )
87+ self ._copy_loaded_embed_to_cache (embed_tensor , cpu_embed_cache_tensor , start_index_in_embed_cache )
88+
89+ if release_ids :
90+ self .cache_client .root .release (release_ids )
91+
5892 assert cpu_embed_cache_tensor .shape [2 ] == hidden_size , (
5993 f"Dimension mismatch: text weight dimension is { hidden_size } , "
6094 f"but image embed dimension is { cpu_embed_cache_tensor .shape [2 ]} "
0 commit comments