본문 바로가기
Study/Python

Python unittest module 사용

by 개발새-발 2021. 12. 23.
반응형

프로그램을 개발할 때 Unit test를 수행하는 경우가 많다. 처음 프로그래밍을 시작하였을 때에는 오류가 생겼을 때 프로그램 중간에 관련 출력문을 추가하여 함수의 출력이나 결과를 확인하는 경우가 많다. 하지만, 프로그램이 살짝만 복잡해져도 어느 함수에서 오류가 발생하였는지 추적하는 것이 힘들어지게 된다. Unit test를 하는 경우 어느 함수 혹은 메소드가 통과하여야 하는 테스트들을 작성해 놓은 뒤 테스트를 수행하게 된다. 테스트 케이스를 잘 작성해 놓은 경우 테스트 결과를 확인하는 것으로 어느 함수가 경우에 대해 동작하고, 제대로 동작하지 않는지를 볼 수 있다. python에서는 이를 unittest 모듈을 사용하여 쉽게 수행할 수 있다.

unittest 모듈 간단하게 사용하기

unittest.TestCase 의 간단한 사용

unittest 모듈에서 Unit test를 쉽게 하기 위하여 관련 class를 제공한다. 그중 하나가, TestCase인데 우리는 이를 사용할 것이다. 우선, unittest.TestCase의 subclass를 만들어 주고, 그 클래스에 Test method들을 삽입하여야 한다. 이때 Test method들의 이름test로 시작하여야 한다. 이경우 아래와 같은 꼴이 될 것이다.

import unittest

class MyTestClass(unittest.TestCase):
    def test_function1(self):
        # do test!
        self.assertEqual(1,1)
        self.assertFalse(4 > 5)

    def test_function2(self):
        # do test!
        self.assertTrue("BRAVO".isupper())
        self.assertEqual("ALPHA".lower(),"alpha")

if __name__ == "__main__":
    unittest.main() # need when run this file directly

Test method의 작성

unittest.TestCase의 subclass를 생성한 후 Test method들이 그 class 내부에 있어야 unittest에 의해 쉽게 테스트가 가능하다. 위에서 말하였듯이 수행하고자 하는 test method들의 이름은 test로 시작하여야 한다. 또, 이 Test method 들에는 원하는 결과를 얻지 못하였을 때 assert를 수행하여 테스트를 통과하지 못했음을 알 수 있도록 하는 메소드들이 존재한다. 위 코드에서 assertEqualassertTrue, assertFalse가 대표적인 예시이다. assertEqual(a,b)a == b를 수행하여 True 이면 통과, False라면 테스트에 실패하게 된다. assertTrue(a)bool(a)가 True일 때 통과하고, False일 때 테스트에 실패한다. 이것 말고도 unittest.TestCase서 제공하는 다른 assert 메소드들이 존재한다.

Method 통과 조건
assertEqual(a, b) a == b
assertNotEqual(a, b) a != b
assertTrue(x) bool(x) is True
assertFalse(x) bool(x) is False
assertIs(a, b) a is b
assertIsNot(a, b) a is not b
assertIsNone(x) x is None
assertIsNotNone(x) x is not None
assertIn(a, b) a in b
assertNotIn(a, b) a not in b
assertIsInstance(a, b) isinstance(a, b)
assertNotIsInstance(a, b) ot isinstance(a, b)
assertAlmostEqual(a, b) round(a-b, 7) == 0
assertNotAlmostEqual(a, b) round(a-b, 7) != 0
assertGreater(a, b) a > b
assertGreaterEqual(a, b) a >= b
assertLess(a, b) a < b
assertLessEqual(a, b) a <= b

이외에도 특정 작업을 진행하는 동안 예외가 발생하는지 확인하는 메소드와 list, dictionary 와 같은 자료형에 대해 내용이 같음을 비교하는 메소드들도 존재한다.

테스트 실행하기

Test를 작성한 파일에 아래 코드가 존재하면 단순히 다른 파이썬 코드를 실행하는 것처럼 실행하면 테스트 결과를 볼 수 있다.

if __name__ == "__main__":
    unittest.main() # need when run this file directly

아까 작성하였던 Test code를 실행하고 얻은 결과는 아래와 같다.

> python test.py
..
----------------------------------------------------------------------
Ran 2 tests in 0.000s

