三步十分钟,用tidb写一个ai机器人

前言

周末去参加了tidb的ai学习会。第一时间写了这篇文章。让没有去开会的小伙伴。自己在本地电脑上也能体验这个课程。

第一步。拿到密钥 花费1分钟

到智普ai上去注册一个手机号的账号获取到api密钥。

第二步。到国外的tidbserverless注册一个号花费1分钟

或者国内的机器也行。主要tidb服务器在国外速度可能会比较慢。

TiDB Serverless: Cost-Efficient, Simple, Modern MySQL That Scales Effortlessly.

注册好

进去直接有一个tidb可以给你使用。点击生产密码然后获取。用他们给的密钥本地电脑就可以链接到tidb了。

第三步 复制粘贴 把之前的密码填进去8分钟。

$ mysql \
      --comments -u '3grcM9DRGNroFfR.JW2bS3MJ' -p'!C06Ao41fcuqiHnLr4lycGU6FmKAIwFFC0QO' \
      -h gateway01.us-west-2.prod.aws.tidbcloud.com -P 4000 \
      --ssl-mode=VERIFY_IDENTITY --ssl-ca=/etc/pki/tls/certs/ca-bundle.crt
DROP DATABASE IF EXISTS chatdb;
   CREATE DATABASE chatdb;
   EXIT

在本地电脑安装python组件

$ pip install \
      click==8.1.7 \
      PyMySQL==1.1.0 \
      SQLAlchemy==2.0.30 \
      tidb-vector==0.0.9 \
      pydantic==2.7.1 pydantic_core==2.18.2 \
      dspy-ai==2.4.12 \
      langchain-community==0.2.0 \
      wikipedia==1.4.0 \
      pyvis==0.3.1 \
      openai==1.27.0 \
      zhipuai==2.1.3
