🎬 HoRain 云小助手个人主页

⛺️生活的理想,就是为了理想的生活!


⛳️ 推荐

前些天发现了一个超棒的服务器购买网站,性价比超高,大内存超划算!忍不住分享一下给大家。点击跳转到网站。

目录

⛳️ 推荐

1️⃣ 保存与加载整个模型

​​方法一:HDF5 格式 (.h5)​​

​​方法二:SavedModel 格式 (TensorFlow 默认格式)​​

2️⃣ 仅保存与加载模型权重

3️⃣ 仅保存与加载模型结构

4️⃣ 转换为 TensorFlow Lite 格式用于移动端部署

5️⃣ 转换为 ONNX 格式用于跨框架交换

💡 模型保存与加载的最佳实践


TensorFlow 提供了多种灵活的方式来保存和加载模型,以满足不同场景的需求。下面我将这些方法整合到一个表格中,方便你快速了解核心的保存与加载方式、适用场景以及关键操作:

保存内容 核心方法/格式 适用场景 关键操作 (保存) 关键操作 (加载)
​整个模型​ ​HDF5 格式​​ (.h5) 需要​​单独文件​​保存​​所有信息​​(结构、权重、优化器状态等),便于分享和归档 model.save('model.h5') model = tf.keras.models.load_model('model.h5')
​SavedModel 格式​​ (目录) ​TensorFlow 生态推荐​​,适合​​生产部署​​、跨平台(服务器、移动端、Web) model.save('saved_model/') model = tf.keras.models.load_model('saved_model/')
​仅模型权重​ ​HDF5 格式​​ (.h5) ​仅需保存训练好的参数​​,模型结构代码已存在且固定 model.save_weights('weights.h5') model.load_weights('weights.h5') (需先构建​​相同结构​​的模型)
​Checkpoint 格式​​ (.ckpt) ​训练过程中定期保存​​权重,防止意外中断后可​​从检查点恢复训练​ model.save_weights('checkpoint.ckpt') model.load_weights('checkpoint.ckpt') (需先构建​​相同结构​​的模型)
​仅模型结构​ ​JSON 字符串​ 需要​​单独保存模型架构​​信息(无权重) json_string = model.to_json() model = tf.keras.models.model_from_json(json_string)
​生产与移动端部署​ ​TensorFlow Lite​​ (.tflite) 将模型部署到​​移动设备​​、​​嵌入式系统​​或 ​​IoT 设备​ converter = tf.lite.TFLiteConverter.from_keras_model(model) 使用 TFLite 解释器加载 (tf.lite.Interpreter)
​跨框架交换​ ​ONNX 格式​​ (.onnx) 需要在​​不同深度学习框架​​(如 PyTorch, MXNet)之间转换和共享模型 使用 tf2onnx.convert 工具进行转换 使用 ONNX Runtime 等支持 ONNX 的引擎进行加载和推理

下面是这些方法的详细说明和代码示例:

1️⃣ 保存与加载整个模型

保存整个模型是最简单直接的方式,它会将​​模型架构、权重、训练配置(如优化器及其状态)、甚至损失和指标函数​​一并保存。这对于​​模型归档、分享或需要从中断处继续训练​​非常方便。

​方法一:HDF5 格式 (.h5)​
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np

# 1. 构建并训练一个简单模型
model = Sequential([
    Dense(64, activation='relu', input_shape=(10,)),
    Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 生成虚拟数据并训练
x_train = np.random.random((1000, 10))
y_train = np.random.randint(2, size=(1000, 1))
model.fit(x_train, y_train, epochs=5)

# 2. 保存整个模型为 HDF5 格式
model.save('my_complete_model.h5')  # 保存为 .h5 文件

# 3. 加载整个模型
loaded_model_from_h5 = tf.keras.models.load_model('my_complete_model.h5')

# 4. 使用加载的模型进行预测
predictions = loaded_model_from_h5.predict(x_train[:5])
print(predictions)
​方法二:SavedModel 格式 (TensorFlow 默认格式)​

SavedModel 是 TensorFlow ​​首选的部署和生产格式​​。它会创建一个包含模型架构、权重及 TensorFlow 计算图的目录。

# 保存为 SavedModel 格式 (会创建一个目录)
model.save('my_savedmodel_directory')

# 加载 SavedModel 格式的模型
loaded_model_from_savedmodel = tf.keras.models.load_model('my_savedmodel_directory')

# 使用加载的模型
predictions = loaded_model_from_savedmodel.predict(x_train[:5])

​HDF5 与 SavedModel 的选择​​:

  • 如果需要​​单个文件​​便于管理或与某些传统工具兼容,可选择 HDF5。
  • 对于​​生产环境、跨平台部署​​或使用 ​​TensorFlow Serving​​,​​SavedModel 是更推荐和更强大的选择​​。

2️⃣ 仅保存与加载模型权重

如果只需要保存训练好的参数,或者打算在​​相同架构的模型间共享权重​​,可以只保存权重。​​加载权重时,必须预先构建一个与保存时完全相同的模型架构​​。

# 保存模型权重
model.save_weights('my_model_weights.h5')  # 默认为 HDF5 格式
# 或 model.save_weights('my_checkpoint.ckpt') # 也可使用 .ckpt 扩展名

# 构建一个与之前架构完全相同的模型
new_model = Sequential([
    Dense(64, activation='relu', input_shape=(10,)),
    Dense(1, activation='sigmoid')
])
new_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 将保存的权重加载到新模型上
new_model.load_weights('my_model_weights.h5')

# 使用加载了权重的模型进行预测
new_predictions = new_model.predict(x_train[:5])

​用途​​:​​迁移学习​​中常用此方法加载预训练权重;也用于​​训练过程中定期保存检查点(Checkpoint)​​,防止训练意外中断后可以从最近的检查点恢复,而不是从头开始。

# 在训练过程中使用回调函数定期保存检查点
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(filepath='training_checkpoints/ckpt-{epoch:02d}',
                                      save_weights_only=True, # 只保存权重
                                      save_freq='epoch') # 每个epoch结束后保存

model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint_callback])

