Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,9 @@ def postprocess(self):
# It will hang when real batch_size < tp_size
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)

if self.routing_replay_config is not None and self.routing_replay_config.enable_routing_replay:
assert self.model_config.runner_type != "pooling", "Routing replay can only work with non-pooling models."

if ErnieArchitectures.is_ernie5_arch(self.model_config.architectures):
# ernie5 model not support chunked_mm_input
self.cache_config.disable_chunked_mm_input = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
num_experts = fd_config.model_config.moe_num_experts + fd_config.model_config.moe_num_shared_experts
self.routing_dtype = self.get_routing_dtype(num_experts=num_experts)
self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num)
self.pending_update_positions = None

# Initialize routing store wrapper
if self.tp_rank == 0:
Expand Down Expand Up @@ -397,7 +398,7 @@ def __init__(self, fd_config: False) -> None:
# Initialize task queue
moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.queue_max_size = moe_layer_num * max_num_seqs * 10
self.queue_max_size = moe_layer_num * max_num_seqs * 1000
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic number?


self.manager = multiprocessing.Manager()
self._task_queue = self.manager.Queue(maxsize=self.queue_max_size)
Expand Down
52 changes: 52 additions & 0 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
calculate_logits_entropy,
speculate_calculate_logits_entropy,
)
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
Expand Down Expand Up @@ -326,6 +329,7 @@ def post_process_normal(
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
):
"""Post-processing steps after completing a single token generation."""
if think_end_id > 0:
Expand Down Expand Up @@ -394,6 +398,21 @@ def post_process_normal(
if enable_entropy:
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)

# Routing replay
if routing_replay_manager is not None:
# Update host cache
slot_mapping = routing_replay_manager.compute_slot_mapping(
positions=routing_replay_manager.pending_update_positions
)
routing_replay_manager.update_host_cache(
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
)

# Put routing of finished requests to store
finished_batch_ids = paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id)[:, 0]
Comment thread
gongshaotian marked this conversation as resolved.
Outdated
context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder
routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens)

# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
Expand Down Expand Up @@ -465,6 +484,7 @@ def post_process_specualate(
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
):
if think_end_id > 0:
speculate_limit_thinking_content_length(
Expand Down Expand Up @@ -494,6 +514,29 @@ def post_process_specualate(
if enable_entropy:
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)

if routing_replay_manager is not None:
# Update host cache
slot_mapping = routing_replay_manager.compute_slot_mapping(
positions=routing_replay_manager.pending_update_positions
)
routing_replay_manager.update_host_cache(
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
)

# Put routing of finished requests to store
last_accept_token = paddle.full_like(model_output.accept_tokens, -1)
col_indices = paddle.arange(model_output.accept_tokens.shape[1], dtype=model_output.accept_num.dtype)
mask = col_indices < paddle.unsqueeze(model_output.accept_num, 1)
last_accept_token[mask] = model_output.accept_tokens[mask]
eos_tokens_flat = model_output.eos_token_id.flatten()
isin_mask = paddle.isin(last_accept_token, eos_tokens_flat)
finished_batch_ids = isin_mask.any(axis=-1)
context_lens = model_output.seq_lens_encoder + model_output.seq_lens_decoder
routing_replay_manager.put_finished_batch(
finished_batch_ids=finished_batch_ids,
seq_lens_decoder=context_lens,
)

speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
Expand Down Expand Up @@ -564,6 +607,7 @@ def post_process(
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
) -> None:
"""Post-processing steps after completing a single token generation."""

Expand All @@ -576,6 +620,7 @@ def post_process(
save_each_rank,
skip_save_output,
async_output_queue,
routing_replay_manager,
)
else:
if speculative_decoding:
Expand All @@ -589,6 +634,7 @@ def post_process(
think_end_id,
line_break_id,
enable_entropy,
routing_replay_manager,
)
else:
post_process_normal(
Expand All @@ -603,6 +649,7 @@ def post_process(
think_end_id,
line_break_id,
enable_entropy,
routing_replay_manager,
)


Expand Down Expand Up @@ -899,6 +946,7 @@ def post_process_pooling(
save_each_rank: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
routing_replay_manager: RoutingReplayManager = None,
) -> None:

paddle.assign(
Expand All @@ -916,6 +964,10 @@ def post_process_pooling(
model_output.stop_flags,
)

# Routing replay
if routing_replay_manager is not None:
raise NotImplementedError

with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
dummy_sampled_tokens = paddle.full_like(model_output.next_tokens, -1, dtype="int64")
Expand Down
19 changes: 3 additions & 16 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,6 @@ def _init_share_inputs(self, max_num_seqs: int):
self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_routing_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
Expand Down Expand Up @@ -2370,7 +2369,7 @@ class at the server level, which is too granular for ModelRunner.
self._prepare_inputs()
self.sampler.pre_process(p_done_idxs)
if self.fd_config.routing_replay_config.enable_routing_replay:
self.positions = self.routing_replay_manager.get_token_positions(
self.routing_replay_manager.pending_update_positions = self.routing_replay_manager.get_token_positions(
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.seq_lens_this_time_buffer,
)
Expand Down Expand Up @@ -2450,6 +2449,7 @@ class at the server level, which is too granular for ModelRunner.
skip_save_output=False,
async_output_queue=self.async_output_queue,
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
routing_replay_manager=self.routing_replay_manager,
)

return None
Expand Down Expand Up @@ -2579,6 +2579,7 @@ class at the server level, which is too granular for ModelRunner.
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
routing_replay_manager=self.routing_replay_manager,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids)
Expand Down Expand Up @@ -2626,20 +2627,6 @@ class at the server level, which is too granular for ModelRunner.
self.speculative_config.num_speculative_tokens,
)

# Routing replay
if self.fd_config.routing_replay_config.enable_routing_replay:
# Update host cache
slot_mapping = self.routing_replay_manager.compute_slot_mapping(positions=self.positions)
self.routing_replay_manager.update_host_cache(positions=self.positions, slot_mapping=slot_mapping)

# Put routing of finished requests to store
finished_batch_ids = paddle.isin(sampler_output.sampled_token_ids, self.share_inputs["eos_token_id"])[:, 0]
self.routing_replay_manager.put_finished_batch(
finished_batch_ids=finished_batch_ids,
seq_lens_decoder=self.seq_lens_routing_buffer,
)
paddle.assign(self.share_inputs["seq_lens_decoder"], self.seq_lens_routing_buffer)

return None

def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
Expand Down
Loading