Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,10 @@ dist/
.idea
*.egg-info/
*.safetensors
outputs/
outputs/
outputs_past/
packed_delta
.cache
delta_outputs/
.io/
outputs_exp/
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
175 changes: 175 additions & 0 deletions autotune_gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import math
import time
import torch
import transformers
import torch.nn as nn
from loguru import logger
from quant import quantize

DEBUG = False

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

def hard_threshold(x, fraction_of_zero=0.1):
if fraction_of_zero == 0:
return x
y, _ = torch.sort(x.view(-1).abs().clone())
num_params = torch.numel(x)
thresh_index = int(num_params * fraction_of_zero)
threshold = y[thresh_index]
mask = x.abs().clone().gt(threshold).type(torch.cuda.HalfTensor)
return mask * x

class GPTQ:
def __init__(self, layer):
self.layer = layer
self.original_weight = layer.weight.data.clone()
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0

def add_batch(self, inp, out):
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()

self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())

def fasterquant(
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, write=True, sparsity=None
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

tick = time.time()

if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)

H = self.H
if write:
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]

Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)

q = quantize(
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

if DEBUG:
pass
#self.layer.weight.data[:, :i2] = Q[:, :i2]
#self.layer.weight.data[:, i2:] = W[:, i2:]
#print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
#print(torch.sum(Losses))

torch.cuda.synchronize()
total_time = time.time() - tick
# print('time %.2f' % total_time)
# error = torch.sum(Losses).item()
# print('error', error)

if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]

if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
# here report the loss of the quantized layer vs. the original layer
new_weight = Q.reshape(self.layer.weight.shape).to(self.layer.weight.dtype)
losses = {}
if sparsity is None:
sparsed_new_weight = new_weight
losses[0] = torch.sum((self.inp1 @ (sparsed_new_weight.T) - self.out1) ** 2)
else:
for s_sity in sparsity:
if write:
logger.info(f"HT with: sparsity={s_sity}")
sparsed_new_weight = hard_threshold(new_weight, fraction_of_zero=s_sity)
losses[s_sity] = torch.sum((self.inp1 @ (sparsed_new_weight.T) - self.out1) ** 2)
if losses[s_sity] > 100:
logger.info(f"{sparsed_new_weight}")
logger.info(f"{new_weight}")
logger.info(f"{sparsed_new_weight.shape}")
logger.info(f"{torch.max(torch.abs(self.inp1 @ (sparsed_new_weight.T) - self.out1))}")
if write:
self.layer.weight.data = sparsed_new_weight
return losses

def free(self):
if DEBUG:
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
15 changes: 9 additions & 6 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from modelutils import get_opt
from evaluation import opt_eval
from datautils import get_loaders
from save_and_load import save_lr_tensors, load_lr_tensors
from core_compression import opt_delta_lr
from save_and_load import save_lr_tensors, load_lr_tensors

@torch.no_grad()
def quantize_with_lowrank(base_model, target_model, dataloader, rank, wbits, n_samples):
def quantize_with_lowrank(base_model, target_model, dataloader, rank, wbits, n_samples, decompose_only=False):
# first do low rank approximation
# then quantize
original_finetuned_model = copy.deepcopy(target_model)
Expand All @@ -29,7 +29,8 @@ def quantize_with_lowrank(base_model, target_model, dataloader, rank, wbits, n_s
'percdamp': 0.01,
'groupsize': -1,
'actorder': False,
}
},
decompose_only=decompose_only
)

target_model.to(base_model.device)
Expand All @@ -53,9 +54,11 @@ def quantize_with_lowrank(base_model, target_model, dataloader, rank, wbits, n_s
argparser.add_argument('--save', type=str, default='', help='Path to save the quantized model')
argparser.add_argument('--wbits', type=int, default=8, help='Number of bits to use for quantization')
argparser.add_argument('--sym', action='store_true', default=True, help='Whether to use symmetric quantization')
argparser.add_argument('--decompose-only', action='store_true', default=False, help='Whether to use quantization')
argparser.add_argument('--trits', action='store_true', default=False, help='Whether to use trits')

args = argparser.parse_args()
print(args)
seed = args.seed

base_model = get_opt(args.base_model)
Expand All @@ -77,10 +80,10 @@ def quantize_with_lowrank(base_model, target_model, dataloader, rank, wbits, n_s
trainloader,
args.rank,
args.wbits,
args.nsamples
args.nsamples,
args.decompose_only,
)
if args.save:
save_lr_tensors(lr_tensors, f"{args.save}/{args.target_model.replace('/', '.')}-r{args.rank}-w{args.wbits}-lr.safetensors")

save_lr_tensors(lr_tensors, f"{args.save}/{args.target_model.replace('/', '.')}-r{args.rank}-w{args.wbits}-decompose.{args.decompose_only}-lr.safetensors")
ppl = opt_eval(target_model, loader_enc, args, target_model.device)
logger.info(f"Perplexity: {ppl}")
Loading