밑바닥부터 시작하는 딥러닝3 - STEP 11

2021. 3. 15. 17:54개인 공부 공간/딥러닝

Function 클래스 수정

기존에 구현한 Function 클래스는 입력과 출력이 모두 하나씩인 경우만 고려했었다. 하지만 함수에 따라 입력이 여러개일수도 있고 출력이 여러개일수도 있다.

11-1.PNG

이러한 가변적인 입출력 길이를 고려하기 위해 변수들을 리스트나 튜플에 넣어 처리하도록 기존의 Function 클래스를 수정하였다.

 

class Function:
  def __call__(self, inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(xs)
    outputs = [Variable(as_array(y)) for y in ys]

    for output in outputs:
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = outputs
    return outputs

  def forward(self, xs):
    raise NotImplementedError

  def bacwark(self, gys):
    raise NotImplementedError

 

여러 변수를 동시에 처리할 수 있도록 변수를 리스트에 담아서 처리한다는 점을 제외하고는 기존의 Function 클래스와 기능은 동일하다.

Add 클래스 구현

두 개의 입력을 받는 Add 클래스를 구현하였다.

 

class Add(Function):
  def forward(self, xs):
    x0, x1 = xs
    y = x0 + x1
    return (y,)

 

위의 코드에서 주의할 점은 return 값이 리스트(또는 튜플) 형태여야 한다는 점이다.

 

xs = [Variable(np.array(2)), Variable(np.array(3))] # 리스트로 준비
f = Add()
ys = f(xs)
y = ys[0]
print(y.data)

 

실행 결과

5

출처:- 사이토 고키, 『밑바닥부터 시작하는 딥러닝3』, 개앞맵시, 한빛미디어(2020)

 

출처: https://privatedevelopnote.tistory.com/81 [개인노트]