我之前一直好奇像下面这种顶级的自动躲避是如何形成的

然后听别人介绍了解了unity中的ML-agent训练学习,这里分享一下步骤以及大概是个什么情况

我的配置是:

unity版本:2023.2.20

ML-agents版本:3.0.0 

使用anaconda虚拟环境进行python训练

python版本:3.10.12

好像ML-agent对版本之间的兼容要求很高,如果你感兴趣想研究的话,建议和我版本一样或者和官方一样

官方链接:https://github.com/Unity-Technologies/ml-agents/blob/develop/docs/Installation.md#advanced-local-installation-for-development

里面版本及相关内容包什么的都有介绍

然后unity里面 编辑-项目设置-ML-agents 里面创建一个端口号为5004的setting Assest

我放置到我创建一个文件夹当中

ppo文件夹下面的TankAgent.yaml是对训练的配置文件

behaviors:
  TankAgent:
    trainer_type: ppo
    hyperparameters:
      batch_size: 64
      buffer_size: 2048
      learning_rate: 3.0e-4
      beta: 5.0e-4
      epsilon: 0.2
      lambd: 0.99
      num_epoch: 3
      learning_rate_schedule: linear
      beta_schedule: constant
      epsilon_schedule: linear
    network_settings:
      normalize: true
      vis_encode_type: simple
      obs_shapes:
        - [2, 4, 3]
      vector_size: 12
      hidden_units: 128
      num_layers: 2
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    checkpoint_interval: 10000
    max_steps: 500000
    time_horizon: 64
    summary_freq: 10
    self_play:
      play_against_current_self_ratio: 0.5 
      save_steps: 5000 
      team_change: 500 

当你想要进行训练的时候,把unity和和虚拟环境进行连接

步骤1:激活环境

步骤2:进入项目的Assest文件夹地址

步骤3:对刚刚的.yaml后缀文件进行连接运行,给这次训练叫做TankAgent1 

出现下面信息是请求你按下unity的运行键

连接成功出现

此时你想要训练的物体就会动了

讲完上面的讲一下关于坦克动荡这个ai我的思路(所有配置和代码放最后)

1:给需要训练的坦克三个脚本并合理配置

2:因为需要训练非常的多数据才能完成,官方示例的:小球在一个正方形平面上(四周会掉落)向随机生成的物体位置移动,我用了5w步才实现流畅的ai移动,这个坦克动荡一开始我想做随机地图的但是后面发现一个地图就要训练非常久了,所以你可以先训练好一个地图的先试试

3:因为训练的时候是按几十倍几十倍的速度去训练的,不可能一步一步去训练,也就是不可能在训练的时候你一局一局的去训练ai,所以我弄的是ai对ai进行训练,能看情况训练的差不多了,你就可以搞一辆坦克跟其中一辆训练玩(把训练好的onnx文件挂载在ai的组件behavior parameters 的model上面就行 )

训练结束后没固定步数会生成一个onnx文件在自动生成的文件夹下面

4:我在处理让ai知道墙壁的问题的时候一开始用的是grid sensor,后面发现好像对2d不是很友好,一直出现一些问题,看别人是自己对这个组件进行重写来满足2d的,或者这个组件也能用但是我在某个地方错了没有找到,所有后面我另外一个组件 ray perception Sensor 2d来实现这个好像更简单

5:关于TankAgent脚本中如果你写的方法或者算法越好肯定时训练步数越少,里面的奖励设置要分配好,我之前奖励设置不好的时候(奖励稀疏,惩罚密集),ai会慢慢的不向奖励的方向训练而是逐渐找一个最小惩罚的行为重复执行

开始训练时这样的:

接下来贴所有配置和参数(我目前训练的ai还是傻傻的,训练数据太少和算法奖励并不是很完美,如果后面能搞得差不多会更新现在贴出来做个参考)

坦克的:

子弹完全弹性碰撞就好了

