From fff3b85bf1386beb3dd022828aa3d9bf3f19af85 Mon Sep 17 00:00:00 2001 From: itzikvaisman Date: Wed, 13 May 2026 10:52:00 +0300 Subject: [PATCH 1/4] Make HF model buffers non-persistent to fix multi-GPU device_map error accelerate raises ValueError when device_map="auto" encounters persistent buffers (adapter_token_ids, token_to_group_mask, adapter_hiding_matrix) that it cannot assign to a device. These buffers are config-derived metadata reconstructed at __init__ time, so they don't need to be in state_dict(). Also adds docs/VLLM_QUANTIZATION_STATUS.md summarizing vLLM quantization work status for future reference. Fixes: generative-computing/granite-switch#16 --- docs/VLLM_QUANTIZATION_STATUS.md | 75 +++++++++++++++++++ .../hf/modeling_granite_switch.py | 10 ++- 2 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 docs/VLLM_QUANTIZATION_STATUS.md diff --git a/docs/VLLM_QUANTIZATION_STATUS.md b/docs/VLLM_QUANTIZATION_STATUS.md new file mode 100644 index 0000000..97f4319 --- /dev/null +++ b/docs/VLLM_QUANTIZATION_STATUS.md @@ -0,0 +1,75 @@ +# vLLM Quantization — Status Summary + +**Date**: 2026-05-13 +**Branch**: `feature/quantized-inference-testing` +**Status**: Partially fixed, needs GPU validation + +--- + +## Issues Found and Fixed + +### 1. BnB dtype error in `SwitchedLoRALinear` (commit `93db1b0`) + +**Problem**: When BitsAndBytes quantizes base layers, the weight dtype becomes `uint8`. +`SwitchedLoRALinear.__init__` used `base_layer.weight.dtype` to allocate LoRA params, +causing them to be `uint8` instead of float — breaking LoRA computation. + +**Fix** (`src/granite_switch/vllm/core/lora.py:107-110`): +```python +if not dtype.is_floating_point: + dtype = torch.bfloat16 +``` + +### 2. OOM on A100 80GB with BnB loading (commit `3a91e71`) + +**Problem**: BnB holds full-precision weights during quantization pass. Default +`gpu_memory_utilization=0.8` caused OOM. + +**Fix**: Lowered to `0.5` in the diagnostic script. + +### 3. xdist group marker (added `7a2b02b`, reverted `9cd26f6`) + +Attempted to group quantization tests for xdist but reverted — not the right approach. + +--- + +## What Exists + +| File | Purpose | Status | +|------|---------|--------| +| `quantization/test_quantized_inference.py` | Standalone vLLM BnB diagnostic script | Ready to run on GPU, no pytest | +| `quantization/quantization_testing.ipynb` | Exploration notebook (FP8/GPTQ/AWQ via API server) | Unused, no outputs | +| `tests/hf/test_quantization.py` | Formal HF backend pytest (BnB NF4/FP4, Quanto INT4/FP8) | Passing | + +--- + +## What's Missing + +1. **No pytest coverage for vLLM + BnB** — only the standalone diagnostic script exists +2. **No GPU validation yet** — the dtype fix (`93db1b0`) hasn't been confirmed on a real GPU + with the full vLLM loading path +3. **FP8/GPTQ/AWQ on vLLM** — the notebook was designed for this but never executed + +--- + +## Key Parameters for vLLM BnB Loading + +```python +from vllm import LLM + +llm = LLM( + model="ibm-granite/granite-switch-4.1-3b-preview", + quantization="bitsandbytes", + load_format="bitsandbytes", + gpu_memory_utilization=0.5, # Lower than default — BnB needs headroom + enforce_eager=True, # CUDA graphs don't work with BnB +) +``` + +--- + +## Next Steps (when resuming) + +1. Run `quantization/test_quantized_inference.py` on a GPU pod to confirm BnB fix works end-to-end +2. If it passes, convert into a proper pytest test in `tests/vllm/test_quantization.py` +3. Investigate FP8/GPTQ support via vLLM (different from BnB path) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 277d947..bd43a18 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -167,8 +167,9 @@ def __init__(self, config: GraniteSwitchConfig): # --- Control token buffers --- # All values come from config (serialized in config.json). - # Stored as buffers (not nn.Parameter) so they follow .to(device) - # without appearing as trainable parameters. + # Stored as non-persistent buffers so they follow .to(device) + # without appearing as trainable parameters or in state_dict() + # (avoids accelerate device_map placement errors on multi-GPU). # # adapter_token_ids: Hidden-flavor control tokens, one per adapter. # The switch layer detects these in the input sequence to determine @@ -180,12 +181,14 @@ def __init__(self, config: GraniteSwitchConfig): self.register_buffer( "adapter_token_ids", torch.tensor(token_ids, dtype=torch.long), + persistent=False, ) else: # Build script hasn't populated yet — zeros placeholder self.register_buffer( "adapter_token_ids", torch.zeros(config.num_adapters, dtype=torch.long), + persistent=False, ) # --- Hiding group buffers --- @@ -208,7 +211,7 @@ def __init__(self, config: GraniteSwitchConfig): for g, tids in group_token_ids.items(): for tid in tids: token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) + self.register_buffer("token_to_group_mask", token_to_group_mask, persistent=False) # adapter_hiding_matrix: [num_adapter_slots, num_groups] boolean. # Index 0 = base, 1+ = adapters. True if adapter hides group g. @@ -216,6 +219,7 @@ def __init__(self, config: GraniteSwitchConfig): self.register_buffer( "adapter_hiding_matrix", torch.tensor(policy_matrix, dtype=torch.bool), + persistent=False, ) else: self.token_to_group_mask = None From 3816d411e1f600d5f9734284dced43626d0a3d90 Mon Sep 17 00:00:00 2001 From: itzikvaisman Date: Wed, 13 May 2026 11:13:29 +0300 Subject: [PATCH 2/4] Move non-persistent buffers to input device in forward pass accelerate's device_map does not move non-persistent buffers to GPU. Add a one-time device sync at the start of the switch block so adapter_token_ids, token_to_group_mask, and adapter_hiding_matrix follow the input tensors to the correct device. --- src/granite_switch/hf/modeling_granite_switch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index bd43a18..9da2812 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -322,6 +322,16 @@ def forward( # Compute adapter_indices using switch (BEFORE RoPE for position correction) hidden_count = None if self.switch is not None: + # Non-persistent buffers aren't moved by accelerate's device_map; + # ensure they're on the same device as the input. + device = input_ids.device if input_ids is not None else inputs_embeds.device + if self.adapter_token_ids.device != device: + self.adapter_token_ids = self.adapter_token_ids.to(device) + if self.token_to_group_mask is not None: + self.token_to_group_mask = self.token_to_group_mask.to(device) + if self.adapter_hiding_matrix is not None: + self.adapter_hiding_matrix = self.adapter_hiding_matrix.to(device) + adapter_indices = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids, From 10c6ec1da5568912f86e12e2d5b18fb75a701084 Mon Sep 17 00:00:00 2001 From: itzikvaisman Date: Wed, 13 May 2026 11:29:01 +0300 Subject: [PATCH 3/4] Add _keys_to_ignore_on_load_unexpected for old checkpoint buffers Suppresses "UNEXPECTED" warnings when loading checkpoints that still contain the now non-persistent buffer keys (adapter_token_ids, token_to_group_mask, adapter_hiding_matrix). --- src/granite_switch/hf/modeling_granite_switch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 9da2812..c9b5269 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -143,6 +143,11 @@ class GraniteSwitchPreTrainedModel(GraniteMoeHybridPreTrainedModel): config_class = GraniteSwitchConfig base_model_prefix = "model" _no_split_modules = ["GraniteSwitchAttentionDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [ + r"model\.adapter_token_ids", + r"model\.token_to_group_mask", + r"model\.adapter_hiding_matrix", + ] _is_stateful = True From 059632956e12abb7662cad7b53ad994e0420c430 Mon Sep 17 00:00:00 2001 From: itzikvaisman Date: Wed, 13 May 2026 14:23:43 +0300 Subject: [PATCH 4/4] Rebuild non-persistent buffers from config on first forward accelerate's init_empty_weights() zeros all buffers during from_pretrained with device_map. Since non-persistent buffers aren't restored from the checkpoint, they stay as zeros. Detect this on first forward call and rebuild adapter_token_ids, token_to_group_mask, and adapter_hiding_matrix from config values. --- .../hf/modeling_granite_switch.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index c9b5269..852be93 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -266,6 +266,34 @@ def __init__(self, config: GraniteSwitchConfig): # Initialize weights self.post_init() + def _rebuild_hiding_buffers(self, device: torch.device): + """Rebuild hiding group buffers from config on the given device. + + Called on first forward when accelerate's init_empty_weights() has + zeroed out the non-persistent buffers during from_pretrained. + """ + config = self.config + num_groups = config.num_hiding_groups + if num_groups > 0: + group_token_ids = config.get_hiding_group_token_ids() + all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] + if config.adapter_token_ids: + all_known_ids.extend(config.adapter_token_ids) + max_tid = max(all_known_ids) if all_known_ids else -1 + table_size = max(config.vocab_size, max_tid + 1) + token_to_group_mask = torch.zeros( + table_size, num_groups, dtype=torch.bool, device=device + ) + for g, tids in group_token_ids.items(): + for tid in tids: + token_to_group_mask[tid, g] = True + self.token_to_group_mask = token_to_group_mask + + policy_matrix = config.get_adapter_hiding_policy_matrix() + self.adapter_hiding_matrix = torch.tensor( + policy_matrix, dtype=torch.bool, device=device + ) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -327,10 +355,15 @@ def forward( # Compute adapter_indices using switch (BEFORE RoPE for position correction) hidden_count = None if self.switch is not None: - # Non-persistent buffers aren't moved by accelerate's device_map; - # ensure they're on the same device as the input. + # Non-persistent buffers are zeroed by accelerate's init_empty_weights() + # during from_pretrained with device_map. Rebuild from config on first forward. device = input_ids.device if input_ids is not None else inputs_embeds.device - if self.adapter_token_ids.device != device: + if self.adapter_token_ids.sum() == 0 and self.config.adapter_token_ids: + self.adapter_token_ids = torch.tensor( + self.config.adapter_token_ids, dtype=torch.long, device=device + ) + self._rebuild_hiding_buffers(device) + elif self.adapter_token_ids.device != device: self.adapter_token_ids = self.adapter_token_ids.to(device) if self.token_to_group_mask is not None: self.token_to_group_mask = self.token_to_group_mask.to(device)