Skip to content

Commit c2225f5

Browse files
committed
add unittest
1 parent 2f77572 commit c2225f5

1 file changed

Lines changed: 237 additions & 0 deletions

File tree

tests/model_executor/test_thinking_budget.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from collections import OrderedDict
23
from dataclasses import asdict # Import asdict
34
from types import SimpleNamespace
45
from unittest.mock import MagicMock, patch # Import MagicMock
@@ -783,6 +784,242 @@ def test_common_engine_line_break_id_from_dict(self):
783784
with self.assertRaises(RuntimeError):
784785
common_engine_module.EngineService._start_worker_service(engine)
785786

787+
def test_text_encode_with_cache_branches(self):
788+
processor = TextDataProcessor.__new__(TextDataProcessor)
789+
processor._tokenize_cache = OrderedDict()
790+
processor._tokenize_cache_capacity = 1
791+
call_counter = {"np": 0, "iter": 0}
792+
793+
def _text2ids(text, max_model_len=None, add_special_tokens=False):
794+
if text == "np":
795+
call_counter["np"] += 1
796+
return np.array([11, 12], dtype=np.int64)
797+
call_counter["iter"] += 1
798+
return (v for v in [21, 22])
799+
800+
processor.text2ids = _text2ids
801+
802+
self.assertEqual(processor.encode_with_cache("np"), [11, 12])
803+
self.assertEqual(processor.encode_with_cache("np"), [11, 12])
804+
self.assertEqual(call_counter["np"], 1)
805+
self.assertEqual(processor.encode_with_cache("iter"), [21, 22])
806+
self.assertNotIn(("np", False), processor._tokenize_cache)
807+
808+
def test_v1_encode_with_cache_branches(self):
809+
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
810+
processor._tokenize_cache = OrderedDict()
811+
processor._tokenize_cache_capacity = 1
812+
call_counter = {"np": 0, "iter": 0}
813+
814+
def _text2ids(text, max_model_len=None, add_special_tokens=False):
815+
if text == "np":
816+
call_counter["np"] += 1
817+
return np.array([31, 32], dtype=np.int64)
818+
call_counter["iter"] += 1
819+
return (v for v in [41, 42])
820+
821+
processor.text2ids = _text2ids
822+
823+
self.assertEqual(processor.encode_with_cache("np"), [31, 32])
824+
self.assertEqual(processor.encode_with_cache("np"), [31, 32])
825+
self.assertEqual(call_counter["np"], 1)
826+
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
827+
self.assertNotIn(("np", False), processor._tokenize_cache)
828+
829+
def test_text_update_thinking_prompt_state_branches(self):
830+
processor = TextDataProcessor.__new__(TextDataProcessor)
831+
processor._think_token_ids = None
832+
processor.tokenizer = DummyTokenizerForTextProcessor()
833+
834+
self.assertEqual(processor._update_thinking_prompt_state([1], "not-dict"), "not-dict")
835+
self.assertEqual(
836+
processor._update_thinking_prompt_state([1], {"thinking_budget": -1}), {"thinking_budget": -1}
837+
)
838+
self.assertEqual(
839+
processor._update_thinking_prompt_state([1], {"thinking_budget": 1, "think_prompt_checked": True}),
840+
{"thinking_budget": 1, "think_prompt_checked": True},
841+
)
842+
self.assertEqual(processor._update_thinking_prompt_state(None, {"thinking_budget": 1}), {"thinking_budget": 1})
843+
self.assertEqual(processor._update_thinking_prompt_state([], {"thinking_budget": 1}), {"thinking_budget": 1})
844+
845+
processor.tokenizer = SimpleNamespace(get_vocab=lambda: {})
846+
self.assertEqual(processor._update_thinking_prompt_state([1], {"thinking_budget": 1}), {"thinking_budget": 1})
847+
848+
processor._think_token_ids = None
849+
processor.tokenizer = DummyTokenizerForTextProcessor()
850+
without_start = processor._update_thinking_prompt_state(
851+
[999, 998],
852+
{"thinking_budget": 1, "think_prompt_last_token_id": 777},
853+
)
854+
self.assertTrue(without_start["think_prompt_checked"])
855+
self.assertFalse(without_start["think_prompt_started"])
856+
self.assertNotIn("think_prompt_last_token_id", without_start)
857+
858+
with_start_no_end = processor._update_thinking_prompt_state(
859+
np.array([1, THINKING_START_TOKEN_ID, 2, 3], dtype=np.int64),
860+
{"thinking_budget": 4},
861+
)
862+
self.assertTrue(with_start_no_end["think_prompt_started"])
863+
self.assertFalse(with_start_no_end["think_prompt_ended"])
864+
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
865+
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
866+
867+
# 命中 _get_think_token_ids 的缓存分支
868+
self.assertEqual(processor._get_think_token_ids(), (THINKING_START_TOKEN_ID, THINKING_END_TOKEN_ID))
869+
870+
def test_v1_update_thinking_prompt_state_branches(self):
871+
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
872+
processor._think_token_ids = None
873+
processor.tokenizer = DummyTokenizerForTextProcessor()
874+
875+
self.assertEqual(processor._update_thinking_prompt_state([1], "not-dict"), "not-dict")
876+
self.assertEqual(
877+
processor._update_thinking_prompt_state([1], {"thinking_budget": -1}), {"thinking_budget": -1}
878+
)
879+
self.assertEqual(processor._update_thinking_prompt_state(None, {"thinking_budget": 1}), {"thinking_budget": 1})
880+
881+
with_start_no_end = processor._update_thinking_prompt_state(
882+
np.array([1, THINKING_START_TOKEN_ID, 2, 3], dtype=np.int64),
883+
{"thinking_budget": 4},
884+
)
885+
self.assertTrue(with_start_no_end["think_prompt_started"])
886+
self.assertFalse(with_start_no_end["think_prompt_ended"])
887+
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
888+
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
889+
890+
# 命中 _get_think_token_ids 的缓存分支
891+
self.assertEqual(processor._get_think_token_ids(), (THINKING_START_TOKEN_ID, THINKING_END_TOKEN_ID))
892+
893+
def test_text_process_request_think_stop_sentence(self):
894+
processor = TextDataProcessor.__new__(TextDataProcessor)
895+
processor._apply_default_parameters = lambda request: request
896+
processor.eos_token_ids = [1]
897+
processor.update_stop_seq = lambda *args, **kwargs: None
898+
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
899+
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [101, 102]
900+
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
901+
processor.reasoning_parser = None
902+
903+
request = DummyRequestV1(
904+
request_id="req_text",
905+
eos_token_ids=[1],
906+
prompt_token_ids=[8],
907+
prompt=None,
908+
messages=None,
909+
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
910+
bad_words=None,
911+
bad_words_token_ids=None,
912+
max_tokens=1,
913+
temperature=1.0,
914+
top_p=0.9,
915+
)
916+
with patch("fastdeploy.input.text_processor.process_stop_token_ids", lambda *args, **kwargs: None):
917+
processed = processor.process_request(request, max_model_len=16)
918+
self.assertEqual(
919+
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
920+
[23, 101, 102],
921+
)
922+
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
923+
924+
def test_text_process_request_dict_think_stop_sentence(self):
925+
processor = TextDataProcessor.__new__(TextDataProcessor)
926+
processor._apply_default_parameters = lambda request: request
927+
processor.eos_token_ids = [1]
928+
processor.update_stop_seq = lambda *args, **kwargs: None
929+
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
930+
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202]
931+
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
932+
processor.reasoning_parser = None
933+
934+
request = {
935+
"request_id": "req_text_dict",
936+
"eos_token_ids": [1],
937+
"prompt_token_ids": [9],
938+
"prompt": None,
939+
"messages": None,
940+
"bad_words": None,
941+
"bad_words_token_ids": None,
942+
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
943+
"max_tokens": 1,
944+
"temperature": 1.0,
945+
"top_p": 0.9,
946+
}
947+
with patch("fastdeploy.input.text_processor.process_stop_token_ids", lambda *args, **kwargs: None):
948+
processed = processor.process_request_dict(request, max_model_len=16)
949+
self.assertEqual(
950+
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
951+
[23, 201, 202],
952+
)
953+
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
954+
955+
def test_v1_process_request_think_stop_sentence(self):
956+
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
957+
processor._apply_default_parameters = lambda request: request
958+
processor.eos_token_ids = [1]
959+
processor.update_stop_seq = lambda *args, **kwargs: None
960+
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
961+
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302]
962+
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
963+
processor.reasoning_parser = None
964+
965+
request = DummyRequestV1(
966+
request_id="req_v1",
967+
eos_token_ids=[1],
968+
prompt_token_ids=[10],
969+
prompt=None,
970+
messages=None,
971+
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
972+
bad_words=None,
973+
bad_words_token_ids=None,
974+
max_tokens=1,
975+
temperature=1.0,
976+
top_p=0.9,
977+
)
978+
with patch("fastdeploy.input.v1.text_processor.process_stop_token_ids", lambda *args, **kwargs: None):
979+
processed = processor.process_request(request, max_model_len=16)
980+
self.assertEqual(
981+
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
982+
[23, 301, 302],
983+
)
984+
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
985+
986+
def test_v1_process_request_dict_think_stop_sentence(self):
987+
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
988+
processor._apply_default_parameters = lambda request: request
989+
processor.eos_token_ids = [1]
990+
processor.update_stop_seq = lambda *args, **kwargs: None
991+
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
992+
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402]
993+
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
994+
processor.reasoning_parser = None
995+
996+
request = DummyRequestV1(
997+
request_id="req_v1_dict",
998+
eos_token_ids=[1],
999+
prompt_token_ids=[11],
1000+
prompt=None,
1001+
messages=None,
1002+
chat_template_kwargs=None,
1003+
sampling_params=SimpleNamespace(
1004+
bad_words=None,
1005+
bad_words_token_ids=None,
1006+
max_tokens=1,
1007+
temperature=1.0,
1008+
top_p=0.9,
1009+
repetition_penalty=1.0,
1010+
frequency_penalty=0.0,
1011+
presence_penalty=0.0,
1012+
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
1013+
),
1014+
)
1015+
with patch("fastdeploy.input.v1.text_processor.process_stop_token_ids", lambda *args, **kwargs: None):
1016+
processed = processor.process_request_dict(request, max_model_len=16)
1017+
self.assertEqual(
1018+
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
1019+
[23, 401, 402],
1020+
)
1021+
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
1022+
7861023

7871024
if __name__ == "__main__":
7881025
unittest.main()

0 commit comments

Comments
 (0)