64 lines
1.6 KiB
Python
64 lines
1.6 KiB
Python
# encoding: utf-8
|
|
'''
|
|
@author: 温进
|
|
@file: cypher_generator.py
|
|
@time: 2023/11/24 上午10:17
|
|
@desc:
|
|
'''
|
|
from loguru import logger
|
|
|
|
from dev_opsgpt.llm_models.openai_model import getChatModel
|
|
from dev_opsgpt.utils.postprocess import replace_lt_gt
|
|
from langchain.schema import (
|
|
HumanMessage,
|
|
)
|
|
from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT
|
|
|
|
|
|
schema = '''
|
|
Node properties: [{'tag': 'package', 'properties': []}, {'tag': 'class', 'properties': []}, {'tag': 'method', 'properties': []}]
|
|
Edge properties: [{'edge': 'contain', 'properties': []}, {'edge': 'depend', 'properties': []}]
|
|
Relationships: ['(:package)-[:contain]->(:class)', '(:class)-[:contain]->(:method)', '(:package)-[:contain]->(:package)']
|
|
'''
|
|
|
|
|
|
class CypherGenerator:
|
|
def __init__(self):
|
|
self.model = getChatModel()
|
|
|
|
def get_cypher(self, query: str):
|
|
'''
|
|
get cypher from query
|
|
@param query:
|
|
@return:
|
|
'''
|
|
content = NGQL_GENERATION_PROMPT.format(schema=schema, question=query)
|
|
|
|
ans = ''
|
|
message = [HumanMessage(content=content)]
|
|
chat_res = self.model.predict_messages(message)
|
|
ans = chat_res.content
|
|
|
|
ans = replace_lt_gt(ans)
|
|
|
|
ans = self.post_process(ans)
|
|
return ans
|
|
|
|
def post_process(self, cypher_res: str):
|
|
'''
|
|
判断是否为正确的 cypher
|
|
@param cypher_res:
|
|
@return:
|
|
'''
|
|
if '(' not in cypher_res or ')' not in cypher_res:
|
|
return ''
|
|
|
|
return cypher_res
|
|
|
|
|
|
if __name__ == '__main__':
|
|
query = '代码中一共有多少个类'
|
|
cg = CypherGenerator(engine='openai')
|
|
|
|
cg.get_cypher(query)
|