2023-09-28 10:58:58 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
import os
|
|
|
|
from os.path import isdir, isfile
|
|
|
|
from pathlib import Path
|
|
|
|
import sys
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GptqConfig:
|
|
|
|
ckpt: str = field(
|
|
|
|
default=None,
|
|
|
|
metadata={
|
|
|
|
"help": "Load quantized model. The path to the local GPTQ checkpoint."
|
|
|
|
},
|
|
|
|
)
|
|
|
|
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
|
|
|
|
groupsize: int = field(
|
|
|
|
default=-1,
|
|
|
|
metadata={"help": "Groupsize to use for quantization; default uses full row."},
|
|
|
|
)
|
|
|
|
act_order: bool = field(
|
|
|
|
default=True,
|
|
|
|
metadata={"help": "Whether to apply the activation order GPTQ heuristic"},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def load_quant_by_autogptq(model):
|
2023-12-26 11:41:53 +08:00
|
|
|
# qwen-72b-int4 use these code
|
|
|
|
from modelscope import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
# Note: The default behavior now has injection attack prevention off.
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model, revision='master', trust_remote_code=True)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
model, device_map="auto",
|
|
|
|
trust_remote_code=True
|
|
|
|
).eval()
|
|
|
|
return model, tokenizer
|
|
|
|
# codellama-34b-int4 use these code
|
|
|
|
# from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
|
|
|
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
|
|
|
|
# model = AutoGPTQForCausalLM.from_quantized(model, inject_fused_attention=False,trust_remote_code=True,
|
|
|
|
# inject_fused_mlp=False,use_cuda_fp16=True,disable_exllama=False,device_map='auto')
|
|
|
|
# return model, tokenizer
|
2023-09-28 10:58:58 +08:00
|
|
|
|
|
|
|
def load_gptq_quantized(model_name, gptq_config: GptqConfig):
|
|
|
|
print("Loading GPTQ quantized model...")
|
2023-12-26 11:41:53 +08:00
|
|
|
model, tokenizer = load_quant_by_autogptq(model_name)
|
2023-09-28 10:58:58 +08:00
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
# def load_gptq_quantized(model_name, gptq_config: GptqConfig):
|
|
|
|
# print("Loading GPTQ quantized model...")
|
|
|
|
|
|
|
|
# try:
|
|
|
|
# script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
|
|
# module_path = os.path.join(script_path, "repositories/GPTQ-for-LLaMa")
|
|
|
|
|
|
|
|
# sys.path.insert(0, module_path)
|
|
|
|
# from llama import load_quant
|
|
|
|
# except ImportError as e:
|
|
|
|
# print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
|
|
|
|
# print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
|
|
|
|
# sys.exit(-1)
|
|
|
|
|
|
|
|
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
|
|
|
# # only `fastest-inference-4bit` branch cares about `act_order`
|
|
|
|
# if gptq_config.act_order:
|
|
|
|
# model = load_quant(
|
|
|
|
# model_name,
|
|
|
|
# find_gptq_ckpt(gptq_config),
|
|
|
|
# gptq_config.wbits,
|
|
|
|
# gptq_config.groupsize,
|
|
|
|
# act_order=gptq_config.act_order,
|
|
|
|
# )
|
|
|
|
# else:
|
|
|
|
# # other branches
|
|
|
|
# model = load_quant(
|
|
|
|
# model_name,
|
|
|
|
# find_gptq_ckpt(gptq_config),
|
|
|
|
# gptq_config.wbits,
|
|
|
|
# gptq_config.groupsize,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def find_gptq_ckpt(gptq_config: GptqConfig):
|
|
|
|
if Path(gptq_config.ckpt).is_file():
|
|
|
|
return gptq_config.ckpt
|
|
|
|
|
|
|
|
# for ext in ["*.pt", "*.safetensors",]:
|
|
|
|
for ext in ["*.pt", "*.bin",]:
|
|
|
|
matched_result = sorted(Path(gptq_config.ckpt).glob(ext))
|
|
|
|
if len(matched_result) > 0:
|
|
|
|
return str(matched_result[-1])
|
|
|
|
|
|
|
|
print("Error: gptq checkpoint not found")
|
|
|
|
sys.exit(1)
|