$ export ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY}
$ cat > build-graph.py <<'EOF'
   import os
   import pymysql
   import dspy
   import enum
   import openai
   
   from zhipuai import ZhipuAI
   from pymysql import Connection
   from pymysql.cursors import DictCursor
   from dspy.functional import TypedPredictor
   from pydantic import BaseModel, Field
   from typing import Mapping, Any, Optional, List
   from langchain_community.document_loaders import WikipediaLoader
   from pyvis.network import Network
   from IPython.display import HTML
   from sqlalchemy import (
       Column,
       Integer,
       String,
       Text,
       JSON,
       ForeignKey,
       BLOB,
       Enum as SQLEnum,
       DateTime,
       URL,
       create_engine,
       or_,
   )
   from sqlalchemy.orm import relationship, Session, declarative_base, joinedload
   from tidb_vector.sqlalchemy import VectorType
   class Entity(BaseModel):
       """List of entities extracted from the text to form the knowledge graph"""
       name: str = Field(
           description="Name of the entity, it should be a clear and concise term"
       )
       description: str = Field(
           description=(
               "Description of the entity, it should be a complete and comprehensive sentence, not few words. "
               "Sample description of entity 'TiDB in-place upgrade': "
               "'Upgrade TiDB component binary files to achieve upgrade, generally use rolling upgrade method'"
           )
       )
   class Relationship(BaseModel):
       """List of relationships extracted from the text to form the knowledge graph"""
       source_entity: str = Field(
           description="Source entity name of the relationship, it should an existing entity in the Entity list"
       )
       target_entity: str = Field(
           description="Target entity name of the relationship, it should an existing entity in the Entity list"
       )
       relationship_desc: str = Field(
           description=(
               "Description of the relationship, it should be a complete and comprehensive sentence, not few words. "
               "Sample relationship description: 'TiDB will release a new LTS version every 6 months.'"
           )
       )
   class KnowledgeGraph(BaseModel):
       """Graph representation of the knowledge for text."""
       entities: List[Entity] = Field(
           description="List of entities in the knowledge graph"
       )
       relationships: List[Relationship] = Field(
           description="List of relationships in the knowledge graph"
       )
   class ExtractGraphTriplet(dspy.Signature):
       text = dspy.InputField(
           desc="a paragraph of text to extract entities and relationships to form a knowledge graph"
       )
       knowledge: KnowledgeGraph = dspy.OutputField(
           desc="Graph representation of the knowledge extracted from the text."
       )
   class Extractor(dspy.Module):
       def __init__(self):
           super().__init__()
           self.prog_graph = TypedPredictor(ExtractGraphTriplet)
       def forward(self, text):
           return self.prog_graph(
               text=text,
               config={
                   "response_format": {"type": "json_object"},
               },
           )
   def interactive_graph(kg: KnowledgeGraph) -> str:
       net = Network(notebook=True, cdn_resources="remote")
       node_map = {}
       for index in range(len(kg.entities)):
           node_map[kg.entities[index].name] = index
           net.add_node(
               index, label=kg.entities[index].name, title=kg.entities[index].description
           )
       for index in range(len(kg.relationships)):
           relation = kg.relationships[index]
           src_index = node_map[relation.source_entity]
           target_index = node_map[relation.target_entity]
           net.add_edge(src_index, target_index)
       filename = "kg.html"
       net.save_graph(filename)
       return filename
   def get_query_embedding(query: str):
       zhipu_ai_client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
       response = zhipu_ai_client.embeddings.create(
           model="embedding-2",
           input=[query],
       )
       return response.data[0].embedding
   def generate_result(query: str, entities, relationships):
       zhipu_ai_client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
       entities_prompt = "\n".join(
           map(lambda e: f'(Name: "{e.name}", Description: "{e.description}")', entities)
       )
       relationships_prompt = "\n".join(
           map(lambda r: f'"{r.relationship_desc}"', relationships)
       )
       response = zhipu_ai_client.chat.completions.create(
           model="glm-4-0520",
           messages=[
               {
                   "role": "system",
                   "content": "Please carefully think the user's "
                   + "question and ONLY use the content below to generate answer:\n"
                   + f"Entities: {entities_prompt}, Relationships: {relationships_prompt}",
               },
               {"role": "user", "content": query},
           ],
       )
       return response.choices[0].message.content
   def get_db_url():
       return URL(
           drivername="mysql+pymysql",
           username="改成你的用户",
           password="改成你的密码",
           host="gateway01.us-west-2.prod.aws.tidbcloud.com",
           port=4000,
           database="chatdb",
           query={"ssl_verify_cert": True, "ssl_verify_identity": True},
       )
   engine = create_engine(get_db_url(), pool_recycle=300)
   Base = declarative_base()
   class DatabaseEntity(Base):
       id = Column(Integer, primary_key=True)
       name = Column(String(512))
       description = Column(Text)
       description_vec = Column(VectorType(1024), comment="HNSW(distance=cosine)")
       __tablename__ = "entities"
   class DatabaseRelationship(Base):
       id = Column(Integer, primary_key=True)
       source_entity_id = Column(Integer, ForeignKey("entities.id"))
       target_entity_id = Column(Integer, ForeignKey("entities.id"))
       relationship_desc = Column(Text)
       source_entity = relationship("DatabaseEntity", foreign_keys=[source_entity_id])
       target_entity = relationship("DatabaseEntity", foreign_keys=[target_entity_id])
       __tablename__ = "relationships"
   def clean_knowledge_graph(kg: KnowledgeGraph) -> KnowledgeGraph:
       entity_name_set = set(map(lambda e: e.name, kg.entities))
       kg.relationships = list(
           filter(
               lambda r: r.source_entity in entity_name_set
               and r.target_entity in entity_name_set,
               kg.relationships,
           )
       )
       return kg
   def save_knowledge_graph(kg: KnowledgeGraph):
       data_entities = list(
           map(
               lambda e: DatabaseEntity(
                   name=e.name,
                   description=e.description,
                   description_vec=get_query_embedding(e.description),
               ),
               kg.entities,
           )
       )
       with Session(engine) as session:
           session.add_all(data_entities)
           session.flush()
           entity_id_map = dict(map(lambda e: (e.name, e.id), data_entities))
           print(entity_id_map)
           data_relationships = list(
               map(
                   lambda r: DatabaseRelationship(
                       source_entity_id=entity_id_map[r.source_entity],
                       target_entity_id=entity_id_map[r.target_entity],
                       relationship_desc=r.relationship_desc,
                   ),
                   kg.relationships,
               )
           )
           session.add_all(data_relationships)
           session.commit()
   def retrieve_entities_relationships(question_embedding) -> (List[DatabaseEntity], List[DatabaseRelationship]):
       with Session(engine) as session:
           entity = (
               session.query(DatabaseEntity)
               .order_by(
                   DatabaseEntity.description_vec.cosine_distance(question_embedding)
               )
               .limit(1)
               .first()
           )
           entities = {entity.id: entity}
           relationships = (
               session.query(DatabaseRelationship)
               .options(
                   joinedload(DatabaseRelationship.source_entity),
                   joinedload(DatabaseRelationship.target_entity),
               )
               .filter(
                   or_(
                       DatabaseRelationship.source_entity == entity,
                       DatabaseRelationship.target_entity == entity,
                   )
               )
           )
           for r in relationships:
               entities.update(
                   {
                       r.source_entity.id: r.source_entity,
                       r.target_entity.id: r.target_entity,
                   }
               )
           return entities.values(), relationships
   extractor = Extractor()
   Base.metadata.create_all(engine)
   zhipu_ai_client = dspy.OpenAI(model="glm-4-0520", api_base="https://open.bigmodel.cn/api/paas/v4/", api_key=os.getenv("ZHIPUAI_API_KEY"), model_type="chat", max_tokens=4096)
   dspy.settings.configure(lm=zhipu_ai_client)
   wiki = WikipediaLoader(query="TiDB").load()
   pred = extractor(text=wiki[0].page_content)
   knowledge_graph = clean_knowledge_graph(pred.knowledge)
   interactive_graph(knowledge_graph)
   save_knowledge_graph(knowledge_graph)
   EOF
   ls -l build-graph.py
$ cat > test-graph.py <<'EOF'
   from zhipuai import ZhipuAI
   import os
   import click
   from sqlalchemy import (
       Column,
       Integer,
       String,
       Text,
       ForeignKey,
       URL,
       create_engine,
       or_,
   )
   from typing import Mapping, Any, Optional, List
   from sqlalchemy.orm import relationship, Session, declarative_base, joinedload
   from tidb_vector.sqlalchemy import VectorType
   def get_db_url():
       return URL(
           drivername="mysql+pymysql",
           username="改成你的用户",
           password="改成你的密码",
           host="gateway01.us-west-2.prod.aws.tidbcloud.com",
           port=4000,
           database="chatdb",
           query={"ssl_verify_cert": True, "ssl_verify_identity": True},
       )
   engine = create_engine(get_db_url(), pool_recycle=300)
   Base = declarative_base()
   class DatabaseEntity(Base):
       id = Column(Integer, primary_key=True)
       name = Column(String(512))
       description = Column(Text)
       description_vec = Column(VectorType(1024), comment="HNSW(distance=cosine)")
       __tablename__ = "entities"
   class DatabaseRelationship(Base):
       id = Column(Integer, primary_key=True)
       source_entity_id = Column(Integer, ForeignKey("entities.id"))
       target_entity_id = Column(Integer, ForeignKey("entities.id"))
       relationship_desc = Column(Text)
       source_entity = relationship("DatabaseEntity", foreign_keys=[source_entity_id])
       target_entity = relationship("DatabaseEntity", foreign_keys=[target_entity_id])
       __tablename__ = "relationships"
   def get_query_embedding(query: str):
       zhipu_ai_client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
       response = zhipu_ai_client.embeddings.create(
           model="embedding-2",
           input=[query],
       )
       return response.data[0].embedding
   def retrieve_entities_relationships(
       question_embedding,
   ) -> (List[DatabaseEntity], List[DatabaseRelationship]):
       with Session(engine) as session:
           entity = (
               session.query(DatabaseEntity)
               .order_by(
                   DatabaseEntity.description_vec.cosine_distance(question_embedding)
               )
               .limit(1)
               .first()
           )
           entities = {entity.id: entity}
           relationships = (
               session.query(DatabaseRelationship)
               .options(
                   joinedload(DatabaseRelationship.source_entity),
                   joinedload(DatabaseRelationship.target_entity),
               )
               .filter(
                   or_(
                       DatabaseRelationship.source_entity == entity,
                       DatabaseRelationship.target_entity == entity,
                   )
               )
           )
           for r in relationships:
               entities.update(
                   {
                       r.source_entity.id: r.source_entity,
                       r.target_entity.id: r.target_entity,
                   }
               )
           return entities.values(), relationships
   def generate_result(query: str, entities, relationships):
       zhipu_ai_client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
       entities_prompt = "\n".join(
           map(lambda e: f'(Name: "{e.name}", Description: "{e.description}")', entities)
       )
       relationships_prompt = "\n".join(
           map(lambda r: f'"{r.relationship_desc}"', relationships)
       )
       response = zhipu_ai_client.chat.completions.create(
           model="glm-4-0520",
           messages=[
               {
                   "role": "system",
                   "content": "Please carefully think the user's "
                   + "question and ONLY use the content below to generate answer:\n"
                   + f"Entities: {entities_prompt}, Relationships: {relationships_prompt}",
               },
               {"role": "user", "content": query},
           ],
       )
       return response.choices[0].message.content
   @click.command()
   def start_chat():
     while True:
         question = click.prompt("Enter your question")
         question_embedding = get_query_embedding(question)
         entities, relationships = retrieve_entities_relationships(question_embedding)
         result = generate_result(question, entities, relationships)
         click.echo(result)
   if __name__ == '__main__':
     start_chat()
   EOF
   ls -l test-graph.py
$ python test-graph.py

目标: 使用知识图谱方法增强 RAG 应用程序,从朴素 RAG 到图 RAG

通过之前的演示,您了解了如何使用 RAG 从外部知识源检索相关信息,使大型语言模型能够回答以前未见过的文档集合中的问题。但是,RAG 在处理针对整个文本语料库的广泛问题时存在不足。使用知识图谱可以将您的 RAG 应用程序从朴素 RAG 升级到图 RAG (Graph RAG),以获取最相关的知识,而不是找到最相似的文本以进行 LLM 即时增强。

在本次实践中,需要支持 glm-4-0520 的 Zhipu AI API KEY。

提示

  • 在本次实践中,你将使用基于 Python 的开源 Graph RAG 方法实现,从 TiDB 数据库 chatdb 中的 Wiki 页面构建有关 TiDB 的知识图谱。
  • 基于 Python 的实现使用 TiDB 中的两个表来表示图,即 entities 和 relationships。知识文档及其向量嵌入存储在 entities 表中。
  • 基于嵌入的向量搜索作为图的入口点,而 JOIN 操作在图中遍历以检索最相关的知识节点。
2 个赞

:clap:大佬果然说干就干

1 个赞

ticool了 :call_me_hand: :call_me_hand: :call_me_hand:

1 个赞

围观围观

学习学习

好强啊,感觉很有帮助。准备尝试一下

强强强,为什么在tidb.ai是英文?

拜读了老师的文章。

国内有ticloud么?
你的py脚本提前写好的 :face_with_open_eyes_and_hand_over_mouth:

的确很厉害。

有哇 P社跟阿里云、西云数据(aws宁夏)均有合作

学习,太强了

点个赞,这个大佬厉害的。

牛啊,这什么原理啊