Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
ca572c5
rebase
hiworldwzj Apr 2, 2026
4b1a990
refine
Mar 20, 2026
a2076bb
fix
Mar 23, 2026
66b945d
fix
hiworldwzj Mar 24, 2026
3cf35c8
rebase
hiworldwzj Apr 2, 2026
e76ccc5
fix
hiworldwzj Mar 24, 2026
7e353ce
fix
hiworldwzj Mar 24, 2026
51e7e48
fix
hiworldwzj Mar 24, 2026
58d8d11
fix
hiworldwzj Mar 24, 2026
49567c8
fix
hiworldwzj Mar 24, 2026
379c724
fix
hiworldwzj Mar 24, 2026
2661349
fix
hiworldwzj Mar 24, 2026
a3dab0c
fix
hiworldwzj Mar 24, 2026
d538295
fix
hiworldwzj Mar 26, 2026
f7f340c
fix
hiworldwzj Mar 26, 2026
4aaf75f
fix
hiworldwzj Mar 26, 2026
5782012
fix
hiworldwzj Mar 26, 2026
a87aca9
fix
hiworldwzj Mar 26, 2026
b29ebb4
fix
hiworldwzj Mar 26, 2026
09fd654
fix
hiworldwzj Mar 26, 2026
4ded90a
fix
hiworldwzj Mar 26, 2026
5e08075
fix
hiworldwzj Mar 26, 2026
047363b
fix
hiworldwzj Mar 27, 2026
dfa0c38
fix
hiworldwzj Mar 27, 2026
d4aa502
fix
hiworldwzj Mar 27, 2026
e28f06e
fix
hiworldwzj Mar 27, 2026
c81a5fa
fix
hiworldwzj Mar 27, 2026
a1b05c7
fix
hiworldwzj Mar 27, 2026
a055338
fix
hiworldwzj Mar 27, 2026
597e8c4
fix
hiworldwzj Mar 27, 2026
4d5f30e
fix
hiworldwzj Mar 27, 2026
7390d48
fix
hiworldwzj Mar 27, 2026
e5f9363
fix
hiworldwzj Mar 28, 2026
8d2a1ef
fix
hiworldwzj Mar 28, 2026
2495953
fix
hiworldwzj Mar 28, 2026
afda7b3
fix
hiworldwzj Mar 28, 2026
9722844
fix
hiworldwzj Mar 28, 2026
c71a6c4
fix
hiworldwzj Mar 28, 2026
1ad3faf
fix
hiworldwzj Mar 28, 2026
8674126
fix
hiworldwzj Mar 28, 2026
c99deed
fix
hiworldwzj Mar 28, 2026
085236f
fix
hiworldwzj Mar 28, 2026
d9e7aaa
fix
hiworldwzj Mar 28, 2026
d2dcb69
fix
hiworldwzj Mar 28, 2026
ab1a9fb
fix
hiworldwzj Mar 28, 2026
6188eb6
fix
hiworldwzj Mar 28, 2026
1a9f7eb
fix
hiworldwzj Mar 28, 2026
72b1793
fix
hiworldwzj Mar 28, 2026
381f1f1
fix
hiworldwzj Mar 28, 2026
53fa1cc
fix
hiworldwzj Mar 28, 2026
5380a25
fix
hiworldwzj Mar 28, 2026
1d6a295
fix
hiworldwzj Mar 28, 2026
ef569b2
fix
hiworldwzj Mar 28, 2026
943d3a5
fix
hiworldwzj Mar 28, 2026
45abe9b
fix
hiworldwzj Mar 28, 2026
d4798fc
rebase
hiworldwzj Apr 2, 2026
963590b
fix
hiworldwzj Mar 30, 2026
4bfee7c
fix
hiworldwzj Mar 30, 2026
0b072d7
fix
hiworldwzj Mar 30, 2026
704ef95
fix
hiworldwzj Mar 30, 2026
0cf2132
fix
hiworldwzj Mar 30, 2026
5f45c40
fix
hiworldwzj Mar 30, 2026
328a55e
fix
hiworldwzj Mar 30, 2026
377ea79
fix
hiworldwzj Mar 30, 2026
83d9201
fix
hiworldwzj Mar 30, 2026
51ab7a3
fix
hiworldwzj Mar 30, 2026
669179e
fix
hiworldwzj Mar 30, 2026
980d049
fix
hiworldwzj Mar 30, 2026
bbb05e0
fix
hiworldwzj Mar 30, 2026
1656c2e
fix
hiworldwzj Mar 30, 2026
ae850a5
fi
hiworldwzj Mar 30, 2026
43b042c
fix
hiworldwzj Mar 30, 2026
7e95468
fix
hiworldwzj Mar 30, 2026
9d08874
fix
hiworldwzj Mar 30, 2026
926c9fc
fix
hiworldwzj Mar 30, 2026
b0f1710
fix
hiworldwzj Mar 30, 2026
655c86f
fix
hiworldwzj Mar 30, 2026
7bff2e2
fix
hiworldwzj Mar 30, 2026
b1c8951
fix
hiworldwzj Mar 30, 2026
f8bd98f
fix
hiworldwzj Mar 30, 2026
6fadf14
fix
hiworldwzj Mar 30, 2026
f2c29e8
fix
hiworldwzj Mar 30, 2026
2772104
fix
hiworldwzj Mar 30, 2026
22e37b4
fix
hiworldwzj Mar 30, 2026
b9c416a
fix
hiworldwzj Mar 30, 2026
cf270fb
fix
hiworldwzj Mar 30, 2026
43f21f8
fix
hiworldwzj Mar 31, 2026
aa8f3ea
fix
hiworldwzj Mar 31, 2026
a53e5b5
fix
hiworldwzj Apr 1, 2026
930a1d7
fix
hiworldwzj Apr 1, 2026
387c7a8
fix
hiworldwzj Apr 1, 2026
233ee7c
fix
hiworldwzj Apr 1, 2026
44b3eff
fix
hiworldwzj Apr 1, 2026
a0028f8
fix
hiworldwzj Apr 1, 2026
f2adc2f
fix
hiworldwzj Apr 1, 2026
b5387e4
fix
hiworldwzj Apr 1, 2026
997cd2d
fix
hiworldwzj Apr 1, 2026
e4b6be9
fix
hiworldwzj Apr 1, 2026
03bbc42
fix
hiworldwzj Apr 1, 2026
49c37fd
fix
hiworldwzj Apr 1, 2026
25f877b
fix
hiworldwzj Apr 1, 2026
4124740
fix
hiworldwzj Apr 2, 2026
ff1472f
fix
hiworldwzj Apr 2, 2026
0a71a76
fix
hiworldwzj Apr 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lightllm/models/vit/triton_kernel/flashattention_nopad.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ def flash_attention_v3_fwd(
sm_margin=0,
sinks=None,
)

