tf.function bug

그동안 장대한 포스트만 쓰다가 이번엔 짧은 버그 리포트 하나를 써보려 한다.

Tensorflow에서 @tf.function이라는 유용한 데코레이터가 있다.

그런데 이걸 쓰다가 계속 WARNING이 떠서 이거 왜 이러나… 찾아보고 테스트 하다가 원인을 발견한 것을 기록하여 같은 실수를 막고자 한다.

AutoGraph could not transform

어떤 함수에 @tf.function를 붙여서 실행을 해보니 “AutoGraph could not transform‘로 시작하는 경고 메시지가 시작할 때 뿐만 아니라 학습 중에도 계속 발생했다.

@tf.function 안의 다른 함수들은 문제가 없는데 유독 특정 함수만 계속 경고가 떴다. 코드를 바꿔가면서 실험을 해보다가 아래와 같은 간단한 테스트 코드를 만들었다.

import tensorflow as tf

class TestTfBug:
    def test_func_in_class(self):
        a = tf.random.uniform((8, 100, 100, 3))
        b = tf.random.uniform((8, 100, 100, 3))
        c = 0.5 * tf.reduce_mean(tf.abs(a), axis=[1, 2, 3]) + \
            0.5 * tf.reduce_mean(tf.abs(b), axis=[1, 2, 3])
        return c

def test_func_just_func():
    a = tf.random.uniform((8, 100, 100, 3))
    b = tf.random.uniform((8, 100, 100, 3))
    c = 0.5 * tf.reduce_mean(tf.abs(a), axis=[1, 2, 3]) + \
        0.5 * tf.reduce_mean(tf.abs(b), axis=[1, 2, 3])
    return c

# This function results in WARNING:tensorflow:AutoGraph could not transform ~~
@tf.function
def test_tf_bug_class():
    result = TestTfBug().test_func_in_class()

# This function has no problem
@tf.function
def test_tf_bug_func():
    result = test_func_just_func()

if __name__ == "__main__":
    test_tf_bug_class()
    test_tf_bug_func()

위에서 실행한 두 함수 중에 하나씩만 번갈아가며 실행하면 어느쪽에서 경고가 나오는지 알 수 있다. test_tf_bug_class()를 실행하면 다음과 같은 경고가 뜬다.

WARNING:tensorflow:AutoGraph could not transform <bound method TestTfBug.test_func_in_class of <__main__.TestTfBug object at 0x7f9da0181a50» and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output. Cause: expected exactly one node node, found [<gast.gast.FunctionDef object at 0x7f9da0176190>, <gast.gast.Return object at 0x7f9da01767d0>]

해결 방법

아니 똑같은 함수를 실행했는데 그냥 함수에서는 문제가 없고 클래스 메소드에서만 문제가 생긴하는게 말이 되나?

그래서 이걸 텐서플로 이슈에 올려볼까 했더니 역시 이미 이슈에 있는 문제였다.

https://github.com/tensorflow/tensorflow/issues/35765

https://github.com/tensorflow/tensorflow/issues/35810

이것은 \를 이용해 코드 라인을 나눌때 생기는 문제였다. 만약 코드를 다음과 같이 고치면 문제가 없다.

import tensorflow as tf

class TestTfBug:
    def test_func_in_class(self):
        a = tf.random.uniform((8, 100, 100, 3))
        b = tf.random.uniform((8, 100, 100, 3))
        # No backslash here!!
        c = 0.5 * tf.reduce_mean(tf.abs(a), axis=[1, 2, 3]) + 0.5 * tf.reduce_mean(tf.abs(b), axis=[1, 2, 3])
        return c

# No problem here
@tf.function
def test_tf_bug_class():
    result = TestTfBug().test_func_in_class()

if __name__ == "__main__":
    test_tf_bug_class()

문제의 원인을 좀 따져본다면 @tf.function이라는게 eager execution으로 돌아가는 코드를 그래프 모드에서 실행할 수 있는 코드로 변환해 준다. 다음 튜토리얼에서 코드 생성에 대한 내용을 볼 수 있다.

https://www.tensorflow.org/guide/function#use_python_control_flow

그래프 모드에서 if ~ else ~ 같은게 있으면 static graph를 만들 수 없기 때문에 똑같은 기능을 그래프 모드에서 실행할 수 있는 코드를 자동으로 생성하는 것이다. 그런데 그 코드 생성 과정에서 저 backslash(\)가 문제를 일으키는 듯 하다. 코드 생성에서 backslash는 생각 못 했나보다.

암튼 backslash만 안 쓰면 문제 해결!!

뭐… 이슈에 있기 때문에 아마 몇 달 지나서 새 버전 나오면 없어질 문제인것 같다.

- 끝 -