脚本:

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class TankAgent : Agent
{
    [Header("队伍与对手")]
    [Tooltip("0 或 1,用于区分自我对抗的两支队伍")]
    public int teamId;
    [Tooltip("对手坦克的 Transform")]
    public Transform opponentTransform;

    [Header("移动与射击")]
    public float moveSpeed = 5f;       // 前后速度
    public float turnSpeed = 180f;     // 旋转速度
    public GameObject bulletPrefab;    // 子弹预制体
    public Transform firePoint;        // 发射口
    public float bulletSpeed = 10f;    // 子弹初速
    public float shootCooldown = 0.5f; // 射击冷却
    public float fireRange = 2f;  

    [Header("子弹检测设置")]
    public float bulletDetectionRadius = 1.5f; // 检测子弹的半径
    public int maxDetectedBullets = 10;       // 最大检测子弹数量
    public LayerMask bulletLayer;            // 子弹所在的物理层

    private Rigidbody2D rb;
    private float lastShootTime;

    

    public override void Initialize()
    {
        rb = GetComponent<Rigidbody2D>();
        rb.freezeRotation = true;
    }

    // OnEpisodeBegin方法:ML-Agents的重写方法,每个训练回合开始时调用
    public override void OnEpisodeBegin()
    {
        
        // 重置自己与对手的位置
        Vector2[] corners = new Vector2[] { new Vector2(-7, -3), new Vector2(7, 3) };
        // 队伍0在左下,队伍1在右上
        Vector2 startPos = corners[teamId];
        Vector2 oppPos = corners[1 - teamId];
        transform.position = startPos;
        transform.rotation = Quaternion.identity;
        rb.velocity = Vector2.zero;

        opponentTransform.position = oppPos;
        opponentTransform.rotation = Quaternion.Euler(0, 0, 180);
        opponentTransform.GetComponent<Rigidbody2D>().velocity = Vector2.zero;

        lastShootTime = -shootCooldown;
        // 清理场上所有子弹
        foreach (var b in GameObject.FindGameObjectsWithTag("Bullet")) Destroy(b);
    }

    // CollectObservations方法:收集环境观察数据供AI学习
    public override void CollectObservations(VectorSensor sensor)
    {
        // 自己的局部位置与朝向
        sensor.AddObservation(transform.localPosition);
        sensor.AddObservation(transform.up);

        // 对手的位置与朝向
        sensor.AddObservation(opponentTransform.localPosition);
        sensor.AddObservation(opponentTransform.up);

        DetectNearbyBullets(sensor);
        

    }

    private void DetectNearbyBullets(VectorSensor sensor)
    {
        // 1. 检测半径内的所有子弹
        Collider2D[] nearbyBullets = Physics2D.OverlapCircleAll(
            transform.position, 
            bulletDetectionRadius, 
            bulletLayer
        );
        
        // 2. 按距离排序(最近的优先)
        List<Collider2D> sortedBullets = new List<Collider2D>(nearbyBullets);
        sortedBullets.Sort((a, b) => 
            Vector2.Distance(a.transform.position, transform.position)
            .CompareTo(Vector2.Distance(b.transform.position, transform.position))
        );
        
        // 3. 添加子弹信息到观察向量
        int bulletsAdded = 0;
        foreach (var bulletCollider in sortedBullets)
        {
            if (bulletsAdded >= maxDetectedBullets) break;
            
            // 获取子弹信息组件
            BulletInfo bulletInfo = bulletCollider.GetComponent<BulletInfo>();
            if (bulletInfo == null) continue;
            
            // 计算子弹相对于自身的位置(极坐标)
            Vector2 relativePos = transform.InverseTransformPoint(bulletCollider.transform.position);
            
            // 添加观察数据:
            sensor.AddObservation(relativePos);       // 相对位置 (x,y)
            sensor.AddObservation(relativePos.magnitude / bulletDetectionRadius); // 归一化距离 [0-1]
            sensor.AddObservation(bulletInfo.teamId == teamId ? 1f : 0f); // 是否友方子弹
            
            bulletsAdded++;
        }
        
        // 4. 如果子弹数量不足,添加空白数据
        for (int i = bulletsAdded; i < maxDetectedBullets; i++)
        {
            sensor.AddObservation(Vector2.zero); // 位置
            sensor.AddObservation(0f);           // 距离
            sensor.AddObservation(0f);           // 是否友方
        }
        
    }


    // OnActionReceived方法:处理AI决策的动作输入
    public override void OnActionReceived(ActionBuffers actions)
    {
        // 离散动作:0=静止,1=前,2=后
        int move = actions.DiscreteActions[0];
        // 离散动作:0=静止,1=左,2=右
        int turn = actions.DiscreteActions[1];
        // 离散动作:0=不射,1=射击
        int fire = actions.DiscreteActions[2];

        // 移动
        Vector2 vel = Vector2.zero;
        if (move == 1) vel = transform.up * moveSpeed;
        else if (move == 2) vel = -transform.up * moveSpeed;
        rb.velocity = vel;

        // 旋转
        float delta = 0f;
        if (turn == 1) delta = turnSpeed * Time.fixedDeltaTime;
        else if (turn == 2) delta = -turnSpeed * Time.fixedDeltaTime;
        rb.MoveRotation(rb.rotation + delta);
        
        if (vel.sqrMagnitude < 0.01f && Mathf.Abs(delta) > 1f)
            AddReward(-0.02f);


        float distToOpponent = Vector2.Distance(transform.position, opponentTransform.position);
        Vector2 toOpponent = opponentTransform.position - transform.position;
        float angleToOpponent = Vector2.Angle(transform.up, toOpponent);

        AddReward( Mathf.Clamp01((fireRange - distToOpponent) / fireRange) * 0.02f ); // 距离越近奖越多
        AddReward((1f - angleToOpponent / 180f) * 0.01f);                                // 越对准奖越多


        // 射击
        if (fire == 1 && Time.time - lastShootTime >= shootCooldown)
        {
            if (distToOpponent <= fireRange && angleToOpponent < 30f)
            {
                // 在允许的射程范围内,真正发射子弹
                AddReward(+1f); // 可以保留原有的射击开销奖励

                var b = Instantiate(bulletPrefab, firePoint.position, transform.rotation);
                b.tag = "Bullet";
                var bRb = b.GetComponent<Rigidbody2D>();
                bRb.velocity = transform.up * bulletSpeed;

                // 记录子弹所属队伍
                var info = b.AddComponent<BulletInfo>();
                info.teamId = teamId;
                info.shooterAgent = this;

                lastShootTime = Time.time;
            }
            // else
            // {
            //     // 不在射程内,忽略发射
            //     // (可选:给一个轻微负奖励,让它学会“不对着远处傻打”)
            //     AddReward(-0.05f);
            // }
        }

        AddReward(-0.01f);
    }

    private void OnDrawGizmosSelected()
    {
        // 确保 fireRange 有意义
        if (fireRange <= 0f) return;

        // 设置 Gizmos 颜色(你可以根据喜好修改颜色或透明度)
        Gizmos.color = new Color(1f, 0f, 0f, 0.5f); // 半透明红色
        Gizmos.DrawWireSphere(transform.position, fireRange);

        Gizmos.color = new Color(0f, 1f, 1f, 0.3f);
        Gizmos.DrawWireSphere(transform.position, bulletDetectionRadius);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var da = actionsOut.DiscreteActions;
        da[0] = Input.GetKey(KeyCode.W) ? 1 : Input.GetKey(KeyCode.S) ? 2 : 0;
        da[1] = Input.GetKey(KeyCode.A) ? 1 : Input.GetKey(KeyCode.D) ? 2 : 0;
        da[2] = Input.GetKey(KeyCode.Space) ? 1 : 0;
    }

    private void OnCollisionEnter2D(Collision2D col)
    {
        // 撞到子弹
        if (col.collider.CompareTag("Bullet"))
        {
            var info = col.collider.GetComponent<BulletInfo>();
            Destroy(col.collider.gameObject);
            if (info == null) return;

            if (info.shooterAgent == this)
            {
                // 「自己的子弹」打到自己 → 自己扣分
                AddReward(-5f);
                EndEpisode();
            }
            else
            {
                // 「我的子弹」打到敌人 → 
                // 1) 给自己加分
                info.shooterAgent.AddReward(+10f);
                // 2) 给被击中的敌人扣分
                AddReward(-5f);

                EndEpisode();
            }
        }

        if (col.collider.CompareTag("wall"))
        {
            // 碰到墙壁,扣一个小分数(例如 -1f)
            AddReward(-0.05f);
        }
    }
}




子弹脚本(不用挂载在子弹预制体上,TankAgent脚本里面有添加方法)

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

// 附加在子弹上的小脚本,用于标记所属队伍
public class BulletInfo : MonoBehaviour
{
    public int teamId;
    public TankAgent shooterAgent;
}

如果你有更好的建议或者训练方法,可以分享出来一起学习,有用的话点个赞

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