codefuse-chatbot/coagent/service/faiss_db_service.py

193 lines
7.0 KiB
Python

import os
import shutil
from typing import List
from functools import lru_cache
from loguru import logger
# from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores.utils import DistanceStrategy
# from configs.model_config import (
# KB_ROOT_PATH,
# CACHED_VS_NUM,
# EMBEDDING_MODEL,
# EMBEDDING_DEVICE,
# SCORE_THRESHOLD,
# FAISS_NORMALIZE_L2
# )
# from configs.model_config import embedding_model_dict
from coagent.base_configs.env_config import (
KB_ROOT_PATH,
CACHED_VS_NUM, SCORE_THRESHOLD, FAISS_NORMALIZE_L2
)
from .base_service import KBService, SupportedVSType
from coagent.utils.path_utils import *
from coagent.orm.utils import DocumentFile
from coagent.utils.server_utils import torch_gc
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
from coagent.embeddings.faiss_m import FAISS
from coagent.llm_models.llm_config import EmbedConfig
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
return hash(self.model_name)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
# @lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embed_config: EmbedConfig,
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
kb_root_path: str = KB_ROOT_PATH,
):
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name, kb_root_path)
if embeddings is None:
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
# search_index.embedding_function = embeddings.embed_documents
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
vs_path: str
kb_path: str
def vs_type(self) -> str:
return SupportedVSType.FAISS
@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def do_init(self):
self.kb_path = FaissKBService.get_kb_path(self.kb_name)
self.vs_path = FaissKBService.get_vs_path(self.kb_name)
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
load_vector_store(self.kb_name, self.embed_config)
def do_drop_kb(self):
self.clear_vs()
shutil.rmtree(self.kb_path)
def do_search(self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
self.embed_config,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name),
kb_root_path=self.kb_root_path)
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def get_all_documents(self, embeddings: Embeddings = None,):
search_index = load_vector_store(self.kb_name,
self.embed_config,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name),
kb_root_path=self.kb_root_path)
return search_index.get_all_documents()
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
**kwargs,
):
vector_store = load_vector_store(self.kb_name,
self.embed_config,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
kb_root_path=self.kb_root_path)
vector_store.embedding_function = embeddings.embed_documents
logger.info("loaded docs, docs' lens is {}".format(len(docs)))
vector_store.add_documents(docs)
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete_doc(self,
kb_file: DocumentFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name,
self.embed_config,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
kb_root_path=self.kb_root_path)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
def do_clear_vs(self):
if os.path.exists(self.vs_path):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
refresh_vs_cache(self.kb_name)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
return "in_db"
content_path = os.path.join(self.kb_path, "content")
if os.path.isfile(os.path.join(content_path, file_name)):
return "in_folder"
else:
return False