Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deployment/docker/datamate/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ services:
- log_level=DEBUG
- pgsql_password=${DB_PASSWORD:-password}
- datamate_jwt_enable=${DATAMATE_JWT_ENABLE:-false}
- milvus_uri=${MILVUS_URI:-http://milvus:19530}
volumes:
- dataset_volume:/dataset
- flow_volume:/flow
Expand Down
2 changes: 2 additions & 0 deletions deployment/helm/datamate/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ backend-python:
key: DB_PASSWORD
- name: datamate_jwt_enable
value: *DATAMATE_JWT_ENABLE
- name: milvus_uri
value: "http://milvus:19530"
volumes:
- *datasetVolume
- *flowVolume
Expand Down
7 changes: 7 additions & 0 deletions runtime/datamate-python/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@ LABEL_STUDIO_BASE_URL=http://localhost:30001
LABEL_STUDIO_USER_TOKEN="demo_dev_token"

DATAMATE_JWT_ENABLE=false

# Milvus settings (Vector Database)
# Development: use localhost
MILVUS_URI=http://localhost:19530
MILVUS_TOKEN=
# Production: use service name
# MILVUS_URI=http://milvus:19530
2 changes: 1 addition & 1 deletion runtime/datamate-python/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def build_database_url(self):
datamate_jwt_enable: bool = False

# Milvus 配置
milvus_uri: str = "http://localhost:19530"
milvus_uri: str = "http://milvus:19530"
milvus_token: str = ""

# 文件存储配置(共享文件系统)
Expand Down
52 changes: 38 additions & 14 deletions runtime/datamate-python/app/module/rag/interface/rag_interface.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,58 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.exception import SuccessResponse
from app.db.session import get_db
from app.module.rag.service.rag_service import RAGService
from app.module.rag.service.knowledge_base_service import KnowledgeBaseService
from ..schema.request import QueryRequest

router = APIRouter(prefix="/rag", tags=["知识图谱 RAG"])

@router.post("/{knowledge_base_id}/process")
async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()):
async def process_knowledge_base(
knowledge_base_id: str,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
):
"""
处理知识库中所有未处理的文件(LightRAG)

接口路径调整:
- 旧路径: /rag/process/{id}
- 新路径: /rag/graph/{id}/process
手动触发知识库文件处理(已废弃,文件处理在添加时自动触发)

此接口保留用于向后兼容或手动重新处理文件
"""
await rag_service.init_graph_rag(knowledge_base_id)
service = KnowledgeBaseService(db)
kb = await service.kb_repo.get_by_id(knowledge_base_id)

if not kb:
return SuccessResponse(
data=None,
message="知识库不存在"
)

files = await service.file_repo.get_unprocessed_files(knowledge_base_id)
if not files:
return SuccessResponse(
data=None,
message="没有待处理的文件"
)

service.file_processor.start_background_processing(
background_tasks=background_tasks,
knowledge_base_id=str(kb.id),
knowledge_base_name=str(kb.name),
knowledge_base_type=str(kb.type or "DOCUMENT"),
request_data={"knowledge_base_id": knowledge_base_id, "files": []},
)

return SuccessResponse(
data=None,
message="Processing started for knowledge base."
message=f"已开始处理 {len(files)} 个文件"
)

@router.post("/query")
async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()):
"""
使用给定的查询文本和知识库 ID 查询知识图谱(LightRAG)

接口路径调整:
- 旧路径: /rag/query
- 新路径: /rag/graph/query
使用给定的查询文本和知识库 ID 查询知识图谱(LightRAG)或向量检索
"""
result = await rag_service.query_rag(payload.query, payload.knowledge_base_id)
return SuccessResponse(data=result)
126 changes: 117 additions & 9 deletions runtime/datamate-python/app/module/rag/service/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
文件处理器