OK

위 코드가 존재하지 않은 경우에는 명령어에 다음과 같이 입력함으로 테스트를 수행할 수 있다. 아래 코드에서 tmp/test.py는 테스트 코드가 적힌 파일의 경로이다. 출력 결과는 위와 같다.

> python -m unittest tmp/test.py

테스트의 좀 더 자세한 정보를 보기 위해서는 v 옵션을 추가해주어야 한다.

> python -m unittest tmp/test.py -v 
test_function1 (tmp.test.MyTestClass) ... ok
test_function2 (tmp.test.MyTestClass) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.001s

OK

경로가 아닌 모듈을 입력해 주는 것으로도 테스트를 수행할 수 있다.

> python -m unittest tmp.test            
..
----------------------------------------------------------------------
Ran 2 tests in 0.001s

OK

모듈의 특정 테스트 클래스에 대해서만 테스트를 수행할 수도 있다.

> python -m unittest tmp.test.MyTestClass
..
----------------------------------------------------------------------
Ran 2 tests in 0.000s

OK

원하는 경우 테스트 클래스의 특정 테스트 함수만 수행하는 것도 가능하다.

> python -m unittest tmp.test.MyTestClass.test_function1
.
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK

이번에는 실패하는 경우를 보기 위하여 test_function2를 아래와 같이 수정하였다.

def test_function2(self):
    # do test!
    self.assertTrue("BRAVO".isupper())
    self.assertEqual("ALPHA".lower(),"alpha")
    self.assertEqual("gaMma".lower(),"gamMMA")
    self.assertTrue(1000 < 1)

이후 실행하면 아래와 같은 결과를 볼 수 있다.

> python -m unittest tmp/test.py
.F
======================================================================
FAIL: test_function2 (__main__.MyTestClass)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "somewhatdirectory\test.py", line 13, in test_function2
    self.assertEqual("gaMma".lower(),"gamMMA")
AssertionError: 'gamma' != 'gamMMA'
- gamma
+ gamMMA


----------------------------------------------------------------------
Ran 2 tests in 0.001s

FAILED (failures=1)

test_function2 가 self.assertEqual("gaMma".lower(),"gamMMA") 에 의해 실패함을 볼 수 있다. self.assertTrue(1000 < 1)도 실행하면 실패하는 코드이지만 이미 이 test method의 앞에서 assert가 발생하였기 때문에 이 코드는 이 test method 내에서 실행되지 않는다.

unittest.TestCase 사용

앞에서 unittest.TestCase를 간단하게 사용하는 법을 알아보았다. unittest.TestCase에서 제공하는 기능은 다양한데, 이 중에서 약간만 더 알아보도록 하자.

setUp(), tearDown()

setUp()tearDown()은 각각 Test method이 실행되기 직전과 실행된 직후에 실행되는 메소드이다. 아래 예시를 보자.

import unittest

class MyTestClass(unittest.TestCase):
    def setUp(self):
        print("setUp() called.")

    def tearDown(self):
        print("tearDown() called.")

    def test_func1(self):
        print("test_func1() called")

    def test_func2(self):
        print("test_func2() called")

테스트 결과의 출력은 아래와 같다.

setUp() called.
test_func1() called
tearDown() called.
.setUp() called.
test_func2() called
tearDown() called.
.
----------------------------------------------------------------------
Ran 2 tests in 0.001s

OK

test_func1test_func2가 실행되기 전에 setUp()이 실행되고 test method들이 실행된 이후에는 tearDown()이 실행되는 것을 볼 수 있다.

이를 통해 각 Test method들이 시작하기 전과 수행된 후에 진행되어야 할 공통적인 작업들을 간단하게 묶어서 처리할 수 있다. 예를 들어 아래와 같이 작성된 class인 Layer에 대해 테스트를 수행하고 싶다고 하자. 매 테스트 전에 Layer의 오브젝트인 layerwb값은 초기값으로 설정되어야 한다. 그렇지 않으면 test_forward2 에서 변경한 w값이 test_forward3에도 적용이 되기 때문이다. layer을 초기화해주는 작업을 setUp()에 작성하여 주면 매 test method마다 이를 작성하는 수고를 덜 수 있다.

# layer.py
class Layer():
    def __init__(self,w=1,b=1):
        self.w = w;
        self.b = b;

    def forward(self,x):
        return self.w*x + self.b
