模型架构

RhoFold+是一个全可微分的端到端框架,从RNA序列输入直接输出全原子3D结构。架构分为几个关键模块:

  1. 输入处理和特征提取:

    • RNA-FM:一个12层Transformer模型,预训练于~23.7百万未标注RNA序列(来自~80万种物种)。它从输入序列中提取嵌入,捕捉进化信息和结构模式,而无需依赖MSA。

      1
      2
      3
      4
      5
      6
      RNA-FM
      └─ Transformer layer × 12
      ├─ Multi-head self-attention // 多头注意力
      ├─ Add and norm // 第一次残差+归一化
      ├─ Feed-forward network // 小的MLP
      └─ Add and norm // 第二次残差+归一化
    • MSA生成:并行搜索大型序列数据库(nt和Rfam)生成MSA特征,作为补充输入。

    • 输入:核苷酸序列的初始embedding

    • 输出:contextual embedding

  2. Rhoformer + Structure Module:

    • Rhoformer (12 layers):使用attention更新embedding
    • Structure Module (8 layers):使用IPA等机制生成全原子坐标
    • Recycling × 10

    IPA(Invariant Point Attention)
    在 Structure Module 中,IPA 用于将残基嵌入映射到三维坐标,同时保持几何不变性。它将每个残基表示映射为若干“点”,再对这些点使用注意力机制更新嵌入和坐标。核心特点是:

    • 保持旋转、平移不变性
    • 学习残基间空间相互作用
    • 支持全原子坐标预测

model

训练和优化

  • 数据集:从PDB中提取5,583条RNA链(分辨率<4.0 Å,长度16-256核苷酸),聚类后得到782个非冗余簇(80%序列相似度阈值)。

  • 损失函数:

    L_mlm:掩码语言模型损失:让模型学会 RNA 序列本身的上下文规律,即使没有 MSA 也能理解序列

    L_dis:距离图损失:监督关键原子对(如 P–C4’)的距离分布,提供全局几何约束

    L_ss:二级结构分类损失:让每个碱基正确预测是 helix、loop 还是 coil

    L_clash:原子冲突惩罚:防止预测结构里原子重叠

    L_FAPE:帧对齐点误差(权重最高 2.0):预测坐标与真实坐标的几何误差

    L_ss3d:3D 感知二级结构损失:要求预测的碱基配对在三维空间里靠得近、方向对

    L_plddt:置信度回归损失:让输出的 pLDDT 分数真实反映局部预测误差,越准分数越高

  • 训练策略:

    • 优化器:Adam(基础学习率 3e-4)
    • 学习率调度:10,000 步 warm-up → 之后 polynomial decay(多项式衰减)
    • 训练步数:300,000 迭代(相当于 1,600 epochs)
    • Batch size:16
    • Dropout:Rhoformer 和 Structure Module 均使用 0.1

硬件要求

训练(Training)

  • GPU:8 × NVIDIA A100 (80GB)
  • CPU:Intel Xeon Gold 6230(64 核)
  • 内存:768 GB
  • 总训练时长:≈ 7 天
  • 模型参数规模:126,913,743

推理(Inference)

  • 1 张 A100:0.14 秒出结构
  • 5070Ti (16 GB) :3.3 秒出结构