diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1d36c72d0b..54e379b32d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -35,6 +35,7 @@ from lightllm.utils.infer_utils import post_empty_cache from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend +from lightllm.utils.profiler import GlobalPerfContext logger = init_logger(__name__) @@ -255,6 +256,7 @@ def _init_att_backend1(self): self.decode_att_backend1: BaseAttBackend = None return + @GlobalPerfContext.disable() def _init_cudagraph(self): self.graph = ( None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch) @@ -552,6 +554,7 @@ def _decode( @final def _context_forward(self, infer_state: InferStateInfo): + GlobalPerfContext.begin_with_sample_rate(sample_rate=1) run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 input_ids = infer_state.input_ids cuda_input_ids = input_ids @@ -602,10 +605,13 @@ def prefill_func(input_tensors, infer_state): # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 dist_group_manager.clear_deepep_buffer() + rank = torch.cuda.current_device() + GlobalPerfContext.finalize_async(marker=f"_context_forward_{rank}", save_jsonl=True, save_log=True) return model_output @final def _token_forward(self, infer_state: InferStateInfo): + GlobalPerfContext.begin_with_sample_rate(sample_rate=0.05) run_mode_index = 1 if self.enable_tpsp_mix_mode else 0 input_ids = infer_state.input_ids cuda_input_ids = input_ids @@ -632,6 +638,13 @@ def _token_forward(self, infer_state: InferStateInfo): if infer_state.is_cuda_graph: model_output.to_no_ref_tensor() + rank = torch.cuda.current_device() + GlobalPerfContext.finalize_async(marker=f"_token_forward_{rank}", save_jsonl=True, save_log=True) + + # import time + # ts = time.time() + # PerfCounterContext.finalize_print(marker=f"_token_forward_{rank}", save=f"_token_forward_{rank}.jsonl") + # print(f"PerfCounterContext.finalize_print took {(time.time() - ts) * 1000:.3f} ms") return model_output @torch.no_grad() diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dd29c9a833..957aae3f03 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -7,6 +7,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.profiler import GlobalPerfContext from .infer_struct import InferStateInfo @@ -45,7 +46,8 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") def can_run(self, batch_size, max_len_in_batch): - return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch + do_profile = GlobalPerfContext.cudagraph_helper(sample_rate=0.05) + return not do_profile and batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch def need_capture(self, batch_size): find_batch_size = self.find_closest_graph_batch_size(batch_size) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 646f998642..4804a25128 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -1,6 +1,8 @@ import os import torch import torch.distributed as dist + +from lightllm.utils.profiler import PerfCounter from ..transformer_layer_infer import TransformerLayerInfer from ...infer_struct import InferStateInfo from lightllm.distributed import all_reduce @@ -62,6 +64,7 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") + @PerfCounter(type="LAYER") def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -74,6 +77,7 @@ def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return o + @PerfCounter(type="LAYER") def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.context_attention_forward(input1, infer_state, layer_weight) @@ -88,6 +92,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings + @PerfCounter(type="LAYER") def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -98,6 +103,7 @@ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, l all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return o + @PerfCounter(type="LAYER") def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.token_attention_forward(input1, infer_state, layer_weight) @@ -111,6 +117,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings + @PerfCounter(type="LAYER") def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -121,6 +128,7 @@ def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_sta o = self._tpsp_get_o(o, infer_state, layer_weight) return o + @PerfCounter(type="LAYER") def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight) @@ -133,6 +141,7 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings + @PerfCounter(type="LAYER") def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -141,6 +150,7 @@ def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state o = self._tpsp_get_o(o, infer_state, layer_weight) return o + @PerfCounter(type="LAYER") def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embdings, infer_state, layer_weight) o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..8232bae1cc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -11,6 +11,7 @@ from lightllm.common.quantization.no_quant import NoQuantization from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger +from lightllm.utils.profiler import PerfCounter from .mm_slicer import SliceMixinTpl logger = init_logger(__name__) @@ -53,9 +54,11 @@ def __init__( self._create_weight() self.gen_weight_quant_param_names() + @PerfCounter(type="GEMM_OP") def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: + self.mm.record_shape(m=input_tensor.shape[0], k=input_tensor.shape[1], n=self.mm_param.weight.shape[1]) return self.quant_method.apply( input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias ) @@ -215,6 +218,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): def verify_load(self): return self.weight.load_ok + @PerfCounter(type="GEMM_OP") def bmm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py index 101d316937..da0368352b 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py @@ -5,6 +5,8 @@ import triton.language as tl from typing import Dict +from lightllm.utils.profiler import PerfCounter + @triton.jit def _fwd_kernel_ep_scatter_1( @@ -87,6 +89,7 @@ def _fwd_kernel_ep_scatter_2( tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) +@PerfCounter(type="COMM_OP") @torch.no_grad() def ep_scatter( recv_x: torch.Tensor, @@ -104,6 +107,7 @@ def ep_scatter( num_warps = 8 num_experts = num_recv_tokens_per_expert.shape[0] # 获取num_recv_tokens_per_expert的元素个数 hidden_size = recv_x.shape[1] + ep_scatter.record_shape(size=recv_x.element_size() * recv_x.numel(), hidden_size=hidden_size, num_experts=num_experts, recv_x_shape=recv_x.shape, output_tensor_shape=output_tensor.shape) # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) grid = num_experts @@ -194,6 +198,7 @@ def _fwd_kernel_ep_gather( ) +@PerfCounter(type="COMM_OP") @torch.no_grad() def ep_gather( input_tensor: torch.Tensor, @@ -206,6 +211,7 @@ def ep_gather( num_warps = 2 num_tokens = output_tensor.shape[0] hidden_size = input_tensor.shape[1] + ep_gather.record_shape(size=input_tensor.element_size() * input_tensor.numel(), hidden_size=hidden_size, num_tokens=num_tokens, input_tensor_shape=input_tensor.shape, output_tensor_shape=output_tensor.shape) assert hidden_size % BLOCK_D == 0 grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) _fwd_kernel_ep_gather[grid]( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 638abbd6ca..eeea92b651 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -23,6 +23,7 @@ import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple from lightllm.utils.log_utils import init_logger +from lightllm.utils.profiler import PerfCounter from lightllm.utils.vllm_utils import vllm_ops from lightllm.utils.device_utils import triton_support_tensor_descriptor from .moe_silu_and_mul import silu_and_mul_fwd @@ -267,6 +268,7 @@ def _get_moe_align_fused_configs(): ] +@PerfCounter("moe_align_fused", type="OTHER_OP") @autotune( kernel_name="moe_align_fused:v1", configs_gen_func=_get_moe_align_fused_configs, @@ -670,6 +672,7 @@ def _get_grouped_matmul_configs(): ] +@PerfCounter("grouped_matmul", type="GEMM_OP") @autotune( kernel_name="grouped_matmul:v1", configs_gen_func=_get_grouped_matmul_configs, @@ -716,6 +719,9 @@ def grouped_matmul( assert expert_to_weights.is_contiguous() assert expert_weights.is_contiguous() + m_total = int(expert_to_token_num.sum().item()) + grouped_matmul.record_shape(m=m_total, k=k, n=n) + # for deepseek_v3 block-wise quant block_size_n = 0 block_size_k = 0 @@ -1157,6 +1163,7 @@ def outplace_fused_experts_impl_fake( ) +@PerfCounter(type="BLOCK") def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..01d3cb31d2 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -19,6 +19,8 @@ from lightllm.common.triton_utils.autotuner import Autotuner import numpy as np +from lightllm.utils.profiler import PerfCounter + logger = init_logger(__name__) try: @@ -31,6 +33,7 @@ HAS_DEEPGEMM = False +@PerfCounter(type="BLOCK") def masked_group_gemm( recv_x: Tuple[torch.Tensor, torch.Tensor], masked_m: torch.Tensor, @@ -59,6 +62,7 @@ def masked_group_gemm( return gemm_out_b +@PerfCounter(type="BLOCK") def fused_experts_impl( hidden_states: torch.Tensor, # [M, K] w1: torch.Tensor, # [group, N, K] @@ -116,6 +120,8 @@ def fused_experts_impl( # recv_topk_idx [recive_num_tokens, topk_num] # recv_topk_weights [recive_num_tokens, topk_num] # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 + p = PerfCounter(type="COMM_OP", name="deepep_buffer.dispatch") + p.start() recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( (qinput_tensor, input_scale), topk_idx=topk_idx, @@ -129,6 +135,24 @@ def fused_experts_impl( allocate_on_comm_stream=False, expert_alignment=128, ) + # 单卡通信量:本卡实际接收的数据字节数(激活+scale+topk元数据) + _comm_size = ( + recv_x[0].element_size() * recv_x[0].numel() + + recv_x[1].element_size() * recv_x[1].numel() + + recv_topk_idx.element_size() * recv_topk_idx.numel() + + recv_topk_weights.element_size() * recv_topk_weights.numel() + ) + p.record_shape( + size=_comm_size, + hidden_size=K, + num_experts=num_experts, + topk_num=topk_idx.shape[1], + recv_x_shape_0=recv_x[0].shape, + recv_x_shape_1=recv_x[1].shape, + recv_topk_idx_shape=recv_topk_idx.shape, + recv_topk_weights_shape=recv_topk_weights.shape, + ) + p.stop() # scatter all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. @@ -198,6 +222,9 @@ def fused_experts_impl( _gemm_out_a, _silu_out = None, None # normal combine + p = PerfCounter(type="COMM_OP", name="deepep_buffer.combine") + p.record_shape(size=gather_out.element_size() * gather_out.numel(), hidden_size=K, num_experts=num_experts, topk_num=topk_idx.shape[1], gather_out_shape=gather_out.shape) + p.start() combined_x, _, event = buffer.combine( gather_out, handle, @@ -206,10 +233,13 @@ def fused_experts_impl( previous_event=previous_event, allocate_on_comm_stream=False, ) + p.stop() else: # low latency dispatch num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts) + p = PerfCounter(type="COMM_OP", name="deepep_buffer.low_latency_dispatch") + p.start() recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch( hidden_states, topk_idx, @@ -219,21 +249,69 @@ def fused_experts_impl( async_finish=False, return_recv_hook=False, ) + # 单卡通信量:本卡在低时延调度中接收的数据字节数(激活+scale) + _comm_size_ll = recv_x[0].element_size() * recv_x[0].numel() + recv_x[1].element_size() * recv_x[1].numel() + p.record_shape( + size=_comm_size_ll, + hidden_size=K, + num_experts=num_experts, + topk_num=topk_idx.shape[1], + recv_x_shape_0=recv_x[0].shape, + recv_x_shape_1=recv_x[1].shape, + masked_m_shape=masked_m.shape, + ) + p.stop() # deepgemm gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m) # low latency combine + p = PerfCounter(type="COMM_OP", name="deepep_buffer.low_latency_combine") + p.record_shape(size=gemm_out_b.element_size() * gemm_out_b.numel(), hidden_size=K, num_experts=num_experts, topk_num=topk_idx.shape[1], gemm_out_b_shape=gemm_out_b.shape) + p.start() combined_x, event_overlap, hook = buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=False ) + p.stop() return combined_x +@PerfCounter(type="GEMM_OP") def _deepgemm_grouped_fp8_nt_contiguous( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, m_indices: torch.Tensor, ): + # record GEMM shapes and FLOPs for contiguous layout + if _deepgemm_grouped_fp8_nt_contiguous.is_perf_counter_active(): + try: + E = int(w_tuple[0].shape[0]) + # input activations are [all_tokens, K_in] or [all_tokens, N/2] + K_in = int(input_tuple[0].shape[-1]) + w = w_tuple[0] + # determine output N_out from weight shape relative to K_in + if int(w.shape[2]) == K_in: + N_out = int(w.shape[1]) + elif int(w.shape[1]) == K_in: + N_out = int(w.shape[2]) + else: + N_out = int(out.shape[-1]) + # m_indices maps tokens -> expert id; count per expert + m_counts = torch.bincount(m_indices, minlength=E) + total_m = int(m_counts.sum().item()) + total_flops = int(2 * total_m * K_in * N_out) + _deepgemm_grouped_fp8_nt_contiguous.record_shape( + experts=E, + k=K_in, + n=N_out, + total_m=total_m, + # m_per_expert=m_counts.tolist(), + flops=total_flops, + input_dtype=str(input_tuple[0].dtype), + weight_dtype=str(w_tuple[0].dtype), + out_dtype=str(out.dtype), + ) + except Exception: + pass if HAS_DEEPGEMM: if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous"): return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tuple, w_tuple, out, m_indices) @@ -242,6 +320,7 @@ def _deepgemm_grouped_fp8_nt_contiguous( raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version") +@PerfCounter(type="GEMM_OP") def _deepgemm_grouped_fp8_nt_masked( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], @@ -249,6 +328,35 @@ def _deepgemm_grouped_fp8_nt_masked( masked_m: torch.Tensor, expected_m: int, ): + # record GEMM shapes and FLOPs for masked layout + if _deepgemm_grouped_fp8_nt_masked.is_perf_counter_active(): + try: + E = int(w_tuple[0].shape[0]) + # masked_m holds active rows per expert + total_m = int(masked_m.sum().item()) + K_in = int(input_tuple[0].shape[-1]) + w = w_tuple[0] + # determine output N_out from weight shape relative to K_in + if int(w.shape[2]) == K_in: + N_out = int(w.shape[1]) + elif int(w.shape[1]) == K_in: + N_out = int(w.shape[2]) + else: + N_out = int(out.shape[-1]) + total_flops = int(2 * total_m * K_in * N_out) + _deepgemm_grouped_fp8_nt_masked.record_shape( + experts=E, + k=K_in, + n=N_out, + total_m=total_m, + expected_m=int(expected_m), + flops=total_flops, + input_dtype=str(input_tuple[0].dtype), + weight_dtype=str(w_tuple[0].dtype), + out_dtype=str(out.dtype), + ) + except Exception: + pass if HAS_DEEPGEMM: if hasattr(deep_gemm, "m_grouped_fp8_gemm_nt_masked"): return deep_gemm.m_grouped_fp8_gemm_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py index fb0323cd4b..e9525fee2f 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py @@ -4,6 +4,8 @@ import triton.language as tl from triton.language.standard import _log2, sum, zeros_like +from lightllm.utils.profiler import PerfCounter + @triton.jit def _compare_and_swap(x, x_1, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr): @@ -202,6 +204,7 @@ def grouped_topk_kernel( return +@PerfCounter(type="OTHER_OP") def triton_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index d7bcc17743..93b19689f3 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -2,6 +2,8 @@ import triton import triton.language as tl + +from lightllm.utils.profiler import PerfCounter from lightllm.common.triton_utils.autotuner import autotune @@ -98,6 +100,7 @@ def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor): return {"N": input.shape[-1] // 2, "out_dtype": str(output.dtype)} +@PerfCounter(type="ACT_OP", name="silu_and_mul_fwd") @autotune( kernel_name="silu_and_mul_fwd:v1", configs_gen_func=_get_silu_and_mul_configs, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py index d2c44b2953..ce72fa1c4b 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -3,6 +3,8 @@ import triton import triton.language as tl +from lightllm.utils.profiler import PerfCounter + @triton.jit def _silu_and_mul_post_quant_kernel( @@ -65,6 +67,7 @@ def _silu_and_mul_post_quant_kernel( ) +@PerfCounter(type="ACT_OP") def silu_and_mul_masked_post_quant_fwd( input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor ): diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index e16351eec8..5bb8e189ab 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -2,6 +2,9 @@ import triton import triton.language as tl from typing import Dict + +from lightllm.utils.profiler import PerfCounter +from typing import Any, Callable, Dict, Optional, Tuple from lightllm.common.triton_utils.autotuner import autotune @@ -60,6 +63,7 @@ def _get_moe_sum_reduce_configs(): ] +@PerfCounter("moe_sum_reduce", type="OTHER_OP") @autotune( kernel_name="moe_sum_reduce:v1", configs_gen_func=_get_moe_sum_reduce_configs, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 59d1f825a3..e1de785add 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -19,6 +19,7 @@ import os import torch +from lightllm.utils.profiler import PerfCounter from lightllm.utils.sgl_utils import sgl_ops from lightllm.utils.light_utils import light_ops from typing import Callable, List, Optional, Tuple @@ -28,6 +29,7 @@ use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"] +@PerfCounter(type="OTHER_OP") def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -128,6 +130,7 @@ def biased_grouped_topk( # This is used by the Deepseek-V2 model +@PerfCounter(type="OTHER_OP") def cuda_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -165,6 +168,7 @@ def cuda_grouped_topk( return topk_weights, topk_indices +@PerfCounter(type="BLOCK") def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -218,6 +222,8 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize ) else: + p = PerfCounter(type="OTHER_OP", name="custom_routing_function") + p.start() topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize ) diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index ca8f9a1c81..d516d4d10f 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -4,6 +4,8 @@ import triton.language as tl import os +from lightllm.utils.profiler import PerfCounter + rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8")) @@ -44,6 +46,7 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) +@PerfCounter(type="ACT_OP") def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None): # allocate output y = torch.empty_like(x) if out is None else out @@ -53,6 +56,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) assert x_arg.shape[-1] == weight.shape[0] and x_arg.shape == y_arg.shape assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape + rmsnorm_forward.record_shape(m=M, n=N) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py index 0a68372887..b7b5f439fe 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py @@ -3,6 +3,7 @@ import triton.language as tl from lightllm.common.kernel_config import KernelConfigs +from lightllm.utils.profiler import PerfCounter from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops from frozendict import frozendict from functools import lru_cache @@ -113,6 +114,7 @@ def lightllm_per_token_group_quant_fp8( return +@PerfCounter(type="QUANT_OP") def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, @@ -213,6 +215,7 @@ def _tma_align_input_scale_kernel( tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size) +@PerfCounter(type="QUANT_OP") def tma_align_input_scale(input_scale: torch.Tensor): assert input_scale.dim() == 2 m, k_div_block_size = input_scale.shape diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index 52d4e61da8..67ca71c48e 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -39,6 +39,7 @@ create_dp_special_inter_group, ) from lightllm.utils.device_utils import get_device_sm_count +from lightllm.utils.profiler import GlobalPerfContext, PerfCounter from lightllm.utils.sgl_utils import HAS_SGL_KERNEL from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL from contextlib import nullcontext, contextmanager @@ -183,6 +184,7 @@ def clear_deepep_buffer(self): self.ep_buffer.clean_low_latency_buffer(self.ll_num_tokens, self.ll_hidden, self.ll_num_experts) +@PerfCounter(type="COMM_OP") def all_reduce( input_: torch.Tensor, group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, @@ -191,12 +193,14 @@ def all_reduce( ) -> None: if _is_single_group(group=group): return + all_reduce.record_shape(size=input_.element_size() * input_.nelement()) if isinstance(group, CustomProcessGroup): return group.all_reduce(input_) else: return dist.all_reduce(input_, op, group, async_op) +@PerfCounter(type="COMM_OP") def all_gather_into_tensor( output_: torch.Tensor, input_: torch.Tensor, @@ -206,12 +210,14 @@ def all_gather_into_tensor( if _is_single_group(group=group): output_.copy_(input_) return + all_gather_into_tensor.record_shape(size=input_.element_size() * input_.nelement()) if isinstance(group, CustomProcessGroup): return group.all_gather_into_tensor(output_, input_) else: return dist.all_gather_into_tensor(output_, input_, group, async_op) +@PerfCounter(type="COMM_OP") def all_gather( output_: List[torch.Tensor], input_: torch.Tensor, @@ -223,12 +229,14 @@ def all_gather( output_[0].copy_(input_) return # todo 目前还没有定制算子的支持。 + all_gather.record_shape(size=input_.element_size() * input_.nelement()) if isinstance(group, CustomProcessGroup): return dist.all_gather(output_, input_, group.device_group, async_op) else: return dist.all_gather(output_, input_, group, async_op) +@PerfCounter(type="COMM_OP") def reduce_scatter_tensor( output: torch.Tensor, input: torch.Tensor, @@ -240,6 +248,7 @@ def reduce_scatter_tensor( output.copy_(input) return # 目前还没有定制算子实现。 + reduce_scatter_tensor.record_shape(size=input.element_size() * input.nelement()) if isinstance(group, CustomProcessGroup): return dist.reduce_scatter_tensor(output, input, op=op, group=group.device_group, async_op=async_op) else: diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index dc6f10be59..7a79f92447 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -9,6 +9,7 @@ from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger +from lightllm.utils.profiler import PerfCounter logger = init_logger(__name__) @@ -37,6 +38,7 @@ def _bind_norm(self): self._ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self) return + @PerfCounter(type="BLOCK") def _context_attention_kernel( self, q: torch.Tensor, @@ -46,15 +48,18 @@ def _context_attention_kernel( ) -> torch.Tensor: _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - o_tensor = infer_state.prefill_att_state.prefill_att( - q=_q, - k=_k, - v=_v, - alloc_func=self.alloc_tensor, - ) + with PerfCounter(type="ATTN_OP", name=infer_state.prefill_att_state.prefill_att.__name__) as p: + p.record_shape(seqlen=q.size(0), bs=infer_state.batch_size) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + alloc_func=self.alloc_tensor, + ) o_tensor = o_tensor.view(q.shape) return o_tensor + @PerfCounter(type="BLOCK") def _token_attention_kernel( self, q: torch.Tensor, @@ -63,19 +68,24 @@ def _token_attention_kernel( ) -> torch.Tensor: _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + with PerfCounter(type="ATTN_OP", name=infer_state.decode_att_state.decode_att.__name__) as p: + p.record_shape(seqlen=q.size(0), bs=infer_state.batch_size) + o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) return o_tensor.view(q.shape) + @PerfCounter(type="BLOCK") def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: return layer_weight.att_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + @PerfCounter(type="BLOCK") def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: return layer_weight.ffn_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + @PerfCounter(type="BLOCK") def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -90,6 +100,7 @@ def _get_qkv( ) return q, cache_kv + @PerfCounter(type="BLOCK") def _tpsp_get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -117,6 +128,7 @@ def _tpsp_get_qkv( return q, cache_kv + @PerfCounter(type="BLOCK") def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -124,6 +136,7 @@ def _get_o( o_tensor = layer_weight.o_proj.mm(input) return o_tensor + @PerfCounter(type="BLOCK") def _tpsp_get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -152,6 +165,7 @@ def _tpsp_get_o( return o_tensor + @PerfCounter(type="BLOCK") def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight) -> torch.Tensor: input = input.view(-1, self.embed_dim_) up_gate_out = layer_weight.gate_up_proj.mm(input) @@ -163,6 +177,7 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTrans ffn1_out = None return ffn2_out + @PerfCounter(type="BLOCK") def _tpsp_ffn( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -205,6 +220,7 @@ def _tpsp_ffn( # gate_out, up_out = None, None # return ffn2_out + @PerfCounter(type="LAYER") def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, @@ -217,6 +233,7 @@ def overlap_tpsp_token_forward( input_embdings1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return input_embdings, input_embdings1 + @PerfCounter(type="LAYER") def overlap_tpsp_context_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py index c6d4f3010d..6ad197fe33 100755 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -3,6 +3,8 @@ import triton import triton.language as tl +from lightllm.utils.profiler import PerfCounter + @triton.jit def _rotary_kernel( @@ -115,11 +117,13 @@ def _rotary_kernel( return +@PerfCounter(type="OTHER_OP") @torch.no_grad() def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): total_len = q.shape[0] head_num_q, head_num_k = q.shape[1], k.shape[1] head_dim = int(q.shape[2] * partial_rotary_factor) + rotary_emb_fwd.record_shape(seqlen=total_len, q_heads=head_num_q, k_heads=head_num_k, head_dim=head_dim) assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 725b0cc02e..785a5381c5 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -5,6 +5,7 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.utils.log_utils import init_logger +from lightllm.utils.profiler import PerfCounter logger = init_logger(__name__) @@ -15,6 +16,7 @@ def __init__(self, layer_num, network_config): self.head_dim_ = network_config["head_dim"] return + @PerfCounter(type="BLOCK") def _get_qkv( self, input: torch.Tensor, diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 721893a4cd..d60dc0faa5 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -16,6 +16,7 @@ from lightllm.utils.dist_utils import get_global_world_size from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.profiler import PerfCounter logger = init_logger(__name__) @@ -54,6 +55,7 @@ def _bind_ffn(self): self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) self._tpsp_ffn = self._tpsp_ffn_tp + @PerfCounter(type="BLOCK") def _get_qkv( self, input: torch.Tensor, @@ -79,6 +81,7 @@ def _get_qkv( ) return q, cache_kv + @PerfCounter(type="BLOCK") def _tpsp_get_qkv( self, input: torch.Tensor, @@ -116,6 +119,7 @@ def _tpsp_get_qkv( return q, cache_kv + @PerfCounter(type="BLOCK") def _moe_ffn( self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: @@ -134,6 +138,7 @@ def _moe_ffn( ) return hidden_states.view(num_tokens, hidden_dim) + @PerfCounter(type="BLOCK") def _moe_ffn_edp( self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: @@ -156,11 +161,13 @@ def _moe_ffn_edp( ep_output = ep_output.view(token_num, hidden_dim) return ep_output + @PerfCounter(type="BLOCK") def _tpsp_ffn( self, input: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ): raise Exception("need bind to real impl") + @PerfCounter(type="BLOCK") def _tpsp_ffn_tp( self, input: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: @@ -186,6 +193,7 @@ def _tpsp_ffn_tp( ffn2_out = reduce_o_tensor return ffn2_out + @PerfCounter(type="BLOCK") def _tpsp_ffn_ep( self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight ) -> torch.Tensor: @@ -195,6 +203,7 @@ def _tpsp_ffn_ep( return ffn2_out + @PerfCounter(type="BLOCK") def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, @@ -307,6 +316,7 @@ def _1_hook_post(): return input_embdings, input_embdings1 + @PerfCounter(type="BLOCK") def overlap_tpsp_context_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index af39e5da7b..e381e40214 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -709,4 +709,19 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_profiling", + type=str, + choices=["torch_profiler", "nvtx"], + default=None, + help="""Enable profiler support. + This will expose '/profiler_start' and '/profiler_stop' API, + below profiling features will only be enabled in this range. + Options: + 'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace', + or set by 'LIGHTLLM_TRACE_DIR' env; + 'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System + (you should set it up by yourself). + A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..be0fda4712 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -337,6 +337,24 @@ async def kv_move_status(websocket: WebSocket): return +@app.get("/profiler_start") +async def profiler_start() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("start") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + +@app.get("/profiler_stop") +async def profiler_stop() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("stop") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e28e4c93ad..7294f1faeb 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,7 +13,7 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator +from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -35,6 +35,7 @@ from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.profiler import ProfilerCmd from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -698,6 +699,16 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def profiler_cmd(self, cmd: Literal["start", "stop"]): + receivers = [self.send_to_router] + if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal: + receivers.append(self.send_to_visual) + for receiver in receivers: + receiver.send_pyobj( + ProfilerCmd(cmd), + protocol=pickle.HIGHEST_PROTOCOL, + ) + async def recycle_resource_loop(self): pre_time_mark = time.time() diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 0d2705fab2..e34ac0a61b 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -26,6 +26,7 @@ from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock @@ -107,6 +108,9 @@ def __init__(self, args: StartArgs): else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) self.router_statics = RouterStatics(self.args) + + profiler_mode = args.enable_profiling + self.profiler = ProcessProfiler(mode=profiler_mode, name="lightllm-router") if profiler_mode else None return async def wait_to_model_ready(self): @@ -507,6 +511,16 @@ def _multinode_tp_generate_new_batch(self): raise e return + async def _profiler_cmd(self, cmd_obj: ProfilerCmd): + self.profiler.cmd(cmd_obj) + + cmd = ProfilerCmd(cmd=cmd_obj.cmd) + while not self.shm_reqs_io_buffer.is_empty(): + await asyncio.sleep(0.02) + + self.shm_reqs_io_buffer.write_obj([cmd]) + self.shm_reqs_io_buffer.set_ready() + async def _recv_new_reqs_and_schedule(self): if not hasattr(self, "recv_max_count"): self.recv_max_count = 64 @@ -514,9 +528,11 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) + elif isinstance(recv_req, ProfilerCmd): + await self._profiler_cmd(recv_req) else: assert False, f"Error Req Inf {recv_req}" diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 49a113b1ba..555700ce35 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -46,6 +46,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd class ModeBackend: @@ -242,6 +243,10 @@ def init_model(self, kvargs): if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) + prof_name = f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}" + prof_mode = self.args.enable_profiling + self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name, use_multi_thread=True) if prof_mode else None + # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) @@ -365,6 +370,10 @@ def _try_read_new_reqs(self): self._try_read_new_reqs_multinode_tp() else: self._try_read_new_reqs_normal() + + # on each loop thread + if self.profiler is not None: + self.profiler.multi_thread_helper() return def _try_read_new_reqs_normal(self): @@ -430,6 +439,8 @@ def _read_reqs_buffer_and_init_reqs(self): if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True + elif isinstance(obj, ProfilerCmd): + self.profiler.cmd(obj) else: assert False, f"error type {type(obj)}" if init_reqs: diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..273447ff57 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,7 +7,7 @@ import pickle import inspect import setproctitle -from typing import List +from typing import List, Union from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -19,6 +19,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd from rpyc.utils.classic import obtain @@ -60,6 +61,8 @@ def __init__( self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() + prof_mode = args.enable_profiling + self.profiler = ProcessProfiler(prof_mode, name="lightllm-visual_server") if prof_mode else None async def wait_to_model_ready(self): @@ -187,13 +190,21 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes | ProfilerCmd = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): logger.info( f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) self.waiting_reqs.append(recv_req) + elif isinstance(recv_req, ProfilerCmd): + self.profiler.cmd(recv_req) + tasks = [] + for dp in range(self.vit_dp): + for tp in range(self.vit_tp): + task = asyncio.create_task(self.model_rpcs[dp][tp].profiler_cmd(recv_req)) + tasks.append(task) + await asyncio.gather(*tasks) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 741707bf93..ef0f32072e 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -26,6 +26,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.utils.profiler import ProcessProfiler class VisualModelRpcServer(rpyc.Service): @@ -46,6 +47,10 @@ def exposed_init_model(self, kvargs): self.data_type = kvargs["data_type"] self.vit_attn_backend = kvargs["vit_attn_backend"] set_vit_att_backend(self.vit_attn_backend) + + prof_mode = get_env_start_args().enable_profiling + prof_name = f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}" + self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name) if prof_mode else None init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -136,6 +141,10 @@ def exposed_encode(self, images: List[ImageItem]): torch.cuda.current_stream().synchronize() return + def exposed_profiler_cmd(self, cmd_obj): + cmd_obj = obtain(cmd_obj) + self.profiler.cmd(cmd_obj) + class VisualModelRpcClient: def __init__(self, model_rpc, vit_tp, rpc_server_process=None): @@ -158,9 +167,11 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.model.init_model) self._encode = async_wrap(self.model.encode) + self._profiler_cmd = async_wrap(self.model.profiler_cmd) else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode + self._profiler_cmd = self.model.exposed_profiler_cmd return async def init_model(self, kvargs): @@ -178,6 +189,14 @@ async def encode(self, images: List[ImageItem]): else: return ans + async def profiler_cmd(self, cmd_obj): + ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj) + if self.use_rpc: + await ans + return + else: + return + def _init_env(port, device_id): # 注册graceful 退出的处理 diff --git a/lightllm/utils/profiler.py b/lightllm/utils/profiler.py new file mode 100644 index 0000000000..4dc1b977d5 --- /dev/null +++ b/lightllm/utils/profiler.py @@ -0,0 +1,698 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime +import json +import os +import random +import threading +import time +import traceback +from typing import IO, Any, DefaultDict, List, Literal, Optional, Tuple, get_args +import functools +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class GlobalPerfContext: + # static control flags + STATIC_DISABLED = True + LAZY_SAVE = True + + # runtime control flags (user control) + global_enabled = True + + cuda_lock = threading.Lock() + threadlocal_active_counter_stack: dict[int, list['PerfCounter']] = DefaultDict(list) + recorded_counters: list['PerfCounter'] = [] + _recorded_counters_lock = threading.Lock() + + _initialized = False + _program_disabled = False + _effective_enabled = False + _force_do_profile_once = False + _file_save_lock = threading.Lock() + _lazy_save_chunks: List[Tuple[str, str, list['PerfCounter'], datetime, bool, bool]] = [] # (filename, marker, counter_list, timestamp, save_jsonl, save_log) + _lazy_save_chunks_lock = threading.Lock() + + @classmethod + def _init(cls): + if cls._initialized: + raise RuntimeError("GlobalPerfContext already initialized") + cls._initialized = True + + if not cls.STATIC_DISABLED and cls.LAZY_SAVE: + def _lazy_saver_worker(): + # only trigger save when idle 25s+ + idle_seconds_threshold = 25 + while True: + time.sleep(1) + with cls._lazy_save_chunks_lock: + if not cls._lazy_save_chunks: + continue + # get last recorded timestamp + last_timestamp = cls._lazy_save_chunks[-1][3] + idle_time_sec = (datetime.now() - last_timestamp).total_seconds() + if idle_time_sec < idle_seconds_threshold: + continue + logger.debug(f"Lazy saver worker triggered after {idle_time_sec:.1f} s idle") + chunks_to_save = cls._lazy_save_chunks + cls._lazy_save_chunks = [] + + with cls._file_save_lock: + jsonl_files: dict[str, IO] = {} + log_files: dict[str, IO] = {} + try: + for filename, marker, counter_list, timestamp, save_jsonl, save_log in chunks_to_save: + if save_jsonl: + if filename not in jsonl_files: + logger.debug(f"open {filename}.jsonl: exist={os.path.exists(f'{filename}.jsonl')}") + if not os.path.exists(f"{filename}.jsonl"): + open(f"{filename}.jsonl", "w").close() + f = open(f"{filename}.jsonl", "r+") + f.seek(0, os.SEEK_END) + jsonl_files[filename] = f + logger.debug(f"create/append mode, offset={jsonl_files[filename].tell()}, inode={os.fstat(jsonl_files[filename].fileno()).st_ino}, fd={jsonl_files[filename].fileno()}") + f_jsonl = jsonl_files[filename] + cls._save_jsonl(f_jsonl, counter_list, marker, timestamp) + if save_log: + if filename not in log_files: + logger.debug(f"open {filename}.log: exist={os.path.exists(f'{filename}.log')}") + if not os.path.exists(f"{filename}.log"): + open(f"{filename}.log", "w").close() + f = open(f"{filename}.log", "r+") + f.seek(0, os.SEEK_END) + log_files[filename] = f + logger.debug(f"create/append mode, offset={log_files[filename].tell()}, inode={os.fstat(log_files[filename].fileno()).st_ino}, fd={log_files[filename].fileno()}") + f_log = log_files[filename] + cls._save_log(f_log, counter_list, marker, timestamp) + finally: + for n, f in jsonl_files.items(): + f.close() + logger.debug(f"closed {n}.jsonl") + for n, f in log_files.items(): + f.close() + logger.debug(f"closed {n}.log") + logger.debug(f"Lazy saver worker finished saving {len(chunks_to_save)} chunks") + + threading.Thread(target=_lazy_saver_worker, daemon=True).start() + + @classmethod + @contextmanager + def disable(cls): + """Disable profiling within this context manager.""" + cls._program_disabled = True + try: + yield + finally: + cls._program_disabled = False + + @classmethod + def _eligible(cls) -> bool: + return not cls.STATIC_DISABLED and cls.global_enabled and not cls._program_disabled + + @classmethod + def cudagraph_helper(cls, sample_rate: float = 1) -> bool: + """helper function to decide whether to do cudagraph profiling""" + if not cls._eligible(): + return False + do_profile = random.random() < sample_rate + cls._force_do_profile_once = do_profile + return do_profile + + @classmethod + def set_counter(cls, counter: 'PerfCounter') -> None: + tid = _get_thread_id() + with cls._recorded_counters_lock: + cls.threadlocal_active_counter_stack[tid].append(counter) + counter.depth = len(cls.threadlocal_active_counter_stack[tid]) - 1 + cls.recorded_counters.append(counter) + + @classmethod + def unset_counter(cls, counter: 'PerfCounter') -> None: + tid = _get_thread_id() + with cls._recorded_counters_lock: + stack = cls.threadlocal_active_counter_stack + if tid not in stack or not stack[tid] or stack[tid][-1] is not counter: + logger.error(f"Mismatched PerfCounter unset operation for {counter.name} (called in different thread?)") + return + + stack[tid].pop() + + @classmethod + def begin_with_sample_rate(cls, sample_rate: float = 1) -> None: + with cls._recorded_counters_lock: + cls.threadlocal_active_counter_stack = DefaultDict(list) + cls.recorded_counters = [] + not_capturing = not torch.cuda.is_current_stream_capturing() + sample_hit = random.random() < sample_rate + if cls._force_do_profile_once: + sample_hit = True + cls._force_do_profile_once = False + cls._effective_enabled = cls._eligible() and not_capturing and sample_hit + + @classmethod + def _finalize_counters(cls, stacks: dict[int, list['PerfCounter']], counters: list['PerfCounter']) -> None: + with cls.cuda_lock: + if any(stack for stack in stacks.values()): + logger.error("Some PerfCounters are still active during finalize, which will be ignored.") + logger.error(f"Active counters: {stacks}") + still_active_counter_set = {counter for stack in stacks.values() for counter in stack} + counters = [counter for counter in counters if counter not in still_active_counter_set] + + # found first es + first_event = counters[0].es + for counter in counters[1:]: + if first_event.elapsed_time(counter.es) < 0: + first_event = counter.es + + for counter in counters: + counter.ref_t_start(first_event) + + for counter in counters: + counter.finalize() + + @classmethod + def finalize(cls) -> Optional[list['PerfCounter']]: + with cls._recorded_counters_lock: + stacks = cls.threadlocal_active_counter_stack + cls.threadlocal_active_counter_stack = DefaultDict(list) + counter_list = cls.recorded_counters + cls.recorded_counters = [] + + if not counter_list: + return + if torch.cuda.is_current_stream_capturing(): + return + + torch.cuda.synchronize() + try: + cls._finalize_counters(stacks, counter_list) + except Exception as e: + logger.error(f"Error in finalizing counters: {e}") + raise e + return + + return counter_list + + @classmethod + def _get_perf_str(cls, counter: 'PerfCounter') -> str: + if counter.type == 'GEMM_OP' and all(k in counter.shapes for k in ("m", 'n', 'k')): + m, n, k = counter.shapes["m"], counter.shapes["n"], counter.shapes["k"] + tflops = 2 * m * n * k / (counter.t_elapsed_ms * 1e9) + return f", {tflops:.3f} TFLOPS" + if counter.type == 'GEMM_OP' and 'flops' in counter.shapes: + flops = counter.shapes['flops'] + tflops = flops / (counter.t_elapsed_ms * 1e9) + return f", {tflops:.3f} TFLOPS" + if counter.type == 'COMM_OP' and 'size' in counter.shapes: + num_bytes = counter.shapes['size'] + gbps = num_bytes / (counter.t_elapsed_ms * 1e6) + return f", {gbps:.3f} GB/s (pseudo)" + return '' + + @classmethod + def _save_jsonl(cls, f: IO, counter_list: list['PerfCounter'], marker: str, timestamp: datetime) -> None: + time_str = f" @ {timestamp.isoformat(timespec='milliseconds')}" + f.write(f"# PerfCounterContext.finalize {marker}{time_str}\n") + for counter in counter_list: + keys = ("name", "type", "shapes", "depth", "t_start_ms", "t_elapsed_ms", "t_start_cpu_timestamp") + json_str = json.dumps({key: getattr(counter, key) for key in keys} | {"marker": marker}) + _ = json.loads(json_str) # validate + f.write(json_str + '\n') + + @classmethod + def _save_log(cls, f: IO, counter_list: list['PerfCounter'], marker: str, timestamp: datetime) -> None: + time_str = f" @ {timestamp.isoformat(timespec='milliseconds')}" + f.write(f"PerfCounterContext.finalize [{marker}]{time_str}:\n") + for i in range(len(counter_list)): + counter = counter_list[i] + shape_str = ", ".join(f"{k}={v}" for k, v in counter.shapes.items()) if counter.shapes is not None else "" + tabs = ' ' * counter.depth + last_t = 0.0 + time_delta = '' + if i > 0: + if counter_list[i-1].depth < counter.depth: + # is child + last_t = counter_list[i-1].t_start_ms + time_delta = f"(+{counter.t_start_ms - last_t:.3f} ms)" + else: + # is next + last_t = counter_list[i-1].t_start_ms + counter_list[i-1].t_elapsed_ms + last_t_same_level = last_t + if counter_list[i-1].depth != counter.depth: + # find last same level + for j in range(i-2, -1, -1): + if counter_list[j].depth == counter.depth: + last_t_same_level = counter_list[j].t_start_ms + counter_list[j].t_elapsed_ms + break + time_delta = f"(+{counter.t_start_ms - last_t:.3f} ms / ~{counter.t_start_ms - last_t_same_level:.3f})" + else: + time_delta = f"(+{counter.t_start_ms - last_t:.3f} ms)" + line = f"{tabs}{counter.name}.{counter.type}({shape_str}): {time_delta} {counter.t_elapsed_ms:.3f} ms" + line += cls._get_perf_str(counter) + f.write(line + "\n") + f.write("\n") + + @classmethod + def finalize_async(cls, marker: str = "perf", save_jsonl: bool = False, save_log: bool = False, filename: Optional[str] = None) -> None: + main_t_start = time.time() + if filename is None: + filename = marker + + with cls._recorded_counters_lock: + stacks = cls.threadlocal_active_counter_stack + cls.threadlocal_active_counter_stack = DefaultDict(list) + counter_list = cls.recorded_counters + cls.recorded_counters = [] + + if not counter_list: + return + if not save_jsonl and not save_log: + return + if torch.cuda.is_current_stream_capturing(): + return + + torch.cuda.synchronize() + + def _worker(): + thread_t_start = time.time() + if torch.cuda.is_current_stream_capturing(): + return + try: + cls._finalize_counters(stacks, counter_list) + except Exception as e: + logger.error(f"Error in finalizing counters: {e}") + # raise e + return + + thread_t_finalize_end = time.time() + + if not cls.LAZY_SAVE: + with cls._file_save_lock: + if save_jsonl: + with open(f"{filename}.jsonl", "a") as f_jsonl: + cls._save_jsonl(f_jsonl, counter_list, marker, datetime.now()) + if save_log: + with open(f"{filename}.log", "a") as f_log: + cls._save_log(f_log, counter_list, marker, datetime.now()) + else: + with cls._lazy_save_chunks_lock: + cls._lazy_save_chunks.append((filename, marker, counter_list, datetime.now(), save_jsonl, save_log)) + + thread_t_end = time.time() + # logger.debug(f"PerfCounterContext.finalize_async worker took {thread_t_end - thread_t_start:.3f} s ({thread_t_finalize_end - thread_t_start:.3f} s in _finalize_counters)") + + threading.Thread(target=_worker).start() + main_t_end = time.time() + # logger.debug(f"PerfCounterContext.finalize_async main took {main_t_end - main_t_start:.3f} s") + + +_PerfType = Literal[ + 'GEMM_OP', 'ATTN_OP', 'COMM_OP', 'ACT_OP', 'QUANT_OP', 'OTHER_OP', + 'LAYER', 'BLOCK', 'MODEL', + '_OTHER' +] + +class PerfCounter: + __slot__ = ("name", "type", "shapes", "depth", + "_is_active", "_is_stopped", + "finalized", "t_start_ms", "t_start_cpu_timestamp", "t_elapsed_ms", + "es", "ee") + TYPES = get_args(_PerfType) + def __init__(self, name: Optional[str] = None, type: _PerfType = '_OTHER'): + self.name = name + self.type = type + self.shapes: dict[str, Any] = {} + self.depth = 0 + self._is_active = False + self._is_stopped = False + + self.finalized = False + self.t_start_ms = None + self.t_start_cpu_timestamp = None + self.t_elapsed_ms = None + + def __enter__(self) -> 'PerfCounter': + """context manager usage: + ``` + with PerfCounter(type="GEMM_OP", name="my_gemm") as p: + p.record_shape(m=..., n=..., k=...) + my_function(...) + ... + ``` + """ + if not GlobalPerfContext._effective_enabled: + return self + if torch.cuda.is_current_stream_capturing(): + return self + + if self._is_active or self._is_stopped: + raise RuntimeError("PerfCounter already in use, cannot restart") + self._is_active = True + self.t_start_cpu_timestamp = time.time() + self.es = torch.cuda.Event(enable_timing=True) + self.ee = torch.cuda.Event(enable_timing=True) + GlobalPerfContext.set_counter(self) + self.es.record() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not GlobalPerfContext._effective_enabled: + return + if torch.cuda.is_current_stream_capturing(): + return + + self.ee.record() + # lazy synchronize + GlobalPerfContext.unset_counter(self) + self._is_active = False + self._is_stopped = True + + def start(self): + self.__enter__() + + def stop(self): + self.__exit__(None, None, None) + + def is_active(self) -> bool: + return self._is_active + + def __call__(self, func): + """decorator usage: + ``` + @PerfCounter(type="GEMM_OP") + def my_function(...): + my_function.record_shape(m=..., n=..., k=...) + ``` + """ + # when used as a decorator, this object won't directly functional, + # instead, it acts as a factory that creates a new PerfCounter object for each function call + if hasattr(func, "_is_perf_counter_factory_wrapped") and func._is_perf_counter_factory_wrapped: + return func + + if not self.name and hasattr(func, "__name__"): + self.name = func.__name__ + @functools.wraps(func) + def wrapped_func(*args, **kwds): + counter_obj = PerfCounter(self.name, self.type) + wrapped_func.is_perf_counter_active = counter_obj.is_active + wrapped_func.record_shape = counter_obj.record_shape + wrapped_func._current_perf_counter = counter_obj + with counter_obj: + return func(*args, **kwds) + wrapped_func._is_perf_counter_factory_wrapped = True + wrapped_func._original_func = func + return wrapped_func + + def wrap_func(self, func): + """decorator usage: + ``` + my_function = PerfCounter(type="GEMM_OP").wrap_func(my_function) # must wrap before every call + my_function(...) + ``` + """ + # in .wrap_func() case, unwrap if already wrapped to use the new PerfCounter + if hasattr(func, "_is_perf_counter_factory_wrapped") and func._is_perf_counter_factory_wrapped: + logger.warning("Warning: wrong usage: .wrap_func() called on a function already wrapped by @PerfCounter decorator, unwrapping to use the new PerfCounter") + func = func._original_func + if hasattr(func, "_is_perf_counter_wrapped") and func._is_perf_counter_wrapped: + func = func._original_func + if not self.name and hasattr(func, "__name__"): + self.name = func.__name__ + @functools.wraps(func) + def wrapped_func(*args, **kwds): + with self: + return func(*args, **kwds) + wrapped_func.is_perf_counter_active = self.is_active + wrapped_func.record_shape = self.record_shape + wrapped_func._is_perf_counter_wrapped = True + wrapped_func._original_func = func + wrapped_func._current_perf_counter = self + return wrapped_func + + def finalize(self) -> None: + if self.finalized: + logger.error(f"PerfCounter{self.name} already finalized") + return + + self.finalized = True + self.t_elapsed_ms = self.es.elapsed_time(self.ee) + self.es = None + self.ee = None + + def ref_t_start(self, first_es: torch.cuda.Event): + self.t_start_ms = first_es.elapsed_time(self.es) + + def record_shape(self, **kwds: Any) -> None: + if self.shapes: + raise RuntimeError("PerfCounter shapes already recorded") + self.shapes = kwds + + +GlobalPerfContext._init() + + +if GlobalPerfContext.STATIC_DISABLED: + # disable all profiling + def _dummy_wrap(func): + try: + func.is_perf_counter_active = lambda : False + func.record_shape = lambda **kwds: None + return func + except Exception: + # some built-in functions may not allow setting attributes + def no_op(*args, **kwds): + return func(*args, **kwds) + no_op.is_perf_counter_active = lambda : False + no_op.record_shape = lambda **kwds: None + return no_op + class PerfCounter_noop: + def __init__(self, *args, **kwds): pass + def __enter__(self) -> 'PerfCounter_noop': return self + def __exit__(self, exc_type, exc_val, exc_tb): pass + def start(self): pass + def stop(self): pass + def is_active(self) -> bool: return False + def __call__(self, func): return _dummy_wrap(func) + def wrap_func(self, func): return _dummy_wrap(func) + def record_shape(self, **kwds: Any) -> None: pass + PerfCounter = PerfCounter_noop + +@dataclass +class ProfilerCmd: + cmd: Literal["start", "stop"] + + +def _get_thread_id() -> int: + # Get native thread ID (LWP) for correlation with system tools like htop/nsys + if hasattr(threading, "get_native_id"): + return threading.get_native_id() + return threading.get_ident() + + +class ProcessProfiler: + def __init__( + self, + mode: Literal["torch_profiler", "nvtx"], + name: Optional[str] = None, + use_multi_thread: bool = False, + torch_profiler_with_stack: bool = True, + ) -> None: + """ + Process Level Profiler Manager. + For multi-threading, set `use_multi_thread=True` + and call `.multi_thread_helper()` regularly in each worker thread. + """ + self.mode = mode + self.name = name or "unnamed" + self.use_multi_thread = use_multi_thread + self.torch_profiler_with_stack = torch_profiler_with_stack + + self.is_active: bool = False # Process-level logical state + self._threadlocal = threading.local() + + # make sure only one active torch.profiler per process + self._lock = threading.Lock() + self._process_torch_profiler_active_tid: int | None = None + + if self.mode == "torch_profiler": + self._trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace") + os.makedirs(self._trace_dir, exist_ok=True) + elif self.mode == "nvtx": + self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE" + else: + raise ValueError("invalid profiler mode") + + self._log_init_info() + + @property + def _local(self): + """Lazy initialization of thread-local storage.""" + if not hasattr(self._threadlocal, "initialized"): + self._threadlocal.initialized = True + self._threadlocal.is_active = False + self._threadlocal.profiler_obj = None + self._threadlocal.nvtx_range_id = None + return self._threadlocal + + def _log_init_info(self): + logger.warning("-" * 50) + logger.warning( + f"[pid={os.getpid()} tid={_get_thread_id()}] Profiler <{self.name}> initialized with mode: {self.mode}" + ) + if self.mode == "torch_profiler": + logger.warning( + "Profiler support for torch.profiler enabled (--enable_profiling=torch_profiler), " + "trace files will be saved to %s (change it with LIGHTLLM_TRACE_DIR env var)", + self._trace_dir, + ) + elif self.mode == "nvtx": + logger.warning( + "Profiler support for NVTX enabled (--enable_profiling=nvtx), toplevel NVTX mark is '%s'\n" + "you can use it with external profiling tools like NVIDIA Nsight Systems.", + self._nvtx_toplevel_mark, + ) + logger.warning( + "e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx " + "-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [other nsys options] " + "python -m lightllm.server.api_server --enable_profiling=nvtx [other lightllm options]", + self._nvtx_toplevel_mark, + ) + logger.warning("Use /profiler_start and /profiler_stop HTTP GET APIs to start/stop profiling") + logger.warning("DO NOT enable this feature in production environment") + logger.warning("-" * 50) + + def _torch_profiler_start(self) -> None: + with self._lock: + if self._process_torch_profiler_active_tid is not None: + return + self._process_torch_profiler_active_tid = _get_thread_id() + + torch.cuda.synchronize() + worker_name = f"{self.name}_tid{_get_thread_id()}" if self.use_multi_thread else self.name + + trace_handler = torch.profiler.tensorboard_trace_handler( + self._trace_dir, + worker_name=worker_name, + use_gzip=True, + ) + + p = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=None, + with_stack=self.torch_profiler_with_stack, + record_shapes=True, + on_trace_ready=trace_handler, + ) + + self._local.profiler_obj = p + p.start() + torch.cuda.synchronize() + + def _nvtx_start(self) -> None: + torch.cuda.synchronize() + self._local.nvtx_range_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark) + torch.cuda.synchronize() + + def _thread_start(self) -> None: + if self._local.is_active: + return + + try: + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Start Profiler.") + if self.mode == "torch_profiler": + self._torch_profiler_start() + elif self.mode == "nvtx": + self._nvtx_start() + + self._local.is_active = True + except Exception as e: + logger.error( + f"[{self.name} @ tid={_get_thread_id()}] Failed to start profiler in thread {_get_thread_id()}: {e}" + ) + traceback.print_exc() + # Reset state on failure to prevent infinite retry loops + self._local.is_active = False + + def _torch_profiler_stop(self) -> None: + if self._process_torch_profiler_active_tid != _get_thread_id(): + return + + torch.cuda.synchronize() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Saving trace (blocking)...") + try: + if self._local.profiler_obj: + self._local.profiler_obj.stop() + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Error stopping torch profiler: {e}") + finally: + self._local.profiler_obj = None # Explicitly release reference to allow GC + self._process_torch_profiler_active_tid = None + + torch.cuda.synchronize() + + def _nvtx_stop(self) -> None: + torch.cuda.synchronize() + if self._local.nvtx_range_id is not None: + torch.cuda.nvtx.range_end(self._local.nvtx_range_id) + self._local.nvtx_range_id = None + torch.cuda.synchronize() + + def _thread_stop(self) -> None: + if not self._local.is_active: + return + + try: + if self.mode == "torch_profiler": + self._torch_profiler_stop() + elif self.mode == "nvtx": + self._nvtx_stop() + logger.info(f"[{self.name} @ tid={_get_thread_id()}] Profiler stopped.") + except Exception as e: + logger.error(f"[{self.name} @ tid={_get_thread_id()}] Failed to stop profiler: {e}") + finally: + # Mark inactive regardless of success to avoid repeated errors + self._local.is_active = False + + def start(self) -> None: + self.is_active = True + if not self.use_multi_thread: + self._thread_start() + + def stop(self) -> None: + self.is_active = False + if not self.use_multi_thread: + self._thread_stop() + + def multi_thread_helper(self) -> None: + """ + **only for multi-threading use cases** + Worker polling method. Must be called within the inference loop. + """ + if not self.use_multi_thread: + return + + # Catch-all to prevent profiler errors from crashing inference logic + try: + local_active = self._local.is_active + + if self.is_active and not local_active: + self._thread_start() + elif not self.is_active and local_active: + self._thread_stop() + except Exception: + pass + + def cmd(self, cmd_obj: ProfilerCmd) -> None: + if cmd_obj.cmd == "start": + self.start() + elif cmd_obj.cmd == "stop": + self.stop() + else: + raise ValueError(f"Invalid profiler cmd: {cmd_obj.cmd}") diff --git a/perf_jsonl_op_pct.py b/perf_jsonl_op_pct.py new file mode 100644 index 0000000000..fc54e4c205 --- /dev/null +++ b/perf_jsonl_op_pct.py @@ -0,0 +1,51 @@ +# 统计jsonl文件中每种OP的时间占比,默认跳过前两轮推理(warmup),可以通过第二个参数调整跳过的轮数 + +import sys, json +from typing import DefaultDict + +jsonl_file = sys.argv[1] +skip_layers = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + +print("===", jsonl_file) + +OMIT_FIRST_INFER = skip_layers + +op_times = DefaultDict(float) +layer_times = [] +with open(jsonl_file, "r") as f: + layer_cnt = 0 + for i, line in enumerate(f): + if line[0] == '#': + layer_cnt += 1 + if layer_cnt > OMIT_FIRST_INFER: + omit_first = i + break + + # Rewind to start reading again + f.seek(0) + for i, line in enumerate(f): + if i < omit_first: + continue + if line[0] == '#': + continue + # {"name": "mm", "type": "GEMM_OP", "shapes": {"m": 1024, "k": 8192, "n": 29568}, "depth": 2, "t_start_ms": 14.655648231506348, "t_elapsed_ms": 4.317984104156494, "marker": "_context_forward1"} + try: + data = json.loads(line) + except json.JSONDecodeError: + print(f"Skipping invalid JSON line ({i}):", line) + continue + if data["type"].endswith("_OP"): + op_times[data["type"]] += data["t_elapsed_ms"] + if data["type"] == "LAYER": + layer_times.append(data["t_elapsed_ms"]) + + +all_op_time = sum(op_times.values()) +print(f"Total time: {all_op_time} ms") +results = [f"{op_type}: {time / all_op_time * 100:.2f}%" for op_type, time in op_times.items()] +results.sort() +for result in results: + print(result) + +print("Average layer time:", sum(layer_times) / len(layer_times) if layer_times else 0) +print("Median layer time:", sorted(layer_times)[len(layer_times) // 2] if layer_times else 0)