torch.nn.utils.prune에 보면 L1 norm + unstructured의 방식으로 prune을 하는 방식이 두 가지나 있다.
torch.nn.utils.prune.L1Unstructured
torch.nn.utils.prune.l1_unstructured
두 방식에는 어떤 차이가 있는 걸까?
1. L1Unstructured
공식 문서에서도 볼 수 있듯이 class로 구성되어 있다.
그리고 class method들로 apply, apply_mask, prune이 있다.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
conv1 = nn.Conv2d(2, 3, 3)
Output:
Parameter containing:
tensor([[[[ 0.0153, 0.0414, -0.2232],
[ 0.1617, -0.2074, -0.0986],
[ 0.2223, 0.0491, 0.0721]],
[[-0.0087, 0.2122, 0.1332],
[-0.2099, 0.1056, 0.1358],
[ 0.1959, 0.2099, -0.0390]]],
[[[ 0.1034, 0.0649, -0.0065],
[-0.1312, 0.0311, 0.1861],
[ 0.2300, 0.1445, -0.0308]],
[[ 0.0283, 0.1499, -0.0785],
[ 0.1997, -0.1511, -0.1992],
[ 0.1865, 0.0456, 0.1187]]],
[[[ 0.0304, -0.0616, -0.1559],
[ 0.0567, -0.2042, -0.0018],
[ 0.0784, -0.1900, 0.1878]],
[[-0.0987, -0.0172, 0.0495],
[ 0.1921, -0.1494, 0.1218],
[-0.1642, 0.1676, 0.0038]]]], requires_grad=True)
먼저 class를 선언한다. argument에는 얼마나 prune할지 결정하는 값인 amount를 넣어주는데, 그 값이 float라면 비율을, int라면 하위 몇개일지를 의미한다.
p = prune.L1Unstructured(amount = 0.3)
prune을 하고싶다면 class method 중 하나인 apply를 사용한다. argument로는 module, module 내에서 prune을 적용 할 parameter 이름, amount를 받는다 (여기서 amount를 또 입력해야 한다). prune할 양은 위의 amount가 아닌 여기 amount에서 결정된다.
p.apply(conv1, "weight", amount=0.3)
conv1을 뽑아보면
conv1.weight
Output:
tensor([[[[ 0.2278, -0.0000, -0.2026],
[-0.2079, 0.0000, -0.0000],
[-0.2257, 0.0000, -0.1631]],
[[ 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[-0.1659, 0.0000, -0.0000]]],
[[[-0.0000, -0.2001, -0.2164],
[-0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000]],
[[ 0.0000, 0.0000, 0.1864],
[ 0.1945, -0.1769, 0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[-0.1898, 0.0000, -0.0000],
[-0.0000, -0.0000, -0.1996],
[ 0.0000, -0.0000, -0.1728]],
[[-0.2223, -0.0000, -0.0000],
[ 0.0000, 0.1870, 0.0000],
[ 0.0000, -0.0000, 0.0000]]]], grad_fn=<MulBackward0>)
원하는 만큼 prune이 된 것을 확인할 수 있다.
참고로 tensor를 prune하고 싶다면 세 번째 class method인 prune을 사용하면 된다. 가령,
p.prune(conv1.weight)
2. l1_unstructured
Class가 아니다. 따라서 class method가 없다.
개인적인 생각으로는, 위에서 설명한 L1Unstructured 의 class method인 apply와 같은 것 같다.
prune(conv1, 'weight', amount = 0.3) # in-place이다.
3. Miscellaneous
L1Unstructured의 apply와 l1_unstructured 모두 기존 nn module에 변화를 준다.
conv1 = nn.Conv2d(2, 3, 3)
conv1.state_dict()
Output:
OrderedDict([('weight',
tensor([[[[-0.0943, 0.1803, -0.0691],
[-0.1613, 0.1080, 0.0792],
[ 0.1101, 0.0647, 0.2262]],
[[ 0.1114, 0.0392, 0.1380],
[ 0.1541, -0.1448, -0.2170],
[ 0.0793, 0.1666, -0.0386]]],
[[[ 0.0568, -0.0235, 0.1814],
[ 0.0451, -0.1229, 0.0975],
[ 0.2066, 0.0371, -0.2349]],
[[ 0.0647, -0.1927, -0.0462],
[-0.1879, -0.1358, 0.1719],
[-0.2349, -0.0773, -0.0520]]],
[[[-0.1819, -0.1023, 0.0018],
[ 0.0110, -0.0644, -0.1239],
[ 0.0686, -0.1545, -0.0176]],
[[ 0.2112, -0.2080, 0.0557],
[ 0.1711, 0.0080, -0.0178],
[ 0.1075, -0.1634, 0.2167]]]])),
('bias', tensor([ 0.1189, 0.1156, -0.1101]))])
"weight"와 "bias"만 있었던 state_dict를 보면,
prune.l1_unstructured(conv1, "weight", amount=0.3)
conv1.state_dict()
Output:
OrderedDict([('bias', tensor([-0.1330, -0.2045, 0.1568])),
('weight_orig',
tensor([[[[ 0.2125, 0.2319, 0.2275],
[-0.0676, -0.0997, 0.0108],
[ 0.0776, -0.1288, 0.1419]],
[[-0.1979, -0.1434, -0.1636],
[ 0.0579, 0.0700, 0.1911],
[-0.0104, 0.1895, -0.0509]]],
[[[ 0.1324, -0.1949, 0.1590],
[ 0.0195, 0.1843, 0.1687],
[-0.2035, 0.1025, -0.0536]],
[[ 0.0484, 0.2060, 0.1205],
[ 0.0989, -0.0038, -0.1013],
[-0.2163, -0.0479, -0.1959]]],
[[[ 0.1532, -0.1301, -0.0177],
[ 0.1751, 0.0899, 0.2255],
[-0.1137, -0.1474, 0.1261]],
[[ 0.1966, 0.0114, 0.1857],
[ 0.0055, 0.0889, 0.1683],
[-0.2031, -0.1496, 0.0010]]]])),
('weight_mask',
tensor([[[[1., 1., 1.],
[0., 1., 0.],
[0., 1., 1.]],
[[1., 1., 1.],
[0., 0., 1.],
[0., 1., 0.]]],
[[[1., 1., 1.],
[0., 1., 1.],
[1., 1., 0.]],
[[0., 1., 1.],
[1., 0., 1.],
[1., 0., 1.]]],
[[[1., 1., 0.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 0., 1.],
[0., 1., 1.],
[1., 1., 0.]]]]))])
"weight"는 "weight_orig"로 바뀌었고 "weight_mask"가 추가된 것을 확인할 수 있다.
"weight_mask"는 module의 buffer로 저장되며 module.named_buffers()에서 확인할 수 있다.
따라서, conv1.weight 의 output은 "weight_orig"와 "weight_mask"의 elemental-wise multiplication임을 유추할 수 있다. 또한 공식 문서에는 다음과 같이 설명되어 있다.
The pruning techniques implemented in torch.nn.utils.prune compute the pruned version of the weight (by combining the mask with the original parameter) and store them in the attribute weight. Note, this is no longer a parameter of the module, it is now simply an attribute.
'Coding' 카테고리의 다른 글
[Python] HackerRank - Capitalize! (0) | 2023.07.23 |
---|---|
[Python] torch.nn.CrossEntropyLoss 에서 ignore_index (0) | 2023.04.18 |
[Python] torch.scatter_ 이해하기 (0) | 2023.04.18 |
[Python] string split과 rsplit method 차이 (0) | 2023.04.16 |
[Python] 클래스 상속(Class inheritance) 그리고 Pytorch 모델에서의 해석. (1) | 2023.02.24 |