Pytorch 系列 - 1. 曲线拟合
1. 曲线拟合问题
考虑以下曲线拟合问题
- 给定一组数据(满足某种 未知映射,可能还会附加一些噪声)
- 利用
pytorch搭建神经网络,基于给定数据,尝试重构 未知映射 - 利用
pytorch进行神经网络参数学习,将迭代过程中每次曲线拟合效果动态的展示出来(与原始给定数据一起展示)
1 | |
2. pytorch 保存与加载神经网络模型
2.1 两种保存神经网络模型的方式
- 保存完整模型(不推荐):
modeltorch.save(model, "path/filename.pth")- 保存模型的所有信息,包括结构和参数。
- 这种方式文件较大,不建议在生产环境中使用。
- 保存模型参数(推荐):
model.state_dict()torch.save(model.state_dict(), "path/filename.pth")- 只保存模型的参数(权重和偏置),而不保存模型的结构。
- 这种方式更轻量级,适用于部署场景。
- 加载模型时,需要先定义模型结构,然后加载参数到模型中。
- 两种模式保存的文件类型都是
.pth- 什么是
.pth?.pth文件本质上就是torch.save存的二进制文件,通常作为模型/参数/checkpoint的容器。- 常见神经网络模型存储文件格式的还有
.pt、.pth.tar,本质一样。 - 存储原理:
torch.save()=Python 的 pickle+PyTorch 的 tensor - 由于是二进制文件,只能用
PyTorch正确读取,不是人类可读文本。
.pth里可以保存什么?- 任何可以被
pickle序列化的Python 对象 - 模型参数(推荐)
- 整个模型对象(不太推荐,只做演示)
- 训练 checkpoint(强烈推荐) 可以包含
- 模型参数
- 优化器参数(动量等)
- 当前 epoch、学习率、损失
- 想记录的其他信息(超参数设置、随机种子等)
1
2
3
4
5
6
7
8
9
10torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss_value,
"extra_info": "anything you want",
},
"ckpt_epoch_100.pth"
)
- 任何可以被
.pth是如何保存数据的?torch.save(obj, path/filename.pth)会把 obj 通过Python 的 pickle序列化;- 对于
tensor,PyTorch会把真实数值存成高效的二进制块(包含dtype、shape等元数据) - 所有这些被写入一个二进制文件,就是
.pth文件 .pth文件 不能用记事本直接看;会是乱码torch.load()本质是反序列化,有安全风险:不要随便加载来路不明的.pth(因为pickle理论上可执行任意代码)
- 什么是
2.2 两种加载神经网络模型的方式
如果保存
model就用1
torch.save(model, "path/filename.pth")torch.load()加载1
2
3
4
5
6
7
8
9
10
11
12
13
14
15# 1. 加载 pth 文件
model = torch.load("path/filename.pth")
model.eval()
# 2. 准备与之前一致的数据,用于展示(x -> y)
np.random.seed(0)
n_samples = 80
x = np.linspace(-3, 3, n_samples)
y = true_func(x) + np.random.normal(scale=3.0, size=n_samples)
# 3. 神经网络计算 (x_tensor -> y_pred)
with torch.no_grad():
x_plot = np.linspace(-3, 3, 200)
x_tensor = torch.from_numpy(x_plot).float().view(-1, 1)
y_pred = model(x_tensor).numpy()如果保存
model.state_dict()就用1
torch.save(model.state_dict(), "path/filename.pth")model.load_state_dict()加载参数名示例:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34# 0. 与训练时保持一致的网络定义
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(1, 64),
torch.nn.Tanh(),
torch.nn.Linear(64, 64),
torch.nn.Tanh(),
torch.nn.Linear(64, 1),
)
def forward(self, x):
return self.net(x)
# 1. 选择要加载的 pth 文件
checkpoint = torch.load("path/filename")
# 2. 重建同结构网络 & 加载参数
model = Net()
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# 3. 准备与之前一致的数据,用于展示(x -> y)
np.random.seed(0)
n_samples = 80
x = np.linspace(-3, 3, n_samples)
y = true_func(x) + np.random.normal(scale=3.0, size=n_samples)
# 4. 神经网络计算 (x_tensor -> y_pred)
with torch.no_grad():
x_plot = np.linspace(-3, 3, 200)
x_tensor = torch.from_numpy(x_plot).float().view(-1, 1)
y_pred = model(x_tensor).numpy()net.0.weight,net.0.bias如果保存
checkpoint就这样加载1
2
3
4
5
6
7
8
9torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
},
"ckpt_epoch_200.pth"
)1
2
3
4
5
6
7
8
9
10
11
12
13ckpt = torch.load("ckpt_epoch_200.pth", map_location="cpu")
print(type(ckpt)) # dict
print(ckpt.keys()) # dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
print("epoch:", ckpt["epoch"])
print("loss:", ckpt["loss"])
# 查看模型参数 keys
for name, param in ckpt["model_state_dict"].items():
print(name, param.shape)
break
3. pytorch 的保存与加载案例
保存
1 | |
加载 model.state_dict() 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57import torch
import matplotlib.pyplot as plt
import numpy as np
import os
# 与训练时保持一致的网络定义
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(1, 64),
torch.nn.Tanh(),
torch.nn.Linear(64, 64),
torch.nn.Tanh(),
torch.nn.Linear(64, 1),
)
def forward(self, x):
return self.net(x)
def true_func(x):
return 1.5 * x**2 - 2 * x + 1
# 1. 选择要加载的 epoch
ckpt_dir = "checkpoints"
target_epoch = 200 # 任意你训练时保存过的,比如 10, 50, 100, 200, ...
ckpt_path = os.path.join(ckpt_dir, f"ckpt_epoch_{target_epoch}.pth")
checkpoint = torch.load(ckpt_path)
# 2. 重建同结构网络 & 加载参数
model = Net()
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# 3. 准备与之前一致的数据(用于展示)
np.random.seed(0)
n_samples = 80
x = np.linspace(-3, 3, n_samples)
y = true_func(x) + np.random.normal(scale=3.0, size=n_samples)
# 4. 画出该 epoch 的拟合效果
with torch.no_grad():
x_plot = np.linspace(-3, 3, 200)
x_tensor = torch.from_numpy(x_plot).float().view(-1, 1)
y_pred = model(x_tensor).numpy()
y_true = true_func(x_plot)
plt.scatter(x, y, s=15, label="Noisy data")
plt.plot(x_plot, y_true, "--", label="True function")
plt.plot(x_plot, y_pred, label=f"Loaded model (epoch={target_epoch})")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.title("Fit of Loaded Checkpoint Model")
plt.show()
加载 model
1 | |