codefuse-chatbot/coagent/utils/path_utils.py

71 lines
2.6 KiB
Python
Raw Normal View History

2023-09-28 10:58:58 +08:00
import os
from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader
from coagent.document_loaders import JSONLLoader, JSONLoader
# from configs.model_config import (
# embedding_model_dict,
# KB_ROOT_PATH,
# )
2023-09-28 10:58:58 +08:00
from loguru import logger
LOADERNAME2LOADER_DICT = {
"UnstructuredFileLoader": UnstructuredFileLoader,
"CSVLoader": CSVLoader,
"PyPDFLoader": PyPDFLoader,
"TextLoader": TextLoader,
"PythonLoader": PythonLoader,
"JSONLoader": JSONLoader,
"JSONLLoader": JSONLLoader
}
LOADER2EXT_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.md', '.msg', '.rst',
'.rtf', '.xml',
'.doc', '.docx', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'],
"CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"],
"TextLoader": ['.txt'],
"PythonLoader": ['.py'],
"JSONLoader": ['.json'],
"JSONLLoader": ['.jsonl']
2023-09-28 10:58:58 +08:00
}
EXT2LOADER_DICT = {ext: LOADERNAME2LOADER_DICT[k] for k, exts in LOADER2EXT_DICT.items() for ext in exts}
SUPPORTED_EXTS = [ext for sublist in LOADER2EXT_DICT.values() for ext in sublist]
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
def get_kb_path(knowledge_base_name: str, kb_root_path: str):
return os.path.join(kb_root_path, knowledge_base_name)
2023-09-28 10:58:58 +08:00
def get_doc_path(knowledge_base_name: str, kb_root_path: str):
return os.path.join(get_kb_path(knowledge_base_name, kb_root_path), "content")
2023-09-28 10:58:58 +08:00
def get_vs_path(knowledge_base_name: str, kb_root_path: str):
return os.path.join(get_kb_path(knowledge_base_name, kb_root_path), "vector_store")
2023-09-28 10:58:58 +08:00
def get_file_path(knowledge_base_name: str, doc_name: str, kb_root_path: str):
return os.path.join(get_doc_path(knowledge_base_name, kb_root_path), doc_name)
2023-09-28 10:58:58 +08:00
def list_kbs_from_folder(kb_root_path: str):
return [f for f in os.listdir(kb_root_path)
if os.path.isdir(os.path.join(kb_root_path, f))]
2023-09-28 10:58:58 +08:00
def list_docs_from_folder(kb_name: str, kb_root_path: str):
doc_path = get_doc_path(kb_name, kb_root_path)
if os.path.exists(doc_path):
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
return []
2023-09-28 10:58:58 +08:00
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER2EXT_DICT.items():
if file_extension in extensions:
return LoaderClass