Coding

[Python] torch.scatter_ 이해하기

Belter 2023. 4. 18. 00:48

우선 scatter와 scatter_ 메소드를 비교하면 다음과 같다:

torch.scatter
torch.scatter_

torch.scatter는 out-of-place 버전, torch.scatter_는 in-place 버전이다. 즉, torch.scatter는 메서드를 실행시키면 즉시 scatter가 tensor에 적용된다.

 

여기서 각 argument들이 의미하는 바는 다음과 같다.

dim: scatter 할 기준이 되는 축. '0'이면 행방향 (아래 방향), '1'이면 열방향 (오른쪽 방향).

index (LongTensor): 흩뿌릴 element들의 index. 즉, 어떤 숫자를 어떤 규칙으로 옮길지 결정하는 tensor.  

src: 어떤 숫자들이 옮겨지는지 그 후보를 담은 소스 tensor.

 

 

그렇다면 scatter가 어떤 효과를 내는지 예시를 통해 이해해보자. 아래는 공식 document의 예시이다.

src = torch.arange(1, 11).reshape((2, 5))
src
>> tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
        
        
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
>> tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

dim=0 으로 설정되었기 때문에 index의 element인 숫자들이 하는 역할은 "어떤 행으로 보낼지"에 해당한다. "어떤 열로 보낼지"는 해당 element 들의 열이 된다.

index가 [0, 1, 2, 0]인 1x4 array이기 때문에 src에서는 그에 해당하는 숫자들인 1,2,3,4만 이사를 보내어지게 된다. 따라서 index의 명령에 따라 1,2,3,4들이 이사가 보내지게 되면 위와 같은 결과를 얻게 된다.