# test.py
import unittest
from layer import Layer

class LayerTestCase(unittest.TestCase):
    def setUp(self):
        self.layer = Layer()

    def test_forward1(self):
        result = self.layer.forward(1)
        self.assertEqual(result,2)

    def test_forward2(self):
        self.layer.w = 5

        result = self.layer.forward(1)
        self.assertEqual(result,6)

    def test_forward3(self):
        self.layer.b = 10

        result = self.layer.forward(1)
        self.assertEqual(result,11)

subTest()

반목문을 사용하여 test를 수행하는 경우가 존재한다. 반복문에 의해 테스트를 수행하다가 오류가 발생하면 처음 오류가 발생한 부분에서 테스트가 실패로 끝나게 된다. 예를 들어 아래와 같은 테스트 코드와 그 결과가 있다고 하자.

import unittest

class MyTestCase(unittest.TestCase):
    def test_func(self):
        for i in range(5):
            self.assertEqual(i%2,1)
F
======================================================================
FAIL: test_func (tmp.test3.MyTestCase)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "\tmp\test3.py", line 6, in test_func
    self.assertEqual(i%2,1)
AssertionError: 0 != 1

----------------------------------------------------------------------
Ran 1 test in 0.001s

FAILED (failures=1)

이 테스트는 i가 짝수일 때 실패하게 된다. i0 이상 4 이하의 수인 경우들에 대해 self.assertEqual(i%2,1)가 수행되기 때문에 이 테스트는 실패하게 된다. 우리는 어느 i에 대하여 테스트가 실패하였는지 알고 싶지만 테스트 결과의 출력은 아래와 같이 나오기 때문에 이를 파악하기가 어렵다.

아래와 같이 subTest()를 사용하는 경우 어느 i에 대하여 오류가 발생하였는지 알 수 있다.

import unittest

class MyTestCase(unittest.TestCase):
    def test_func(self):
        for i in range(5):
            with self.subTest(i=i): 
                self.assertEqual(i%2,1)
======================================================================
FAIL: test_func (tmp.test3.MyTestCase) (i=0)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "\tmp\test3.py", line 7, in test_func
    self.assertEqual(i%2,1)
AssertionError: 0 != 1

======================================================================
FAIL: test_func (tmp.test3.MyTestCase) (i=2)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "\tmp\test3.py", line 7, in test_func
    self.assertEqual(i%2,1)
AssertionError: 0 != 1

======================================================================
FAIL: test_func (tmp.test3.MyTestCase) (i=4)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "\tmp\test3.py", line 7, in test_func
    self.assertEqual(i%2,1)
AssertionError: 0 != 1

----------------------------------------------------------------------
Ran 1 test in 0.000s

FAILED (failures=3)

정수가 아닌 다른 데이터 타입에도 적용이 가능하다.

import unittest

class MyTestCase(unittest.TestCase):
    def test_func(self):
        strs = ["hello","World"]

        for s in strs:
            with self.subTest(s=s): 
                self.assertTrue(s.islower())
======================================================================
FAIL: test_func (tmp.test3.MyTestCase) (s='World')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "tmp\test3.py", line 9, in test_func
    self.assertTrue(s.islower())
AssertionError: False is not true

----------------------------------------------------------------------
Ran 1 test in 0.000s

FAILED (failures=1)

assertRaises(exception, callable, *args, **kwds)

callable이 특정 인자에 의해 실행되는 경우 exception이 발생하는지 확인한다. 나눗셈을 하는 div 함수에 b 값이 0일 때 ZeroDivisionError이 발생하는지 확인한다고 하자.

import unittest

def div(a,b):
    if(b == 0):
        raise ZeroDivisionError
    return a/b

class MyTestCase(unittest.TestCase):
    def test_func(self):
        self.assertRaises(ZeroDivisionError,div,10,0)
.
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK

테스트가 통과함을 볼 수 있다.

단순히 Exception이 발생하는 것을 보는 것이 아니라 Exception 내부의 정보가 중요한 경우 assertRaises(exception) 꼴로 사용할 수 있다. 이 경우, context manager의 형태로 사용하여야 하며, context manager에서 발생한 Exception object를 받을 수 있다.

import unittest

def div(a,b):
    if(b == 0):
        raise ZeroDivisionError("err in div")
    return a/b