# 恢复模型并从指定检查点加载权重
latest_checkpoint = 'training_checkpoints/ckpt-05' # 假设从第5个epoch的检查点恢复
model.load_weights(latest_checkpoint)

3️⃣ 仅保存与加载模型结构

如果只想​​保存模型的架构​​(层、配置),而不包含权重和训练配置,可以将其序列化为 JSON 或 YAML 字符串。

# 将模型架构序列化为 JSON 字符串
model_json = model.to_json()
# 也可以保存到文件
with open('model_architecture.json', 'w') as json_file:
    json_file.write(model_json)

# 从 JSON 字符串重建模型架构 (此时模型尚未加载权重)
reconstructed_model_from_json = tf.keras.models.model_from_json(model_json)
# 或者从文件读取
with open('model_architecture.json', 'r') as json_file:
    loaded_json_string = json_file.read()
reconstructed_model_from_json = tf.keras.models.model_from_json(loaded_json_string)

# 注意:重建的模型需要编译,并且如果需要使用,必须加载相应的权重
reconstructed_model_from_json.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# reconstructed_model_from_json.load_weights('weights.h5') # 然后再加载权重

4️⃣ 转换为 TensorFlow Lite 格式用于移动端部署

要将模型部署到手机、嵌入式设备等资源受限的环境,需要将其转换为 TensorFlow Lite 格式。

# 转换模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# 保存转换后的模型
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

# 加载和运行 TFLite 模型(通常在目标设备上进行)
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()

# 获取输入输出张量详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 准备输入数据(需要符合模型期望的形状和类型)
input_data = np.array(np.random.random_sample(input_details[0]['shape']), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

# 运行推理
interpreter.invoke()

# 获取输出结果
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

5️⃣ 转换为 ONNX 格式用于跨框架交换

ONNX 是一种开放的模型格式,支持在不同深度学习框架(如 PyTorch, TensorFlow, MXNet)之间转换模型。

# 安装转换工具 tf2onnx
# pip install tf2onnx

# 使用命令行工具将 SavedModel 转换为 ONNX
# python -m tf2onnx.convert --saved-model saved_model --output model.onnx

💡 模型保存与加载的最佳实践

  1. ​版本控制​​:保存模型时,建议在文件名中加入​​版本号或日期​​(如 model_v1.2.h5),便于管理不同迭代版本的模型。
  2. ​记录环境信息​​:记录下训练和保存模型时使用的 ​​TensorFlow、Python 及主要依赖库的版本​​,以便在加载时复现相同的环境,避免兼容性问题。
  3. ​验证加载的模型​​:加载模型后,最好用一些测试数据验证其预测结果是否与保存前一致。
  4. ​自定义对象处理​​:如果模型包含​​自定义层、损失函数或指标​​,在加载时需要通过 custom_objects 参数将它们提供给 load_model 函数。
    # 假设模型包含一个名为 CustomLayer 的自定义层
    loaded_model_with_custom_layer = tf.keras.models.load_model(
        'model_with_custom_layer.h5',
        custom_objects={'CustomLayer': CustomLayer}
    )
  5. ​生产环境部署​​:
    • 对于 ​​TensorFlow Serving​​,使用 ​​SavedModel​​ 格式是标准做法。
    • 对于​​移动端或嵌入式设备​​,使用 ​​TensorFlow Lite​​ 格式并进行可能的量化以优化性能和体积。
    • 考虑使用 ​​ONNX​​ 格式 if you need to serve your model in environments that might use different runtimes or hardware accelerators that have better support for ONNX.

希望这份详细的指南能帮助你更好地理解和运用 TensorFlow 的模型保存与加载功能!

❤️❤️❤️本人水平有限,如有纰漏,欢迎各位大佬评论批评指正!😄😄😄

💘💘💘如果觉得这篇文对你有帮助的话,也请给个点赞、收藏下吧,非常感谢!👍 👍 👍

🔥🔥🔥Stay Hungry Stay Foolish 道阻且长,行则将至,让我们一起加油吧!🌙🌙🌙

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