return
return o

except ImportError:
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
Expand Down
53 changes: 52 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"],
choices=[
"normal",
"prefill",
"decode",
"nixl_prefill",
"nixl_decode",
"pd_master",
"config_server",
"visual_only",
],
default="normal",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
Expand Down Expand Up @@ -61,6 +70,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="The port number for the config server in config_server mode.",
)
parser.add_argument(
"--config_server_visual_redis_port",
type=int,
default=None,
help="""when run_mode is config_server, set this params will start a redis server,
when a llm infer node start to set this params, the visual infer module will start a
proxy module use config server to find remote vit infer nodes to infer img""",
)
parser.add_argument(
"--nixl_pd_kv_page_num",
type=int,
Expand Down Expand Up @@ -458,6 +475,24 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502",
)
parser.add_argument(
"--visual_rpyc_port",
type=int,
default=None,
help="""
when run_mode is visual_only, set this port, make others to call local visual infer to
transfer image to embed.
""",
)
parser.add_argument(
"--visual_use_proxy_mode",
action="store_true",
help="""
when run_mode is normal, set this params,
will call remote visual infer to transfer image to embed,
need set --config_server_host, --config_server_port,
--config_server_visual_redis_port""",
)
parser.add_argument(
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)
Expand Down Expand Up @@ -629,6 +664,22 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=0.03,
help="""The interval of the schedule time, default is 30ms.""",
)
parser.add_argument(
"--afs_image_embed_dir",
type=str,
default=None,
help="path for vit embed, when use vit remote infer mode",
)
parser.add_argument(
"--afs_embed_capacity",
type=int,
default=250000,
help="""
capacity for vit embed in remote infer mode,
it control how many image can be cached in afs,
when the cache is full, the least recently used
image embed will be removed""",
)
parser.add_argument(
"--enable_cpu_cache",
action="store_true",
Expand Down
4 changes: 3 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
parser = make_argument_parser()
args = parser.parse_args()
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start

if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
elif args.run_mode in ["visual_only"]:
visual_only_start(args)
else:
normal_or_p_d_start(args)
117 changes: 99 additions & 18 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .router.manager import start_router_process
from lightllm.utils.process_check import is_process_active
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.redis_utils import start_redis_service
from lightllm.utils.shm_size_check import check_recommended_shm_size
from lightllm.utils.config_utils import has_audio_module, has_vision_module

Expand Down Expand Up @@ -57,7 +58,8 @@ def signal_handler(sig, frame):
signal.signal(signal.SIGINT, signal_handler)

logger.info(f"start process pid {os.getpid()}")
logger.info(f"http server pid {http_server_process.pid}")
if http_server_process:
logger.info(f"http server pid {http_server_process.pid}")
return


Expand All @@ -73,7 +75,7 @@ def normal_or_p_d_start(args):

enable_mps()

if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]:
if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual_only"]:
return

