零基础神经网络第六课零代码Excel表手搓回归任务价格预测神经网络AI模型

神经网络可以简单理解为一个 “会自己学规律的智能计算器”,它的核心原理和人类学习的过程很像:通过不断试错来调整自己的判断逻辑,最后学会从输入信息中找出规律并做出准确预测。

1. 神经网络的本质:用数学模拟 “学习”

  • 类比场景:比如小孩学认苹果,第一次看到红通通、圆圆的东西,可能猜是 “西红柿”,大人告诉他 “错了,这是苹果”,他就会记住苹果的特征(红色、圆形、甜味等),下次看到类似的东西就会调整判断。
    神经网络也是如此:先 “猜” 一个结果,再根据真实答案调整自己的 “判断逻辑”,直到猜得越来越准。
  • 核心逻辑:通过数学公式(如加权求和、激活函数)模拟信息处理过程,用 “误差反馈” 来修正参数,最终学会输入(如箱数)和输出(如价格)之间的映射关系。

2. 神经网络的基本结构:像搭积木一样分层

  • 输入层:接收原始信息,比如车厘子的箱数、数字图片的像素值。
  • 隐藏层:中间的 “思考层”,用数学公式对信息进行加工(比如计算 “箱数 × 权重 + 偏置”),多层隐藏层可以理解为 “多步思考”。
  • 输出层:给出最终结果,比如预测的价格、识别的数字类别。
  • 关键组件
    • 权重(w):类似 “重要程度” 的旋钮,决定输入信息对结果的影响大小(比如箱数对价格的影响权重)。
    • 偏置(b):类似 “基准值” 的调整,让模型能适应更复杂的规律(比如除了箱数,还有其他因素影响价格的基准值)。

3. 神经网络如何 “学习”:前向预测 + 反向调参

  • 前向传播(先猜结果)
    比如预测 1 箱车厘子价格,先按初始权重(如 w=0.2)和偏置(b=0)计算:
    预测值 = Sigmoid(0.2×1 + 0)(Sigmoid 是把数值 “翻译” 成概率的函数),得到一个预测价格。
  • 反向传播(根据错误调参数)
    发现预测值和真实价格(249.79 元)有差距,就计算误差,然后像拧螺丝一样调整权重和偏置:
    • 如果预测值比真实值高,就减小权重;如果低,就增大权重。
    • 这个过程通过公式新参数 = 旧参数 - 学习率×误差×调整系数实现,反复迭代直到误差最小。

