打印模型中的参数个数
import torch
import torch.nn as nn
假设你已经定义了一个模型,例如:
class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
创建一个模型实例
model = MyModel()
计算并打印模型的参数个数
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f’模型的参数个数: {count_parameters(model)}’)