LSTM-KNN融合模型:让AI既有记忆又会"查字典"

在这里插入图片描述

引言:预测股价,你需要两个"大脑"

假设你是一位股票交易员,每天早上醒来要预测今天的股价走势。你会怎么做?

方法一:看趋势
你翻开过去三个月的K线图,发现"每次跌破20日均线后,通常会反弹",这就是记忆规律

方法二:找相似
你发现今天的市场情况(成交量暴增、外资流入)和去年6月某天极其相似,那天之后连涨三天,这就是查找历史相似案例

如果把这两种方法结合起来呢?这就是今天要讲的LSTM-KNN融合模型

  • LSTM:负责记忆长期规律(像人的"左脑")
  • KNN:负责查找相似案例(像人的"右脑")

第一部分:LSTM——时间的"记忆大师"

为什么需要LSTM?

想象你在读一本侦探小说,第1章提到"凶手左手有疤",读到第15章时,这个线索突然变得关键。普通的神经网络就像"金鱼记忆",读完前面章节就忘了,但LSTM能记住重要信息

在金融市场中:

  • 上周美联储加息(重要信息)
  • 昨天某只股票微涨0.3%(次要信息)
  • 三个月前的季报利好(可能仍有影响)

LSTM会自动判断哪些信息该记住、哪些该忘记。

LSTM的"三道门"机制

把LSTM想象成一个保险箱,有三道门控制信息流动:

1. 遗忘门(Forget Gate)

作用:决定丢弃哪些过时信息

生活类比:你的大脑会忘记"上周三午餐吃了什么",但记得"第一次约会的餐厅"。

数学表达

f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
  • σ 是sigmoid函数,输出0到1之间的数
  • 输出接近0:彻底忘记这条信息
  • 输出接近1:完全保留这条信息

实例:假设模型记住了"这只股票处于上升趋势",但突然出现暴跌,遗忘门会把"上升趋势"这个记忆的权重降低。


2. 输入门(Input Gate)

作用:决定接受多少新信息

生活类比:听到"今晚有暴雨预警"(重要信息),你会全力记住;听到"路边有只猫"(无关信息),你可能忽略。

数学表达

i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C)
  • i_t:新信息的重要性评分
  • C̃_t:候选的新记忆内容

实例:当看到"公司发布超预期财报"这个消息时,输入门会打开,让这条重要信息进入记忆。


3. 输出门(Output Gate)

作用:决定输出什么预测结果

生活类比:考试时,你脑子里有很多知识,但只输出与题目相关的答案。

数学表达

o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o_t * tanh(C_t)

实例:综合记忆中的所有信息(历史趋势、成交量、财报等),输出"明天股价上涨概率70%"。


LSTM的完整工作流程

用一个5天股价预测的例子说明:

日期 收盘价 LSTM处理过程
Day 1 100元 记忆:起始价格100元
Day 2 102元 遗忘门:保留"上涨趋势";输入门:记住"涨2%"
Day 3 98元 遗忘门:降低"上涨趋势"权重;输入门:记住"下跌4%"
Day 4 99元 输出门:综合判断,倾向"小幅震荡"
Day 5 预测:根据前4天的记忆,预测Day 5价格

第二部分:KNN——历史的"相似度搜索引擎"

KNN的核心思想

类比:你想知道今晚约会该穿什么衣服,你会怎么做?

  1. 回想过去类似的约会场景(西餐厅、看电影、爬山…)
  2. 找出最相似的3次约会(都是去高档西餐厅)
  3. 看看那3次你穿了什么(2次穿西装、1次穿休闲装)
  4. 投票决定:今晚穿西装!

KNN的逻辑完全一样,只是把"相似约会"换成"相似股价走势"。


KNN的四步工作流程

步骤1:构造特征向量

把每一天的市场状态用数字表示:

今天的特征 = [
    过去5天平均涨跌幅,
    成交量变化率,
    RSI指标,
    MACD值,
    外资流入量
]

例如:今天 = [+2.3%, +15%, 65, 0.8, +500]

步骤2:计算距离

欧式距离衡量"今天"和"历史某一天"有多像:

距离 = √[(今天涨跌幅 - 历史某天涨跌幅)² + (今天成交量 - 历史某天成交量)² + ...]

直观理解

  • 距离 = 0:完全相同(不可能)
  • 距离 < 10:非常相似
  • 距离 > 100:差异很大

步骤3:找到K个最近邻

假设K=5,我们找出历史上距离最小的5天:

历史日期 距离 次日涨跌
2023-03-15 8.2 +3.1%
2023-07-22 9.5 +2.8%
2023-09-10 11.3 -1.2%
2023-11-05 12.1 +1.5%
2024-01-18 13.7 +2.3%

步骤4:加权预测

给距离越近的样本更高的权重:

预测值 = (3.1%/8.2 + 2.8%/9.5 + ... ) / (1/8.2 + 1/9.5 + ...)
      ≈ +2.1%

结论:KNN预测明天股价上涨约2.1%


第三部分:LSTM-KNN融合——"1+1>2"的智慧

为什么要融合?

单独使用LSTM或KNN都有局限性:

模型 优势 劣势
LSTM 能捕捉长期趋势,理解"最近三个月在涨" 对突发事件反应慢(如突然暴跌)
KNN 能快速找到相似历史场景,适应突变 忽略时间顺序,可能找到不相关的案例

融合的目标:让LSTM提供"大趋势判断",让KNN提供"局部相似性修正"。


融合方法一:加权平均(Late Fusion)

最简单直接的方法:

最终预测 = α × LSTM预测 + (1-α) × KNN预测

实战案例

LSTM预测明天涨2.5%
KNN预测明天涨1.8%

