Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions docs/VLLM_QUANTIZATION_STATUS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# vLLM Quantization — Status Summary
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this file


**Date**: 2026-05-13
**Branch**: `feature/quantized-inference-testing`
**Status**: Partially fixed, needs GPU validation

---

## Issues Found and Fixed

### 1. BnB dtype error in `SwitchedLoRALinear` (commit `93db1b0`)

**Problem**: When BitsAndBytes quantizes base layers, the weight dtype becomes `uint8`.
`SwitchedLoRALinear.__init__` used `base_layer.weight.dtype` to allocate LoRA params,
causing them to be `uint8` instead of float — breaking LoRA computation.

**Fix** (`src/granite_switch/vllm/core/lora.py:107-110`):
```python
if not dtype.is_floating_point:
dtype = torch.bfloat16
```

### 2. OOM on A100 80GB with BnB loading (commit `3a91e71`)

**Problem**: BnB holds full-precision weights during quantization pass. Default
`gpu_memory_utilization=0.8` caused OOM.

**Fix**: Lowered to `0.5` in the diagnostic script.

### 3. xdist group marker (added `7a2b02b`, reverted `9cd26f6`)

Attempted to group quantization tests for xdist but reverted — not the right approach.

---

## What Exists

| File | Purpose | Status |
|------|---------|--------|
| `quantization/test_quantized_inference.py` | Standalone vLLM BnB diagnostic script | Ready to run on GPU, no pytest |
| `quantization/quantization_testing.ipynb` | Exploration notebook (FP8/GPTQ/AWQ via API server) | Unused, no outputs |
| `tests/hf/test_quantization.py` | Formal HF backend pytest (BnB NF4/FP4, Quanto INT4/FP8) | Passing |

---

## What's Missing

1. **No pytest coverage for vLLM + BnB** — only the standalone diagnostic script exists
2. **No GPU validation yet** — the dtype fix (`93db1b0`) hasn't been confirmed on a real GPU
with the full vLLM loading path
3. **FP8/GPTQ/AWQ on vLLM** — the notebook was designed for this but never executed

---

## Key Parameters for vLLM BnB Loading

```python
from vllm import LLM

llm = LLM(
model="ibm-granite/granite-switch-4.1-3b-preview",
quantization="bitsandbytes",
load_format="bitsandbytes",
gpu_memory_utilization=0.5, # Lower than default — BnB needs headroom
enforce_eager=True, # CUDA graphs don't work with BnB
)
```

---

## Next Steps (when resuming)

1. Run `quantization/test_quantized_inference.py` on a GPU pod to confirm BnB fix works end-to-end
2. If it passes, convert into a proper pytest test in `tests/vllm/test_quantization.py`
3. Investigate FP8/GPTQ support via vLLM (different from BnB path)
58 changes: 55 additions & 3 deletions src/granite_switch/hf/modeling_granite_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ class GraniteSwitchPreTrainedModel(GraniteMoeHybridPreTrainedModel):
config_class = GraniteSwitchConfig
base_model_prefix = "model"
_no_split_modules = ["GraniteSwitchAttentionDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [
r"model\.adapter_token_ids",
r"model\.token_to_group_mask",
r"model\.adapter_hiding_matrix",
]
_is_stateful = True


Expand All @@ -167,8 +172,9 @@ def __init__(self, config: GraniteSwitchConfig):

# --- Control token buffers ---
# All values come from config (serialized in config.json).
# Stored as buffers (not nn.Parameter) so they follow .to(device)
# without appearing as trainable parameters.
# Stored as non-persistent buffers so they follow .to(device)
# without appearing as trainable parameters or in state_dict()
# (avoids accelerate device_map placement errors on multi-GPU).
#
# adapter_token_ids: Hidden-flavor control tokens, one per adapter.
# The switch layer detects these in the input sequence to determine
Expand All @@ -180,12 +186,14 @@ def __init__(self, config: GraniteSwitchConfig):
self.register_buffer(
"adapter_token_ids",
torch.tensor(token_ids, dtype=torch.long),
persistent=False,
)
else:
# Build script hasn't populated yet — zeros placeholder
self.register_buffer(
"adapter_token_ids",
torch.zeros(config.num_adapters, dtype=torch.long),
persistent=False,
)

# --- Hiding group buffers ---
Expand All @@ -208,14 +216,15 @@ def __init__(self, config: GraniteSwitchConfig):
for g, tids in group_token_ids.items():
for tid in tids:
token_to_group_mask[tid, g] = True
self.register_buffer("token_to_group_mask", token_to_group_mask)
self.register_buffer("token_to_group_mask", token_to_group_mask, persistent=False)

# adapter_hiding_matrix: [num_adapter_slots, num_groups] boolean.
# Index 0 = base, 1+ = adapters. True if adapter hides group g.
policy_matrix = config.get_adapter_hiding_policy_matrix()
self.register_buffer(
"adapter_hiding_matrix",
torch.tensor(policy_matrix, dtype=torch.bool),
persistent=False,
)
else:
self.token_to_group_mask = None
Expand Down Expand Up @@ -257,6 +266,34 @@ def __init__(self, config: GraniteSwitchConfig):
# Initialize weights
self.post_init()

def _rebuild_hiding_buffers(self, device: torch.device):
"""Rebuild hiding group buffers from config on the given device.

Called on first forward when accelerate's init_empty_weights() has
zeroed out the non-persistent buffers during from_pretrained.
"""
config = self.config
num_groups = config.num_hiding_groups
if num_groups > 0:
group_token_ids = config.get_hiding_group_token_ids()
all_known_ids = [tid for tids in group_token_ids.values() for tid in tids]
if config.adapter_token_ids:
all_known_ids.extend(config.adapter_token_ids)
max_tid = max(all_known_ids) if all_known_ids else -1
table_size = max(config.vocab_size, max_tid + 1)
token_to_group_mask = torch.zeros(
table_size, num_groups, dtype=torch.bool, device=device
)
for g, tids in group_token_ids.items():
for tid in tids:
token_to_group_mask[tid, g] = True
self.token_to_group_mask = token_to_group_mask

policy_matrix = config.get_adapter_hiding_policy_matrix()
self.adapter_hiding_matrix = torch.tensor(
policy_matrix, dtype=torch.bool, device=device
)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -318,6 +355,21 @@ def forward(
# Compute adapter_indices using switch (BEFORE RoPE for position correction)
hidden_count = None
if self.switch is not None:
# Non-persistent buffers are zeroed by accelerate's init_empty_weights()
# during from_pretrained with device_map. Rebuild from config on first forward.
device = input_ids.device if input_ids is not None else inputs_embeds.device
if self.adapter_token_ids.sum() == 0 and self.config.adapter_token_ids:
self.adapter_token_ids = torch.tensor(
self.config.adapter_token_ids, dtype=torch.long, device=device
)
self._rebuild_hiding_buffers(device)
elif self.adapter_token_ids.device != device:
self.adapter_token_ids = self.adapter_token_ids.to(device)
if self.token_to_group_mask is not None:
self.token_to_group_mask = self.token_to_group_mask.to(device)
if self.adapter_hiding_matrix is not None:
self.adapter_hiding_matrix = self.adapter_hiding_matrix.to(device)

adapter_indices = self.switch(
input_ids=input_ids,
adapter_token_ids=self.adapter_token_ids,
Expand Down