PyTorch 모형의 총 모수 수 확인
PyTorch 모델의 총 매개 변수 수를 계산하려면 어떻게 해야 합니까?과 비슷한 것.model.count_params()
케라스에서
PyTorch에는 Keras처럼 총 매개변수 수를 계산하는 기능이 없지만 각 매개변수 그룹에 대한 요소 수를 합계할 수 있습니다.
pytorch_total_params = sum(p.numel() for p in model.parameters())
훈련 가능한 파라미터만 계산하려는 경우:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
PyTorch 포럼의 이 답변에서 영감을 얻은 답변입니다.
Keras와 같은 각 레이어의 매개 변수 개수를 가져오기 위해 PyTorch는 매개 변수 이름과 매개 변수 자체의 반복기를 모두 반환합니다.예:
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
count_parameters(net)
출력 예:
+-------------------+------------+
| Modules | Parameters |
+-------------------+------------+
| embeddings.weight | 922866 |
| conv1.weight | 1048576 |
| conv1.bias | 1024 |
| bn1.weight | 1024 |
| bn1.bias | 1024 |
| conv2.weight | 2097152 |
| conv2.bias | 1024 |
| bn2.weight | 1024 |
| bn2.bias | 1024 |
| conv3.weight | 2097152 |
| conv3.bias | 1024 |
| bn3.weight | 1024 |
| bn3.bias | 1024 |
| lin1.weight | 50331648 |
| lin1.bias | 512 |
| lin2.weight | 265728 |
| lin2.bias | 519 |
+-------------------+------------+
Total Trainable Params: 56773369
공유 매개 변수를 이중으로 계산하지 않으려면 다음과 같이 사용합니다.
sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
다음은 훈련 불가능한 매개 변수를 선택적으로 필터링할 수 있는 보다 자세한 구현입니다.
def numel(m: torch.nn.Module, only_trainable: bool = False):
"""
Returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
사용할 수 있습니다.torchsummary
같은 일을 하는 것.두 줄의 코드일 뿐입니다.
from torchsummary import summary
print(summary(model, (input_shape)))
모델을 인스턴스화하지 않고 각 레이어의 가중치 및 편향 수를 계산하려는 경우 원시 파일을 로드하고 결과에 대해 반복하면 됩니다.collections.OrderedDict
예:
import torch
tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
print('Layer {}: {} elements'.format(layer_tensor_name, torch.numel(tensor)))
당신은 다음과 같은 것을 얻을 것입니다.
conv1.weight: 312
conv1.bias: 26
batch_norm1.weight: 26
batch_norm1.bias: 26
batch_norm1.running_mean: 26
batch_norm1.running_var: 26
conv2.weight: 2340
conv2.bias: 10
batch_norm2.weight: 10
batch_norm2.bias: 10
batch_norm2.running_mean: 10
batch_norm2.running_var: 10
fcs.layers.0.weight: 135200
fcs.layers.0.bias: 260
fcs.layers.1.weight: 33800
fcs.layers.1.bias: 130
fcs.batch_norm_layers.0.weight: 260
fcs.batch_norm_layers.0.bias: 260
fcs.batch_norm_layers.0.running_mean: 260
fcs.batch_norm_layers.0.running_var: 260
다른 가능한 해결책은 다음과 같습니다.
def model_summary(model):
print("model_summary")
print()
print("Layer_name"+"\t"*7+"Number of Parameters")
print("="*100)
model_parameters = [layer for layer in model.parameters() if layer.requires_grad]
layer_name = [child for child in model.children()]
j = 0
total_params = 0
print("\t"*10)
for i in layer_name:
print()
param = 0
try:
bias = (i.bias is not None)
except:
bias = False
if not bias:
param =model_parameters[j].numel()+model_parameters[j+1].numel()
j = j+2
else:
param =model_parameters[j].numel()
j = j+1
print(str(i)+"\t"*3+str(param))
total_params+=param
print("="*100)
print(f"Total Params:{total_params}")
model_summary(net)
그러면 아래와 유사한 출력이 표시됩니다.
model_summary
Layer_name Number of Parameters
====================================================================================================
Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1)) 60
Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1)) 880
Linear(in_features=576, out_features=120, bias=True) 69240
Linear(in_features=120, out_features=84, bias=True) 10164
Linear(in_features=84, out_features=10, bias=True) 850
====================================================================================================
Total Params:81194
텐서의 반복 가능한 값을 텐서로 변환한 다음 다음과 결합하는 내장 유틸리티 기능이 내장되어 있습니다.
torch.nn.utils.parameters_to_vector(model.parameters()).numel()
또는 명명된 가져오기를 사용하여 더 짧게(from torch.nn.utils import parameters_to_vector
):
parameters_to_vector(model.parameters()).numel()
연결할 수 있는 최종 답변:
def count_number_of_parameters(model: nn.Module, only_trainable: bool = True) -> int:
"""
Counts the number of trainable params. If all params, specify only_trainable = False.
Ref:
- https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=brando_miranda
- https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model/62764464#62764464
:return:
"""
if only_trainable:
num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
else: # counts trainable and none-traibale
num_params: int = sum(p.numel() for p in model.parameters() if p)
assert num_params > 0, f'Err: {num_params=}'
return int(num_params)
서로 다른 매개 변수가 메모리를 공유하는 경우, 사용하는 응답을 포함하여 모든 응답이 완전히 해결되지 않습니다.numel
,PrettyTable
그리고..data_ptr
@teechert는 정확하게 동일한 텐서를 가리키는 두 개의 다른 매개 변수가 있는 경우를 처리하는 훌륭한 답변을 제공했습니다.하지만 한 매개변수가 다른 매개변수의 한 조각이라면 어떨까요?비록 그들이 약간의 기억을 공유하겠지만, 사용..data_ptr()
순진하게도 다른 결과를 도출할 것입니다. 따라서 여전히 그의 접근 방식을 사용하는 초과 계산이 있을 것입니다.
모든 텐서의 항목이 동일한 항목을 가리키지 않도록 주의해야 합니다.이 작업은 집합 이해를 사용하여 수행할 수 있습니다.
훈련할 수 없는 매개 변수 포함:
len({e.data_ptr() for p in model.parameters() for e in p.view(-1)})
훈련할 수 없는 매개 변수 무시:
len({e.data_ptr() for p in model.parameters() if p.requires_grad for e in p.view(-1)})
텐서가 메모리를 공유할 수 있다면 고유 텐서의 수를 세어보는 것은 어떨까요?이것은 까다로운 인터뷰 문제처럼 들리지만 UnionFind 데이터 구조를 사용하면 쉽습니다!당신이 원하지 않으면pip install
이 파일을 문자 그대로 복사하여 대체할 수 있습니다.
모델을 이 함수에 전달하면 일부 매개 변수가 다른 매개 변수의 슬라이스인 경우에도 메모리 공유가 초과되지 않습니다.
def num_parameters(model, show_only_trainable):
from UnionFind import UnionFind
u = UnionFind()
for p in model.parameters():
if not show_only_trainable or p.requires_grad:
u.union(*[e.data_ptr() for e in p.view(-1)])
print(f'Number of parameters: {len(u)}')
print(f'Number of tensors: {u.num_connected_components}')
이 코드는 다른 기술을 사용할 때의 문제와 위의 기능을 사용하여 해결하는 방법을 보여줍니다.
>>> import torch.nn as nn
>>> import torch
>>> torch.manual_seed(0)
>>>
>>> # This layer is not trainable
>>> frozen_layer = nn.Linear(out_features=3, in_features=4, bias=False)
>>> for p in frozen_layer.parameters():
... p.requires_grad = False
...
>>>
>>> # There are 4*2 + 3*4 = 20 total parameters
>>> # There are 4*2 = 8 trainable parameters
>>> model = nn.Sequential(
... nn.Linear(out_features=4, in_features=2, bias=False),
... nn.ReLU(),
... frozen_layer,
... nn.Sigmoid()
... )
>>>
>>> # Parameters seem properly accounted for so far
>>> sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
20
>>> sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
8
>>>
>>> # Add a new Parameter that is an arbitrary slice of an existing Parameter.
>>> # NOTE that slice syntax `[]` and wrapping with `nn.Parameter()` do
>>> # NOT copy the data, but merely point to part of existing tensor.
>>> model.newparam = nn.Parameter(next(model.parameters())[0:2, 1:2])
>>>
>>> params = list(model.parameters())
>>>
>>> # Notice that both appear the same. Do they share memory?
>>> # `params[0]` is `model.newparam`. `params[1]` is tensor that `params[0]` was sliced from.
>>>
>>> params[0]
Parameter containing:
tensor([[-0.4683],
[ 0.0262]], requires_grad=True)
>>>
>>> params[1][0:2, 1:2]
tensor([[-0.4683],
[ 0.0262]], grad_fn=<SliceBackward0>)
>>>
>>> with torch.no_grad():
... params[0][0, 0] = 1.2345
...
>>>
>>> # Both have changed, proving that they DO share memory.
>>>
>>> params[0]
Parameter containing:
tensor([[1.2345],
[0.0262]], requires_grad=True)
>>>
>>> params[1][0:2, 1:2]
tensor([[1.2345],
[0.0262]], grad_fn=<SliceBackward0>)
>>>
>>> # WRONG - the number of parameters "appears" to have increased by 2 (because of `model.newparam`).
>>> sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
22
>>> sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
10
>>>
>>> # CORRECT - this discounts all shared parameters
>>> len({e.data_ptr() for p in model.parameters() for e in p.view(-1)})
20
>>> len({e.data_ptr() for p in model.parameters() if p.requires_grad for e in p.view(-1)})
8
>>>
>>> # To count unique tensors, we can use this function.
>>> # It utilizes the UnionFind data structure which can be dropped in directly from here:
>>> # https://gist.github.com/timgianitsos/0878a0b241cb5d0ad8b16ebc2b14322a
>>> def num_parameters(model, show_only_trainable):
... from UnionFind import UnionFind
... u = UnionFind()
... for p in model.parameters():
... if not show_only_trainable or p.requires_grad:
... u.union(*[e.data_ptr() for e in p.view(-1)])
... print(f'Number of parameters: {len(u)}')
... print(f'Number of tensors: {u.num_connected_components}')
...
>>>
>>> # Notice that the problem has been fixed
>>> num_parameters(model, show_only_trainable=False)
Number of parameters: 20
Number of tensors: 2
>>> num_parameters(model, show_only_trainable=True)
Number of parameters: 8
Number of tensors: 1
내 기능을 사용하는 것을 기억하십시오.num_parameters()
함수는 모든 텐서의 모든 항목을 루프해야 하기 때문에 다른 솔루션보다 실행하는 데 더 오래 걸립니다. 2200만 매개 변수 모델의 Mac CPU에서 약 2분입니다.연속적인 메모리 주소에 대한 데이터 포인터가 동일한 일정한 양만큼 다르다는 사실을 활용하면 훨씬 더 빨리 만들어질 수 있습니다(예: 텐서가 다음과 같은 경우 4바이트).torch.float32
하지만 이를 위해서는 텐서의dtype
그리고.stride
2천만 개 이상의 매개 변수에 대해 몇 분을 기다릴 의사가 있다면 아마도 과잉 살상이 될 것입니다.
@fábio-perez가 언급했듯이, PyTorch에는 이러한 내장 기능이 없습니다.
하지만, 저는 이것이 동일한 결과를 얻기 위한 작고 깔끔한 방법이라는 것을 알게 되었습니다.
num_of_parameters = sum(map(torch.numel, model.parameters()))
언급URL : https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
'programing' 카테고리의 다른 글
Oracle like 절에서 밑줄이 작동하지 않습니다. (0) | 2023.06.30 |
---|---|
Git 이등분 오류 실행 취소 (0) | 2023.06.30 |
문자열 열의 각 행에서 주어진 문자의 발생 횟수를 계산하는 방법은 무엇입니까? (0) | 2023.06.30 |
Firebase 프로젝트에서 앱을 삭제/제거하려면 어떻게 해야 합니까? (0) | 2023.06.30 |
Oracle 패키지 수준 변수의 범위 (0) | 2023.06.30 |