class MyTestCase(unittest.TestCase):
    def test_func(self):
        with self.assertRaises(ZeroDivisionError) as cm:
            div(5,0)

        err = cm.exception # get exception
        self.assertEqual(err.args[0],"err in div")
.
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK

fail(msg=None)

fail을 사용하면 테스트를 실패하게 한다.

import unittest

class MyTestCase(unittest.TestCase):
    def test_func1(self):
        self.assertTrue(True)
        self.fail()
        self.assertTrue(True)

    def test_func2(self):
        self.fail("FAIL~~")
FF
======================================================================
FAIL: test_func1 (tmp.test5.MyTestCase)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "tmp\test5.py", line 6, in test_func1
    self.fail()
AssertionError: None

======================================================================
FAIL: test_func2 (tmp.test5.MyTestCase)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "tmp\test5.py", line 10, in test_func2
    self.fail("FAIL~~")
AssertionError: FAIL~~

----------------------------------------------------------------------
Ran 2 tests in 0.001s

FAILED (failures=2)

skip, skipIf, skipUnless, skipTest

원하는 test를 skip 할 수 있다. 특정 OS에서만 작동하는 Test case 이거나 특정 라이브러리의 버전이 얼마 이상 혹은 이하일 때만 수행 가능한 테스트 케이스의 경우 유용하다. skip, skipIf, skipUnlessdecorator 형식으로 사용한다. 각 인자에서 reason은 skip이 될 경우 출력할 skip의 이유이고, conditionskipIf, skipUnless에서 skip을 할지 안 할지의 여부를 판단하기 위한 조건이다.

  • skip(reason) : 무조건 Skip
  • skipIf(condition, reason) : Condition이 True 일 때 skip
  • skipUnless(condition,reason) : Condition이 False일 때 skip
import unittest
import sys

class MyTestCase(unittest.TestCase):

    @unittest.skip("Must skip")
    def test_func1(self):
        pass

    @unittest.skipIf(True,"skipIf")
    def test_func2(self):
        pass

    @unittest.skipIf(False,"skipIf")
    def test_func3(self):
        pass

    @unittest.skipUnless(False,"skipUnless")
    def test_func4(self):
        pass

    @unittest.skipUnless(True,"skipUnless")
    def test_func5(self):
        pass

    @unittest.skipUnless(sys.platform.startswith("win"),"Need windows")
    def test_func6(self):
        pass

더 많은 정보를 보기위해 v옵션과 함께 테스트를 진행한 결과이다. Windows에서 진행하였다.

> python  -m unittest tmp/test6.py -v 
test_func1 (tmp.test6.MyTestCase) ... skipped 'Must skip'
test_func2 (tmp.test6.MyTestCase) ... skipped 'skipIf'
test_func3 (tmp.test6.MyTestCase) ... ok
test_func4 (tmp.test6.MyTestCase) ... skipped 'skipUnless'
test_func5 (tmp.test6.MyTestCase) ... ok
test_func6 (tmp.test6.MyTestCase) ... ok

----------------------------------------------------------------------
Ran 6 tests in 0.001s

OK (skipped=3)

setUp 혹은 test method 에서 skipTest(reason)을 호출하게 되면 그 테스트는 스킵된다. 이는 unittest.SkipTestraise 하는 경우에도 동일하다.

Unit test를 위한 폴더 구성

테스트 코드를 다른 코드와 함께 넣어서 테스트를 수행하고 개발을 진행하는 것보다는 테스트 코드는 따로 분리하는 것이 깔끔하다. 그래서 나는 주로 프로젝트 내부에 src 폴더에 주요 코드를 넣고, test 폴더에 테스트 파일들을 작성한다. 이 경우 프로젝트의 구조는 아래와 유사한 꼴을 가진다.

projectroot
│
├───src
│   │   __init__.py
│   │   codes1.py
│   │   codes2.py
│
└───test
    │   test1.py
    │   test2.py
    │   __init__.py

위와 같이 폴더가 구성되어 있다면 project root에서 아래와 같이 명령어를 입력하는 것으로 작성한 모든 테스트들을 수행할 수 있다.

> python -m unittest

또, 주로 vs code를 사용한다면, vs code의 테스트 창을 사용할 수 있게 된다.

반응형

댓글