MART:基于多尺度关系Transformer网络的多智能体轨迹预测
多智能体轨迹预测大多依赖于GNN,图transformer,超图神经网络,基于超图的还有待研究,因此提出一个多尺度的交互transformer网络(MART),可以在transformer中考虑个体和群体的行为。
2024 ECCV
1、介绍
1.1、主要贡献
多智能体轨迹预测大多依赖于GNN,图transformer,超图神经网络,基于超图的还有待研究,因此提出一个多尺度的交互transformer网络(MART),可以在transformer中考虑个体和群体的行为
为了解决超图Transformer架构中的社会关系问题,引入编码器MARTE,其解码器包括一个对交互transformer(PRT)和一个超交互transformer(HRT),该编码器通过引入HRT扩展了PRT的功能,HRT将超边特征集成到tansformer机制中,促进注意力权重专注于群智能的内部关系
在GroupNet中,群体中的智能体数需要手动定义,并且随着群尺度数量的增加,需要更多的编码器,从而导致更高的计算成本,为了解决上述限制,提出了自适应群体估计器 (AGE),旨在推断现实环境中复杂的群体关系,利用自适应阈值将高度相关的智能体纳入同一组,可以估计重叠的群体关系,认识到智能体可以与多个群体相关联,不需要手动定义智能体数量
轨迹预测主要集中于GNN,和transformer
如EqMotion通过GCN来解释场景中智能体的交互关系
在考虑个体和群体时,GroupNet是一个多尺度超图信息传递网络,DynGroupNet可以捕获配对和群体尺度上的时变交互,但transformer机制并没有完全应用
在NBA数据集实验中,MART优于EqMotion
1.2、关系transformer(RT)
RT将边向量作为基本元件,类似于信息传递神经网络,RT包含两个过程:节点更新和边更新
节点更新中,RT引入关系注意力(RA),改变原有的Q,K,V的生成,将边向量合并到注意力的计算中

节点更新的其余部分与transformer机制相同
边更新为

将该边的两个更新节点和边聚合得到信息mij,Fe为transformer的更新公式

2、模型

首先特诊提取得到节点特征,成对边特征,和超边特征,其中,成对边特征是双向的

然后分别利用L层PRT编码成对边特征,HRT编码超边特征,分别得到成对尺度下和群尺度下的节点特征

将这两种尺度下的节点特征和初始节点特征连接,,输入给多头解码器,每个解码器即一个三层MLP,最后得到k个预测轨迹
实验中,利用k个预测中L2范数的最小值进行反向传播,求的平均值即为损失值

2.1、特征提取
首先将过去的历史轨迹通过MLP得到节点特征,来表示节点特征和边特征
其次分别得到成对边特征和超边特征
将两个相邻节点特征连接通过MLP得到成对边特征
由于边关系在现实环境中通常没有明确定义,因此通过聚合相邻节点的特征来表示超边特征,首先将对应超边的节点特征求平均值,然后通过一个MLP得到超边特征
2.2、AGE
AGE通过自适应阈值来有效推断群体关系,从而将高相关的智能体分组,通过一个超图来计算得到群事件矩阵G,Gij当为1时,表示第i个智能体属于第j个超边
首先计算相关矩阵,反映两个智能体之间的相关度
利用单位阶跃函数,当相关度大于一个可学习阈值时,群事件值为1,否则为0

由于单位阶跃函数不可微分,因此利用一个STE技巧,估计该函数的梯度来实现反向传播,通过公式8来推理得到公式7


2.3、MARTE
引入MARTE通过超图transformer架构来提取多尺度的社会交互
PRT和HRT分别解决个人和群体行为

PRT就是一个RT的结构
HRT则是RT的一个扩展,引入了一个超关系注意力(HRA),此外还引入了一个信息函数来更新超边
节点更新,通过聚合平均智能体对应的超边特征,将聚合信息集成到Q,K,V中,其中归属关系来源于群事件矩阵G

超边更新,通过将超边特征,和相应更新节点特征的平均值进行处理得到信息

再通过一个基于前馈网络的边更新函数更新超边特征

3、实验
在三大数据集上优于DynGroupNet、LED、EqMotion等方法
没有考虑未来轨迹的时空交互
缺乏多模态轨迹
更多推荐


所有评论(0)