# 设置权重:LSTM占60%,KNN占40%
α = 0.6
最终预测 = 0.6 × 2.5% + 0.4 × 1.8% = 2.22%

权重如何选择?

  • 市场平稳期:α = 0.7(更信任LSTM的趋势判断)
  • 市场剧烈波动期:α = 0.4(更信任KNN找相似案例)

融合方法二:级联融合(Stacking)

让第三个模型学习"如何组合LSTM和KNN":

步骤1:LSTM输出 → [2.5%, 置信度0.8]
步骤2:KNN输出  → [1.8%, 置信度0.9]
步骤3:元模型(如逻辑回归)学习:
       - 当两者预测一致时 → 高权重给高置信度的
       - 当两者预测矛盾时 → 降低整体置信度
步骤4:输出最终预测

优势:元模型能自动学习在什么情况下该信任谁。


第四部分:实战代码框架

数据准备

import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsRegressor
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense

# 加载股票数据
data = pd.read_csv('stock_data.csv')
features = ['close', 'volume', 'ma5', 'ma20', 'rsi']
X = data[features].values
y = data['next_day_return'].values  # 次日涨跌幅

# 划分训练集和测试集
split = int(len(X) * 0.8)
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

LSTM模型

# 重塑数据为LSTM所需格式 [样本数, 时间步, 特征数]
X_train_lstm = X_train.reshape((X_train.shape[0], 1, X_train.shape[1]))
X_test_lstm = X_test.reshape((X_test.shape[0], 1, X_test.shape[1]))

# 构建LSTM模型
model_lstm = Sequential([
    LSTM(50, activation='relu', input_shape=(1, X_train.shape[1])),
    Dense(25, activation='relu'),
    Dense(1)  # 输出:预测的涨跌幅
])

model_lstm.compile(optimizer='adam', loss='mse')
model_lstm.fit(X_train_lstm, y_train, epochs=50, batch_size=32, verbose=0)

# 预测
lstm_pred = model_lstm.predict(X_test_lstm).flatten()

KNN模型

# KNN不需要重塑数据
model_knn = KNeighborsRegressor(n_neighbors=5, weights='distance')
model_knn.fit(X_train, y_train)

# 预测
knn_pred = model_knn.predict(X_test)

融合预测

# 方法一:简单加权
alpha = 0.6
fusion_pred = alpha * lstm_pred + (1 - alpha) * knn_pred

# 方法二:动态权重(根据历史表现调整)
from sklearn.linear_model import LinearRegression

# 用验证集训练元模型
meta_model = LinearRegression()
meta_features = np.column_stack([lstm_pred[:100], knn_pred[:100]])
meta_model.fit(meta_features, y_test[:100])

# 最终预测
final_pred = meta_model.predict(np.column_stack([lstm_pred, knn_pred]))

第五部分:性能评估与优化技巧

评估指标

from sklearn.metrics import mean_squared_error, mean_absolute_error

# 均方误差(MSE):越小越好
mse_lstm = mean_squared_error(y_test, lstm_pred)
mse_knn = mean_squared_error(y_test, knn_pred)
mse_fusion = mean_squared_error(y_test, fusion_pred)

print(f"LSTM MSE: {mse_lstm:.4f}")
print(f"KNN MSE: {mse_knn:.4f}")
print(f"融合模型 MSE: {mse_fusion:.4f}")

# 方向准确率:预测涨跌方向是否正确
def direction_accuracy(y_true, y_pred):
    return np.mean((y_true > 0) == (y_pred > 0))

print(f"融合模型方向准确率: {direction_accuracy(y_test, fusion_pred):.2%}")

优化技巧

1. 特征工程
# 添加更多技术指标
data['momentum'] = data['close'].pct_change(periods=10)
data['volatility'] = data['close'].rolling(20).std()
data['volume_ma'] = data['volume'].rolling(10).mean()
2. 超参数调优
# KNN的K值选择
from sklearn.model_selection import GridSearchCV

param_grid = {'n_neighbors': [3, 5, 7, 10, 15]}
grid_search = GridSearchCV(KNeighborsRegressor(), param_grid, cv=5)
grid_search.fit(X_train, y_train)
best_k = grid_search.best_params_['n_neighbors']
3. 集成多个时间窗口
# 短期LSTM(看10天)+ 长期LSTM(看60天)+ KNN
lstm_short = build_lstm(window=10)
lstm_long = build_lstm(window=60)
fusion_pred = 0.3*lstm_short + 0.3*lstm_long + 0.4*knn_pred

总结:何时使用LSTM-KNN融合模型?

适用场景

✅ 金融市场预测(股票、期货、外汇)
✅ 能源负荷预测(电力需求)
✅ 网络流量预测
✅ 销售额预测

不适用场景

❌ 数据量太少(<1000条)
❌ 数据完全随机无规律
❌ 需要实时毫秒级预测


关键要点回顾

  1. LSTM:善于记忆长期模式,通过三道门(遗忘、输入、输出)管理信息
  2. KNN:善于查找历史相似案例,通过距离计算找到K个最近邻
  3. 融合策略:加权平均简单有效,级联融合更智能但复杂
  4. 实战建议:先单独调优LSTM和KNN,再调整融合权重

延伸阅读

  • Hochreiter & Schmidhuber (1997) - LSTM原始论文
  • “Hands-On Machine Learning” by Aurélien Géron
  • 金融时间序列分析:ARIMA vs LSTM对比
  • 注意力机制在时间序列预测中的应用

作者寄语
预测未来从来不是精确科学,LSTM-KNN也不是"圣杯"。模型只是工具,真正的智慧在于理解市场背后的逻辑,结合多种方法做出理性决策。祝你在时间序列预测的道路上越走越远!


如果这篇文章对你有帮助,欢迎点赞、收藏、转发!有问题欢迎在评论区讨论~

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