模型保存与加载

  • save/load weight 最轻量级的方式,只保存网络参数,适用于对网络参数了解,且有源代码
  • save/load entire model 保存全部模型
  • save_model 一种协议通用模式 可以通过其他环境进行读取,比如C++

save/load weight

1
2
3
4
5
6
#保存
model.save_weight('./checkpoint/my_checkpoint.ckpt')
#恢复
model = create_model()
model.load_weights('./checkpoint/my_checkpoint.ckpt')
loss,acc = model.evaluate(test_images,test_labels)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
....省略
network.evaluate(ds_val)

network.save_weights('weights.ckpt')
print('saved weights.')

del network #删除

network
= Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)
1
2
3
4
79/79 [==============================] - 0s 6ms/step - loss: 0.1275 - accuracy: 0.9676
saved weights.
loaded weights!
79/79 [==============================] - 1s 7ms/step - loss: 0.1275 - accuracy: 0.9676

save/load entire model

1
2
3
4
5
6
7
network.save('model.h5')
print('save entire model')
del network

print('load entire model')
network
= tf.keras.models.load_model('model.h5')
network.evaluate(x_val,y_val)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
...省略
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

network.save('model.h5')
print('saved total model.')
del network

print('loaded model from file.')
network
= tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

network.evaluate(ds_val)
1
2
3
4
79/79 [==============================] - 0s 6ms/step - loss: 0.1268 - accuracy: 0.9645
saved total model.
loaded model from file.
79/79 [==============================] - 1s 7ms/step - loss: 0.1268 - accuracy: 0.9645

save_model

工业中使用model

1
2
3
4
tf.saved_model.save(network,'save')#相对路径新建save文件夹

imported = tf.saved_model.load('save')
f = imported.signatures[“serving_default”]
打赏
  • 版权声明: 本博客所有文章除特别声明外,均采用 Apache License 2.0 许可协议。转载请注明出处!

扫一扫,分享到微信

微信分享二维码

请我喝杯咖啡吧~

支付宝
微信