codefuse-chatbot/examples/gptq.py

122 lines
4.2 KiB
Python

from dataclasses import dataclass, field
import os
from os.path import isdir, isfile
from pathlib import Path
import sys
from transformers import AutoTokenizer
@dataclass
class GptqConfig:
ckpt: str = field(
default=None,
metadata={
"help": "Load quantized model. The path to the local GPTQ checkpoint."
},
)
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
groupsize: int = field(
default=-1,
metadata={"help": "Groupsize to use for quantization; default uses full row."},
)
act_order: bool = field(
default=True,
metadata={"help": "Whether to apply the activation order GPTQ heuristic"},
)
def load_quant_by_autogptq(model):
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
model = AutoGPTQForCausalLM.from_quantized(model,
inject_fused_attention=False,
inject_fused_mlp=False,
use_cuda_fp16=True,
disable_exllama=False,
device_map='auto'
)
return model
def load_gptq_quantized(model_name, gptq_config: GptqConfig):
print("Loading GPTQ quantized model...")
if gptq_config.act_order:
try:
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")
sys.path.insert(0, module_path)
from llama import load_quant
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# only `fastest-inference-4bit` branch cares about `act_order`
model = load_quant(
model_name,
find_gptq_ckpt(gptq_config),
gptq_config.wbits,
gptq_config.groupsize,
act_order=gptq_config.act_order,
)
except ImportError as e:
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
sys.exit(-1)
else:
# other branches
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = load_quant_by_autogptq(model_name)
return model, tokenizer
# def load_gptq_quantized(model_name, gptq_config: GptqConfig):
# print("Loading GPTQ quantized model...")
# try:
# script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
# module_path = os.path.join(script_path, "repositories/GPTQ-for-LLaMa")
# sys.path.insert(0, module_path)
# from llama import load_quant
# except ImportError as e:
# print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
# print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
# sys.exit(-1)
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# # only `fastest-inference-4bit` branch cares about `act_order`
# if gptq_config.act_order:
# model = load_quant(
# model_name,
# find_gptq_ckpt(gptq_config),
# gptq_config.wbits,
# gptq_config.groupsize,
# act_order=gptq_config.act_order,
# )
# else:
# # other branches
# model = load_quant(
# model_name,
# find_gptq_ckpt(gptq_config),
# gptq_config.wbits,
# gptq_config.groupsize,
# )
# return model, tokenizer
def find_gptq_ckpt(gptq_config: GptqConfig):
if Path(gptq_config.ckpt).is_file():
return gptq_config.ckpt
# for ext in ["*.pt", "*.safetensors",]:
for ext in ["*.pt", "*.bin",]:
matched_result = sorted(Path(gptq_config.ckpt).glob(ext))
if len(matched_result) > 0:
return str(matched_result[-1])
print("Error: gptq checkpoint not found")
sys.exit(1)