밑바닥부터 시작하는 딥러닝3 - STEP 16

2021. 3. 15. 18:06개인 공부 공간/딥러닝

세대 추가

밑바닥부터 시작하는 딥러닝3 - STEP 15 포스트에서 설명한 기존 코드의 문제점을 해결하기 위해 Variable 클래스와 Function 클래스에 몇 번째 세대의 함수, 변수 인지 나타내는 generation 인스턴스 변수를 추가했습니다.

 

class Variable:
  def __init__(self, data):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.'.format(type(data)))

    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0 # 세대 수를 기록하는 변수

  def set_creator(self, func):
    self.creator = func
    self.generation=func.generation + 1 # 세대를 기록한다(부모 세대 + 1)
  ....

 

generation 을 0으로 초기 설정한 후 set_creator 메서드가 호출될 때 부모 함수의 세대에 1을 더해준 값을 설정합니다.

 

class Function(object):
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)
    if not isinstance(ys, tuple):
      ys = (ys,)    
    outputs = [Variable(as_array(y)) for y in ys]

    self.generation = max([x.generation for x in inputs])
    for output in outputs:
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = outputs

    return outputs if len(outputs) > 1 else outputs[0]

 

세대 순으로 꺼내기

코드 수정을 통해 순전파 과정에서 각 함수와 변수들에게 세대가 설정됩니다.

16-3.PNG

이제는 역전파 과정에서 함수를 올바른 순서인 D, B, C, A(D, C, B, A) 순으로 꺼낼 수 있게 됩니다. 이를 위해 Variable 클래스의 backward 메서드에도 수정이 필요합니다.

Variable 클래스의 backward

class Variable:
  ....

  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)

    funcs = []
    seen_set = set()

    def add_func(f): # 세대 순으로 정렬하는 역할
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)

    add_func(self.creator)

    while funcs:
      f = funcs.pop()
      gys = [output.grad for output in f.outputs]
      gxs = f.backward(*gys)
      if not isinstance(gxs, tuple):
        gxs = (gxs,)

      for x, gx in zip(f.inputs, gxs):
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)

 

add_func 함수를 새로 정의하였고 이를 이용해 함수 리스트를 세대 순으로 정렬하는 기능을 구현 했습니다.

동작 확인

x = Variable(np.array(2.0))
a = square(x)
y = add(square(a), square(a))
y.backward()

print(y.data)
print(x.grad)

 

실행 결과

32.0
64.0

기존에는 올바른 계산이 불가능 했었던 계산 그래프의 미분 과정까지 정확하게 구했습니다.


출처:- 사이토 고키, 『밑바닥부터 시작하는 딥러닝3』, 개앞맵시, 한빛미디어(2020)

출처: https://privatedevelopnote.tistory.com/81 [개인노트]