우선 scatter와 scatter_ 메소드를 비교하면 다음과 같다:
torch.scatter는 out-of-place 버전, torch.scatter_는 in-place 버전이다. 즉, torch.scatter는 메서드를 실행시키면 즉시 scatter가 tensor에 적용된다.
여기서 각 argument들이 의미하는 바는 다음과 같다.
dim: scatter 할 기준이 되는 축. '0'이면 행방향 (아래 방향), '1'이면 열방향 (오른쪽 방향).
index (LongTensor): 흩뿌릴 element들의 index. 즉, 어떤 숫자를 어떤 규칙으로 옮길지 결정하는 tensor.
src: 어떤 숫자들이 옮겨지는지 그 후보를 담은 소스 tensor.
그렇다면 scatter가 어떤 효과를 내는지 예시를 통해 이해해보자. 아래는 공식 document의 예시이다.
src = torch.arange(1, 11).reshape((2, 5))
src
>> tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
>> tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
dim=0 으로 설정되었기 때문에 index의 element인 숫자들이 하는 역할은 "어떤 행으로 보낼지"에 해당한다. "어떤 열로 보낼지"는 해당 element 들의 열이 된다.
index가 [0, 1, 2, 0]인 1x4 array이기 때문에 src에서는 그에 해당하는 숫자들인 1,2,3,4만 이사를 보내어지게 된다. 따라서 index의 명령에 따라 1,2,3,4들이 이사가 보내지게 되면 위와 같은 결과를 얻게 된다.
'Coding' 카테고리의 다른 글
[Python] HackerRank - Capitalize! (0) | 2023.07.23 |
---|---|
[Python] torch.nn.CrossEntropyLoss 에서 ignore_index (0) | 2023.04.18 |
[Python] string split과 rsplit method 차이 (0) | 2023.04.16 |
[Python] 클래스 상속(Class inheritance) 그리고 Pytorch 모델에서의 해석. (1) | 2023.02.24 |
torch.nn.utils.prune에서 L1Unstructured와 l1_unstructured의 차이 (0) | 2023.01.25 |