# 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块
Expand Down Expand Up @@ -174,6 +176,10 @@ def normal_or_p_d_start(args):
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0

if args.afs_image_embed_dir is not None:
os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(args.afs_image_embed_dir, 0o777)

# 检查GPU数量是否足够
if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
Expand Down Expand Up @@ -240,6 +246,8 @@ def normal_or_p_d_start(args):
already_uesd_ports.append(args.nccl_port)
if args.pd_decode_rpyc_port is not None:
already_uesd_ports.append(args.pd_decode_rpyc_port)
if args.visual_nccl_ports is not None:
already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp])

# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
# 捕获到端口设置冲突的问题
Expand All @@ -248,7 +256,8 @@ def normal_or_p_d_start(args):

node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp,
used_ports=already_uesd_ports,
)
logger.info(f"alloced ports: {can_use_ports}")
(
Expand All @@ -265,14 +274,14 @@ def normal_or_p_d_start(args):
) = can_use_ports[0:10]
can_use_ports = can_use_ports[10:]

visual_model_tp_ports = []
visual_nccl_ports = []
for _ in range(args.visual_dp):
tp_ports_for_dp = can_use_ports[0 : args.visual_tp]
visual_model_tp_ports.append(tp_ports_for_dp)
can_use_ports = can_use_ports[args.visual_tp :]
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]
if args.visual_nccl_ports is None:
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]

if args.visual_nccl_ports is not None:
visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]

# 将申请好的端口放入args参数中
if args.nccl_port is None:
Expand Down Expand Up @@ -323,16 +332,29 @@ def normal_or_p_d_start(args):
)

if not args.disable_vision:
from .visualserver.manager import start_visual_process

process_manager.start_submodule_processes(
start_funcs=[
start_visual_process,
],
start_args=[
(args, visual_model_tp_ports),
],
)
if not args.visual_use_proxy_mode:
from .visualserver.manager import start_visual_process

process_manager.start_submodule_processes(
start_funcs=[
start_visual_process,
],
start_args=[
(args,),
],
)
else:
from .visualserver.proxy_manager import start_visual_process

