Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions docs/guides/data_input_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,34 @@
-->

(data-input-pipeline)=

# Data pipelines

Currently MaxText has three data input pipelines:

| Pipeline | Dataset formats | Features | Limitations |
| -------- | --------------- | -------- | ----------- |
| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended)| [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))<br>[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle <br>With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | |
| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |
| Pipeline | Dataset formats | Features | Limitations |
| ------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended) | [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))<br>[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview))<br>[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle <br>With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | |
| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |

## Multihost dataloading best practice

Training in a multi-host environment presents unique challenges for data input pipelines. An effective data loading strategy must address three key issues:

1. **Concurrent access**: Multiple hosts need to read from the same dataset simultaneously without causing conflicts.
2. **Data uniqueness**: Each host must be fed a unique, non-overlapping subset of the data to ensure the model sees each example correctly.
3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging.
The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access.
3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging.
The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access.

### Random access dataset (Recommended)

Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.<br>
In MaxText, this is best supported by the ArrayRecord format using the Grain input pipeline. This approach gracefully handles the key challenges:
* **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.

* **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.
- **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.

- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.

```{note}
When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this.
Expand All @@ -48,12 +53,14 @@ If all hosts exhaust their data before the target step count is reached, both `t
```

### Sequential access dataset
* **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
* **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.

```{toctree}
:hidden:
- **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
- **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.

```{toctree}
---
hidden:
---
data_input_pipeline/data_input_grain
data_input_pipeline/data_input_hf
data_input_pipeline/data_input_tfds
Expand Down
6 changes: 3 additions & 3 deletions docs/guides/data_input_pipeline/data_input_grain.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state

## Using Grain

1. Grain currently supports two data formats: [ArrayRecord](https://github.com/google/array_record) (random access) and [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
- **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
2. When the dataset is hosted on a Cloud Storage bucket, Grain can read it through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.

```sh
bash tools/setup/setup_gcsfuse.sh \
Expand All @@ -45,7 +45,7 @@ MOUNT_PATH=${MOUNT_PATH?} \

Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)).

1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.

2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.

Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ grain_ram_budget_mb: 1024 # RAM budget (MB) for auto-tuning worker count. Only u
grain_num_threads_eval: 16
grain_prefetch_buffer_size_eval: 500
grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.
grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord.
# for using pathways
colocated_python_data_input: False # experimental feature, under testing

Expand Down
15 changes: 6 additions & 9 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,17 +1030,13 @@ class GrainDataset(BaseModel):
"",
description="Path to a JSON file specifying the mixture weights for Grain training data.",
)
grain_file_type: str = Field("arrayrecord", description="File type for Grain data.")
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
grain_per_worker_buffer_size: int = Field(
1,
description="Buffer size for each worker for Grain data loading during training.",
grain_file_type: str = Field(
"arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet."
)
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
grain_per_worker_buffer_size: int = Field(1, description="Per-worker buffer size for Grain train data loading.")
grain_worker_count_eval: int = Field(1, description="Number of workers for Grain eval data loading.")
grain_per_worker_buffer_size_eval: int = Field(
1,
description="Buffer size for each worker for Grain data loading during evaluation.",
)
grain_per_worker_buffer_size_eval: int = Field(1, description="Per-worker buffer size for Grain eval data loading.")
grain_ram_budget_mb: int = Field(1024, description="RAM budget (MB) for auto-tuning worker count.")
grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.")
grain_prefetch_buffer_size: int = Field(500, description="Prefetch buffer size for Grain ReadOptions during training.")
Expand All @@ -1052,6 +1048,7 @@ class GrainDataset(BaseModel):
16,
description="Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.",
)
grain_shuffle_buffer_size: int = Field(100, description="Shuffle buffer size when using Parquet or TFRecord.")


class FineTuning(BaseModel):
Expand Down
Loading
Loading