这是用于管理从患者症状输入到最终治疗建议的完整流程。它通过状态图(StateGraph)组织多个智能体(Agent)的调用,并利用数据库记录状态、通过RabbitMQ处理异步化验结果。下面将分模块详细解释:


1. MedicalWorkflowState 类

该类表示工作流中每一步的状态数据,实现了状态的持久化和恢复。

属性

def __init__(self):
    self.inquiry_record_id: Optional[int] = None
    self.patient_id: Optional[int] = None
    self.doctor_id: Optional[int] = None
    self.symptom_description: Optional[str] = None
    self.symptom_analysis_result: Optional[Dict[str, Any]] = None
    self.lab_tests_needed: List[str] = []
    self.lab_results: List[Dict[str, Any]] = []
    self.diagnosis_result: Optional[Dict[str, Any]] = None
    self.treatment_advice: Optional[Dict[str, Any]] = None
    self.current_status: InquiryStatus = InquiryStatus.INITIAL
    self.error_message: Optional[str] = None
    self.workflow_completed: bool = False
  • inquiry_record_id:数据库中的问询记录主键。

  • patient_id / doctor_id:患者和医生的ID。

  • symptom_description:原始症状描述。

  • symptom_analysis_result:症状分析智能体返回的结构化结果。

  • lab_tests_needed:症状分析后建议的化验项目列表。

  • lab_results:从化验系统返回的实际结果列表(每个结果是字典)。

  • diagnosis_result:诊断分析结果。

  • treatment_advice:治疗建议结果。

  • current_status:当前工作流所处的阶段(枚举值,如INITIALSYMPTOM_ANALYSISWAITING_LAB等)。

  • error_message:出错时的错误信息。

  • workflow_completed:标志整个工作流是否完成。

方法

  • to_dict():将对象转换为字典,便于序列化存储(例如保存到检查点)。

  • from_dict(cls, data):类方法,从字典恢复状态对象。


2. MedicalWorkflow 类

核心工作流引擎,负责构建LangGraph图、定义节点函数、启动和恢复工作流。

初始化

def __init__(self):
    self.memory = MemorySaver()
    self.graph = self._build_workflow_graph()
  • memory:使用MemorySaver在内存中保存每个执行步骤的状态,使得工作流可以中断和恢复(如等待化验结果时)。

  • graph:调用_build_workflow_graph()构建并编译状态图。

构建工作流图 (_build_workflow_graph)

workflow = StateGraph(MedicalWorkflowState)

创建一个状态图,状态类型为MedicalWorkflowState

添加节点

每个节点是一个函数,接收当前状态,执行某些操作,返回更新后的状态。

  • "create_inquiry" → _create_inquiry_record

  • "symptom_analysis" → _perform_symptom_analysis

  • "check_lab_needed" → _check_lab_tests_needed

  • "wait_for_lab" → _wait_for_lab_results

  • "diagnosis_analysis" → _perform_diagnosis_analysis

  • "treatment_advice" → _perform_treatment_advice

  • "complete_workflow" → _complete_workflow

  • "handle_error" → _handle_error

设置入口点和边
workflow.set_entry_point("create_inquiry")

工作流从create_inquiry节点开始。

workflow.add_edge("create_inquiry", "symptom_analysis")
workflow.add_edge("symptom_analysis", "check_lab_needed")

顺序连接:创建记录后立即进行症状分析,然后检查是否需要化验。

workflow.add_conditional_edges(
    "check_lab_needed",
    self._should_wait_for_lab,
    {
        "wait": "wait_for_lab",
        "continue": "diagnosis_analysis"
    }
)

条件分支:根据_should_wait_for_lab的返回值决定下一步是等待化验还是直接进入诊断分析。

workflow.add_edge("wait_for_lab", "diagnosis_analysis")
workflow.add_edge("diagnosis_analysis", "treatment_advice")
workflow.add_edge("treatment_advice", "complete_workflow")
workflow.add_edge("complete_workflow", END)

正常流程的后续边。

workflow.add_edge("handle_error", END)

如果任何节点出错,可跳转到错误处理节点并终止。

最后,编译图:

return workflow.compile(checkpointer=self.memory)

checkpointer使得图在执行时可以保存和恢复状态。


节点函数详解

每个节点函数接收一个MedicalWorkflowState对象,返回更新后的对象。

