diff --git a/docs/VLLM_QUANTIZATION_STATUS.md b/docs/VLLM_QUANTIZATION_STATUS.md new file mode 100644 index 0000000..97f4319 --- /dev/null +++ b/docs/VLLM_QUANTIZATION_STATUS.md @@ -0,0 +1,75 @@ +# vLLM Quantization — Status Summary + +**Date**: 2026-05-13 +**Branch**: `feature/quantized-inference-testing` +**Status**: Partially fixed, needs GPU validation + +--- + +## Issues Found and Fixed + +### 1. BnB dtype error in `SwitchedLoRALinear` (commit `93db1b0`) + +**Problem**: When BitsAndBytes quantizes base layers, the weight dtype becomes `uint8`. +`SwitchedLoRALinear.__init__` used `base_layer.weight.dtype` to allocate LoRA params, +causing them to be `uint8` instead of float — breaking LoRA computation. + +**Fix** (`src/granite_switch/vllm/core/lora.py:107-110`): +```python +if not dtype.is_floating_point: + dtype = torch.bfloat16 +``` + +### 2. OOM on A100 80GB with BnB loading (commit `3a91e71`) + +**Problem**: BnB holds full-precision weights during quantization pass. Default +`gpu_memory_utilization=0.8` caused OOM. + +**Fix**: Lowered to `0.5` in the diagnostic script. + +### 3. xdist group marker (added `7a2b02b`, reverted `9cd26f6`) + +Attempted to group quantization tests for xdist but reverted — not the right approach. + +--- + +## What Exists + +| File | Purpose | Status | +|------|---------|--------| +| `quantization/test_quantized_inference.py` | Standalone vLLM BnB diagnostic script | Ready to run on GPU, no pytest | +| `quantization/quantization_testing.ipynb` | Exploration notebook (FP8/GPTQ/AWQ via API server) | Unused, no outputs | +| `tests/hf/test_quantization.py` | Formal HF backend pytest (BnB NF4/FP4, Quanto INT4/FP8) | Passing | + +--- + +## What's Missing + +1. **No pytest coverage for vLLM + BnB** — only the standalone diagnostic script exists +2. **No GPU validation yet** — the dtype fix (`93db1b0`) hasn't been confirmed on a real GPU + with the full vLLM loading path +3. **FP8/GPTQ/AWQ on vLLM** — the notebook was designed for this but never executed + +--- + +## Key Parameters for vLLM BnB Loading + +```python +from vllm import LLM + +llm = LLM( + model="ibm-granite/granite-switch-4.1-3b-preview", + quantization="bitsandbytes", + load_format="bitsandbytes", + gpu_memory_utilization=0.5, # Lower than default — BnB needs headroom + enforce_eager=True, # CUDA graphs don't work with BnB +) +``` + +--- + +## Next Steps (when resuming) + +1. Run `quantization/test_quantized_inference.py` on a GPU pod to confirm BnB fix works end-to-end +2. If it passes, convert into a proper pytest test in `tests/vllm/test_quantization.py` +3. Investigate FP8/GPTQ support via vLLM (different from BnB path) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 277d947..852be93 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -143,6 +143,11 @@ class GraniteSwitchPreTrainedModel(GraniteMoeHybridPreTrainedModel): config_class = GraniteSwitchConfig base_model_prefix = "model" _no_split_modules = ["GraniteSwitchAttentionDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [ + r"model\.adapter_token_ids", + r"model\.token_to_group_mask", + r"model\.adapter_hiding_matrix", + ] _is_stateful = True @@ -167,8 +172,9 @@ def __init__(self, config: GraniteSwitchConfig): # --- Control token buffers --- # All values come from config (serialized in config.json). - # Stored as buffers (not nn.Parameter) so they follow .to(device) - # without appearing as trainable parameters. + # Stored as non-persistent buffers so they follow .to(device) + # without appearing as trainable parameters or in state_dict() + # (avoids accelerate device_map placement errors on multi-GPU). # # adapter_token_ids: Hidden-flavor control tokens, one per adapter. # The switch layer detects these in the input sequence to determine @@ -180,12 +186,14 @@ def __init__(self, config: GraniteSwitchConfig): self.register_buffer( "adapter_token_ids", torch.tensor(token_ids, dtype=torch.long), + persistent=False, ) else: # Build script hasn't populated yet — zeros placeholder self.register_buffer( "adapter_token_ids", torch.zeros(config.num_adapters, dtype=torch.long), + persistent=False, ) # --- Hiding group buffers --- @@ -208,7 +216,7 @@ def __init__(self, config: GraniteSwitchConfig): for g, tids in group_token_ids.items(): for tid in tids: token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) + self.register_buffer("token_to_group_mask", token_to_group_mask, persistent=False) # adapter_hiding_matrix: [num_adapter_slots, num_groups] boolean. # Index 0 = base, 1+ = adapters. True if adapter hides group g. @@ -216,6 +224,7 @@ def __init__(self, config: GraniteSwitchConfig): self.register_buffer( "adapter_hiding_matrix", torch.tensor(policy_matrix, dtype=torch.bool), + persistent=False, ) else: self.token_to_group_mask = None @@ -257,6 +266,34 @@ def __init__(self, config: GraniteSwitchConfig): # Initialize weights self.post_init() + def _rebuild_hiding_buffers(self, device: torch.device): + """Rebuild hiding group buffers from config on the given device. + + Called on first forward when accelerate's init_empty_weights() has + zeroed out the non-persistent buffers during from_pretrained. + """ + config = self.config + num_groups = config.num_hiding_groups + if num_groups > 0: + group_token_ids = config.get_hiding_group_token_ids() + all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] + if config.adapter_token_ids: + all_known_ids.extend(config.adapter_token_ids) + max_tid = max(all_known_ids) if all_known_ids else -1 + table_size = max(config.vocab_size, max_tid + 1) + token_to_group_mask = torch.zeros( + table_size, num_groups, dtype=torch.bool, device=device + ) + for g, tids in group_token_ids.items(): + for tid in tids: + token_to_group_mask[tid, g] = True + self.token_to_group_mask = token_to_group_mask + + policy_matrix = config.get_adapter_hiding_policy_matrix() + self.adapter_hiding_matrix = torch.tensor( + policy_matrix, dtype=torch.bool, device=device + ) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -318,6 +355,21 @@ def forward( # Compute adapter_indices using switch (BEFORE RoPE for position correction) hidden_count = None if self.switch is not None: + # Non-persistent buffers are zeroed by accelerate's init_empty_weights() + # during from_pretrained with device_map. Rebuild from config on first forward. + device = input_ids.device if input_ids is not None else inputs_embeds.device + if self.adapter_token_ids.sum() == 0 and self.config.adapter_token_ids: + self.adapter_token_ids = torch.tensor( + self.config.adapter_token_ids, dtype=torch.long, device=device + ) + self._rebuild_hiding_buffers(device) + elif self.adapter_token_ids.device != device: + self.adapter_token_ids = self.adapter_token_ids.to(device) + if self.token_to_group_mask is not None: + self.token_to_group_mask = self.token_to_group_mask.to(device) + if self.adapter_hiding_matrix is not None: + self.adapter_hiding_matrix = self.adapter_hiding_matrix.to(device) + adapter_indices = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids,