diff --git a/deployment/docker/datamate/docker-compose.yml b/deployment/docker/datamate/docker-compose.yml index e2d97e41..1bb368b0 100644 --- a/deployment/docker/datamate/docker-compose.yml +++ b/deployment/docker/datamate/docker-compose.yml @@ -30,6 +30,7 @@ services: - log_level=DEBUG - pgsql_password=${DB_PASSWORD:-password} - datamate_jwt_enable=${DATAMATE_JWT_ENABLE:-false} + - milvus_uri=${MILVUS_URI:-http://milvus:19530} volumes: - dataset_volume:/dataset - flow_volume:/flow diff --git a/deployment/helm/datamate/values.yaml b/deployment/helm/datamate/values.yaml index a95bdf34..fb5044ef 100644 --- a/deployment/helm/datamate/values.yaml +++ b/deployment/helm/datamate/values.yaml @@ -140,6 +140,8 @@ backend-python: key: DB_PASSWORD - name: datamate_jwt_enable value: *DATAMATE_JWT_ENABLE + - name: milvus_uri + value: "http://milvus:19530" volumes: - *datasetVolume - *flowVolume diff --git a/runtime/datamate-python/.env.example b/runtime/datamate-python/.env.example index d1883839..2f95c434 100644 --- a/runtime/datamate-python/.env.example +++ b/runtime/datamate-python/.env.example @@ -20,3 +20,10 @@ LABEL_STUDIO_BASE_URL=http://localhost:30001 LABEL_STUDIO_USER_TOKEN="demo_dev_token" DATAMATE_JWT_ENABLE=false + +# Milvus settings (Vector Database) +# Development: use localhost +MILVUS_URI=http://localhost:19530 +MILVUS_TOKEN= +# Production: use service name +# MILVUS_URI=http://milvus:19530 diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index 49dd3320..77139ab6 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -78,7 +78,7 @@ def build_database_url(self): datamate_jwt_enable: bool = False # Milvus 配置 - milvus_uri: str = "http://localhost:19530" + milvus_uri: str = "http://milvus:19530" milvus_token: str = "" # 文件存储配置(共享文件系统) diff --git a/runtime/datamate-python/app/module/rag/interface/rag_interface.py b/runtime/datamate-python/app/module/rag/interface/rag_interface.py index 910954a3..57b9f9b9 100644 --- a/runtime/datamate-python/app/module/rag/interface/rag_interface.py +++ b/runtime/datamate-python/app/module/rag/interface/rag_interface.py @@ -1,34 +1,58 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, BackgroundTasks +from sqlalchemy.ext.asyncio import AsyncSession from app.core.exception import SuccessResponse +from app.db.session import get_db from app.module.rag.service.rag_service import RAGService +from app.module.rag.service.knowledge_base_service import KnowledgeBaseService from ..schema.request import QueryRequest router = APIRouter(prefix="/rag", tags=["知识图谱 RAG"]) @router.post("/{knowledge_base_id}/process") -async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()): +async def process_knowledge_base( + knowledge_base_id: str, + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), +): """ - 处理知识库中所有未处理的文件(LightRAG) - - 接口路径调整: - - 旧路径: /rag/process/{id} - - 新路径: /rag/graph/{id}/process + 手动触发知识库文件处理(已废弃,文件处理在添加时自动触发) + + 此接口保留用于向后兼容或手动重新处理文件 """ - await rag_service.init_graph_rag(knowledge_base_id) + service = KnowledgeBaseService(db) + kb = await service.kb_repo.get_by_id(knowledge_base_id) + + if not kb: + return SuccessResponse( + data=None, + message="知识库不存在" + ) + + files = await service.file_repo.get_unprocessed_files(knowledge_base_id) + if not files: + return SuccessResponse( + data=None, + message="没有待处理的文件" + ) + + service.file_processor.start_background_processing( + background_tasks=background_tasks, + knowledge_base_id=str(kb.id), + knowledge_base_name=str(kb.name), + knowledge_base_type=str(kb.type or "DOCUMENT"), + request_data={"knowledge_base_id": knowledge_base_id, "files": []}, + ) + return SuccessResponse( data=None, - message="Processing started for knowledge base." + message=f"已开始处理 {len(files)} 个文件" ) @router.post("/query") async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()): """ - 使用给定的查询文本和知识库 ID 查询知识图谱(LightRAG) - - 接口路径调整: - - 旧路径: /rag/query - - 新路径: /rag/graph/query + 使用给定的查询文本和知识库 ID 查询知识图谱(LightRAG)或向量检索 """ result = await rag_service.query_rag(payload.query, payload.knowledge_base_id) return SuccessResponse(data=result) diff --git a/runtime/datamate-python/app/module/rag/service/file_processor.py b/runtime/datamate-python/app/module/rag/service/file_processor.py index d542862b..c64c1832 100644 --- a/runtime/datamate-python/app/module/rag/service/file_processor.py +++ b/runtime/datamate-python/app/module/rag/service/file_processor.py @@ -2,6 +2,7 @@ 文件处理器 负责文件的后台 ETL 处理:加载、分块、向量化、存储。 +支持两种知识库类型:DOCUMENT(向量检索)和 GRAPH(知识图谱)。 使用全局 WorkerPool 实现并发控制,最多 10 个文件并行处理。 """ import logging @@ -13,8 +14,7 @@ from fastapi import BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession -from app.core.exception import BusinessError, ErrorCodes -from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus +from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus, RagType from app.db.session import AsyncSessionLocal from app.module.rag.infra.document import ingest_file_to_chunks, DocumentChunk from app.module.rag.infra.embeddings import EmbeddingFactory @@ -42,6 +42,7 @@ def start_background_processing( background_tasks: BackgroundTasks, knowledge_base_id: str, knowledge_base_name: str, + knowledge_base_type: str, request_data: dict, ) -> None: """启动后台文件处理 @@ -50,20 +51,23 @@ def start_background_processing( background_tasks: FastAPI BackgroundTasks knowledge_base_id: 知识库 ID knowledge_base_name: 知识库名称 + knowledge_base_type: 知识库类型 (DOCUMENT/GRAPH) request_data: 添加文件请求数据 """ background_tasks.add_task( self._process_files_background, knowledge_base_id, knowledge_base_name, + knowledge_base_type, request_data, ) - logger.info("已注册后台任务: 知识库=%s", knowledge_base_name) + logger.info("已注册后台任务: 知识库=%s, 类型=%s", knowledge_base_name, knowledge_base_type) async def _process_files_background( self, knowledge_base_id: str, knowledge_base_name: str, + knowledge_base_type: str, request_data: dict, ) -> None: """后台处理文件(使用新的数据库 session)""" @@ -84,10 +88,12 @@ async def _process_files_background( logger.info("知识库 %s 没有待处理的文件", knowledge_base_name) return - logger.info("开始处理 %d 个文件,知识库: %s", len(files), knowledge_base_name) + logger.info("开始处理 %d 个文件,知识库: %s, 类型: %s", len(files), knowledge_base_name, knowledge_base_type) - # 并发处理文件(最多 10 个并行) - await self._process_files_concurrently(db, files, knowledge_base, request) + if knowledge_base_type == RagType.GRAPH.value: + await self._process_graph_files(db, files, knowledge_base) + else: + await self._process_document_files(db, files, knowledge_base, request) logger.info("知识库 %s 文件处理完成", knowledge_base_name) @@ -96,14 +102,14 @@ async def _process_files_background( finally: await db.close() - async def _process_files_concurrently( + async def _process_document_files( self, db: AsyncSession, files: List[RagFile], knowledge_base: KnowledgeBase, request: AddFilesReq, ) -> None: - """并发处理多个文件(最多10个并行)""" + """处理 DOCUMENT 类型文件(向量化)""" import asyncio async def process_with_semaphore(rag_file: RagFile): @@ -117,6 +123,109 @@ async def process_with_semaphore(rag_file: RagFile): tasks = [process_with_semaphore(f) for f in files] await asyncio.gather(*tasks, return_exceptions=True) + async def _process_graph_files( + self, + db: AsyncSession, + files: List[RagFile], + knowledge_base: KnowledgeBase, + ) -> None: + """处理 GRAPH 类型文件(知识图谱)""" + from app.module.shared.llm import LLMFactory + from app.module.shared.common.document_loaders import load_documents + + try: + rag_instance = await self._initialize_graph_rag(db, knowledge_base, LLMFactory) + + for rag_file in files: + await self._process_single_graph_file(db, rag_file, rag_instance, load_documents) + + except Exception as e: + logger.exception("初始化知识图谱失败: %s", e) + for rag_file in files: + file_repo = RagFileRepository(db) + await self._mark_failed(db, file_repo, str(rag_file.id), f"知识图谱初始化失败: {str(e)}") # type: ignore + + async def _initialize_graph_rag(self, db: AsyncSession, knowledge_base: KnowledgeBase, LLMFactory): + """初始化 GraphRAG 实例""" + from .graph_rag import ( + DEFAULT_WORKING_DIR, + build_embedding_func, + build_llm_model_func, + initialize_rag, + ) + + embedding_entity = await get_model_by_id(db, str(knowledge_base.embedding_model)) # type: ignore + if not embedding_entity: + raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") + + chat_entity = await get_model_by_id(db, str(knowledge_base.chat_model)) # type: ignore + if not chat_entity: + raise ValueError(f"聊天模型不存在: {knowledge_base.chat_model}") + + llm_callable = await build_llm_model_func( + str(chat_entity.model_name), str(chat_entity.base_url), str(chat_entity.api_key) # type: ignore + ) + embedding_callable = await build_embedding_func( + str(embedding_entity.model_name), + str(embedding_entity.base_url), + str(embedding_entity.api_key), + embedding_dim=LLMFactory.get_embedding_dimension( + str(embedding_entity.model_name), str(embedding_entity.base_url), str(embedding_entity.api_key) # type: ignore + ), + ) + + kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, str(knowledge_base.name)) # type: ignore + return await initialize_rag(llm_callable, embedding_callable, kb_working_dir) + + async def _process_single_graph_file( + self, + db: AsyncSession, + rag_file: RagFile, + rag_instance, + load_documents, + ) -> None: + """处理单个 GRAPH 类型文件""" + file_repo = RagFileRepository(db) + + try: + await self._update_status(db, file_repo, str(rag_file.id), FileStatus.PROCESSING, 10) # type: ignore + await db.commit() + + dataset_file = await self._get_dataset_file(db, str(rag_file.file_id)) # type: ignore + if not dataset_file: + await self._mark_failed(db, file_repo, str(rag_file.id), "数据集文件不存在") # type: ignore + return + + documents = load_documents(str(dataset_file.file_path)) # type: ignore + if not documents: + await self._mark_failed(db, file_repo, str(rag_file.id), "文件解析失败,未生成文档") # type: ignore + return + + await self._update_progress(db, file_repo, str(rag_file.id), 30) # type: ignore + await db.commit() + + for idx, doc in enumerate(documents): + logger.info("插入文档到知识图谱: %s, 进度: %d/%d", str(rag_file.file_name), idx + 1, len(documents)) # type: ignore + await rag_instance.ainsert(input=doc.page_content, file_paths=[str(dataset_file.file_path)]) # type: ignore + + await self._mark_success(db, file_repo, str(rag_file.id), len(documents)) # type: ignore + logger.info("文件 %s 知识图谱处理完成", str(rag_file.file_name)) + + except Exception as e: + logger.exception("文件 %s 知识图谱处理失败: %s", str(rag_file.file_name), e) # type: ignore + await self._mark_failed(db, file_repo, str(rag_file.id), str(e)) # type: ignore + + async def _get_dataset_file(self, db: AsyncSession, file_id: str): # type: ignore + """获取数据集文件""" + from sqlalchemy import select + from app.db.models.dataset_management import DatasetFiles + + + result = await db.execute( + select(DatasetFiles).where(DatasetFiles.id == file_id) + ) + return result.scalar_one_or_none() + async def _process_single_file( self, db: AsyncSession, @@ -592,7 +701,6 @@ def _clean_text(self, text: str) -> str: Returns: 清理后的文本,如果无效则返回空字符串 """ - import re if not text or not isinstance(text, str): return "" diff --git a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py index 27b1eff9..ef2a2625 100644 --- a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py +++ b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py @@ -216,27 +216,24 @@ async def add_files( Returns: 包含成功和跳过文件数量的字典 """ - # 1. 验证知识库存在 knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - # 2. 创建文件记录 rag_files, skipped_file_ids = await self._create_rag_files(request) - # 3. 立即提交事务,接口返回 await self.db.commit() - # 4. 注册后台任务(异步处理) if rag_files and background_tasks: + kb_type = knowledge_base.type if knowledge_base.type else "DOCUMENT" self.file_processor.start_background_processing( background_tasks=background_tasks, - knowledge_base_id=knowledge_base.id, - knowledge_base_name=knowledge_base.name, + knowledge_base_id=str(knowledge_base.id), + knowledge_base_name=str(knowledge_base.name), + knowledge_base_type=str(kb_type), request_data=request.model_dump(), ) - # 5. 返回结果 return { "success_count": len(rag_files), "skipped_count": len(skipped_file_ids), diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 2773f66a..423f3753 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -1,6 +1,5 @@ import os -import asyncio -from typing import Optional, Sequence +from typing import Optional from fastapi import Depends from langchain_core.prompts import ChatPromptTemplate @@ -9,24 +8,20 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger -from app.db.models.dataset_management import DatasetFiles -from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus -from app.db.session import get_db, AsyncSessionLocal +from app.db.models.knowledge_gen import KnowledgeBase +from app.db.session import get_db from app.module.rag.infra.embeddings import EmbeddingFactory from app.module.rag.infra.vectorstore import VectorStoreFactory -from app.module.shared.common.document_loaders import load_documents from .graph_rag import ( DEFAULT_WORKING_DIR, build_embedding_func, build_llm_model_func, initialize_rag, ) -from app.module.shared.llm import LLMFactory from ...system.service.common_service import get_model_by_id logger = get_logger(__name__) -# DOCUMENT 类型 RAG 使用 LangChain 检索链 RAG_DOCUMENT_PROMPT = ChatPromptTemplate.from_messages([ ("system", "根据以下上下文回答问题。如果上下文中没有相关信息,请说明。\n\n上下文:\n{context}"), ("human", "{input}"), @@ -34,25 +29,12 @@ class RAGService: - def __init__( - self, - db: AsyncSession = Depends(get_db), - - ): + def __init__(self, db: AsyncSession = Depends(get_db)): self.db = db - self.background_tasks = None self.rag = None - async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFile]: - result = await self.db.execute( - select(RagFile).where( - RagFile.knowledge_base_id == knowledge_base_id, - RagFile.status != FileStatus.PROCESSED, - ) - ) - return result.scalars().all() - async def init_graph_rag(self, knowledge_base_id: str): + """初始化知识图谱 RAG 实例(用于查询)""" kb = await self._get_knowledge_base(knowledge_base_id) embedding_model = await self._get_models(kb.embedding_model) chat_model = await self._get_models(kb.chat_model) @@ -60,6 +42,7 @@ async def init_graph_rag(self, knowledge_base_id: str): llm_callable = await build_llm_model_func( chat_model.model_name, chat_model.base_url, chat_model.api_key ) + from app.module.shared.llm import LLMFactory embedding_callable = await build_embedding_func( embedding_model.model_name, embedding_model.base_url, @@ -72,61 +55,8 @@ async def init_graph_rag(self, knowledge_base_id: str): kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name) self.rag = await initialize_rag(llm_callable, embedding_callable, kb_working_dir) - await self._schedule_file_processing(knowledge_base_id) - return {"status": "initialized", "knowledge_base_id": knowledge_base_id} - async def _schedule_file_processing(self, knowledge_base_id: str): - if self.background_tasks is not None: - self.background_tasks.add_task(self._process_with_fresh_session, knowledge_base_id, self.rag) - else: - asyncio.create_task(self._process_with_fresh_session(knowledge_base_id, self.rag)) - - @staticmethod - async def _process_with_fresh_session(knowledge_base_id: str, rag_instance): - async with AsyncSessionLocal() as session: - service = RAGService(session) - service.rag = rag_instance - await service._process_pending_files(knowledge_base_id) - - async def _process_pending_files(self, knowledge_base_id: str): - rag_files = await self.get_unprocessed_files(knowledge_base_id) - if not rag_files: - logger.info(f"No pending files to process for knowledge base {knowledge_base_id}") - return - - for rag_file in rag_files: - await self._process_single_file(rag_file) - - async def _process_single_file(self, rag_file: RagFile): - try: - await self._mark_file_status(rag_file, FileStatus.PROCESSING) - dataset_file = await self._get_dataset_file(rag_file.file_id) - documents = load_documents(dataset_file.file_path) - for doc in documents: - logger.info(f"Processing document {doc.page_content}") - await self.rag.ainsert(input=doc.page_content, file_paths=[dataset_file.file_path]) - except Exception: # noqa: BLE001 - logger.exception("Failed to process rag file %s", rag_file.id) - await self._mark_file_status(rag_file, FileStatus.PROCESS_FAILED) - return - await self._mark_file_status(rag_file, FileStatus.PROCESSED) - - async def _get_dataset_file(self, file_id: str) -> DatasetFiles: - result = await self.db.execute( - select(DatasetFiles).where(DatasetFiles.id == file_id) - ) - dataset_file = result.scalars().first() - if not dataset_file: - raise ValueError(f"Dataset file with ID {file_id} not found.") - return dataset_file - - async def _mark_file_status(self, rag_file: RagFile, status: FileStatus): - rag_file.status = status - self.db.add(rag_file) - await self.db.commit() - await self.db.refresh(rag_file) - async def _get_knowledge_base(self, knowledge_base_id: str): result = await self.db.execute( select(KnowledgeBase).where(KnowledgeBase.id == knowledge_base_id)