4. 视频中 Excel 手搓的核心:用表格模拟数学计算

  • 不用写代码,直接在 Excel 里用公式实现:
    • 输入箱数到 A 列,真实价格到 B 列;
    • 用公式计算加权和(如=w*A2+b)、Sigmoid 值(如=1/(1+EXP(-C2));
    • 计算误差(真实值 - 预测值),再按公式更新 w 和 b,重复几百次直到预测准确。
  • 为什么用 Excel?:直观看到每一步计算,理解 “参数怎么变”“误差怎么降”,就像手动调试一个智能机器。

5. 回归任务 vs 分类任务:神经网络的两种 “应用题”

  • 回归任务(如价格预测):输出是连续的数值(比如 30 箱车厘子总价 4525 元),用 “均方误差” 衡量预测值和真实值的差距(误差 = 真实值 - 预测值的平方平均)。
  • 分类任务(如数字识别):输出是类别(比如图片里的数字是 “5”),用 “交叉熵” 等损失函数衡量分类准确性,就像考试判对错。

一句话总结

神经网络是一个 “能自己改答案的数学计算器”,通过分层计算、反复试错调整参数,学会从输入数据中找出规律,无论是预测价格还是识别数字,本质都是用数学模拟 “学习” 的过程,而 Excel 手搓就是把这个过程拆分成表格里的公式,让普通人也能看懂智能机器如何思考。

一、数据准备与表格搭建

  1. 输入基础数据

    • 在 Excel 中创建表格,第一列输入车厘子箱数(如 1、30)作为输入特征x
    • 第二列输入实际总价(如 249.79、4525.74 元)作为目标值y_t
    • 归一化处理:将价格y_t转换为适合 Sigmoid 函数的区间[0.5, 1],公式为:
      归一化后价格 = (y_t / 10000) + 0.5(先除以 10000 缩放到 0~0.5,再加 0.5 平移到 0.5~1)。
  2. 初始化参数

    • 在表格中单独设置单元格存放初始权重w(如设为 0.2)和偏置b(如设为 0),方便后续引用和更新。

二、前向传播:计算预测值

  1. 计算加权和i

    • 对每个x,计算i = w * x + b,例如:
      x=1w=0.2b=0,则i = 0.2*1 + 0 = 0.2
    • 在 Excel 中用公式=w单元格 * x单元格 + b单元格实现。
  2. 通过 Sigmoid 激活函数计算预测值y_o

    • Sigmoid 公式:y_o = 1 / (1 + EXP(-i)),例如i=0.2时,y_o = 1/(1+EXP(-0.2))≈0.54
    • 在 Excel 中输入公式=1/(1+EXP(-i单元格)),得到归一化的预测价格。

三、误差计算:衡量预测偏差

  1. 计算实际值与预测值的误差

    • 误差Delta_y = 归一化后实际价格 - y_o,例如:
      若归一化后实际价格为(249.79/10000)+0.5=0.524979y_o=0.54,则Delta_y=0.524979-0.54=-0.015021
    • 在 Excel 中用公式=归一化后实际价格单元格 - y_o单元格计算。
  2. 计算 Sigmoid 导数(用于参数更新)

    • 导数公式:y_o * (1 - y_o),例如y_o=0.54时,导数为0.54*(1-0.54)=0.2484
    • 在 Excel 中用公式=y_o单元格*(1-y_o单元格)计算,作为中间值备用。

四、反向传播:更新权重和偏置

  1. 设置学习率α

    • 例如设α=0.2,在表格中单独存放,方便调整。
  2. 计算权重更新量Delta_w和偏置更新量Delta_b

    • 公式:
      Delta_w = α * Delta_y * 导数 * x
      Delta_b = α * Delta_y * 导数
    • 例如:α=0.2Delta_y=-0.015021,导数 = 0.2484,x=1,则:
      Delta_w=0.2*(-0.015021)*0.2484*1≈-0.000747
      Delta_b=0.2*(-0.015021)*0.2484≈-0.000747
  3. 更新参数

    • 新参数公式:
      w_new = w_old - Delta_w
      b_new = b_old - Delta_b
    • 例如初始w=0.2,则w_new=0.2 - (-0.000747)≈0.200747b_new=0 - (-0.000747)≈0.000747
    • 在 Excel 中用公式=w_old单元格 - Delta_w单元格更新参数,覆盖原单元格值。

五、迭代优化:重复计算直到误差减小

  1. 循环前向传播和参数更新

    • 用更新后的wb重新计算所有x对应的y_oDelta_y,再次更新参数,重复此过程(如迭代 2000 次)。
    • 关键:每轮迭代后观察误差Delta_y是否逐渐趋近于 0,若误差波动或增大,可减小学习率α
  2. 记录迭代过程

    • 在表格中扩展列,记录每轮的wby_oDelta_y,方便观察参数收敛情况。

六、价格预测:用优化后的参数计算新数据

  1. 输入新箱数x_new(如 31 箱)

    • 用最终优化的wb计算i = w * x_new + b,再通过 Sigmoid 得到y_o_new
  2. 反归一化得到实际价格

    • 公式:实际预测价格 = (y_o_new - 0.5) * 10000,例如y_o_new=0.6时,实际价格为(0.6-0.5)*10000=1000元。

Excel 操作核心技巧

  • 公式引用:用单元格地址(如A1B2)代替具体数值,方便自动计算。
  • 批量计算:对多组数据(如 1 箱和 30 箱),可横向复制公式,同时计算多个样本的误差和更新量。
  • 可视化:用 Excel 图表绘制误差随迭代次数的变化曲线,直观判断模型收敛情况。

通过以上步骤,无需代码即可在 Excel 中手动实现神经网络的价格预测,核心是理解 “前向计算预测→反向更新参数” 的循环逻辑,并通过表格公式模拟数学运算。

<!DOCTYPE html>
<html lang="zh-CN">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>神经网络--回归任务</title>
  <script src="https://cdn.tailwindcss.com"></script>
  <link href="https://cdn.jsdelivr.net/npm/font-awesome@4.7.0/css/font-awesome.min.css" rel="stylesheet">
  <script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.8/dist/chart.umd.min.js"></script>
  <script>
    tailwind.config = {
      theme: {
        extend: {
          colors: {
            primary: '#165DFF',
            secondary: '#7B61FF',
            success: '#00B42A',
            warning: '#FF7D00',
            danger: '#F53F3F',
            info: '#86909C',
            light: '#F2F3F5',
            dark: '#1D2129',
          },
          fontFamily: {
            inter: ['Inter', 'system-ui', 'sans-serif'],
          },
        },
      }
    }
  </script>
  <style type="text/tailwindcss">
    @layer utilities {
      .content-auto {
        content-visibility: auto;
      }
      .card-shadow {
        box-shadow: 0 10px 30px -5px rgba(0, 0, 0, 0.1);
      }
      .glass-effect {
        backdrop-filter: blur(10px);
        background-color: rgba(255, 255, 255, 0.7);
      }
      .gradient-bg {
        background: linear-gradient(135deg, #165DFF 0%, #7B61FF 100%);
      }
      .btn-hover {
        @apply transition-all duration-300 hover:shadow-lg transform hover:-translate-y-1;
      }
      .input-focus {
        @apply focus:ring-2 focus:ring-primary/50 focus:border-primary;
      }
    }
  </style>
</head>
<body class="font-inter bg-gray-50 min-h-screen">
  <!-- 顶部导航栏 -->
  <header class="sticky top-0 z-50 gradient-bg text-white shadow-md">
    <div class="container mx-auto px-4 py-4 flex justify-between items-center">
      <div class="flex items-center space-x-3">
        <i class="fa fa-line-chart text-2xl"></i>
        <h1 class="text-xl md:text-2xl font-bold">神经网络可视化工具——回归任务</h1>
      </div>
      <div class="hidden md:flex items-center space-x-6">
        <a href="#" class="hover:text-light transition-colors">首页</a>
        <a href="#" class="hover:text-light transition-colors">文档</a>
        <a href="#" class="hover:text-light transition-colors">关于</a>
      </div>
      <button class="md:hidden text-xl">
        <i class="fa fa-bars"></i>
      </button>
    </div>
  </header>

  <!-- 主内容区 -->
  <main class="container mx-auto px-4 py-8">
    <!-- 控制面板 -->
    <section class="mb-8">
      <div class="bg-white rounded-xl p-6 card-shadow">
        <h2 class="text-xl font-bold mb-6 flex items-center">
          <i class="fa fa-sliders text-primary mr-2"></i>参数设置
        </h2>
        
        <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
          <!-- 权重和偏置设置 -->
          <div class="space-y-4">
            <div>
              <label for="weight" class="block text-sm font-medium text-gray-700 mb-1">权重 (w)</label>
              <div class="flex items-center">
                <input type="number" id="weight" value="0.5" step="0.01" min="-10" max="10" 
                  class="w-full px-3 py-2 border border-gray-300 rounded-md input-focus">
                <button id="resetWeight" class="ml-2 p-2 text-gray-500 hover:text-primary transition-colors">
                  <i class="fa fa-refresh"></i>
                </button>
              </div>
            </div>
            
            <div>
              <label for="bias" class="block text-sm font-medium text-gray-700 mb-1">偏置 (b)</label>
              <div class="flex items-center">
                <input type="number" id="bias" value="0.0" step="0.01" min="-10" max="10" 
                  class="w-full px-3 py-2 border border-gray-300 rounded-md input-focus">
                <button id="resetBias" class="ml-2 p-2 text-gray-500 hover:text-primary transition-colors">
                  <i class="fa fa-refresh"></i>
                </button>
              </div>
            </div>
            
            <div>
              <label for="learningRate" class="block text-sm font-medium text-gray-700 mb-1">学习率</label>
              <div class="flex items-center">
                <input type="number" id="learningRate" value="0.1" step="0.01" min="0.001" max="1" 
                  class="w-full px-3 py-2 border border-gray-300 rounded-md input-focus">
                <button id="resetLearningRate" class="ml-2 p-2 text-gray-500 hover:text-primary transition-colors">
                  <i class="fa fa-refresh"></i>
                </button>
              </div>
            </div>
          </div>
          
          <!-- 文件上传 -->
          <div>
            <label class="block text-sm font-medium text-gray-700 mb-1">训练数据上传</label>
            <div class="border-2 border-dashed border-gray-300 rounded-lg p-6 text-center hover:border-primary transition-colors cursor-pointer" id="fileDropArea">
              <i class="fa fa-file-excel-o text-4xl text-gray-400 mb-2"></i>
              <p class="text-gray-600">拖放Excel文件到这里,或</p>
              <button type="button" class="mt-2 px-4 py-2 bg-primary text-white rounded-md btn-hover">
                <i class="fa fa-upload mr-1"></i> 选择文件
              </button>
              <input type="file" id="fileInput" accept=".xlsx,.xls,.csv" class="hidden">
              <p id="fileStatus" class="mt-2 text-sm text-gray-500"></p>
            </div>
          </div>
          
          <!-- 训练控制 -->
          <div class="space-y-4">
            <div>
              <label for="epochs" class="block text-sm font-medium text-gray-700 mb-1">训练轮次</label>
              <input type="number" id="epochs" value="100" min="1" max="10000" 
                class="w-full px-3 py-2 border border-gray-300 rounded-md input-focus">
            </div>
            
            <div>
              <label class="block text-sm font-medium text-gray-700 mb-1">训练状态</label>
              <div class="flex items-center p-3 bg-gray-100 rounded-md">
                <span id="trainingStatus" class="flex-1 text-gray-700">就绪</span>
                <span id="lossValue" class="px-3 py-1 bg-gray-200 rounded-full text-sm font-medium">
                  损失: --
                </span>
              </div>
            </div>
            
            <div class="grid grid-cols-2 gap-3">
              <button id="trainBtn" class="px-4 py-3 bg-primary text-white rounded-md btn-hover flex items-center justify-center">
                <i class="fa fa-play mr-2"></i> 开始训练
              </button>
              <button id="stopBtn" class="px-4 py-3 bg-danger text-white rounded-md btn-hover flex items-center justify-center" disabled>
                <i class="fa fa-stop mr-2"></i> 停止训练
              </button>
            </div>
          </div>
        </div>
      </div>
    </section>
    
    <!-- 图表展示区 -->
    <section class="grid grid-cols-1 lg:grid-cols-2 gap-6 mb-8">
      <!-- 数据和预测可视化 -->
      <div class="bg-white rounded-xl p-6 card-shadow">
        <h2 class="text-xl font-bold mb-4 flex items-center">
          <i class="fa fa-area-chart text-primary mr-2"></i>数据与预测
        </h2>
        <div class="aspect-w-16 aspect-h-9">
          <canvas id="dataChart"></canvas>
        </div>
      </div>
      
      <!-- 损失函数可视化 -->
      <div class="bg-white rounded-xl p-6 card-shadow">
        <h2 class="text-xl font-bold mb-4 flex items-center">
          <i class="fa fa-line-chart text-secondary mr-2"></i>损失函数
        </h2>
        <div class="aspect-w-16 aspect-h-9">
          <canvas id="lossChart"></canvas>
        </div>
      </div>
    </section>
    
    <!-- 神经网络结构可视化 -->
    <section class="mb-8">
      <div class="bg-white rounded-xl p-6 card-shadow">
        <h2 class="text-xl font-bold mb-4 flex items-center">
          <i class="fa fa-sitemap text-warning mr-2"></i>神经网络结构
        </h2>
        <div class="aspect-w-16 aspect-h-6 bg-gray-50 rounded-lg overflow-hidden">
          <svg id="nnVisualization" viewBox="0 0 1000 300" class="w-full h-full">
            <!-- 神经网络结构将通过JS动态生成 -->
          </svg>
        </div>
      </div>
    </section>
    
    <!-- 预测工具 -->
    <section class="mb-8">
      <div class="bg-white rounded-xl p-6 card-shadow">
        <h2 class="text-xl font-bold mb-4 flex items-center">
          <i class="fa fa-calculator text-success mr-2"></i>预测工具
        </h2>
        
        <div class="grid grid-cols-1 md:grid-cols-3 gap-6">
          <div>
            <label for="predictionInput" class="block text-sm font-medium text-gray-700 mb-1">输入值 (x)</label>
            <input type="number" id="predictionInput" value="0.5" step="0.01" min="-10" max="10" 
              class="w-full px-3 py-2 border border-gray-300 rounded-md input-focus">
          </div>
          
          <div class="flex items-end">
            <button id="predictBtn" class="w-full px-4 py-2 bg-primary text-white rounded-md btn-hover">
              <i class="fa fa-magic mr-2"></i> 计算预测值
            </button>
          </div>
          
          <div>
            <label class="block text-sm font-medium text-gray-700 mb-1">预测结果 (y)</label>
            <div class="p-3 bg-gray-100 rounded-md text-lg font-medium" id="predictionResult">
              --
            </div>
          </div>
        </div>
      </div>
    </section>
    
    <!-- 功能说明 -->
    <section class="mb-8">
      <div class="bg-white rounded-xl p-6 card-shadow">
        <h2 class="text-xl font-bold mb-4 flex items-center">
          <i class="fa fa-info-circle text-primary mr-2"></i>关于本工具
        </h2>
        <p class="mb-4">这个神经网络可视化工具可以帮助你理解基本的神经网络原理。它具有以下功能:</p>
        <ul class="list-disc pl-5 mb-4 space-y-2">
          <li>支持通过Excel表格导入训练数据,表格需包含两列(x和y值)</li>
          <li>可自定义设置权重(w)、偏置(b)和学习率</li>
          <li>实时可视化训练过程,包括数据点、预测值和预测曲线</li>
          <li>展示损失函数随训练轮次的变化</li>
          <li>可视化神经网络结构,直观展示输入层、输出层和参数</li>
          <li>提供预测工具,可输入x值计算预测结果</li>
        </ul>
        <p class="mb-4">你可以通过点击"开始训练"按钮来训练模型,也可以随时停止训练。如需重新开始,可以上传新的训练数据或重置参数。</p>
        <p class="text-primary font-medium">技术课程请看抖音号码938129762欢迎讨论和索要课件</p>
      </div>
    </section>
  </main>

  <!-- 页脚 -->
  <footer class="bg-dark text-white py-8">
    <div class="container mx-auto px-4">
      <div class="grid grid-cols-1 md:grid-cols-3 gap-8">
        <div>
          <h3 class="text-lg font-bold mb-4">神经网络可视化工具</h3>
          <p class="text-gray-400">一个直观展示神经网络训练过程和预测效果的工具,帮助理解基本神经网络原理。</p>
        </div>
        
        <div>
          <h3 class="text-lg font-bold mb-4">功能特性</h3>
          <ul class="space-y-2 text-gray-400">
            <li><i class="fa fa-check text-success mr-2"></i>Excel数据导入</li>
            <li><i class="fa fa-check text-success mr-2"></i>参数自定义设置</li>
            <li><i class="fa fa-check text-success mr-2"></i>实时可视化训练过程</li>
            <li><i class="fa fa-check text-success mr-2"></i>神经网络结构展示</li>
          </ul>
        </div>
        
        <div>
          <h3 class="text-lg font-bold mb-4">相关资源</h3>
          <ul class="space-y-2 text-gray-400">
            <li><a href="#" class="hover:text-white transition-colors"><i class="fa fa-book mr-2"></i>使用文档</a></li>
            <li><a href="#" class="hover:text-white transition-colors"><i class="fa fa-github mr-2"></i>GitHub仓库</a></li>
            <li><a href="#" class="hover:text-white transition-colors"><i class="fa fa-question-circle mr-2"></i>常见问题</a></li>
            <li><a href="https://www.douyin.com/user/938129762" target="_blank" class="hover:text-white transition-colors"><i class="fa fa-play-circle mr-2"></i>抖音教学课程</a></li>
          </ul>
        </div>
      </div>
      
      <div class="border-t border-gray-800 mt-8 pt-8 text-center text-gray-500">
        <p>&copy; 2025 神经网络可视化工具 | 保留所有权利</p>
      </div>
    </div>
  </footer>

  <script>
    // 全局变量
    let trainingData = [];
    let predictions = [];
    let lossHistory = [];
    let isTraining = false;
    let trainingInterval;
    
    // 神经网络参数
    let w = 0.5;
    let b = 0.0;
    let learningRate = 0.1;
    
    // DOM元素
    const weightInput = document.getElementById('weight');
    const biasInput = document.getElementById('bias');
    const learningRateInput = document.getElementById('learningRate');
    const fileInput = document.getElementById('fileInput');
    const fileDropArea = document.getElementById('fileDropArea');
    const fileStatus = document.getElementById('fileStatus');
    const trainBtn = document.getElementById('trainBtn');
    const stopBtn = document.getElementById('stopBtn');
    const epochsInput = document.getElementById('epochs');
    const trainingStatus = document.getElementById('trainingStatus');
    const lossValue = document.getElementById('lossValue');
    const predictionInput = document.getElementById('predictionInput');
    const predictBtn = document.getElementById('predictBtn');
    const predictionResult = document.getElementById('predictionResult');
    const resetWeightBtn = document.getElementById('resetWeight');
    const resetBiasBtn = document.getElementById('resetBias');
    const resetLearningRateBtn = document.getElementById('resetLearningRate');
    
    // 初始化图表
    const dataChart = new Chart(
      document.getElementById('dataChart'),
      {
        type: 'scatter',
        data: {
          datasets: [
            {
              label: '实际数据',
              data: [],
              backgroundColor: 'rgba(22, 93, 255, 0.7)',
              pointRadius: 6,
              pointHoverRadius: 8
            },
            {
              label: '预测值',
              data: [],
              backgroundColor: 'rgba(245, 63, 63, 0.7)',
              pointRadius: 6,
              pointHoverRadius: 8
            },
            {
              label: '预测曲线',
              data: [],
              borderColor: 'rgba(255, 125, 0, 1)',
              borderWidth: 3,
              pointRadius: 0,
              fill: false,
              type: 'line',
              tension: 0.4
            }
          ]
        },
        options: {
          responsive: true,
          maintainAspectRatio: false,
          scales: {
            x: {
              title: {
                display: true,
                text: '输入值 (x)',
                font: {
                  size: 14
                }
              },
              grid: {
                color: 'rgba(0, 0, 0, 0.05)'
              }
            },
            y: {
              title: {
                display: true,
                text: '输出值 (y)',
                font: {
                  size: 14
                }
              },
              grid: {
                color: 'rgba(0, 0, 0, 0.05)'
              }
            }
          },
          plugins: {
            legend: {
              position: 'top',
            },
            tooltip: {
              mode: 'index',
              intersect: false
            }
          },
          animation: {
            duration: 500
          }
        }
      }
    );
    
    const lossChart = new Chart(
      document.getElementById('lossChart'),
      {
        type: 'line',
        data: {
          labels: [],
          datasets: [{
            label: '损失值',
            data: [],
            borderColor: 'rgba(123, 97, 255, 1)',
            backgroundColor: 'rgba(123, 97, 255, 0.1)',
            borderWidth: 2,
            fill: true,
            tension: 0.3
          }]
        },
        options: {
          responsive: true,
          maintainAspectRatio: false,
          scales: {
            x: {
              title: {
                display: true,
                text: '训练轮次',
                font: {
                  size: 14
                }
              },
              grid: {
                color: 'rgba(0, 0, 0, 0.05)'
              }
            },
            y: {
              title: {
                display: true,
                text: '损失值',
                font: {
                  size: 14
                }
              },
              grid: {
                color: 'rgba(0, 0, 0, 0.05)'
              },
              ticks: {
                callback: function(value) {
                  return value.toFixed(4);
                }
              }
            }
          },
          plugins: {
            legend: {
              position: 'top',
            },
            tooltip: {
              callbacks: {
                label: function(context) {
                  return `损失值: ${context.raw.toFixed(6)}`;
                }
              }
            }
          },
          animation: {
            duration: 500
          }
        }
      }
    );
    
    // 初始化神经网络可视化
    function initNNVisualization() {
      const svg = document.getElementById('nnVisualization');
      svg.innerHTML = '';
      
      // 定义节点大小和间距
      const nodeRadius = 25;
      const layerSpacing = 200;
      const nodeSpacing = 80;
      
      // 输入层
      const inputLayerX = 150;
      const inputLayerY = 150;
      
      // 输出层
      const outputLayerX = inputLayerX + layerSpacing;
      const outputLayerY = 150;
      
      // 创建输入节点
      const inputNode = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
      inputNode.setAttribute('cx', inputLayerX);
      inputNode.setAttribute('cy', inputLayerY);
      inputNode.setAttribute('r', nodeRadius);
      inputNode.setAttribute('fill', '#E6F7FF');
      inputNode.setAttribute('stroke', '#165DFF');
      inputNode.setAttribute('stroke-width', '2');
      svg.appendChild(inputNode);
      
      // 输入节点标签
      const inputLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      inputLabel.setAttribute('x', inputLayerX);
      inputLabel.setAttribute('y', inputLayerY + 5);
      inputLabel.setAttribute('text-anchor', 'middle');
      inputLabel.setAttribute('font-size', '14');
      inputLabel.setAttribute('fill', '#1D2129');
      inputLabel.textContent = 'x';
      svg.appendChild(inputLabel);
      
      // 输入标签
      const inputLayerLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      inputLayerLabel.setAttribute('x', inputLayerX);
      inputLayerLabel.setAttribute('y', inputLayerY + nodeRadius + 30);
      inputLayerLabel.setAttribute('text-anchor', 'middle');
      inputLayerLabel.setAttribute('font-size', '16');
      inputLayerLabel.setAttribute('font-weight', 'bold');
      inputLayerLabel.setAttribute('fill', '#1D2129');
      inputLayerLabel.textContent = '输入层';
      svg.appendChild(inputLayerLabel);
      
      // 创建输出节点
      const outputNode = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
      outputNode.setAttribute('cx', outputLayerX);
      outputNode.setAttribute('cy', outputLayerY);
      outputNode.setAttribute('r', nodeRadius);
      outputNode.setAttribute('fill', '#FFF2E8');
      outputNode.setAttribute('stroke', '#FF7D00');
      outputNode.setAttribute('stroke-width', '2');
      svg.appendChild(outputNode);
      
      // 输出节点标签
      const outputLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      outputLabel.setAttribute('x', outputLayerX);
      outputLabel.setAttribute('y', outputLayerY + 5);
      outputLabel.setAttribute('text-anchor', 'middle');
      outputLabel.setAttribute('font-size', '14');
      outputLabel.setAttribute('fill', '#1D2129');
      outputLabel.textContent = 'y';
      svg.appendChild(outputLabel);
      
      // 输出标签
      const outputLayerLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      outputLayerLabel.setAttribute('x', outputLayerX);
      outputLayerLabel.setAttribute('y', outputLayerY + nodeRadius + 30);
      outputLayerLabel.setAttribute('text-anchor', 'middle');
      outputLayerLabel.setAttribute('font-size', '16');
      outputLayerLabel.setAttribute('font-weight', 'bold');
      outputLayerLabel.setAttribute('fill', '#1D2129');
      outputLayerLabel.textContent = '输出层';
      svg.appendChild(outputLayerLabel);
      
      // 创建权重连接
      const connection = document.createElementNS('http://www.w3.org/2000/svg', 'line');
      connection.setAttribute('x1', inputLayerX + nodeRadius);
      connection.setAttribute('y1', inputLayerY);
      connection.setAttribute('x2', outputLayerX - nodeRadius);
      connection.setAttribute('y2', outputLayerY);
      connection.setAttribute('stroke', '#86909C');
      connection.setAttribute('stroke-width', '2');
      connection.setAttribute('marker-end', 'url(#arrowhead)');
      svg.appendChild(connection);
      
      // 创建箭头标记
      const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs');
      svg.appendChild(defs);
      
      const marker = document.createElementNS('http://www.w3.org/2000/svg', 'marker');
      marker.setAttribute('id', 'arrowhead');
      marker.setAttribute('markerWidth', '10');
      marker.setAttribute('markerHeight', '7');
      marker.setAttribute('refX', '9');
      marker.setAttribute('refY', '3.5');
      marker.setAttribute('orient', 'auto');
      defs.appendChild(marker);
      
      const polygon = document.createElementNS('http://www.w3.org/2000/svg', 'polygon');
      polygon.setAttribute('points', '0 0, 10 3.5, 0 7');
      polygon.setAttribute('fill', '#86909C');
      marker.appendChild(polygon);
      
      // 权重标签
      const weightLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      weightLabel.setAttribute('x', (inputLayerX + outputLayerX) / 2);
      weightLabel.setAttribute('y', outputLayerY - 20);
      weightLabel.setAttribute('text-anchor', 'middle');
      weightLabel.setAttribute('font-size', '14');
      weightLabel.setAttribute('fill', '#1D2129');
      weightLabel.textContent = `w = ${w.toFixed(2)}`;
      svg.appendChild(weightLabel);
      
      // 偏置标签
      const biasLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      biasLabel.setAttribute('x', (inputLayerX + outputLayerX) / 2);
      biasLabel.setAttribute('y', outputLayerY + 30);
      biasLabel.setAttribute('text-anchor', 'middle');
      biasLabel.setAttribute('font-size', '14');
      biasLabel.setAttribute('fill', '#1D2129');
      biasLabel.textContent = `b = ${b.toFixed(2)}`;
      svg.appendChild(biasLabel);
      
      // 激活函数标签
      const activationLabel = document.createElementNS('http://www.w3.org/2000/svg', 'text');
      activationLabel.setAttribute('x', outputLayerX);
      activationLabel.setAttribute('y', outputLayerY - nodeRadius - 10);
      activationLabel.setAttribute('text-anchor', 'middle');
      activationLabel.setAttribute('font-size', '14');
      activationLabel.setAttribute('fill', '#1D2129');
      activationLabel.textContent = 'Sigmoid';
      svg.appendChild(activationLabel);
    }
    
    // 激活函数:Sigmoid
    function sigmoid(x) {
      return 1 / (1 + Math.exp(-x));
    }
    
    // 前向传播
    function forwardPropagation(x) {
      return sigmoid(w * x + b);
    }
    
    // 计算损失(均方误差)
    function calculateLoss() {
      if (trainingData.length === 0) return 0;
      
      let totalLoss = 0;
      for (let i = 0; i < trainingData.length; i++) {
        const x = trainingData[i][0];
        const y = trainingData[i][1];
        const prediction = forwardPropagation(x);
        totalLoss += Math.pow(prediction - y, 2);
      }
      return totalLoss / trainingData.length;
    }
    
    // 反向传播和参数更新
    function backPropagation() {
      if (trainingData.length === 0) return;
      
      let dw = 0;
      let db = 0;
      
      for (let i = 0; i < trainingData.length; i++) {
        const x = trainingData[i][0];
        const y = trainingData[i][1];
        
        // 前向传播
        const z = w * x + b;
        const a = sigmoid(z);
        
        // 计算梯度
        const error = a - y;
        const sigmoidDerivative = a * (1 - a);
        
        dw += error * sigmoidDerivative * x;
        db += error * sigmoidDerivative;
      }
      
      // 更新参数
      w -= (learningRate * dw) / trainingData.length;
      b -= (learningRate * db) / trainingData.length;
      
      // 更新UI中的参数
      weightInput.value = w.toFixed(4);
      biasInput.value = b.toFixed(4);
    }
    
    // 更新预测数据
    function updatePredictions() {
      predictions = [];
      const curvePoints = [];
      
      for (let i = 0; i < trainingData.length; i++) {
        const x = trainingData[i][0];
        const prediction = forwardPropagation(x);
        predictions.push({x: x, y: prediction});
      }
      
      // 生成预测曲线的点
      const minX = Math.min(...trainingData.map(point => point[0])) - 1;
      const maxX = Math.max(...trainingData.map(point => point[0])) + 1;
      
      for (let x = minX; x <= maxX; x += 0.1) {
        curvePoints.push({x: x, y: forwardPropagation(x)});
      }
      
      // 更新图表
      dataChart.data.datasets[1].data = predictions;
      dataChart.data.datasets[2].data = curvePoints;
      dataChart.update();
      
      // 更新神经网络可视化
      initNNVisualization();
    }
    
    // 训练模型
    function trainModel(epochs) {
      if (trainingData.length === 0) {
        alert('请先上传训练数据!');
        return;
      }
      
      isTraining = true;
      trainBtn.disabled = true;
      stopBtn.disabled = false;
      trainingStatus.textContent = '训练中...';
      
      let epoch = 0;
      
      // 使用setInterval进行分步训练,避免UI阻塞
      trainingInterval = setInterval(() => {
        if (epoch >= epochs || !isTraining) {
          clearInterval(trainingInterval);
          isTraining = false;
          trainBtn.disabled = false;
          stopBtn.disabled = true;
          trainingStatus.textContent = '训练完成';
          
          // 计算最终损失
          const loss = calculateLoss();
          lossValue.textContent = `损失: ${loss.toFixed(6)}`;
          
          return;
        }
        
        // 执行一轮训练
        backPropagation();
        const loss = calculateLoss();
        lossHistory.push(loss);
        
        // 每10个epoch更新一次图表,减少性能消耗
        if (epoch % 10 === 0) {
          updatePredictions();
          
          lossChart.data.labels = lossHistory.map((_, index) => index);
          lossChart.data.datasets[0].data = lossHistory;
          lossChart.update();
          
          lossValue.textContent = `损失: ${loss.toFixed(6)}`;
        }
        
        epoch++;
      }, 10);
    }
    
    // 停止训练
    function stopTraining() {
      clearInterval(trainingInterval);
      isTraining = false;
      trainBtn.disabled = false;
      stopBtn.disabled = true;
      trainingStatus.textContent = '训练已停止';
    }
    
    // 预测单个值
    function predictValue(x) {
      const prediction = forwardPropagation(x);
      predictionResult.textContent = prediction.toFixed(6);
      
      // 在图表上添加预测点
      dataChart.data.datasets.push({
        label: '当前预测',
        data: [{x: x, y: prediction}],
        backgroundColor: 'rgba(0, 180, 42, 1)',
        pointRadius: 8,
        pointHoverRadius: 10,
        showLine: false
      });
      dataChart.update();
      
      // 2秒后移除临时预测点
      setTimeout(() => {
        dataChart.data.datasets.pop();
        dataChart.update();
      }, 2000);
    }
    
    // 解析Excel文件
    function parseExcelFile(file) {
      const reader = new FileReader();
      
      reader.onload = function(e) {
        try {
          const data = new Uint8Array(e.target.result);
          const workbook = XLSX.read(data, {type: 'array'});
          
          // 获取第一个工作表
          const worksheetName = workbook.SheetNames[0];
          const worksheet = workbook.Sheets[worksheetName];
          
          // 转换为JSON
          const jsonData = XLSX.utils.sheet_to_json(worksheet, {header: 1});
          
          // 验证数据格式
          if (jsonData.length < 2) {
            throw new Error('数据行数不足,至少需要一行表头和一行数据');
          }
          
          // 清除旧数据
          trainingData = [];
          
          // 从第二行开始解析数据(跳过表头)
          for (let i = 1; i < jsonData.length; i++) {
            const row = jsonData[i];
            if (row.length >= 2) {
              const x = parseFloat(row[0]);
              const y = parseFloat(row[1]);
              
              if (!isNaN(x) && !isNaN(y)) {
                trainingData.push([x, y]);
              }
            }
          }
          
          if (trainingData.length === 0) {
            throw new Error('未能解析到有效的数据');
          }
          
          fileStatus.textContent = `已加载 ${trainingData.length} 条数据`;
          
          // 更新图表
          dataChart.data.datasets[0].data = trainingData.map(point => ({x: point[0], y: point[1]}));
          dataChart.update();
          
          // 重置损失图表
          lossHistory = [];
          lossChart.data.labels = [];
          lossChart.data.datasets[0].data = [];
          lossChart.update();
          
          // 启用训练按钮
          trainBtn.disabled = false;
          
        } catch (error) {
          fileStatus.textContent = '文件解析错误';
          console.error('Excel解析错误:', error);
          alert(`解析文件时出错: ${error.message}`);
        }
      };
      
      reader.onerror = function() {
        fileStatus.textContent = '文件读取错误';
        alert('读取文件时出错');
      };
      
      reader.readAsArrayBuffer(file);
    }
    
    // 事件监听器
    fileDropArea.addEventListener('click', () => {
      fileInput.click();
    });
    
    fileInput.addEventListener('change', (e) => {
      if (e.target.files.length > 0) {
        const file = e.target.files[0];
        fileStatus.textContent = `正在加载: ${file.name}`;
        parseExcelFile(file);
      }
    });
    
    // 拖放功能
    fileDropArea.addEventListener('dragover', (e) => {
      e.preventDefault();
      fileDropArea.classList.add('border-primary');
      fileDropArea.classList.add('bg-primary/5');
    });
    
    fileDropArea.addEventListener('dragleave', () => {
      fileDropArea.classList.remove('border-primary');
      fileDropArea.classList.remove('bg-primary/5');
    });
    
    fileDropArea.addEventListener('drop', (e) => {
      e.preventDefault();
      fileDropArea.classList.remove('border-primary');
      fileDropArea.classList.remove('bg-primary/5');
      
      if (e.dataTransfer.files.length > 0) {
        const file = e.dataTransfer.files[0];
        if (file.name.endsWith('.xlsx') || file.name.endsWith('.xls') || file.name.endsWith('.csv')) {
          fileStatus.textContent = `正在加载: ${file.name}`;
          parseExcelFile(file);
        } else {
          fileStatus.textContent = '请上传Excel文件(.xlsx, .xls, .csv)';
          alert('请上传Excel文件(.xlsx, .xls, .csv)');
        }
      }
    });
    
    trainBtn.addEventListener('click', () => {
      const epochs = parseInt(epochsInput.value, 10) || 100;
      trainModel(epochs);
    });
    
    stopBtn.addEventListener('click', stopTraining);
    
    predictBtn.addEventListener('click', () => {
      const x = parseFloat(predictionInput.value);
      if (!isNaN(x)) {
        predictValue(x);
      } else {
        alert('请输入有效的数值');
      }
    });
    
    // 参数变化事件
    weightInput.addEventListener('change', (e) => {
      w = parseFloat(e.target.value);
      updatePredictions();
    });
    
    biasInput.addEventListener('change', (e) => {
      b = parseFloat(e.target.value);
      updatePredictions();
    });
    
    learningRateInput.addEventListener('change', (e) => {
      learningRate = parseFloat(e.target.value);
    });
    
    // 重置按钮事件
    resetWeightBtn.addEventListener('click', () => {
      w = 0.5;
      weightInput.value = w;
      updatePredictions();
    });
    
    resetBiasBtn.addEventListener('click', () => {
      b = 0.0;
      biasInput.value = b;
      updatePredictions();
    });
    
    resetLearningRateBtn.addEventListener('click', () => {
      learningRate = 0.1;
      learningRateInput.value = learningRate;
    });
    
    // 初始化
    function init() {
      // 禁用训练按钮直到数据加载完成
      trainBtn.disabled = true;
      
      // 初始化神经网络可视化
      initNNVisualization();
      
      // 加载SheetJS库
      const script = document.createElement('script');
      script.src = 'https://cdn.jsdelivr.net/npm/xlsx@0.18.5/dist/xlsx.full.min.js';
      script.onload = function() {
        fileStatus.textContent = '请上传Excel训练数据';
      };
      document.body.appendChild(script);
    }
    
    // 启动应用
    init();
  </script>
</body>
</html>
    

 通过以上完整的参数迭代更新和预测过程,我们能够利用构建好的神经网络准确预测车厘子 的价格。如果在实际操作中遇到问题,或者想要尝试不同的初始参数和数据,欢迎随时交流。 上述内容清晰展示了模型迭代优化与预测的全流程。若你对某部分公式推导、Excel 操作细 节还有疑问,或者想尝试不同设定,随时告诉我。联系抖音号码938129762,索要资料、课 件等。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