打印模型中的参数个数

import torch
import torch.nn as nn

假设你已经定义了一个模型,例如:

class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 1)

def forward(self, x):
    x = torch.relu(self.fc1(x))
    x = self.fc2(x)
    return x

创建一个模型实例

model = MyModel()

计算并打印模型的参数个数

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f’模型的参数个数: {count_parameters(model)}’)