1. _create_inquiry_record
  • 获取数据库会话。

  • 调用InquiryRecordDAO.create_inquiry_record插入一条新的问询记录,状态为初始。

  • 将生成的inquiry_record_id保存到状态中,并将current_status更新为SYMPTOM_ANALYSIS

  • 同时更新数据库中的记录状态为SYMPTOM_ANALYSIS

  • 如果出现异常,记录错误并设置状态为FAILED
     

    def _create_inquiry_record(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """创建问询记录"""
            try:
                session = next(get_db_session())
                
                # 创建问询记录
                inquiry_record = InquiryRecordDAO.create_inquiry_record(
                    session=session,
                    patient_id=state.patient_id,
                    doctor_id=state.doctor_id
                )
                
                state.inquiry_record_id = inquiry_record.id
                state.current_status = InquiryStatus.SYMPTOM_ANALYSIS
                
                # 更新问询记录状态
                InquiryRecordDAO.update_inquiry_status(
                    session=session,
                    inquiry_id=inquiry_record.id,
                    status=InquiryStatus.SYMPTOM_ANALYSIS
                )
                
                logger.info(f"创建问询记录成功,ID: {inquiry_record.id}")
                
            except Exception as e:
                logger.error(f"创建问询记录失败: {e}")
                state.error_message = str(e)
                state.current_status = InquiryStatus.FAILED
            
            finally:
                session.close()
            
            return state

2. _perform_symptom_analysis
  • 构建SymptomAnalysisRequest对象,包含患者ID、医生ID和症状描述。

  • 调用symptom_analysis_agent.analyze_symptoms(request)进行症状分析。该智能体返回一个字典,包含successanalysis_resultlab_tests_needed

  • 如果成功,将结果存入状态,并调用symptom_analysis_agent.save_analysis_result将结果保存到数据库(通过SymptomAnalysisResult模型)。

  • 失败则抛出异常,进入错误处理。

     

        def _perform_symptom_analysis(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """执行症状分析"""
            try:
                # 构建症状分析请求
                request = SymptomAnalysisRequest(
                    patient_id=state.patient_id,
                    doctor_id=state.doctor_id,
                    symptom_description=state.symptom_description
                )
                
                # 调用症状分析智能体
                result = symptom_analysis_agent.analyze_symptoms(request)
                
                if result["success"]:
                    state.symptom_analysis_result = result["analysis_result"]
                    state.lab_tests_needed = result["lab_tests_needed"]
                    
                    # 保存症状分析结果
                    analysis_result = SymptomAnalysisResult(**result["analysis_result"])
                    symptom_analysis_agent.save_analysis_result(
                        inquiry_record_id=state.inquiry_record_id,
                        patient_id=state.patient_id,
                        analysis_result=analysis_result
                    )
                    
                    logger.info(f"症状分析完成,问询记录ID: {state.inquiry_record_id}")
                else:
                    raise Exception(result["error"])
                    
            except Exception as e:
                logger.error(f"症状分析失败: {e}")
                state.error_message = str(e)
                state.current_status = InquiryStatus.FAILED
            

3. _check_lab_tests_needed
  • 根据state.lab_tests_needed是否有值更新current_status

    • 有化验项 → WAITING_LAB

    • 无化验项 → DIAGNOSIS_ANALYSIS

  • 返回状态(不执行实际逻辑,仅为状态转换做准备)。

     

      def _check_lab_tests_needed(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """检查是否需要化验检查"""
            if state.lab_tests_needed:
                state.current_status = InquiryStatus.WAITING_LAB
                logger.info(f"需要化验检查: {state.lab_tests_needed}")
            else:
                state.current_status = InquiryStatus.DIAGNOSIS_ANALYSIS
                logger.info("无需化验检查,直接进入诊断分析")
            
            return state

4. _should_wait_for_lab

条件判断函数,返回字符串,用于决定分支走向:

  • 如果lab_tests_needed非空且lab_results为空,返回"wait"

  • 否则返回"continue"
     

        def _should_wait_for_lab(self, state: MedicalWorkflowState) -> str:
            """判断是否需要等待化验结果"""
            if state.lab_tests_needed and not state.lab_results:
                return "wait"
            return "continue"

5. _wait_for_lab_results
  • 此节点目前仅打印日志,表示正在等待化验结果。实际等待是通过外部异步机制(RabbitMQ)触发工作流继续。

  • 工作流在此节点会中断(因为图执行到这里后,没有自动前进的边),直到外部事件调用handle_lab_results方法恢复执行。
     

     def _wait_for_lab_results(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """等待化验结果"""
            # 这里应该订阅RabbitMQ消息队列,等待化验结果
            # 在实际实现中,这个节点会被异步触发
            logger.info(f"等待化验结果,问询记录ID: {state.inquiry_record_id}")
            return state

6. _perform_diagnosis_analysis
  • 构建DiagnosisAnalysisRequest,包含问询记录ID、患者ID、症状分析结果(JSON序列化)和已有化验结果。

  • 调用diagnosis_analysis_agent.analyze_diagnosis进行诊断。

  • 成功则将结果存入状态,并保存到数据库,同时更新状态为TREATMENT_ADVICE
     

      def _perform_diagnosis_analysis(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """执行诊断分析"""
            try:
                # 构建诊断分析请求
                request = DiagnosisAnalysisRequest(
                    inquiry_record_id=state.inquiry_record_id,
                    patient_id=state.patient_id,
                    symptom_analysis_result=json.dumps(state.symptom_analysis_result),
                    lab_results=state.lab_results
                )
                
                # 调用诊断分析智能体
                result = diagnosis_analysis_agent.analyze_diagnosis(request)
                
                if result["success"]:
                    state.diagnosis_result = result["analysis_result"]
                    
                    # 保存诊断分析结果
                    analysis_result = DiagnosisAnalysisResult(**result["analysis_result"])
                    diagnosis_analysis_agent.save_analysis_result(
                        inquiry_record_id=state.inquiry_record_id,
                        patient_id=state.patient_id,
                        analysis_result=analysis_result
                    )
                    
                    state.current_status = InquiryStatus.TREATMENT_ADVICE
                    logger.info(f"诊断分析完成,问询记录ID: {state.inquiry_record_id}")
                else:
                    raise Exception(result["error"])
                    
            except Exception as e:
                logger.error(f"诊断分析失败: {e}")
                state.error_message = str(e)
                state.current_status = InquiryStatus.FAILED
            
            return state
        

7. _perform_treatment_advice
  • 构建TreatmentAdviceRequest,包含问询记录ID、患者ID、诊断结果(JSON序列化),过敏史从患者档案获取(此处硬编码为空列表)。

  • 调用treatment_advice_agent.suggest_treatment获取治疗建议。

  • 成功则保存结果,更新状态为COMPLETED

 def _perform_treatment_advice(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
        """执行治疗建议"""
        try:
            # 构建治疗建议请求
            request = TreatmentAdviceRequest(
                inquiry_record_id=state.inquiry_record_id,
                patient_id=state.patient_id,
                diagnosis_result=json.dumps(state.diagnosis_result),
                allergy_history=[]  # 这里应该从患者档案获取过敏史
            )
            
            # 调用治疗建议智能体
            result = treatment_advice_agent.suggest_treatment(request)
            
            if result["success"]:
                state.treatment_advice = result["advice_result"]
                
                # 保存治疗建议结果
                advice_result = TreatmentAdviceResult(**result["advice_result"])
                treatment_advice_agent.save_advice_result(
                    inquiry_record_id=state.inquiry_record_id,
                    patient_id=state.patient_id,
                    advice_result=advice_result
                )
                
                state.current_status = InquiryStatus.COMPLETED
                logger.info(f"治疗建议完成,问询记录ID: {state.inquiry_record_id}")
            else:
                raise Exception(result["error"])
                
        except Exception as e:
            logger.error(f"治疗建议失败: {e}")
            state.error_message = str(e)
            state.current_status = InquiryStatus.FAILED
        
        return state
8. _complete_workflow
  • 更新数据库中问询记录的状态为COMPLETED,设置workflow_completed为True。

     

    def _complete_workflow(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """完成工作流"""
            try:
                session = next(get_db_session())
                
                # 更新问询记录状态为完成
                InquiryRecordDAO.update_inquiry_status(
                    session=session,
                    inquiry_id=state.inquiry_record_id,
                    status=InquiryStatus.COMPLETED
                )
                
                state.workflow_completed = True
                logger.info(f"工作流完成,问询记录ID: {state.inquiry_record_id}")
                
            except Exception as e:
                logger.error(f"完成工作流失败: {e}")
                state.error_message = str(e)
            
            finally:
                session.close()
            
            return state

9. _handle_error
  • 当任何节点抛出异常时,可通过条件边进入此节点(本代码未显式添加错误边,但可以在外部通过try/except捕获后手动触发)。它更新数据库记录状态为FAILED,并记录失败原因。

     

    def _handle_error(self, state: MedicalWorkflowState) -> MedicalWorkflowState:
            """处理错误"""
            try:
                session = next(get_db_session())
                
                # 更新问询记录状态为失败
                InquiryRecordDAO.update_inquiry_status(
                    session=session,
                    inquiry_id=state.inquiry_record_id,
                    status=InquiryStatus.FAILED,
                    failed_reason=1  # 可以根据错误类型设置不同的失败原因
                )
                
                logger.error(f"工作流失败,问询记录ID: {state.inquiry_record_id}, 错误: {state.error_message}")
                
            except Exception as e:
                logger.error(f"处理错误失败: {e}")
            
            finally:
                session.close()
            
            return state

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