Skip to content

Commit 6756621

Browse files
authored
[bugfix] fix multimodal mtp (#21)
1 parent 35e6d2a commit 6756621

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/mcore_bridge/model/gpt_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ def _postprocess(
399399
output_weight = None
400400
if self.share_embeddings_and_output_weights:
401401
output_weight = self.shared_embedding_or_output_weight()
402-
if self.config.is_multimodal and self.config.context_parallel_size > 1:
402+
if self.config.is_multimodal and self.config.context_parallel_size > 1 and input_ids is not None:
403+
# input_ids is required by MTP.
403404
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
404405

405406
if self.mtp_process and labels is not None:

0 commit comments

Comments
 (0)