负责文件的后台 ETL 处理:加载、分块、向量化、存储。
支持两种知识库类型:DOCUMENT(向量检索)和 GRAPH(知识图谱)。
使用全局 WorkerPool 实现并发控制,最多 10 个文件并行处理。
"""
import logging
Expand All @@ -13,8 +14,7 @@
from fastapi import BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.exception import BusinessError, ErrorCodes
from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus
from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus, RagType
from app.db.session import AsyncSessionLocal
from app.module.rag.infra.document import ingest_file_to_chunks, DocumentChunk
from app.module.rag.infra.embeddings import EmbeddingFactory
Expand Down Expand Up @@ -42,6 +42,7 @@ def start_background_processing(
background_tasks: BackgroundTasks,
knowledge_base_id: str,
knowledge_base_name: str,
knowledge_base_type: str,
request_data: dict,
) -> None:
"""启动后台文件处理
Expand All @@ -50,20 +51,23 @@ def start_background_processing(
background_tasks: FastAPI BackgroundTasks
knowledge_base_id: 知识库 ID
knowledge_base_name: 知识库名称
knowledge_base_type: 知识库类型 (DOCUMENT/GRAPH)
request_data: 添加文件请求数据
"""
background_tasks.add_task(
self._process_files_background,
knowledge_base_id,
knowledge_base_name,
knowledge_base_type,
request_data,
)
logger.info("已注册后台任务: 知识库=%s", knowledge_base_name)
logger.info("已注册后台任务: 知识库=%s, 类型=%s", knowledge_base_name, knowledge_base_type)

async def _process_files_background(
self,
knowledge_base_id: str,
knowledge_base_name: str,
knowledge_base_type: str,
request_data: dict,
) -> None:
"""后台处理文件(使用新的数据库 session)"""
Expand All @@ -84,10 +88,12 @@ async def _process_files_background(
logger.info("知识库 %s 没有待处理的文件", knowledge_base_name)
return

logger.info("开始处理 %d 个文件,知识库: %s", len(files), knowledge_base_name)
logger.info("开始处理 %d 个文件,知识库: %s, 类型: %s", len(files), knowledge_base_name, knowledge_base_type)

# 并发处理文件(最多 10 个并行)
await self._process_files_concurrently(db, files, knowledge_base, request)
if knowledge_base_type == RagType.GRAPH.value:
await self._process_graph_files(db, files, knowledge_base)
else:
await self._process_document_files(db, files, knowledge_base, request)

logger.info("知识库 %s 文件处理完成", knowledge_base_name)

Expand All @@ -96,14 +102,14 @@ async def _process_files_background(
finally:
await db.close()

async def _process_files_concurrently(
async def _process_document_files(
self,
db: AsyncSession,
files: List[RagFile],
knowledge_base: KnowledgeBase,
request: AddFilesReq,
) -> None:
"""并发处理多个文件(最多10个并行)"""
"""处理 DOCUMENT 类型文件(向量化)"""
import asyncio

async def process_with_semaphore(rag_file: RagFile):
Expand All @@ -117,6 +123,109 @@ async def process_with_semaphore(rag_file: RagFile):
tasks = [process_with_semaphore(f) for f in files]
await asyncio.gather(*tasks, return_exceptions=True)

async def _process_graph_files(
self,
db: AsyncSession,
files: List[RagFile],
knowledge_base: KnowledgeBase,
) -> None:
"""处理 GRAPH 类型文件(知识图谱)"""
from app.module.shared.llm import LLMFactory
from app.module.shared.common.document_loaders import load_documents

try:
rag_instance = await self._initialize_graph_rag(db, knowledge_base, LLMFactory)

for rag_file in files:
await self._process_single_graph_file(db, rag_file, rag_instance, load_documents)

except Exception as e:
logger.exception("初始化知识图谱失败: %s", e)
for rag_file in files:
file_repo = RagFileRepository(db)
await self._mark_failed(db, file_repo, str(rag_file.id), f"知识图谱初始化失败: {str(e)}") # type: ignore

async def _initialize_graph_rag(self, db: AsyncSession, knowledge_base: KnowledgeBase, LLMFactory):
"""初始化 GraphRAG 实例"""
from .graph_rag import (
DEFAULT_WORKING_DIR,
build_embedding_func,
build_llm_model_func,
initialize_rag,
)

embedding_entity = await get_model_by_id(db, str(knowledge_base.embedding_model)) # type: ignore
if not embedding_entity:
raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}")

chat_entity = await get_model_by_id(db, str(knowledge_base.chat_model)) # type: ignore
if not chat_entity:
raise ValueError(f"聊天模型不存在: {knowledge_base.chat_model}")

llm_callable = await build_llm_model_func(
str(chat_entity.model_name), str(chat_entity.base_url), str(chat_entity.api_key) # type: ignore
)
embedding_callable = await build_embedding_func(
str(embedding_entity.model_name),
str(embedding_entity.base_url),
str(embedding_entity.api_key),
embedding_dim=LLMFactory.get_embedding_dimension(
str(embedding_entity.model_name), str(embedding_entity.base_url), str(embedding_entity.api_key) # type: ignore
),
)

kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, str(knowledge_base.name)) # type: ignore
return await initialize_rag(llm_callable, embedding_callable, kb_working_dir)

async def _process_single_graph_file(
self,
db: AsyncSession,
rag_file: RagFile,
rag_instance,
load_documents,
) -> None:
"""处理单个 GRAPH 类型文件"""
file_repo = RagFileRepository(db)

try:
await self._update_status(db, file_repo, str(rag_file.id), FileStatus.PROCESSING, 10) # type: ignore
await db.commit()

dataset_file = await self._get_dataset_file(db, str(rag_file.file_id)) # type: ignore
if not dataset_file:
await self._mark_failed(db, file_repo, str(rag_file.id), "数据集文件不存在") # type: ignore
return

documents = load_documents(str(dataset_file.file_path)) # type: ignore
if not documents:
await self._mark_failed(db, file_repo, str(rag_file.id), "文件解析失败,未生成文档") # type: ignore
return

await self._update_progress(db, file_repo, str(rag_file.id), 30) # type: ignore
await db.commit()

for idx, doc in enumerate(documents):
logger.info("插入文档到知识图谱: %s, 进度: %d/%d", str(rag_file.file_name), idx + 1, len(documents)) # type: ignore
await rag_instance.ainsert(input=doc.page_content, file_paths=[str(dataset_file.file_path)]) # type: ignore

await self._mark_success(db, file_repo, str(rag_file.id), len(documents)) # type: ignore
logger.info("文件 %s 知识图谱处理完成", str(rag_file.file_name))

except Exception as e:
logger.exception("文件 %s 知识图谱处理失败: %s", str(rag_file.file_name), e) # type: ignore
await self._mark_failed(db, file_repo, str(rag_file.id), str(e)) # type: ignore

async def _get_dataset_file(self, db: AsyncSession, file_id: str): # type: ignore
"""获取数据集文件"""
from sqlalchemy import select
from app.db.models.dataset_management import DatasetFiles


result = await db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
return result.scalar_one_or_none()

async def _process_single_file(
self,
db: AsyncSession,
Expand Down Expand Up @@ -592,7 +701,6 @@ def _clean_text(self, text: str) -> str:
Returns:
清理后的文本,如果无效则返回空字符串
"""
import re

if not text or not isinstance(text, str):
return ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,27 +216,24 @@ async def add_files(
Returns:
包含成功和跳过文件数量的字典
"""
# 1. 验证知识库存在
knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id)
if not knowledge_base:
raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND)

# 2. 创建文件记录
rag_files, skipped_file_ids = await self._create_rag_files(request)

# 3. 立即提交事务,接口返回
await self.db.commit()

# 4. 注册后台任务(异步处理)
if rag_files and background_tasks:
kb_type = knowledge_base.type if knowledge_base.type else "DOCUMENT"
self.file_processor.start_background_processing(
background_tasks=background_tasks,
knowledge_base_id=knowledge_base.id,
knowledge_base_name=knowledge_base.name,
knowledge_base_id=str(knowledge_base.id),
knowledge_base_name=str(knowledge_base.name),
knowledge_base_type=str(kb_type),
request_data=request.model_dump(),
)

# 5. 返回结果
return {
"success_count": len(rag_files),
"skipped_count": len(skipped_file_ids),
Expand Down
Loading
Loading