@@ -104,9 +104,15 @@ def __call__(
104104 self , batch_input_ids : torch .Tensor , batch_scores : torch .Tensor
105105 ) -> torch .Tensor :
106106 i_batch , _ = batch_input_ids .shape
107- s_batch , s_vocab = batch_scores .shape
107+ s_batch , _ = batch_scores .shape
108108 assert i_batch == s_batch
109- assert s_vocab == self .vocab_size
109+
110+ # s_batch, s_vocab = batch_scores.shape
111+ # assert s_vocab == self.vocab_size
112+ #
113+ # NOTE: somehow, this does not hold. s_vocab is not same as either of
114+ # * self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=True) == self.vocab_size == ll_tokenizer.vocab_size
115+ # * self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=False)
110116
111117 if self .batch_size != i_batch :
112118 self .batch_size = i_batch
@@ -232,6 +238,10 @@ def __init__(
232238 self ._llguidance_tokenizer : llguidance .LLTokenizer = (
233239 llguidance .hf .from_tokenizer (self ._tokenizer ) # type:ignore
234240 )
241+ assert (
242+ self ._llguidance_tokenizer .vocab_size
243+ == self ._tokenizer ._tokenizer .get_vocab_size (with_added_tokens = True )
244+ ), "vocab size mismatch between llguidance and huggingface tokenizers ... wtf?"
235245
236246 self ._use_caches = use_caches
237247 self ._cache = cache if cache is not None else SimpleLRUCache (3 )
0 commit comments