2024-01-26 14:03:25 +08:00
|
|
|
import os
|
|
|
|
from functools import lru_cache
|
|
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
2024-03-12 15:31:06 +08:00
|
|
|
from langchain.embeddings.base import Embeddings
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
# from configs.model_config import embedding_model_dict
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(1)
|
|
|
|
def load_embeddings(model: str, device: str, embedding_model_dict: dict):
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
|
|
|
model_kwargs={'device': device})
|
|
|
|
return embeddings
|
|
|
|
|
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
# @lru_cache(1)
|
|
|
|
def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None):
|
|
|
|
if langchain_embeddings:
|
|
|
|
return langchain_embeddings
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=model_path,
|
|
|
|
model_kwargs={'device': device})
|
|
|
|
return embeddings
|
|
|
|
|