codefuse-chatbot/coagent/codechat/code_search/cypher_generator.py

83 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: utf-8
'''
@author: 温进
@file: cypher_generator.py
@time: 2023/11/24 上午10:17
@desc:
'''
from langchain import PromptTemplate
from loguru import logger
from coagent.llm_models.openai_model import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig
from coagent.utils.postprocess import replace_lt_gt
from langchain.schema import (
HumanMessage,
)
from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT, CYPHER_GENERATION_TEMPLATE
schema = '''
Node properties: [{'tag': 'package', 'properties': []}, {'tag': 'class', 'properties': []}, {'tag': 'method', 'properties': []}]
Edge properties: [{'edge': 'contain', 'properties': []}, {'edge': 'depend', 'properties': []}]
Relationships: ['(:package)-[:contain]->(:class)', '(:class)-[:contain]->(:method)', '(:package)-[:contain]->(:package)']
'''
class CypherGenerator:
def __init__(self, llm_config: LLMConfig):
self.model = getChatModelFromConfig(llm_config)
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
Instructions:
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
3. it uses double equals sign for comparison: `==` rather than `=`
4. only use id(Foo) to get the name of node or edge
```\n"""
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Generate Cypher", "Generate NebulaGraph Cypher"
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
self.NGQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
)
def get_cypher(self, query: str):
'''
get cypher from query
@param query:
@return:
'''
content = self.NGQL_GENERATION_PROMPT.format(schema=schema, question=query)
logger.info(content)
ans = ''
message = [HumanMessage(content=content)]
chat_res = self.model.predict_messages(message)
ans = chat_res.content
ans = replace_lt_gt(ans)
ans = self.post_process(ans)
return ans
def post_process(self, cypher_res: str):
'''
判断是否为正确的 cypher
@param cypher_res:
@return:
'''
if '(' not in cypher_res or ')' not in cypher_res:
return ''
return cypher_res
if __name__ == '__main__':
query = '代码库里有哪些函数返回5个就可以'
cg = CypherGenerator()
ans = cg.get_cypher(query)
logger.debug(f'ans=\n{ans}')