-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathrange.py
More file actions
173 lines (137 loc) · 6.15 KB
/
range.py
File metadata and controls
173 lines (137 loc) · 6.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import typing
import numpy as np
import torch
from fast_llm.config import Field, config_class
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import (
Batch,
MemmapReader,
MemmapReaderBase,
MemmapReaderBaseConfig,
MemmapReaderConfig,
MemmapWriter,
Sample,
)
from fast_llm.utils import Assert, get_unique
def crop_ranges(ranges: list[tuple[int, int]], begin: int, end: int) -> list[tuple[int, int]]:
cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in ranges)
return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]
class RangeSample(Sample):
"""
A reusable component holding a set of ranges in a sample.
"""
def __init__(self, ranges: list[tuple[int, int]], sample_size: int):
self.ranges = ranges
self.sample_size = sample_size
@classmethod
def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self:
"""
Used to merge ranges from multiple documents, i.e. when multiple docuemnts are packed together.
"""
document: RangeSample
ranges = []
sample_size = 0
for document in documents:
for begin, end in document.ranges:
ranges.append((begin + sample_size, end + sample_size))
sample_size += document.sample_size
return cls(ranges, sample_size)
def crop(self, begin: int, end: int) -> typing.Self:
return self.__class__(crop_ranges(self.ranges, begin, end), end - begin)
def __len__(self) -> int:
return self.sample_size
def get_padding(self, size: int) -> typing.Self:
return self.__class__([], size)
class RangeBatch(Batch):
def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int):
self.sample_size = sample_size
self.ranges = ranges
@classmethod
def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self:
return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples))
def crop(self, begin: int, end: int) -> typing.Self:
return self.__class__([crop_ranges(sample_ranges, begin, end) for sample_ranges in self.ranges], end - begin)
@config_class()
class RangeReaderBaseConfig(MemmapReaderBaseConfig):
_abstract = False
@config_class(dynamic_type={MemmapReaderBaseConfig: "range"})
class RangeReaderConfig(RangeReaderBaseConfig, MemmapReaderConfig):
header: typing.ClassVar[bytes] = b"range begin"
footer: typing.ClassVar[bytes] = b"range end"
num_documents: int = Field()
num_ranges: int = Field()
@property
def reader_class(self) -> "type[RangeReader]":
return RangeReader
@property
def writer_class(self) -> "type[RangeWriter]":
return RangeWriter
@property
def _expected_buffer_size(self) -> int:
return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize
def get_metadata(self) -> dict[str, typing.Any]:
return {
"num_documents": self.num_documents,
"num_ranges": self.num_ranges,
}
@classmethod
def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
return {
"num_documents": sum(metadata_["num_documents"] for metadata_ in metadata),
"num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata),
}
class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]):
def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None):
super().__init__(config, buffer, model_preprocessing)
self._ranges = torch.frombuffer(
self._buffer,
dtype=torch.int32,
count=self._config.num_ranges * 2,
).view(-1, 2)
self._count_cumsums = torch.frombuffer(
self._buffer,
dtype=torch.int32,
count=self._config.num_documents + 1,
offset=self._ranges.nbytes,
)
def get_document(self, index: int, begin: int, end: int) -> Sample:
sample_size = end - begin
cropped_ranges = (
(max(begin_ - begin, 0), min(end_ - begin, sample_size))
for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist()
)
return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size)
def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]:
Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents])
return {
"num_documents": end_index - begin_index,
"num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(),
}
class EmptyRangeReader[ConfigType: RangeReaderBaseConfig](MemmapReaderBase[ConfigType]):
def get_document(self, index: int, begin: int, end: int) -> Sample:
return RangeSample([], end - begin)
class RangeWriter(MemmapWriter):
def __enter__(self):
super().__enter__()
self._count_cumsum = [0]
return self
def write(self, document: RangeSample):
super().write(document)
self._stream.write(np.array(document.ranges, dtype=np.int32).tobytes(order="C"))
self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges))
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
Assert.lt(self._count_cumsum[-1], np.iinfo(np.int32).max)
self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C"))
super().__exit__(exc_type, exc_val, exc_tb)
@classmethod
def _get_config_class(cls) -> type[RangeReaderConfig]:
return RangeReaderConfig
def _get_config(self, begin: int, end: int):
return RangeReaderConfig(
begin=begin,
end=end,
num_documents=len(self._count_cumsum) - 1,
num_ranges=self._count_cumsum[-1],
preprocessing=self._preprocessing_config,
)