HoRain云--TensorFlow模型保存与加载全攻略
TensorFlow模型保存与加载指南 本文详细介绍了TensorFlow中多种模型保存与加载方式: 完整模型保存:支持HDF5(.h5)和SavedModel格式,适用于模型归档和部署 仅保存权重:适用于迁移学习和训练检查点 仅保存结构:可序列化为JSON/YAML格式 移动端部署:转换为TensorFlow Lite(.tflite)格式 跨框架交换:支持转换为ONNX格式 最佳实践建议包括版
🎬 HoRain 云小助手:个人主页
⛺️生活的理想,就是为了理想的生活!
⛳️ 推荐
前些天发现了一个超棒的服务器购买网站,性价比超高,大内存超划算!忍不住分享一下给大家。点击跳转到网站。
目录
方法二:SavedModel 格式 (TensorFlow 默认格式)
4️⃣ 转换为 TensorFlow Lite 格式用于移动端部署
TensorFlow 提供了多种灵活的方式来保存和加载模型,以满足不同场景的需求。下面我将这些方法整合到一个表格中,方便你快速了解核心的保存与加载方式、适用场景以及关键操作:
保存内容 | 核心方法/格式 | 适用场景 | 关键操作 (保存) | 关键操作 (加载) |
---|---|---|---|---|
整个模型 | HDF5 格式 (.h5 ) |
需要单独文件保存所有信息(结构、权重、优化器状态等),便于分享和归档 | model.save('model.h5') |
model = tf.keras.models.load_model('model.h5') |
SavedModel 格式 (目录) | TensorFlow 生态推荐,适合生产部署、跨平台(服务器、移动端、Web) | model.save('saved_model/') |
model = tf.keras.models.load_model('saved_model/') |
|
仅模型权重 | HDF5 格式 (.h5 ) |
仅需保存训练好的参数,模型结构代码已存在且固定 | model.save_weights('weights.h5') |
model.load_weights('weights.h5') (需先构建相同结构的模型) |
Checkpoint 格式 (.ckpt ) |
训练过程中定期保存权重,防止意外中断后可从检查点恢复训练 | model.save_weights('checkpoint.ckpt') |
model.load_weights('checkpoint.ckpt') (需先构建相同结构的模型) |
|
仅模型结构 | JSON 字符串 | 需要单独保存模型架构信息(无权重) | json_string = model.to_json() |
model = tf.keras.models.model_from_json(json_string) |
生产与移动端部署 | TensorFlow Lite (.tflite ) |
将模型部署到移动设备、嵌入式系统或 IoT 设备 | converter = tf.lite.TFLiteConverter.from_keras_model(model) |
使用 TFLite 解释器加载 (tf.lite.Interpreter ) |
跨框架交换 | ONNX 格式 (.onnx ) |
需要在不同深度学习框架(如 PyTorch, MXNet)之间转换和共享模型 | 使用 tf2onnx.convert 工具进行转换 |
使用 ONNX Runtime 等支持 ONNX 的引擎进行加载和推理 |
下面是这些方法的详细说明和代码示例:
1️⃣ 保存与加载整个模型
保存整个模型是最简单直接的方式,它会将模型架构、权重、训练配置(如优化器及其状态)、甚至损失和指标函数一并保存。这对于模型归档、分享或需要从中断处继续训练非常方便。
方法一:HDF5 格式 (.h5
)
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
# 1. 构建并训练一个简单模型
model = Sequential([
Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 生成虚拟数据并训练
x_train = np.random.random((1000, 10))
y_train = np.random.randint(2, size=(1000, 1))
model.fit(x_train, y_train, epochs=5)
# 2. 保存整个模型为 HDF5 格式
model.save('my_complete_model.h5') # 保存为 .h5 文件
# 3. 加载整个模型
loaded_model_from_h5 = tf.keras.models.load_model('my_complete_model.h5')
# 4. 使用加载的模型进行预测
predictions = loaded_model_from_h5.predict(x_train[:5])
print(predictions)
方法二:SavedModel 格式 (TensorFlow 默认格式)
SavedModel 是 TensorFlow 首选的部署和生产格式。它会创建一个包含模型架构、权重及 TensorFlow 计算图的目录。
# 保存为 SavedModel 格式 (会创建一个目录)
model.save('my_savedmodel_directory')
# 加载 SavedModel 格式的模型
loaded_model_from_savedmodel = tf.keras.models.load_model('my_savedmodel_directory')
# 使用加载的模型
predictions = loaded_model_from_savedmodel.predict(x_train[:5])
HDF5 与 SavedModel 的选择:
- 如果需要单个文件便于管理或与某些传统工具兼容,可选择 HDF5。
- 对于生产环境、跨平台部署或使用 TensorFlow Serving,SavedModel 是更推荐和更强大的选择。
2️⃣ 仅保存与加载模型权重
如果只需要保存训练好的参数,或者打算在相同架构的模型间共享权重,可以只保存权重。加载权重时,必须预先构建一个与保存时完全相同的模型架构。
# 保存模型权重
model.save_weights('my_model_weights.h5') # 默认为 HDF5 格式
# 或 model.save_weights('my_checkpoint.ckpt') # 也可使用 .ckpt 扩展名
# 构建一个与之前架构完全相同的模型
new_model = Sequential([
Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')
])
new_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 将保存的权重加载到新模型上
new_model.load_weights('my_model_weights.h5')
# 使用加载了权重的模型进行预测
new_predictions = new_model.predict(x_train[:5])
用途:迁移学习中常用此方法加载预训练权重;也用于训练过程中定期保存检查点(Checkpoint),防止训练意外中断后可以从最近的检查点恢复,而不是从头开始。
# 在训练过程中使用回调函数定期保存检查点
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(filepath='training_checkpoints/ckpt-{epoch:02d}',
save_weights_only=True, # 只保存权重
save_freq='epoch') # 每个epoch结束后保存
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint_callback])
# 恢复模型并从指定检查点加载权重
latest_checkpoint = 'training_checkpoints/ckpt-05' # 假设从第5个epoch的检查点恢复
model.load_weights(latest_checkpoint)
3️⃣ 仅保存与加载模型结构
如果只想保存模型的架构(层、配置),而不包含权重和训练配置,可以将其序列化为 JSON 或 YAML 字符串。
# 将模型架构序列化为 JSON 字符串
model_json = model.to_json()
# 也可以保存到文件
with open('model_architecture.json', 'w') as json_file:
json_file.write(model_json)
# 从 JSON 字符串重建模型架构 (此时模型尚未加载权重)
reconstructed_model_from_json = tf.keras.models.model_from_json(model_json)
# 或者从文件读取
with open('model_architecture.json', 'r') as json_file:
loaded_json_string = json_file.read()
reconstructed_model_from_json = tf.keras.models.model_from_json(loaded_json_string)
# 注意:重建的模型需要编译,并且如果需要使用,必须加载相应的权重
reconstructed_model_from_json.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# reconstructed_model_from_json.load_weights('weights.h5') # 然后再加载权重
4️⃣ 转换为 TensorFlow Lite 格式用于移动端部署
要将模型部署到手机、嵌入式设备等资源受限的环境,需要将其转换为 TensorFlow Lite 格式。
# 转换模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# 保存转换后的模型
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
# 加载和运行 TFLite 模型(通常在目标设备上进行)
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
# 获取输入输出张量详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 准备输入数据(需要符合模型期望的形状和类型)
input_data = np.array(np.random.random_sample(input_details[0]['shape']), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
# 运行推理
interpreter.invoke()
# 获取输出结果
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
5️⃣ 转换为 ONNX 格式用于跨框架交换
ONNX 是一种开放的模型格式,支持在不同深度学习框架(如 PyTorch, TensorFlow, MXNet)之间转换模型。
# 安装转换工具 tf2onnx
# pip install tf2onnx
# 使用命令行工具将 SavedModel 转换为 ONNX
# python -m tf2onnx.convert --saved-model saved_model --output model.onnx
💡 模型保存与加载的最佳实践
- 版本控制:保存模型时,建议在文件名中加入版本号或日期(如
model_v1.2.h5
),便于管理不同迭代版本的模型。 - 记录环境信息:记录下训练和保存模型时使用的 TensorFlow、Python 及主要依赖库的版本,以便在加载时复现相同的环境,避免兼容性问题。
- 验证加载的模型:加载模型后,最好用一些测试数据验证其预测结果是否与保存前一致。
- 自定义对象处理:如果模型包含自定义层、损失函数或指标,在加载时需要通过
custom_objects
参数将它们提供给load_model
函数。# 假设模型包含一个名为 CustomLayer 的自定义层 loaded_model_with_custom_layer = tf.keras.models.load_model( 'model_with_custom_layer.h5', custom_objects={'CustomLayer': CustomLayer} )
- 生产环境部署:
- 对于 TensorFlow Serving,使用 SavedModel 格式是标准做法。
- 对于移动端或嵌入式设备,使用 TensorFlow Lite 格式并进行可能的量化以优化性能和体积。
- 考虑使用 ONNX 格式 if you need to serve your model in environments that might use different runtimes or hardware accelerators that have better support for ONNX.
希望这份详细的指南能帮助你更好地理解和运用 TensorFlow 的模型保存与加载功能!
❤️❤️❤️本人水平有限,如有纰漏,欢迎各位大佬评论批评指正!😄😄😄
💘💘💘如果觉得这篇文对你有帮助的话,也请给个点赞、收藏下吧,非常感谢!👍 👍 👍
🔥🔥🔥Stay Hungry Stay Foolish 道阻且长,行则将至,让我们一起加油吧!🌙🌙🌙
更多推荐
所有评论(0)