forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocessing_chameleon.py
More file actions
170 lines (144 loc) · 8.21 KB
/
processing_chameleon.py
File metadata and controls
170 lines (144 loc) · 8.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# coding=utf-8
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Chameleon.
"""
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
class ChameleonTextKwargs(TextKwargs, total=False):
return_for_text_completion: bool
class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: ChameleonTextKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}
class ChameleonProcessor(ProcessorMixin):
r"""
Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
processor.
[`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
Args:
image_processor ([`ChameleonImageProcessor`]):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`]):
The tokenizer is a required input.
image_seq_length (`int`, *optional*, defaults to 1024):
Sequence length of one image embedding.
image_token (`str`, *optional*, defaults to `"<image>"`):
The special token used to indicate image in the text.
"""
attributes = ["image_processor", "tokenizer"]
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
image_processor_class = "ChameleonImageProcessor"
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
self.image_seq_length = image_seq_length
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_start_token = (
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
) # fixed tokens for start and end, so can hardcode
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
super().__init__(image_processor, tokenizer)
def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ChameleonProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
if text is None and images is None:
raise ValueError("You must provide either text or images")
output_kwargs = self._merge_kwargs(
ChameleonProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
# Replace the image token with the expanded image token sequence
prompt_strings = []
one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
for sample in text:
sample = sample.replace(self.image_token, one_img_tokens)
if not return_for_text_completion:
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
if images is not None:
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))