diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index b43f8f95af..b21464a02b 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -204,8 +204,7 @@ def flash_attention_v3_fwd( sm_margin=0, sinks=None, ) - - return + return o except ImportError: print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index af39e5da7b..7f4bf2513e 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,16 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], + choices=[ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual_only", + ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -61,6 +70,14 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--config_server_visual_redis_port", + type=int, + default=None, + help="""when run_mode is config_server, set this params will start a redis server, + when a llm infer node start to set this params, the visual infer module will start a + proxy module use config server to find remote vit infer nodes to infer img""", + ) parser.add_argument( "--nixl_pd_kv_page_num", type=int, @@ -458,6 +475,24 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) + parser.add_argument( + "--visual_rpyc_port", + type=int, + default=None, + help=""" + when run_mode is visual_only, set this port, make others to call local visual infer to + transfer image to embed. + """, + ) + parser.add_argument( + "--visual_use_proxy_mode", + action="store_true", + help=""" + when run_mode is normal, set this params, + will call remote visual infer to transfer image to embed, + need set --config_server_host, --config_server_port, + --config_server_visual_redis_port""", + ) parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) @@ -629,6 +664,22 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--afs_image_embed_dir", + type=str, + default=None, + help="path for vit embed, when use vit remote infer mode", + ) + parser.add_argument( + "--afs_embed_capacity", + type=int, + default=250000, + help=""" + capacity for vit embed in remote infer mode, + it control how many image can be cached in afs, + when the cache is full, the least recently used + image embed will be removed""", + ) parser.add_argument( "--enable_cpu_cache", action="store_true", diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..6e04d5d47e 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode in ["visual_only"]: + visual_only_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 364f9ca281..c2388830da 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -15,6 +15,7 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import has_audio_module, has_vision_module @@ -57,7 +58,8 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return @@ -73,7 +75,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -174,6 +176,10 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + if args.afs_image_embed_dir is not None: + os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) + os.chmod(args.afs_image_embed_dir, 0o777) + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -240,6 +246,8 @@ def normal_or_p_d_start(args): already_uesd_ports.append(args.nccl_port) if args.pd_decode_rpyc_port is not None: already_uesd_ports.append(args.pd_decode_rpyc_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -248,7 +256,8 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp, + used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -265,14 +274,14 @@ def normal_or_p_d_start(args): ) = can_use_ports[0:10] can_use_ports = can_use_ports[10:] - visual_model_tp_ports = [] visual_nccl_ports = [] for _ in range(args.visual_dp): - tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - visual_model_tp_ports.append(tp_ports_for_dp) - can_use_ports = can_use_ports[args.visual_tp :] - visual_nccl_ports.append(can_use_ports[0]) - can_use_ports = can_use_ports[1:] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] # 将申请好的端口放入args参数中 if args.nccl_port is None: @@ -323,16 +332,29 @@ def normal_or_p_d_start(args): ) if not args.disable_vision: - from .visualserver.manager import start_visual_process - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, visual_model_tp_ports), - ], - ) + if not args.visual_use_proxy_mode: + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args,), + ], + ) + else: + from .visualserver.proxy_manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args,), + ], + ) if not args.disable_audio: from .audioserver.manager import start_audio_process @@ -469,6 +491,62 @@ def pd_master_start(args): http_server_process.wait() +def visual_only_start(args): + from lightllm.server.core.objs.start_args_type import StartArgs + + args: StartArgs = args + if args.afs_image_embed_dir is not None: + os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) + os.chmod(args.afs_image_embed_dir, 0o777) + + already_uesd_ports = [] + already_uesd_ports.append(args.visual_rpyc_port) + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp + args.visual_dp, + used_ports=already_uesd_ports, + ) + + if args.visual_gpu_ids is None: + args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) + if args.visual_infer_batch_size is None: + args.visual_infer_batch_size = args.visual_dp + if args.data_type is None: + from lightllm.utils.config_utils import get_dtype + + args.data_type = get_dtype(args.model_dir) + assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + + logger.info(f"alloced ports: {can_use_ports}") + + args.visual_nccl_ports = can_use_ports[: args.visual_dp] + can_use_ports = can_use_ports[args.visual_dp :] + args.visual_node_id = uuid.uuid4().int + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.visual_only_manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args,), + ], + ) + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": @@ -476,6 +554,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.config_server_visual_redis_port is not None: + start_redis_service(args) + set_env_start_args(args) command = [ diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..5c015f234c 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.objs import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name @@ -19,7 +20,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -80,6 +107,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_objs) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index c7c3975f7d..4382899ba8 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,7 +8,9 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -20,6 +22,9 @@ class StartArgs: pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) + config_server_visual_redis_port: int = field(default=None) + afs_image_embed_dir: str = field(default=None) + afs_embed_capacity: int = field(default=250000) pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") @@ -81,6 +86,7 @@ class StartArgs: enable_multimodal: bool = field(default=False) disable_vision: Optional[bool] = field(default=None) disable_audio: Optional[bool] = field(default=None) + visual_use_proxy_mode: bool = field(default=False) enable_tpsp_mix_mode: bool = field(default=False) enable_dp_prefill_balance: bool = field(default=False) enable_decode_microbatch_overlap: bool = field(default=False) @@ -99,12 +105,14 @@ class StartArgs: job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) push_interval: int = field(default=10) + visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) + visual_rpyc_port: Optional[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py new file mode 100644 index 0000000000..5dcbac8d61 --- /dev/null +++ b/lightllm/server/embed_cache/afs_utils.py @@ -0,0 +1,164 @@ +import os +import time +import torch +import uuid +import itertools +from typing import List, Tuple, Optional +from pathlib import Path +from .redis_utils import RedisMetadataLib +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class AfsUtils: + def __init__(self, base_dir: str, dir_depth: int = 2): + self.args = get_env_start_args() + self.base_dir = base_dir + # 判断 base_dir 是否存在,不存在则创建并赋予777权限,让其他人也可以写入 + if not os.path.exists(base_dir): + os.makedirs(base_dir, mode=0o777, exist_ok=True) + + # build sub dirs + parent_dir = Path(base_dir) + subdirs = ["".join(p) for p in itertools.product("0123456789abcdef", repeat=dir_depth)] + for sub in subdirs: + sub_dir_path = parent_dir / sub + os.makedirs(sub_dir_path, mode=0o777, exist_ok=True) + return + + def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> bool: + try: + target_path = self._get_afs_path(name) + if target_path.exists(): + return True + tmp_path = self._get_afs_path(name=name, uuid_tail_str=str(uuid.uuid4())) + with open(tmp_path, "wb") as f: + tensor = tensor.detach().cpu() + dest = torch.empty_like(tensor) + dest.copy_(tensor) + torch.save(dest, f, _use_new_zipfile_serialization=False, pickle_protocol=4) + os.rename(tmp_path, target_path) + if self.args.detail_log: + logger.debug(f"save tensor to afs success, name: {name} target_path: {target_path}") + os.chmod(target_path, 0o777) + return True + except Exception as e: + logger.warning(f"failed to save embed tensor file: {target_path} tmp_path: {tmp_path} excetion {str(e)}") + return False + finally: + try: + tmp_path.unlink(missing_ok=True) + except: + pass + + def load_tensor_afs(self, name: str) -> Optional[torch.Tensor]: + try: + path = self._get_afs_path(name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) + except Exception as e: + logger.warning(f"fail to load afs file {name} error: {str(e)}") + return None + + def free_afs(self, name: str) -> bool: + try: + path = self._get_afs_path(name) + if not path.exists(): + return True + path.unlink(missing_ok=True) + return True + except Exception as e: + logger.warning(f"free_afs name: {name} error: {str(e)}") + return False + return + + def exist_afs(self, name: str) -> bool: + try: + path = self._get_afs_path(name) + return path.exists() + except Exception as e: + logger.warning(f"exist_afs name: {name} error: {str(e)}") + return False + + def _get_afs_path(self, name: str, uuid_tail_str: Optional[str] = None) -> Path: + if uuid_tail_str is None: + return Path(self.base_dir) / name[0:2] / name + else: + return Path(self.base_dir) / name[0:2] / f"{name}.{uuid_tail_str}" + + +class SepEmbedHandler: + def __init__( + self, + afs_embed_dir: str, + redis_host: str, + redis_port: int, + capacity: int = 250000, + evict_fraction: float = 0.3, + ) -> None: + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 100: + raise ValueError("capacity must be >= 100") + + redis_url = f"redis://{redis_host}:{redis_port}/0" + self.redis_client = RedisMetadataLib(redis_url=redis_url) + self.capacity = capacity + self.remove_count = int(self.capacity * evict_fraction) # full的时候,每次清理的数量 + self.afs_embed_dir = afs_embed_dir + self.afs_utils = AfsUtils(self.afs_embed_dir) + self.args = get_env_start_args() + + def full_to_clean(self): + remove_objs: List[str] = self.redis_client.get_eviction_candidates( + remove_size=self.remove_count, capacity=self.capacity + ) + for obj in remove_objs: + try: + if self.afs_utils.free_afs(obj): + self.redis_client.remove([obj]) + if self.args.detail_log: + logger.debug(f"full_to_clean remove md5 {obj} from redis and afs success") + except BaseException as e: + logger.warning(f"full_to_clean md5 {obj} error {str(e)}") + + def insert(self, md5: str, tensor: torch.Tensor) -> bool: + self.full_to_clean() + try: + # 保证一定会有清理的可能性 + self.redis_client.update(md5) + ans = self.afs_utils.save_tensor_afs(md5, tensor) + self.redis_client.update(md5) + return ans + except: + return False + + def load(self, md5: str) -> Optional[torch.Tensor]: + try: + ans = self.afs_utils.load_tensor_afs(md5) + if ans is not None: + self.redis_client.update(md5) + return ans + else: + return None + except Exception as e: + logger.warning(f"load md5 {md5} error {str(e)}") + return None + + def check_ready(self, md5_list: List[str]) -> List[bool]: + try: + tmp1 = self.redis_client.check_and_update(md5_list) + assert len(tmp1) == len(md5_list) + start = time.time() + tmp2 = [exists and self.afs_utils.exist_afs(md5) for md5, exists in zip(md5_list, tmp1)] + cost_time = time.time() - start + if cost_time > 0.05: + logger.warning(f"slow afs check exist {cost_time} seconds, md5_list size: {len(md5_list)}") + assert len(tmp1) == len(tmp2) + ans = [a and b for a, b in zip(tmp1, tmp2)] + return ans + except Exception as e: + logger.warning(f"check_ready error {str(e)}") + return [False] * len(md5_list) diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index b72d8b2c5e..2d62cb73e5 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -15,7 +15,7 @@ class CpuEmbedCacheClient(object): This class is responsible for handling cpu kv cache meta data. """ - def __init__(self, create_meta_data: bool, init_shm_data: bool): + def __init__(self, create_meta_data: bool, init_shm_data: bool, pin_shm: bool = True): self.args = get_env_start_args() # to do here need calcu from from settings. self.embed_cache_tensor_meta = calcu_embed_cache_meta() @@ -37,7 +37,7 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool): cache_tensor_creator = CpuCacheCreator(tensor_spec=cache_tensor_spec) self.cpu_embed_cache_tensor, _ = cache_tensor_creator.create_or_attach( init_shm_data=init_shm_data, - pin=not init_shm_data, + pin=pin_shm, pin_no_blocking=False, ) return @@ -69,8 +69,6 @@ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: ) return - return - if __name__ == "__main__": mem = MemoryManager(total_size=2000) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 5ad26fbcc8..73a0e0b250 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -48,7 +48,7 @@ def __init__(self, args) -> None: self.token_id_range_start = 0 self.token_id_range_end = 0 self.use_config_server = self.args.config_server_host and self.args.config_server_port - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=True, init_shm_data=True) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=True, init_shm_data=True, pin_shm=False) def _check_and_set_new_id_range(self, alloced_token_num): need_update_range = self.token_id_range_start + alloced_token_num >= self.token_id_range_end @@ -128,18 +128,26 @@ def _free_to_alloc(self, free_min_count: int, new_md5_dict: Dict[str, int]) -> D def _add_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] - self._sorted_records.remove(rec) - rec.ref += 1 - self._sorted_records.add(rec) + self._update_record_ref(rec, 1) return def _del_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] + self._update_record_ref(rec, -1) + return + + def _update_record_ref(self, rec: Record, delta: int): self._sorted_records.remove(rec) - rec.ref -= 1 + rec.ref += delta + rec.visittime = time.time() self._sorted_records.add(rec) return + def _update_record_ref_by_id(self, id_: int, delta: int): + rec: Record = self._id_to_records[id_] + self._update_record_ref(rec, delta) + return + def _judge_enough_token_cache(self, md5sum_list: list[str], token_num_list: list[int]) -> bool: tmp_dict = {} for md5, token_num in zip(md5sum_list, token_num_list): @@ -167,6 +175,10 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l alloc_md5_dict = self._free_to_alloc( free_min_count=new_needed - (self.capacity - self.occupied), new_md5_dict=new_md5_dict ) + for md5 in add_ref_m_list: + # 解锁 + self._del_ref(md5) + if len(alloc_md5_dict) == len(new_md5_dict): for md5sum, mem_block in alloc_md5_dict.items(): token_num = new_md5_dict[md5sum] @@ -190,10 +202,6 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l self._sorted_records.add(rec) self.occupied += 1 - for md5 in add_ref_m_list: - # 解锁 - self._del_ref(md5) - # 遍历加 ref results = [] for md5 in md5sum_list: @@ -215,10 +223,7 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: - rec: Record = self._id_to_records[id_] - self._sorted_records.remove(rec) - rec.ref -= 1 - self._sorted_records.add(rec) + self._update_record_ref_by_id(id_, -1) def set_items_data(self, ids: list[int]) -> None: for id_ in ids: diff --git a/lightllm/server/embed_cache/redis_utils.py b/lightllm/server/embed_cache/redis_utils.py new file mode 100644 index 0000000000..7f238d6d8d --- /dev/null +++ b/lightllm/server/embed_cache/redis_utils.py @@ -0,0 +1,167 @@ +import redis +from typing import List, Tuple, Union, Optional + + +class RedisMetadataLib: + """ + # 代码任务 + 创建一个基于redis 管理的元数据操作库代码。 + 要求: + 2. 提供一个包装的 redis 操作client 库,提供以下功能: + (1) 提供一个时间排序队列,向队列中插入md5,并更新时间错(单位为s即可). + (2) 输入为(md5_list,), 向队列中插入所有的md5, 并更新其对应时间错。 + (3) 输入为(md5_list,), 将队列中的md5进行删除。 + (4) 输入为(md5_list,), 返回 md5_list 中所有md5 每个是否在链表中存在,返回一个bool list来标识,同时对所有存在的md5,更新时间错到最新。 + (5) 输入为(remove_size, capcity), 当时间排序队列中的元素数量大于等于capcity, 返回时间排序队列中排在前面的 remove_size 个元素,其内容为 md5。 + (6) 所有操作都使用lua 脚本,以实现原子化操作,同时返回的错误要能区分具体错误的原因,注意lua脚本的可读性,和相关函数的输入输出测试。时间错为server端s级别的参数。 + """ + + def __init__(self, redis_url: str = "redis://localhost:6379/0", prefix: str = "meta"): + # decode_responses=True 确保返回的是字符串而非字节 + self.r = redis.Redis.from_url(redis_url, decode_responses=True) + self.lru_key = f"{prefix}:queue:lru" + self._register_scripts() + + def _register_scripts(self): + """注册 Lua 脚本实现原子化操作""" + + # (1) & (2) 更新/插入:支持传入单个或多个 MD5 + # 逻辑:获取服务器时间,循环执行 ZADD + self._lua_update = self.r.register_script( + """ + local lru_key = KEYS[1] + local now = redis.call('TIME')[1] + local count = 0 + for i, md5 in ipairs(ARGV) do + redis.call('ZADD', lru_key, now, md5) + count = count + 1 + end + return count + """ + ) + + # (3) 删除:从队列中移除指定的 MD5 + self._lua_remove = self.r.register_script( + """ + local lru_key = KEYS[1] + local count = 0 + for i, md5 in ipairs(ARGV) do + count = count + redis.call('ZREM', lru_key, md5) + end + return count + """ + ) + + # (4) 检查并更新:判断是否存在,存在则刷新时间,返回 bool 状态列表 + self._lua_check_update = self.r.register_script( + """ + local lru_key = KEYS[1] + local now = redis.call('TIME')[1] + local results = {} + for i, md5 in ipairs(ARGV) do + if redis.call('ZSCORE', lru_key, md5) then + redis.call('ZADD', lru_key, now, md5) + table.insert(results, 1) + else + table.insert(results, 0) + end + end + return results + """ + ) + + # (5) 容量清理:检查容量并获取候选列表 + self._lua_evict = self.r.register_script( + """ + local lru_key = KEYS[1] + local remove_size = tonumber(ARGV[1]) + local capacity = tonumber(ARGV[2]) + + local current_size = redis.call('ZCARD', lru_key) + if current_size >= capacity then + -- 按照分数(时间戳)从小到大排列,获取最旧的 N 个 + return redis.call('ZRANGE', lru_key, 0, remove_size - 1) + else + return {} + end + """ + ) + + def _to_list(self, data: Union[str, List[str]]) -> List[str]: + """内部工具:将输入统一转为列表形式""" + if isinstance(data, str): + return [data] + return data + + def update(self, md5_list: Union[str, List[str]]) -> int: + """ + 功能 (1) & (2):插入或更新 md5 的时间戳。 + 支持传入单个字符串或字符串列表。 + """ + items = self._to_list(md5_list) + if not items: + return 0 + return self._lua_update(keys=[self.lru_key], args=items) + + def remove(self, md5_list: Union[str, List[str]]) -> int: + """ + 功能 (3):将队列中的 md5 进行删除。 + 支持传入单个字符串或字符串列表。 + """ + items = self._to_list(md5_list) + if not items: + return 0 + return self._lua_remove(keys=[self.lru_key], args=items) + + def check_and_update(self, md5_list: List[str]) -> List[bool]: + """ + 功能 (4):返回 md5_list 中每个 md5 是否在队列中存在。 + 对存在的 md5 会同时更新时间戳到最新。 + """ + if not md5_list: + return [] + raw_res = self._lua_check_update(keys=[self.lru_key], args=md5_list) + return [res == 1 for res in raw_res] + + def get_eviction_candidates(self, remove_size: int, capacity: int) -> List[str]: + """ + 功能 (5):当队列数量 >= capacity 时,返回排在前面的 remove_size 个 md5。 + """ + return self._lua_evict(keys=[self.lru_key], args=[remove_size, capacity]) + + +# ---------------- 功能测试 ---------------- + + +def test_meta_lib(): + lib = RedisMetadataLib(prefix="test_service") + # 清理历史数据 + lib.r.delete(lib.lru_key) + + print("1. 测试更新 (update)") + lib.update("file_0") # 单个 + lib.update(["file_1", "file_2", "file_3"]) # 批量 + print(f"当前队列大小: {lib.r.zcard(lib.lru_key)}") + + print("\n2. 测试检查并更新 (check_and_update)") + # file_1 存在,file_none 不存在,file_3 存在 + check_list = ["file_1", "file_none", "file_3"] + exists_results = lib.check_and_update(check_list) + for m, exists in zip(check_list, exists_results): + print(f"MD5: {m}, 存在状态: {exists}") + + print("\n3. 测试容量逐出 (get_eviction_candidates)") + # 当前有 4 个元素,设容量为 3,要求返回最旧的 2 个 + candidates = lib.get_eviction_candidates(remove_size=2, capacity=3) + print(f"容量达到3时,建议删除的最旧2个元素: {candidates}") + + print("\n4. 测试删除 (remove)") + removed_count = lib.remove(["file_0", "file_1"]) + print(f"成功移除数量: {removed_count}") + + final_check = lib.check_and_update(["file_1", "file_2"]) + print(f"最终检查 [file_1, file_2]: {final_check}") + + +if __name__ == "__main__": + test_meta_lib() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e28e4c93ad..8126d76446 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -174,6 +174,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, token_num = self.tokenizer.get_image_token_length(img) md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) md5sums.append(md5sum) + img.md5 = md5sum tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -183,6 +184,7 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, token_num = self.tokenizer.get_audio_token_length(audio) md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) md5sums.append(md5sum) + audio.md5 = md5sum tokens_nums.append(token_num) datas.append(data) items.append(audio) @@ -594,6 +596,9 @@ async def _wait_to_token_package( except asyncio.TimeoutError: pass + if req_status.aborted: + raise Exception(f"req_id {group_request_id} aborted notifyed by other module") + if not self.disable_abort and request is not None and await request.is_disconnected(): await self.abort(group_request_id) raise Exception(f"req_id {group_request_id} disconnected") @@ -718,11 +723,16 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) + _is_aborted = False for req in req_status.group_req_objs.shm_req_objs: + _is_aborted = _is_aborted or req.is_aborted logger.debug(f"httpserver release req_id {req.request_id}, index {req.index_in_shm_mem}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) + if _is_aborted: + req_status.aborted = True + logger.debug(f"mark req_id {req_status.group_req_objs.group_req_id} aborted in recycle loop") # 先保留这个关键得日志,用于方便定位重构中的问题。 if time.time() - pre_time_mark > 120: @@ -833,6 +843,7 @@ def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], sta time_mark=start_time, ) self.out_token_info_list = [] + self.aborted = False def can_release(self): for req in self.group_req_objs.shm_req_objs: diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 09a07455b3..05dd479411 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -24,6 +24,8 @@ def __init__(self, **kwargs): self.start_index_in_embed_cache = None # the audio token num self.token_num = None + # the data md5 sum + self.md5 = None # the audio length self.audio_length = None @@ -65,6 +67,7 @@ def to_dict(self): ret["token_id"] = self.token_id ret["token_num"] = self.token_num ret["start_index_in_embed_cache"] = self.start_index_in_embed_cache + ret["md5"] = self.md5 return ret def to_origin_dict(self): @@ -89,6 +92,8 @@ def __init__(self, **kwargs): self.start_index_in_embed_cache = None # the image token num self.token_num = None + # the data md5 sum + self.md5 = None # the start index of the image in the input_ids # used for mrope position id calculation self.start_idx = None @@ -141,6 +146,7 @@ def to_dict(self): ret["token_num"] = self.token_num ret["grid_thwd"] = self.grid_thwd ret["start_idx"] = self.start_idx + ret["md5"] = self.md5 return ret def to_origin_dict(self): diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..a165be78f2 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,13 +7,15 @@ import pickle import inspect import setproctitle +import threading +import collections from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry @@ -29,8 +31,8 @@ class VisualManager: def __init__( self, args: StartArgs, - visual_model_rpc_ports, ): + self.args = args context = zmq.Context(2) enable_audio = not args.disable_audio if enable_audio: @@ -48,47 +50,38 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.cache_port = args.cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir - self.tp_world_size = args.tp self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp + # image 最大推理 batch size self.infer_batch_size = args.visual_infer_batch_size - self.trust_remote_code = args.trust_remote_code - self.args = args - self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() + self.lock = asyncio.Lock() async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): - tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] - rpc_model = await start_model_process( - port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id - ) + + rpc_model = await start_model_process() self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] for dp_rank_id in range(self.vit_dp): # async init model process for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "vit_dp": self.vit_dp, + "device_id": device_id, "vit_tp": self.vit_tp, - "cache_port": self.cache_port, + "cache_port": self.args.cache_port, "tp_rank_id": tp_rank_id, "dp_rank_id": dp_rank_id, - "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id, "data_type": self.args.data_type, "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], - "visual_gpu_ids": self.args.visual_gpu_ids, "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), @@ -98,125 +91,114 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - async def infer_imgs(self, images: List[ImageItem]): - if len(images) == 0: + def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[ImageItem]: + shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) + is_aborted = shm_req.is_aborted + disable_prompt_cache = shm_req.sample_params.disable_prompt_cache + self.shm_req_manager.put_back_req_obj(shm_req) + # case 0 + if is_aborted: + # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 + # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 + # 需要一些一致的流程来保证不出现异步问题。 + return [] + + multimodal_params = group_req_indexes.multimodal_params + img_uuids = [img.uuid for img in multimodal_params.images] + # disable prompt cache通常用来测试,需要也去掉image cache的影响 + if disable_prompt_cache: + ready_image = [False] * len(img_uuids) + else: + if len(img_uuids) > 0: + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + else: + ready_image = [] + + images_need_infer = [] + for img, ready in zip(multimodal_params.images, ready_image): + if not ready: + images_need_infer.append(img) + + return images_need_infer + + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + images_need_infer = self.get_need_infer_images(group_req_indexes) + + if len(images_need_infer) == 0: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + else: + await self.handle_images(images_need_infer) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return - tasks = [] - for vit_dp_rank in range(self.vit_dp): - assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] - if assigned_images: - for vit_tp_rank in range(self.vit_tp): - task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) - tasks.append(task) + async def handle_images(self, images_need_infer: List[ImageItem]): + if not hasattr(self, "cur_dp_index"): + self.cur_dp_index = 0 + + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.append( + self.infer_images(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images]) + ) - await asyncio.gather(*tasks) + async with self.lock: + try: + await asyncio.gather(*taskes) + except BaseException as e: + logger.exception(str(e)) + raise e + + # 等待推理通知已经 ok + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) return - async def loop_for_fwd(self): - while True: - if len(self.waiting_reqs) == 0: - await asyncio.sleep(0.01) # 10ms - else: - processing_group_reqs = [] - images_need_infer = [] - ready_to_send = [] - - def flush_ready(force: bool = False): - if not ready_to_send: - return - if not force and len(ready_to_send) < self.send_batch_size: - return - - for group_req_indexes in ready_to_send: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - ready_to_send.clear() - - while len(self.waiting_reqs) > 0: - group_req_indexes = self.waiting_reqs.pop(0) - shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) - is_aborted = shm_req.is_aborted - disable_prompt_cache = shm_req.sample_params.disable_prompt_cache - self.shm_req_manager.put_back_req_obj(shm_req) - if is_aborted: - # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 - # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 - # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - continue - - multimodal_params = group_req_indexes.multimodal_params - - img_uuids = [img.uuid for img in multimodal_params.images] - # disable prompt cache通常用来测试,需要也去掉image cache的影响 - if disable_prompt_cache: - ready_image = [False] * len(img_uuids) - else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) - - for img, ready in zip(multimodal_params.images, ready_image): - if not ready: - images_need_infer.append(img) - - if len(images_need_infer) == self.infer_batch_size: - await self.infer_imgs(images_need_infer) - images_need_infer = [] - ready_to_send.extend(processing_group_reqs) - processing_group_reqs = [] - flush_ready(force=False) - - if len(images_need_infer) == 0: - ready_to_send.append(group_req_indexes) - flush_ready(force=False) - else: - processing_group_reqs.append(group_req_indexes) - - if len(images_need_infer) > 0: - await self.infer_imgs(images_need_infer) - images_need_infer = [] - - # 这些处理完 image 的 group 也 ready 了 - ready_to_send.extend(processing_group_reqs) - processing_group_reqs = [] - flush_ready(force=True) + async def infer_images(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + await asyncio.gather(*taskes) async def loop_for_netio_req(self): - if not hasattr(self, "visual_recv_max_count"): - self.visual_recv_max_count = 64 - - while True: - try: - for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): - logger.info( - f"visual recv req id {recv_req.group_req_id} " - f"img count {len(recv_req.multimodal_params.images)}" - ) - self.waiting_reqs.append(recv_req) - else: - assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) - except zmq.ZMQError: - # 当队列已经开始清空的时候,将一次接受数量下调 - self.visual_recv_max_count = 64 - await asyncio.sleep(0.01) + try: + while True: + recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) + if isinstance(recv_req, GroupReqIndexes): + logger.info( + f"visual recv req id {recv_req.group_req_id} " + f"img count {len(recv_req.multimodal_params.images)}" + ) + asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + else: + assert False, f"Error Req Inf {recv_req}" + except Exception as e: + logger.exception(str(e)) def clean_up(self): - for model_rpc in self.model_rpcs: - model_rpc.rpc_server_process.kill() - for model_rpc in self.model_rpcs: - model_rpc.rpc_server_process.join() return -def start_visual_process(args, model_rpc_ports, pipe_writer): +def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() try: - visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) + visualserver = VisualManager(args=args) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) @@ -231,6 +213,5 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index e69de29bb2..ae3c4204db 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -0,0 +1,61 @@ +import asyncio +import rpyc +import inspect +import uuid +import os +import multiprocessing +from lightllm.utils.retry_utils import retry +from rpyc.utils.factory import unix_connect +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.envs_utils import get_env_start_args +from .model_rpc_client import VisualModelRpcClient +from .model_rpc import VisualModelRpcServer +from ..objs import rpyc_config + + +def _init_env(socket_path: str, success_event): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + + import lightllm.utils.rpyc_fix_utils as _ + + t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config=rpyc_config) + success_event.set() + t.start() + return + + +async def start_model_process(): + import lightllm.utils.rpyc_fix_utils as _ + + socket_path = _generate_unix_socket_path() + if os.path.exists(socket_path): + os.remove(socket_path) + + success_event = multiprocessing.Event() + proc = multiprocessing.Process( + target=_init_env, + args=( + socket_path, + success_event, + ), + ) + proc.start() + await asyncio.to_thread(success_event.wait, timeout=40) + assert proc.is_alive() + + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config=rpyc_config) + assert proc.is_alive() + + # 服务端需要调用客户端传入的event所以,客户端需要一个后台线程进行相关的处理。 + conn._bg_thread = rpyc.BgServingThread(conn, sleep_interval=0.001) + + return VisualModelRpcClient(conn) + + +def _generate_unix_socket_path() -> str: + """Generate a random Unix socket path""" + unique_id = uuid.uuid4().hex[:8] + return f"/tmp/lightllm_model_infer_{unique_id}.sock" diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 741707bf93..55f4704a31 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -1,14 +1,14 @@ -import asyncio -import numpy as np import rpyc import torch import socket -import inspect -from datetime import timedelta -from typing import Dict, List, Tuple +import torch.multiprocessing as mp +import queue +import threading +import time +import torch.distributed as dist +from typing import Dict, List, Tuple, Deque, Optional from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain -from rpyc.utils.server import ThreadedServer from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel @@ -22,27 +22,43 @@ from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env -from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) class VisualModelRpcServer(rpyc.Service): def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) - import torch - import torch.distributed as dist - self.vit_dp = kvargs["vit_dp"] + # kvargs = { + # "weight_dir": self.model_weightdir, + # "device_id": device_id, + # "vit_tp": self.vit_tp, + # "cache_port": self.args.cache_port, + # "tp_rank_id": tp_rank_id, + # "dp_rank_id": dp_rank_id, + # "data_type": self.args.data_type, + # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + # "quant_type": self.args.vit_quant_type, + # "quant_cfg": self.args.vit_quant_cfg, + # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "vit_attn_backend": self.vit_attn_backend, + # } + + weight_dir = kvargs["weight_dir"] + self.infer_max_batch_size = kvargs["max_batch_size"] + self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] self.cache_port = kvargs["cache_port"] - weight_dir = kvargs["weight_dir"] - self.vit_rank_id = kvargs["vit_rank_id"] - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) - self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.is_visual_only_mode = get_env_start_args().run_mode == "visual_only" self.data_type = kvargs["data_type"] self.vit_attn_backend = kvargs["vit_attn_backend"] set_vit_att_backend(self.vit_attn_backend) @@ -95,7 +111,24 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + if not self.is_visual_only_mode: + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + else: + # 独立部署vit模式下,不需要连接 cache_client, 结果是写入 afs + args = get_env_start_args() + self.args = args + assert args.visual_dp == 1 + if self.tp_rank_id == 0: + self.afs_handler = SepEmbedHandler( + afs_embed_dir=self.args.afs_image_embed_dir, + redis_host=self.args.config_server_host, + redis_port=self.args.config_server_visual_redis_port, + capacity=self.args.afs_embed_capacity, + ) + + self._init_taskes() except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -107,112 +140,194 @@ def exposed_init_model(self, kvargs): set_random_seed(2147483647) return + def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): + try: + images = obtain(images) + for i in range(len(images)): + images[i].event = ref_event_list[i] + images[i].start_time = time.time() + self.infer_queue.put(images[i]) + + except BaseException as e: + logger.exception(str(e)) + raise e + return + + def _log_latency(self, image: ImageItem, stage: str): + latency = time.time() - image.start_time + if latency > 0.02: + logger.info(f"{stage} latency {latency:.4f} seconds for image with md5 {image.md5}") + image.start_time = time.time() + + def _init_taskes(self): + self.args = get_env_start_args() + + # 异步队列, 用于接受任务 + self.infer_queue = queue.Queue() + # 将计算得到的结果放入 afs 或者 embed cache 的 queue + self.store_queue = queue.Queue() + + # 限制并发, 主要是为了控制内存用量,防止过多造成内存OOM + self.sempare = threading.Semaphore(self.infer_max_batch_size * 8) + + # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 + self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") + + # 启动任务处理线程 + self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) + self._infer_thread.start() + + self._store_thread = threading.Thread(target=self._store_worker, daemon=True) + self._store_thread.start() + return + # @calculate_time(show=True, min_cost_ms=150) @torch.no_grad() - def forward(self, images: List[ImageItem]): + def _forward(self, images: List[ImageItem]): return self.model.encode(images) - # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images: List[ImageItem]): - images = obtain(images) - all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) + def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) - if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] - self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache - ) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) - torch.cuda.current_stream().synchronize() - return + if not force_same: + # 尝试继续获取更多任务,直到达到 max_num + while len(tasks) < max_num: + try: + self.sempare.acquire() + task = self.infer_queue.get(block=False) + tasks.append(task) + except queue.Empty: + self.sempare.release() + break + else: + while len(tasks) < max_num: + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) + return tasks -class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc - self.vit_tp = vit_tp - self.rpc_server_process = rpc_server_process - self.use_rpc = True - if self.use_rpc: + def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + task = self.store_queue.get(block=True) + tasks.append(task) - def async_wrap(f): - f = rpyc.async_(f) + while len(tasks) < max_num: + try: + task = self.store_queue.get(block=False) + tasks.append(task) + except queue.Empty: + break - async def _func(*args, **kwargs): - ans = f(*args, **kwargs) - await asyncio.to_thread(ans.wait) - # raise if exception - return ans.value + return tasks - return _func + def _infer_worker(self): + """ + 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 + """ + torch.cuda.set_device(self.device_id) + while True: + try: + # 从队列获取任务, 阻塞等待 + if self.tp_rank_id == 0: + images = self._get_image_items_from_infer_queue(max_num=self.infer_max_batch_size) + dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) + else: + ans = [None] + dist.broadcast_object_list(ans, src=0, group=self.gloo_group) + images = self._get_image_items_from_infer_queue(max_num=ans[0], force_same=True) - self._init_model = async_wrap(self.model.init_model) - self._encode = async_wrap(self.model.encode) - else: - self._init_model = self.model.exposed_init_model - self._encode = self.model.exposed_encode - return + for image in images: + self._log_latency(image, stage="queue_cost_time") - async def init_model(self, kvargs): - ans: rpyc.AsyncResult = self._init_model(kvargs) - if self.use_rpc: - await ans - return - else: - return + # 执行任务: 调用父类的forward方法处理图像 + all_img_embeds, uuids, valid_ids = self._forward(images) + all_img_embeds = all_img_embeds.to(torch.device("cuda")) - async def encode(self, images: List[ImageItem]): - ans = self._encode(images) - if self.use_rpc: - return await ans - else: - return ans + if self.is_visual_only_mode: + self._store_to_afs(all_img_embeds, valid_ids, images) + else: + self._store_to_cpu_cache(all_img_embeds, valid_ids, images) + + except Exception as e: + logger.exception(str(e)) + raise e + def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): + for i in range(len(images)): + start, end = valid_ids[i] + image = images[i] + if self.tp_rank_id == 0: + self.cpu_embed_cache_client.copy_vision_to_cache( + embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + ) + cuda_event = torch.cuda.Event() + cuda_event.record() + image.cuda_event = cuda_event + self.store_queue.put(image) -def _init_env(port, device_id): - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) + def _store_to_afs(self, all_img_embeds, valid_ids, images): + all_img_embeds = all_img_embeds.detach().cpu() + for image, valid_id in zip(images, valid_ids): + self._log_latency(image, stage="inference") + start, end = valid_id + gen_embed = all_img_embeds[start:end] + image.gen_embed = gen_embed + self.store_queue.put(image) - import lightllm.utils.rpyc_fix_utils as _ + def _store_worker(self): + """ + 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 + """ + while True: + try: + # 从队列获取任务, 阻塞等待 + images: List[ImageItem] = self._get_image_items_from_store_queue(max_num=self.infer_max_batch_size) - t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) - t.start() - return + if self.is_visual_only_mode: + self._commit_to_afs(images=images) + else: + self._commit_to_cpu_cache(images=images) + for _ in images: + self.sempare.release() -async def start_model_process(port, vit_tp, device_id): - import multiprocessing + except Exception as e: + logger.exception(str(e)) + raise e - proc = multiprocessing.Process( - target=_init_env, - args=( - port, - device_id, - ), - ) - proc.start() - await asyncio.sleep(2) - repeat_count = 0 - while repeat_count < 20: - try: - con = rpyc.connect("localhost", port, config={"allow_pickle": True}) - con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - break - except BaseException: - await asyncio.sleep(1) - repeat_count += 1 - if repeat_count == 20: - raise Exception("init rpc env error!") - - assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + def _commit_to_afs(self, images): + if self.tp_rank_id == 0: + for image in images: + self.afs_handler.insert(image.md5, image.gen_embed) + self._log_latency(image, stage="store_to_afs") + image.event.set() + self._log_latency(image, stage="set_event") + + def _commit_to_cpu_cache(self, images): + if self.tp_rank_id == 0: + for image in images: + # 等待拷贝到cpu cache 完成。 + image.cuda_event.synchronize() + self._log_latency(image, stage="inference") + + uuids = [image.uuid for image in images] + self.cache_client.root.set_items_embed(uuids) + + for image in images: + self._log_latency(image, stage="set_items_embed") + + for image in images: + image.event.set() + self._log_latency(image, stage="set_event") diff --git a/lightllm/server/visualserver/model_infer/model_rpc_client.py b/lightllm/server/visualserver/model_infer/model_rpc_client.py new file mode 100644 index 0000000000..682d6affcc --- /dev/null +++ b/lightllm/server/visualserver/model_infer/model_rpc_client.py @@ -0,0 +1,36 @@ +import asyncio +import rpyc +import threading +from typing import Dict, List, Tuple, Deque, Optional, Union +from lightllm.server.multimodal_params import ImageItem +from .model_rpc import VisualModelRpcServer + + +class VisualModelRpcClient: + def __init__(self, rpc_conn): + self.rpc_conn: VisualModelRpcServer = rpc_conn + + def async_wrap(f): + f = rpyc.async_(f) + + async def _func(*args, **kwargs): + ans = f(*args, **kwargs) + await asyncio.to_thread(ans.wait) + # raise if exception + return ans.value + + return _func + + self._init_model = async_wrap(self.rpc_conn.root.init_model) + self._run_task = async_wrap(self.rpc_conn.root.run_task) + + return + + async def init_model(self, kvargs): + ans: rpyc.AsyncResult = self._init_model(kvargs) + await ans + return + + async def run_task(self, images: List[ImageItem], ref_event_list: List[threading.Event]): + ans = self._run_task(images, ref_event_list) + return await ans diff --git a/lightllm/server/visualserver/objs.py b/lightllm/server/visualserver/objs.py new file mode 100644 index 0000000000..656f3d3eae --- /dev/null +++ b/lightllm/server/visualserver/objs.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +rpyc_config = { + "allow_pickle": True, + "allow_all_attrs": True, + "allow_getattr": True, + "allow_setattr": True, +} + + +@dataclass +class VIT_Obj: + node_id: int + host_ip: str + port: int + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip}:{self.port}, node_id: {self.node_id}" diff --git a/lightllm/server/visualserver/proxy_manager.py b/lightllm/server/visualserver/proxy_manager.py new file mode 100644 index 0000000000..2cf02d19e6 --- /dev/null +++ b/lightllm/server/visualserver/proxy_manager.py @@ -0,0 +1,232 @@ +import time +import asyncio +import uvloop +import rpyc +import socket +import pickle +import inspect +import setproctitle +import threading +import base64 +import httpx +import random +import copy +from typing import List, Dict, Optional +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.core.objs import ShmReqManager, StartArgs +from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from .manager import VisualManager +from .objs import VIT_Obj + +logger = init_logger(__name__) + + +class ProxyVisualManager(VisualManager): + def __init__( + self, + args: StartArgs, + ): + super().__init__(args) + assert self.vit_dp == 1 and self.vit_tp == 1 + self.id_to_rpyc_conn: Dict[str, rpyc.Connection] = {} + self.conn_lock = threading.Lock() + + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False, pin_shm=False) + + self.afs_handler = SepEmbedHandler( + afs_embed_dir=self.args.afs_image_embed_dir, + redis_host=self.args.config_server_host, + redis_port=self.args.config_server_visual_redis_port, + capacity=self.args.afs_embed_capacity, + ) + + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + images_need_infer = self.get_need_infer_images(group_req_indexes) + + # case 1 + if len(images_need_infer) == 0: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + try: + + def _get_not_afs_ready_images(): + readys = self.afs_handler.check_ready([image.md5 for image in images_need_infer]) + not_readys_images = [image for image, ready in zip(images_need_infer, readys) if not ready] + # 将 images_need_infer 按照 self.infer_batch_size 切分成多个 batch,发送给不同的 visual server 进行推理,\ + # 最后等待所有推理完成后再发送给下一个模块 + images_batches = [ + not_readys_images[i : i + self.infer_batch_size] + for i in range(0, len(not_readys_images), self.infer_batch_size) + ] + return images_batches + + images_batches = await asyncio.to_thread(_get_not_afs_ready_images) + taskes = [] + + for images_batch in images_batches: + conn = self.select_vit_conn() + taskes.append(asyncio.to_thread(self.run_task, conn, images_batch)) + + if len(taskes) > 0: + + await asyncio.gather(*taskes) + + # 将需要处理的 image 从 afs 中写入到 cpu cache 中 + def _load_to_cpu_cache(): + for image in images_need_infer: + tensor = self.afs_handler.load(md5=image.md5) + if tensor is None: + raise Exception(f"Failed to load tensor from afs for image with md5 {image.md5}") + start = image.start_index_in_embed_cache + end = start + tensor.shape[0] + assert end - start == image.token_num + self.cpu_embed_cache_client.cpu_embed_cache_tensor[start:end].copy_(tensor) + self.cache_client.root.set_items_embed([image.uuid for image in images_need_infer]) + + await asyncio.to_thread(_load_to_cpu_cache) + + except Exception as e: + # mark aborted + for shm_req_index in group_req_indexes.shm_req_indexes: + shm_req = self.shm_req_manager.get_req_obj_by_index(shm_req_index) + shm_req.is_aborted = True + self.shm_req_manager.put_back_req_obj(shm_req) + + logger.exception(str(e)) + + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + def select_vit_conn(self) -> Optional[rpyc.Connection]: + with self.conn_lock: + if not self.id_to_rpyc_conn: + return None + ids = list(self.id_to_rpyc_conn.keys()) + id = random.choice(ids) + return self.id_to_rpyc_conn[id] + + def run_task(self, conn: rpyc.Connection, images: List[ImageItem]): + event = threading.Event() + # 避免修改原始的 image 对象,主要是为了避免在后续的流程中出现问题,因为后续的流程可能会对 image 对象进行访问, + # 尤其是一些 cache 的逻辑,如果直接修改了原始的 image 对象,可能会导致一些不可预期的问题。 + images = copy.deepcopy(images) + # 将 bytes 从 shm 中读取出来,放到 image.data_bytes 中,供远端的 vit 进行推理使用。 + for image in images: + image.data_bytes = read_shm(get_shm_name_data(image.uuid)) + if self.args.detail_log: + start = time.time() + logger.info(f"Start to remote infer images {[image.md5 for image in images]}") + conn.root.remote_infer_images(images, event) + event.wait(timeout=600) + if self.args.detail_log: + logger.info( + f"Remote infer images done for images {[image.md5 for image in images]}" + f" cost time {time.time() - start} s" + ) + return + + async def loop_to_connect_remote_visual_server(self): + counter = 0 + error_counter = 0 + while True: + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + + counter += 1 + if counter % 6 == 0: + logger.info(f"Got visual server info from config server: {id_to_vit_obj}") + + for node_id in list(self.id_to_rpyc_conn.keys()): + if node_id not in id_to_vit_obj: + logger.info(f"Visual server {node_id} is removed, closing connection") + with self.conn_lock: + self.id_to_rpyc_conn.pop(node_id).close() + + for node_id, vit_obj in id_to_vit_obj.items(): + vit_obj: VIT_Obj = vit_obj + if node_id not in self.id_to_rpyc_conn: + + def _connect(): + from .objs import rpyc_config + + conn = rpyc.connect(vit_obj.host_ip, vit_obj.port, config=rpyc_config) + conn._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + conn._bg_thread = rpyc.BgServingThread(conn, sleep_interval=0.001) + logger.info( + f"Connected to visual server {node_id} at {vit_obj.host_ip}:{vit_obj.port}" + ) + return conn + + try: + new_conn = await asyncio.to_thread(_connect) + with self.conn_lock: + self.id_to_rpyc_conn[node_id] = new_conn + except Exception as e: + logger.exception(str(e)) + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + except Exception as e: + logger.error(f"Error occurred while connecting to config server: {e}") + error_counter += 1 + + if error_counter >= 6: + logger.error( + "Failed to connect to config server for a long time, remove all connections to visual servers" + ) + error_counter = 0 + try: + with self.conn_lock: + for node_id, conn in self.id_to_rpyc_conn.items(): + logger.info(f"Closing connection to visual server {node_id}") + conn.close() + self.id_to_rpyc_conn.clear() + except Exception as e: + logger.exception(str(e)) + + # 在没有连接的时候,高频率更新,有的时候降低更新频率 + if len(self.id_to_rpyc_conn) == 0: + await asyncio.sleep(10) + else: + await asyncio.sleep(30) + + +def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + start_parent_check_thread() + try: + visualserver = ProxyVisualManager(args=args) + except Exception as e: + logger.exception(str(e)) + raise e + + pipe_writer.send("init ok") + + def handle_exception(loop, context): + logger.exception(f"VisualServer Caught exception: {str(context)}") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(handle_exception) + asyncio.set_event_loop(loop) + loop.create_task(visualserver.loop_to_connect_remote_visual_server()) + loop.run_until_complete(visualserver.loop_for_netio_req()) + return diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py new file mode 100644 index 0000000000..27275c1e8c --- /dev/null +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -0,0 +1,209 @@ +import asyncio +import uvloop +import rpyc +import inspect +import setproctitle +import threading +import uuid +import pickle +import websockets +import socket +import sys +import os +import signal +import time +from lightllm.utils.net_utils import get_hostname_ip +from .objs import VIT_Obj +from typing import List +from lightllm.server.core.objs import StartArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from .model_infer import start_model_process, VisualModelRpcClient +from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain +from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data, free_shm +from .manager import VisualManager + + +logger = init_logger(__name__) + + +class VisualOnlyManager(rpyc.Service): + def __init__( + self, + args: StartArgs, + ): + self.args = args + self.model_weightdir = args.model_dir + self.vit_dp = args.visual_dp + assert self.vit_dp == 1 + self.vit_tp = args.visual_tp + # image 最大推理 batch size + self.infer_batch_size = args.visual_infer_batch_size + self.lock = asyncio.Lock() + + self.new_loop = asyncio.new_event_loop() + + def _event_loop(): + asyncio.set_event_loop(self.new_loop) + self.new_loop.run_forever() + + t = threading.Thread(target=_event_loop, daemon=True) + t.start() + + async def register_to_config_server_loop(self, args: StartArgs): + if args.host in ["127.0.0.1", "localhost"]: + logger.error("remote visual server must specify host ip, can not be localhost or 127.0.0.1") + # kill father process to trigger graceful exit, avoid orphan process + os.kill(os.getppid(), signal.SIGTERM) + sys.exit(-1) + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip=host_ip, port=args.visual_rpyc_port) + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") + + async def wait_to_model_ready(self): + + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.vit_attn_backend = init_vit_att_backend(index=0) + for dp_rank_id in range(self.vit_dp): + for tp_rank_id in range(self.vit_tp): + + rpc_model = await start_model_process() + self.model_rpcs[dp_rank_id].append(rpc_model) + + init_model_ret = [] + for dp_rank_id in range(self.vit_dp): # async init model process + for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] + kvargs = { + "weight_dir": self.model_weightdir, + "device_id": device_id, + "vit_tp": self.vit_tp, + "cache_port": None, # visual only 模式下不使用 embed cache + "tp_rank_id": tp_rank_id, + "dp_rank_id": dp_rank_id, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + "quant_type": self.args.vit_quant_type, + "quant_cfg": self.args.vit_quant_cfg, + "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, + } + init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) + await asyncio.gather(*init_model_ret) + return + + async def handle_images(self, images_need_infer: List[ImageItem]): + await VisualManager.handle_images(self, images_need_infer=images_need_infer) + + async def infer_images(self, dp_index: int, images, events): + await VisualManager.infer_images(self, dp_index=dp_index, images=images, events=events) + + def clean_up(self): + return + + def exposed_remote_infer_images(self, images: List[ImageItem], ref_event: threading.Event): + try: + images = obtain(images) + logger.info(f"Received infer_images request with {len(images)} images md5s: {[img.md5 for img in images]}") + start = time.time() + # 将 images 的内容写入到 shm 中,这里修改了原始的uuid,主要是在远端的vit + # 本身不具有 embed cache 的引用保证,则新的唯一标识来进行推理,最终写入的 + # 目标的 md5 一致即可,这样调用端一样可以拿到准确的数据。 + for image in images: + image.uuid = str(uuid.uuid4()) + create_shm(get_shm_name_data(image.uuid), image.data_bytes) + del image.data_bytes + + handle = asyncio.run_coroutine_threadsafe(self.handle_images(images_need_infer=images), loop=self.new_loop) + + def _callback(fut): + if time.time() - start > 0.05: + logger.info( + f"Finished infer_images for images {[image.md5 for image in images]}" + f" with latency {time.time() - start} seconds" + ) + ref_event.set() + # 清理资源 + for image in images: + free_shm(get_shm_name_data(image.uuid)) + logger.info( + f"Finished infer_images request for images {[image.md5 for image in images]}" + f" and cleaned up shm resources" + ) + + handle.add_done_callback(_callback) + except BaseException as e: + logger.exception(str(e)) + # 清理资源 + for image in images: + free_shm(get_shm_name_data(image.uuid)) + raise e + return + + +def start_visual_process(args: StartArgs, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + start_parent_check_thread() + + try: + visualserver = VisualOnlyManager(args=args) + + def handle_exception(loop, context): + logger.exception(f"VisualServer Caught exception: {str(context)}") + + visualserver.new_loop.set_exception_handler(handle_exception) + + future = asyncio.run_coroutine_threadsafe(visualserver.wait_to_model_ready(), loop=visualserver.new_loop) + future.result() + + asyncio.run_coroutine_threadsafe( + visualserver.register_to_config_server_loop(args=args), loop=visualserver.new_loop + ) + + from .objs import rpyc_config + + t = rpyc.ThreadedServer(visualserver, port=args.visual_rpyc_port, protocol_config=rpyc_config) + except Exception as e: + logger.exception(str(e)) + visualserver.clean_up() + raise e + + pipe_writer.send("init ok") + + t.start() + return diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4c..bec02cd4cc 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,14 +55,29 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): + """ + # kvargs = { + # "weight_dir": self.model_weightdir, + # "device_id": device_id, + # "vit_tp": self.vit_tp, + # "cache_port": self.args.cache_port, + # "tp_rank_id": tp_rank_id, + # "dp_rank_id": dp_rank_id, + # "data_type": self.args.data_type, + # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + # "quant_type": self.args.vit_quant_type, + # "quant_cfg": self.args.vit_quant_cfg, + # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "vit_attn_backend": self.vit_attn_backend, + # } + """ tp_world_size = kvargs["vit_tp"] dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = kvargs["visual_gpu_ids"] - device_id = visual_gpu_ids[kvargs["vit_rank_id"]] + device_id = kvargs["device_id"] set_current_device_id(device_id) torch.cuda.set_device(device_id) dist.init_process_group( diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 0000000000..30b4ae6450 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,73 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + + config_server_host = args.config_server_host + redis_port = args.config_server_visual_redis_port + try: + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + + try: + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{config_server_host}", + "--daemonize", + "no", + "--logfile", + "/dev/stdout", + "--loglevel", + "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception as e: + logger.error(f"Error occurred while checking Redis service: {e}") + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None diff --git a/lightllm/utils/rpyc_fix_utils.py b/lightllm/utils/rpyc_fix_utils.py index 101b3938d8..9f7efe8d68 100644 --- a/lightllm/utils/rpyc_fix_utils.py +++ b/lightllm/utils/rpyc_fix_utils.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) -_BUFF_SIZE = 4 * 1024 * 1024 +_BUFF_SIZE = 8 * 1024 * 1024 def _fix_connect( diff --git a/requirements.txt b/requirements.txt index 5b0b201ae3..5331227586 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,3 +94,5 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +xformers==0.0.33.post2 +redis==7.3.0 diff --git a/test/acc/test_vit_sep_mode.sh b/test/acc/test_vit_sep_mode.sh new file mode 100644 index 0000000000..1ab73794a1 --- /dev/null +++ b/test/acc/test_vit_sep_mode.sh @@ -0,0 +1,46 @@ +# 安装 redis-server +sudo apt-get update +sudo apt-get install redis-server + +# 启动 config_server +python -m lightllm.server.api_server \ +--run_mode config_server \ +--config_server_host 0.0.0.0 \ +--config_server_port 8090 \ +--config_server_visual_redis_port 6000 + + +# 启动 visual_only 模式的推理服务, --visual_rpyc_port 8091 visual_only 模式需要设置这个参数,提供给其他服务调用本地视觉推理接口 +# --config_server_host 应该是启动config_server 服务的ip, 这里因为测试是在同一台机器上,所以是0.0.0.0。 +CUDA_VISIBLE_DEVICES=0 python -m lightllm.server.api_server \ +--run_mode visual_only \ +--host 0.0.0.0 \ +--config_server_host 0.0.0.0 \ +--config_server_port 8090 \ +--config_server_visual_redis_port 6000 \ +--model_dir /mtc/models/Qwen3-VL-8B-Instruct \ +--visual_dp 1 \ +--visual_tp 1 \ +--afs_image_embed_dir /mtc/afs/vit_embed_dir \ +--afs_embed_capacity 250000 \ +--visual_rpyc_port 8091 + + +# 启动 llm 推理服务,normal 模式 +CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--run_mode normal \ +--model_dir /mtc/models/Qwen3-VL-8B-Instruct \ +--tp 2 \ +--port 8089 \ +--config_server_host 0.0.0.0 \ +--config_server_port 8090 \ +--config_server_visual_redis_port 6000 \ +--visual_dp 1 \ +--afs_image_embed_dir /mtc/afs/vit_embed_dir \ +--afs_embed_capacity 250000 \ +--visual_use_proxy_mode + + + +# todo test +1. 将 afs_embed_capacity 设置为一个较小的值,比如 100,这样可以更快地测试替换逻辑。 \ No newline at end of file diff --git a/test/performance/multimodal_test.py b/test/performance/multimodal_test.py new file mode 100755 index 0000000000..07d9808677 --- /dev/null +++ b/test/performance/multimodal_test.py @@ -0,0 +1,827 @@ +import os +import argparse +import yaml +import requests +import json +import time +import random +import numpy as np +from tqdm import tqdm +from typing import Union, List, Tuple +from concurrent.futures import ThreadPoolExecutor +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +import aiohttp +import asyncio +from PIL import Image +import io +import base64 + + +def generate_random_image_and_encode_to_base64(width=448, height=448): + # Step 1: Generate a random image (RGB) + random_image = np.random.randint(100, 256, (height, width, 3), dtype=np.uint8) + + # Step 2: Convert NumPy array to PIL Image + image = Image.fromarray(random_image) + + # Step 3: Save the image to a BytesIO buffer + buffered = io.BytesIO() + image.save(buffered, format="bmp") + + # Step 4: Encode the image bytes to Base64 + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + + return base64_image + + +def seed_all(seed): + if not seed: + seed = int(time.time()) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + + +def get_tokenizer( + tokenizer_name: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + return tokenizer + + +def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]: + lens = [] + lens = np.random.randint( + max(int(length * range_ratio), 1), + length + 1, + size=reqs_num, + ) + return lens.tolist() + + +def gen_random_input_text(tokenizer, input_len) -> str: + random_ids = [random.randint(1, tokenizer.vocab_size) for _ in range(input_len)] + random_text = tokenizer.decode(random_ids) + return random_text + + +def build_image_placeholders(tokenizer_path: str, num_images: int) -> str: + if num_images <= 0: + return "" + + tokenizer_name = tokenizer_path.lower() + if "internvl" in tokenizer_name: + return "".join("\n" for _ in range(num_images)) + + return "".join("" for _ in range(num_images)) + + +def gen_random_data( + input_len: int, + output_len: int, + reqs_num: int, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + range_ratio: float, + num_images: int, +) -> Tuple[List[str], List[int], List[int]]: + prompts = [] + output_lens = get_random_length(reqs_num, output_len, range_ratio) + input_lens = get_random_length(reqs_num, input_len, range_ratio) + for i in range(reqs_num): + input_text = gen_random_input_text(tokenizer, input_lens[i]) + images = [] + for _ in range(num_images): + images.append({"type": "base64", "data": generate_random_image_and_encode_to_base64()}) + # input_text += "<|vision_start|><|image_pad|><|vision_end|>" + prompts.append((input_text, input_lens[i], images)) + print("Generate random data finish.") + return prompts, output_lens + + +def get_custom_input_data(data_path, output_len, tokenizer, range_ratio): + prompts = [] + with open(data_path, "r") as f: + for line in f.readlines(): + data_line = json.loads(line) + input_data = tokenizer.apply_chat_template( + data_line["messages"], add_generation_prompt=True, tokenize=False + ) + input_len = len(tokenizer.encode(input_data)) + prompts.append([input_data, input_len]) + output_lens = get_random_length(len(prompts), output_len, range_ratio) + print("Load random data finish.") + return prompts, output_lens + + +model_name = [""] + + +async def async_post_stream_openai(url, prompt, max_new_tokens, session): + input_len = 0 + try: + text_input, input_len, images = prompt + if images: + text_input = build_image_placeholders(model_name[-1], len(images)) + text_input + text_input = "a" + text_input + "<|im_start|>assistant\n" + content = [{"type": "text", "text": text_input}] + for img in images: + b64 = img["data"] + mime = "image/png" + content.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) + messages = [{"role": "user", "content": content}] + # print(messages) + data = { + # "model": "../InternVL2_5-26B", + "model": "/mnt/mtc/niushengxiao/251024_math_ocr_15b_v1.6.5", + "messages": messages, + "max_tokens": max_new_tokens, + "ignore_eos": True, + "stream": True, + "temperature": 0.0, + # "best_of": 1, + } + headers = {"Content-Type": "application/json"} + used_time = [] + start_time = time.time() + last_time = start_time + async with session.post(url, headers=headers, json=data) as response: + if response.status != 200: + return [], input_len + + while True: + line = await response.content.readline() + if not line: + break + line = line.strip() + if line: + current_time = time.time() + elapsed_time = current_time - last_time + used_time.append(elapsed_time) + last_time = current_time + return used_time, input_len + except Exception as e: + print(f"openai request failed: {repr(e)}") + return [], input_len + + +async def async_post_stream_lightllm(url, prompt, max_new_tokens, session): + input_len = 0 + try: + text_input, input_len, images = prompt + if images: + text_input = build_image_placeholders(model_name[-1], len(images)) + text_input + data = { + "inputs": text_input, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": max_new_tokens, + "add_special_tokens": False, + "return_details": True, + # "image_max_patch_num": 1 + }, + "multimodal_params": { + "images": images, + }, + } + headers = {"Content-Type": "application/json"} + used_time = [] + start_time = time.time() + last_time = start_time + async with session.post(url, headers=headers, json=data) as response: + if response.status != 200: + return [], input_len + + while True: + line = await response.content.readline() + if not line: + break + if line and line.startswith(b"data:"): + current_time = time.time() + elapsed_time = current_time - last_time + used_time.append(elapsed_time) + last_time = current_time + line = json.loads(line[5:].strip()) + input_len = int(line["token"]["prompt_tokens"]) + return used_time, input_len + except Exception as e: + print(f"lightllm request failed: {repr(e)}") + return [], input_len + + +async def continuous_sender( + session, + pending_tasks, + async_task, + url, + prompts, + max_new_tokens, + request_queue, + stop_send, + sent_count, + input_qps, + max_count, + continuous_send, +): + prompt_index = 0 + while not stop_send.is_set(): + if not continuous_send and sent_count[0] >= max_count: + break + prompt = prompts[prompt_index % len(prompts)] + max_tokens = max_new_tokens[prompt_index % len(max_new_tokens)] + + task = asyncio.create_task(async_task(url, prompt, max_tokens, session)) + pending_tasks.append(task) + await request_queue.put(task) + + prompt_index += 1 + sent_count[0] += 1 + # 控制发送速率 + await asyncio.sleep(1.0 / input_qps) + + +async def response_collector( + request_queue, + results, + reqs_num, + stop_event, + stop_send, + counter, + end_time, + sent_count, + force_terminate, + pending_tasks, +): + try: + while True: + try: + task = await asyncio.wait_for(request_queue.get(), timeout=1.0) + task_result = await task + request_queue.task_done() + if task_result is None: + raise ValueError("task returned None") + + result, input_len = task_result + if result is None: + result = [] + + if len(result) >= 1 and not stop_send.is_set(): + results.append((result, input_len)) + current_count = counter[0] + 1 + counter[0] = current_count + print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="") + if len(results) >= reqs_num and not stop_send.is_set(): + end_time[0] = time.time() + print("\nReached target number of responses") + stop_send.set() + if force_terminate and not stop_event.is_set(): + stop_event.set() + else: + print("\nWaiting remining responses to finish...") + + if current_count >= sent_count[0] and not stop_event.is_set(): + stop_event.set() + + if stop_event.is_set() and (force_terminate or request_queue.empty()): + return + + except asyncio.TimeoutError: + if stop_event.is_set() and (force_terminate or request_queue.empty()): + return + continue + except Exception as e: + print(f"\nError collecting response: {e}") + finally: + if force_terminate: + for task in pending_tasks: + if not task.done(): + task.cancel() + + +async def run_continuous_benchmark( + async_task, url, prompts, max_new_tokens, reqs_num, num_clients, input_qps, force_terminate, continuous_send +): + request_queue = asyncio.Queue() + stop_event = asyncio.Event() + stop_send = asyncio.Event() + results_data = [] + counter = [0] + sent_count = [0] + end_time = [0.0] + pending_tasks = [] + + timeout = aiohttp.ClientTimeout( + total=3600, # 总超时时间1小时 + connect=300, # 连接超时5分钟 + sock_connect=300, + sock_read=3600, + ) + + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=10 * reqs_num), + timeout=timeout, + ) as session: + sender_task = asyncio.create_task( + continuous_sender( + session, + pending_tasks, + async_task, + url, + prompts, + max_new_tokens, + request_queue, + stop_send, + sent_count, + input_qps, + reqs_num, + continuous_send, + ) + ) + + collector_task = [ + asyncio.create_task( + response_collector( + request_queue, + results_data, + reqs_num, + stop_event, + stop_send, + counter, + end_time, + sent_count, + force_terminate, + pending_tasks, + ) + ) + for _ in range(num_clients) + ] + await asyncio.wait(collector_task) + + if not sender_task.done(): + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + return results_data, sent_count[0], end_time[0] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--url", + type=str, + default="http://localhost:18009/generate_stream", + help="lightllm:http://127.0.0.1:18007/generate_stream, openai:http://127.0.0.1:18007/v1/completions", + ) + parser.add_argument("--num_clients", type=int, default=100) + parser.add_argument( + "--tokenizer_path", + type=str, + default="/data_vqa/wangruohui/train_internvl_qwen3/RUN/qwen3_32B_vit_300m_mlp_vit_10k_sft/5000_hf", + ) + parser.add_argument("--data_path", type=str, default=None) + parser.add_argument("--input_num", type=int, default=200) + parser.add_argument("--input_qps", type=float, default=30.0) + parser.add_argument("--input_len", type=int, default=4096) + parser.add_argument("--output_len", type=int, default=1) + parser.add_argument("--server_api", type=str, default="lightllm") + parser.add_argument("--dump_file", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--range_ratio", type=float, default=1.0) + parser.add_argument( + "--force_terminate", + type=int, + default=0, + help="0: waiting all reqs return; 1: only waiting input_num reqs return", + ) + parser.add_argument( + "--continuous_send", + type=int, + default=0, + help="0: only send input_num reqs; 1: send continuously until receiving input_num reqs", + ) + parser.add_argument("--num_images", type=int, default=13) + args = parser.parse_args() + if args.dump_file and os.path.exists(args.dump_file): + # 读取并输出 JSON 内容 + with open(args.dump_file, "r") as json_file: + content = json.load(json_file) + print(json.dumps(content, indent=4)) + return + + assert args.tokenizer_path is not None + model_name.append(args.tokenizer_path) + seed_all(args.seed) + url = args.url + tokenizer = get_tokenizer(args.tokenizer_path) + if args.data_path is not None: + prompts, max_new_tokens = get_custom_input_data(args.data_path, args.output_len, tokenizer, args.range_ratio) + args.input_num = len(prompts) + else: + # qps发送模式发送请求的数量不固定,这里暂定为input_num的10倍 + prompts, max_new_tokens = gen_random_data( + args.input_len, + args.output_len, + args.input_num if not args.continuous_send else 10 * args.input_num, + tokenizer, + args.range_ratio, + num_images=args.num_images, + ) + + percentiles = [25, 50, 75, 90, 95, 99, 100] + if args.server_api == "lightllm": + async_post_stream = async_post_stream_lightllm + elif args.server_api == "openai": + async_post_stream = async_post_stream_openai + else: + raise Exception(f"Not support {args.server_api} server_api.") + + dump_dict = {} + dump_dict["backend"] = args.server_api + dump_dict["clients"] = args.num_clients + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + start_time = time.time() + results, sent_reqs, end_time = loop.run_until_complete( + run_continuous_benchmark( + async_post_stream, + url, + prompts, + max_new_tokens, + args.input_num, + args.num_clients, + args.input_qps, + args.force_terminate, + args.continuous_send, + ) + ) + loop.close() + print(len(results)) + first_token_time = [] + decode_token_time = [] + request_time = [] + final_output_lens = [] + valid_num = 0 + input_lens = [] + for result, input_len in results: + if len(result) > 1: # 统计至少decode出两个token的数据 + first_token_time.append(result[0]) + decode_token_time.append(sum(result[1:]) / len(result[1:])) + request_time.append(sum(result)) + final_output_lens.append(len(result)) + input_lens.append(input_len) + valid_num += 1 + else: + first_token_time.append(result[0]) + decode_token_time.append(0) # no decode + request_time.append(sum(result)) + final_output_lens.append(len(result)) + input_lens.append(input_len) + valid_num += 1 + + print( + f"\n\nvalid num = {valid_num}; all data num = {len(results)}; valid ratio = {valid_num * 1.0 / len(results)}\n" + ) + print(f"Total QPS: {valid_num / (end_time - start_time)}") + print(f"Sender QPS: {sent_reqs / (end_time - start_time)}") + print(f"Avg Input Length: {sum(input_lens) / len(input_lens)}") + print(f"Avg Output Length: {sum(final_output_lens) / len(final_output_lens)}") + print(f"Total Throughput: {(sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)} token/s") + print(f"Input Throughput: {sum(input_lens) / (end_time - start_time)} token/s") + print(f"Output Throughput: {sum(final_output_lens) / (end_time - start_time)} token/s") + print("-" * 10) + dump_dict["request_num"] = valid_num + dump_dict["Total QPS"] = valid_num / (end_time - start_time) + dump_dict["Sender QPS"] = sent_reqs / (end_time - start_time) + dump_dict["Avg Input Length"] = sum(input_lens) / len(input_lens) + dump_dict["Avg Output Length"] = sum(final_output_lens) / len(final_output_lens) + dump_dict["Total Throughput"] = (sum(input_lens) + sum(final_output_lens)) / (end_time - start_time) + dump_dict["Input Throughput"] = sum(input_lens) / (end_time - start_time) + dump_dict["Output Throughput"] = sum(final_output_lens) / (end_time - start_time) + + values = np.percentile(request_time, percentiles) + request_time_dict = {} + for percentile, value in zip(percentiles, values): + print(f"request_time P{percentile}: {value:.6f}s") + request_time_dict[f"P{percentile}"] = value + dump_dict["request_time"] = request_time_dict + print("-" * 10) + + first_token_time_dict = {} + values = np.percentile(first_token_time, percentiles) + for percentile, value in zip(percentiles, values): + print(f"first_token_time P{percentile}: {value:.6f}s") + first_token_time_dict[f"P{percentile}"] = value + dump_dict["first_token_time_dict"] = first_token_time_dict + print("-" * 10) + + decode_token_time_dict = {} + values = np.percentile(decode_token_time, percentiles) + for percentile, value in zip(percentiles, values): + print(f"decode_token_time P{percentile}: {value * 1000:.6f}ms") + decode_token_time_dict[f"P{percentile}"] = value * 1000 + dump_dict["decode_token_time_dict"] = decode_token_time_dict + print(dump_dict) + + if args.dump_file: + with open(args.dump_file, "w") as json_file: + json.dump(dump_dict, json_file, indent=4) + print(f"Results have been written to {args.dump_file}") + + +if __name__ == "__main__": + main() + +# import os +# import argparse +# import json +# import time +# import random +# import numpy as np +# from typing import Union, List, Tuple +# from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +# import aiohttp +# import asyncio +# from PIL import Image +# import io +# import base64 + + +# # -------------------- 工具函数 -------------------- + +# def seed_all(seed): +# if not seed: +# seed = int(time.time()) +# random.seed(seed) +# os.environ["PYTHONHASHSEED"] = str(seed) +# np.random.seed(seed) + + +# def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +# return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + + +# def generate_random_image_and_encode_to_base64(width=448, height=448): +# arr = np.random.randint(1, 256, (height, width, 3), dtype=np.uint8) +# img = Image.fromarray(arr) +# buf = io.BytesIO() +# img.save(buf, format="bmp") +# return base64.b64encode(buf.getvalue()).decode("utf-8") + + +# def encode_image_to_base64(image_path): +# with open(image_path, "rb") as f: +# return base64.b64encode(f.read()).decode("utf-8") + + +# def get_random_length(reqs_num, length, range_ratio): +# return np.random.randint(max(int(length * range_ratio), 1), length + 1, size=reqs_num).tolist() + + +# # -------------------- 数据生成 -------------------- + +# def gen_random_data(input_len, output_len, reqs_num, tokenizer, range_ratio, num_images): +# prompts = [] +# output_lens = get_random_length(reqs_num, output_len, range_ratio) +# input_lens = get_random_length(reqs_num, input_len, range_ratio) +# for i in range(reqs_num): +# random_ids = [random.randint(1, tokenizer.vocab_size - 1) for _ in range(input_lens[i])] +# text = tokenizer.decode(random_ids) +# images = [{"type": "base64", "data": generate_random_image_and_encode_to_base64()} for _ in range(num_images)] +# prompts.append((text, input_lens[i], images)) +# print("Generated random data.") +# return prompts, output_lens + + +# def get_prompts_from_json(json_path, tokenizer, output_len, range_ratio): +# with open(json_path, "r") as f: +# data = json.load(f) +# prompts = [] +# for item in data: +# input_data = tokenizer.apply_chat_template(item["messages"], add_generation_prompt=True, tokenize=False) +# input_len = len(tokenizer.encode(input_data)) +# images = [] +# for img_path in item.get("images", []): +# if os.path.exists(img_path): +# img_b64 = encode_image_to_base64(img_path) +# images.append({"type": "base64", "data": img_b64}) +# else: +# print(f"[Warning] Image not found: {img_path}") +# prompts.append((input_data, input_len, images)) +# output_lens = get_random_length(len(prompts), output_len, range_ratio) +# print(f"Loaded {len(prompts)} prompts from JSON file.") +# return prompts, output_lens + + +# # -------------------- 请求函数 -------------------- + +# async def async_post_stream_lightllm(url, prompt, max_new_tokens, session): +# try: +# text_input, input_len, images = prompt +# text_input += "<|vision_start|><|image_pad|><|vision_end|>\n<|im_start|>assistant\n" +# print(f"text_input is {text_input}") +# data = { +# "inputs": text_input, +# "parameters": { +# "do_sample": False, +# "ignore_eos": True, +# "max_new_tokens": max_new_tokens, +# "add_special_tokens": False, +# "return_details": True, +# }, +# "multimodal_params": {"images": images}, +# } +# headers = {"Content-Type": "application/json"} +# used_time = [] +# start = time.time() +# last = start +# async with session.post(url, headers=headers, json=data) as resp: +# if resp.status != 200: +# return [], input_len +# async for line in resp.content: +# if line and line.startswith(b"data:"): +# now = time.time() +# used_time.append(now - last) +# last = now +# try: +# line = json.loads(line[5:]) +# input_len = int(line["token"]["prompt_tokens"]) +# except Exception: +# pass +# return used_time, input_len +# except Exception as e: +# print(f"[Error] {e}") +# return [], 0 + + +# async def async_post_stream_openai(url, prompt, max_new_tokens, session): +# try: +# text_input, input_len, images = prompt +# text_input = "a" + text_input + "<|im_start|>assistant\n" +# content = [{"type": "text", "text": text_input}] +# for img in images: +# mime = "image/png" +# content.append({ +# "type": "image_url", +# "image_url": {"url": f"data:{mime};base64,{img['data']}"} +# }) +# messages = [{"role": "user", "content": content}] +# data = { +# "model": "test_model", +# "messages": messages, +# "max_tokens": max_new_tokens, +# "ignore_eos": True, +# "stream": True, +# "temperature": 0.0, +# } +# headers = {"Content-Type": "application/json"} +# used_time = [] +# start = time.time() +# last = start +# async with session.post(url, headers=headers, json=data) as resp: +# if resp.status != 200: +# return [], input_len +# async for line in resp.content: +# line = line.strip() +# if line: +# now = time.time() +# used_time.append(now - last) +# last = now +# return used_time, input_len +# except Exception as e: +# print(f"[Error] {e}") +# return [], 0 + + +# # -------------------- 并发控制 -------------------- + +# async def worker(semaphore, session, async_task, url, prompt, max_new_tokens, results): +# async with semaphore: +# res = await async_task(url, prompt, max_new_tokens, session) +# if res and len(res[0]) > 0: +# results.append(res) + + +# async def run_fixed_concurrency_benchmark(async_task, url, prompts, max_new_tokens, num_concurrent): +# timeout = aiohttp.ClientTimeout(total=3600) +# semaphore = asyncio.Semaphore(num_concurrent) +# results = [] +# start = time.time() +# async with aiohttp.ClientSession(timeout=timeout) as session: +# tasks = [ +# asyncio.create_task(worker(semaphore, session, async_task, url, p, max_new_tokens[i], results)) +# for i, p in enumerate(prompts) +# ] +# await asyncio.gather(*tasks) +# end = time.time() +# return results, start, end + + +# # -------------------- 主程序 -------------------- + +# def main(): +# parser = argparse.ArgumentParser() +# parser.add_argument("--url", type=str, default="http://localhost:18009/generate_stream") +# parser.add_argument("--num_concurrent", type=int, default=5, help="并发数量") +# parser.add_argument("--tokenizer_path", type=str, required=True) +# parser.add_argument("--data_path", type=str, default=None) +# parser.add_argument("--input_num", type=int, default=100) +# parser.add_argument("--input_len", type=int, default=4096) +# parser.add_argument("--output_len", type=int, default=1) +# parser.add_argument("--server_api", type=str, default="lightllm") +# parser.add_argument("--range_ratio", type=float, default=1.0) +# parser.add_argument("--num_images", type=int, default=2) +# parser.add_argument("--dump_file", type=str, default="") +# parser.add_argument("--seed", type=int, default=0) +# args = parser.parse_args() + +# seed_all(args.seed) +# tokenizer = get_tokenizer(args.tokenizer_path) + +# if args.data_path: +# prompts, max_new_tokens = get_prompts_from_json(args.data_path, tokenizer, args.output_len, args.range_ratio) +# else: +# prompts, max_new_tokens = gen_random_data( +# args.input_len, args.output_len, args.input_num, tokenizer, args.range_ratio, args.num_images +# ) + +# if args.server_api == "lightllm": +# async_task = async_post_stream_lightllm +# else: +# async_task = async_post_stream_openai + +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# results, start_time, end_time = loop.run_until_complete( +# run_fixed_concurrency_benchmark(async_task, args.url, prompts, max_new_tokens, args.num_concurrent) +# ) +# loop.close() + +# # ---------------- 统计部分(与原始保持一致) ---------------- +# percentiles = [25, 50, 75, 90, 95, 99, 100] +# first_token_time = [] +# decode_token_time = [] +# request_time = [] +# final_output_lens = [] +# input_lens = [] + +# for result, input_len in results: +# if len(result) > 1: +# first_token_time.append(result[0]) +# decode_token_time.append(sum(result[1:]) / len(result[1:])) +# else: +# first_token_time.append(result[0]) +# decode_token_time.append(0) +# request_time.append(sum(result)) +# final_output_lens.append(len(result)) +# input_lens.append(input_len) + +# valid_num = len(results) +# print(f"\nvalid num = {valid_num}; all data num = {len(prompts)}; valid ratio = {valid_num / len(prompts):.4f}") +# print(f"Total QPS: {valid_num / (end_time - start_time)}") +# print(f"Avg Input Length: {np.mean(input_lens):.2f}") +# print(f"Avg Output Length: {np.mean(final_output_lens):.2f}") +# print(f"Total Throughput: {(sum(input_lens) + sum(final_output_lens)) / (end_time - start_time):.2f} token/s") +# print(f"Input Throughput: {sum(input_lens) / (end_time - start_time):.2f} token/s") +# print(f"Output Throughput: {sum(final_output_lens) / (end_time - start_time):.2f} token/s") +# print("-" * 10) + +# dump_dict = { +# "backend": args.server_api, +# "clients": args.num_concurrent, +# "request_num": valid_num, +# "Total QPS": valid_num / (end_time - start_time), +# "Avg Input Length": np.mean(input_lens), +# "Avg Output Length": np.mean(final_output_lens), +# "Total Throughput": (sum(input_lens) + sum(final_output_lens)) / (end_time - start_time), +# "Input Throughput": sum(input_lens) / (end_time - start_time), +# "Output Throughput": sum(final_output_lens) / (end_time - start_time), +# } + +# # 各项延迟分位数 +# for name, arr, scale, unit in [ +# ("request_time", request_time, 1, "s"), +# ("first_token_time", first_token_time, 1, "s"), +# ("decode_token_time", decode_token_time, 1000, "ms"), +# ]: +# vals = np.percentile(arr, percentiles) +# d = {} +# for p, v in zip(percentiles, vals): +# print(f"{name} P{p}: {v * scale:.6f}{unit}") +# d[f"P{p}"] = v * scale +# dump_dict[name] = d +# print("-" * 10) + +# if args.dump_file: +# with open(args.dump_file, "w") as f: +# json.dump(dump_dict, f, indent=4) +# print(f"Results have been written to {args.dump_file}") + + +# if __name__ == "__main__": +# main() diff --git a/test/performance/test_vl.sh b/test/performance/test_vl.sh new file mode 100644 index 0000000000..72dba9eae3 --- /dev/null +++ b/test/performance/test_vl.sh @@ -0,0 +1 @@ +python ./multimodal_test.py --num_clients 2 --tokenizer_path /mtc/models/Qwen3-VL-8B-Instruct --input_num 500 --output_len 128 --num_images 1 --input_len 1024 --url http://localhost:8089/v1/chat/completions --server_api openai --input_qps 500 --seed 1234 \ No newline at end of file