process_manager.start_submodule_processes(
start_funcs=[
start_visual_process,
],
start_args=[
(args,),
],
)

if not args.disable_audio:
from .audioserver.manager import start_audio_process
Expand Down Expand Up @@ -469,13 +491,72 @@ def pd_master_start(args):
http_server_process.wait()


def visual_only_start(args):
from lightllm.server.core.objs.start_args_type import StartArgs

args: StartArgs = args
if args.afs_image_embed_dir is not None:
os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(args.afs_image_embed_dir, 0o777)

already_uesd_ports = []
already_uesd_ports.append(args.visual_rpyc_port)
can_use_ports = alloc_can_use_network_port(
num=5 + args.visual_dp * args.visual_tp + args.visual_dp,
used_ports=already_uesd_ports,
)

if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
if args.visual_infer_batch_size is None:
args.visual_infer_batch_size = args.visual_dp
if args.data_type is None:
from lightllm.utils.config_utils import get_dtype

args.data_type = get_dtype(args.model_dir)
assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]

logger.info(f"alloced ports: {can_use_ports}")

args.visual_nccl_ports = can_use_ports[: args.visual_dp]
can_use_ports = can_use_ports[args.visual_dp :]
args.visual_node_id = uuid.uuid4().int

logger.info(f"all start args:{args}")

set_env_start_args(args)

from .visualserver.visual_only_manager import start_visual_process

process_manager.start_submodule_processes(
start_funcs=[
start_visual_process,
],
start_args=[
(args,),
],
)
setup_signal_handlers(None, process_manager)
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down...")
process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
sys.exit(0)


def config_server_start(args):
set_unique_server_name(args)
if args.run_mode != "config_server":
return

logger.info(f"all start args:{args}")

if args.config_server_visual_redis_port is not None:
start_redis_service(args)

set_env_start_args(args)

command = [
Expand Down
35 changes: 35 additions & 0 deletions lightllm/server/config_server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List
from fastapi.responses import JSONResponse
from lightllm.utils.log_utils import init_logger
from lightllm.server.visualserver.objs import VIT_Obj
from ..pd_io_struct import PD_Master_Obj
from .nccl_tcp_store import start_tcp_store_server
from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name
Expand All @@ -19,7 +20,9 @@
app = FastAPI()

registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
registered_visual_server_objs: Dict[str, VIT_Obj] = {}
registered_pd_master_obj_lock = Lock()
registered_visual_server_obj_lock = Lock()

global_req_id = 0
global_req_id_lock = Lock()
Expand Down Expand Up @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket):
return


@app.websocket("/visual_register")
async def visual_websocket_endpoint(websocket: WebSocket):
await websocket.accept()
client_ip, client_port = websocket.client
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Deserializing data with pickle from a network source is a security risk and can lead to arbitrary code execution. Even for internal services, it's safer to use a more secure serialization format like JSON if possible. If you must use pickle, ensure the communication channel is secure and authenticated.

logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}")
with registered_visual_server_obj_lock:
registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj

try:
while True:
data = await websocket.receive_text()
assert data == "heartbeat"
except (WebSocketDisconnect, Exception, RuntimeError) as e:
logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}")
logger.exception(str(e))
finally:
logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed")
with registered_visual_server_obj_lock:
registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None)
return


@app.get("/registered_objects")
async def get_registered_objects():
with registered_pd_master_obj_lock:
Expand All @@ -80,6 +107,14 @@ async def get_registered_objects():
return {"data": base64_encoded}


@app.get("/registered_visual_objects")
async def get_vit_registered_objects():
with registered_visual_server_obj_lock:
serialized_data = pickle.dumps(registered_visual_server_objs)
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
return {"data": base64_encoded}


@app.get("/allocate_global_unique_id_range")
async def allocate_global_id_range():
"""
Expand Down
Loading
Loading