본문 바로가기
Study/Python

Matplotlib로 subplot을 생성하고, 겹치지 않게 하기

by 개발새-발 2021. 7. 17.
반응형

matplotlib를 사용하면서 subplot을 만들어야 하는 상황이 생긴다. subplot(), add_subplot(), subplots() 이 세 개의 함수은 subplot을 만드는데 쓰인다. 사용법의 차이를 알아보기 위하여 아래 x로부터 생성된 y1, y2, y3를 각기 다른 방법으로 subplot에 그려보았다.

import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0,5,0.1)
y1 = np.cos(x)
y2 = np.exp(x)
y3 = y1 * y2
plt.plot(x,y1)
plt.show()

plt.plot(x,y2)
plt.show()

plt.plot(x,y3)
plt.show()

 

 

 

Subplot 만들기

pyplot.subplot()

subplot() 은 현재 조작 중인 figure에 subplot을 만들어준다. matplotlib는 figure 객체를 생성하여 그 객체를 조작할 수 있도록 하는 Object Oriented(객체지향) 인터페이스를 제공하는 동시에 MATLAB과 유사하게 동작하는 state-machine에 기반한 인터페이스도 제공한다. subplot은 state-machine에 기반한 인터페이스를 사용한다. 그렇기에 plt.subplot() 이후에 쓰인 코드는 plt.subplot()이 나타내는 plot을 조작하는데 쓰이게 된다.

plt.subplot은 다음과 같이 사용한다.

plt.subplot(nrows,ncols,index)

nrows행의 수, ncols의 수, index위치를 나타낸다. 좌측 상단에서 1로 시작하며 오른쪽으로 세어나간다. 해당 행을 다 센 경우 바로 아래의 행부터 이어서 센다.

만약 만들고자 하는 subplot의 개수가 9개가 넘지 않아 nrows, ncols, index가 모두 한자리 수라면 이 수들을 이어서 세 자릿수의 정수 하나로 만들어 plt.subplot()에서 사용할 수도 있다.

# if nrows = 3, ncols = 1, index = 2 -> 312
plt.subplot(312)

아래 코드는 위에서 생성한 y1, y2, y3를 plt.subplot()을 이용해 세 개의 subplot에서 그린 예제이다.

plt.subplot(311)
plt.plot(x,y1)
plt.title('y = cos(x)')

plt.subplot(312)
plt.plot(x,y2)
plt.title('y = exp(x)')

plt.subplot(3,1,3)
plt.plot(x,y3)
plt.title('y = cos(x)*exp(y)')

plt.show()

title과 subplot이 겹쳐 보인다. 이를 해결하기 위한 방법은 나중에 확인할 것임으로 일단 무시하자. 이번엔 가로, 세로로 각각 2등분 하여 좌측 상단에 하나, 우측에 두 개를 그려보자.

plt.subplot(221)
plt.plot(x,y1)

plt.subplot(222)
plt.plot(x,y2)

plt.subplot(2,2,4)
plt.plot(x,y3)

plt.show()

Figure.add_subplot()

Figure.add_subplot는 지정한 figure에 직접 subplot을 만들어준다. 만든 subplot은 객체로 반환되며, 이 객체를 조작하면 만들어진 subplot을 조작할 수 있다. 이외의 사용법은 pyplot.subplot()과 같다. pyplot.subplot()이 state-machine 기반의 인터페이스였다면, Figure.add_subplot은 Object Oriented 인터페이스라고 생각하면 된다.

fig = plt.figure()

ax1 = fig.add_subplot(311)
ax1.plot(x,y1)
ax1.set_title('y = cos(x)')

ax2 = fig.add_subplot(312)
ax2.plot(x,y2)
ax2.set_title('y = exp(x)')

ax3 = fig.add_subplot(3,1,3)
ax3.plot(x,y3)
ax3.set_title('y = cos(x)*exp(y)')

plt.show()

fig = plt.figure()

ax1 = fig.add_subplot(221)
ax1.plot(x,y1)

ax2 = fig.add_subplot(222)
ax2.plot(x,y2)

ax3 = fig.add_subplot(2,2,4)
ax3.plot(x,y3)

plt.show()

pyplot.subplots()

행과 열의 수를 입력하면 figure에 subplot들을 만들어 반환한다. 반환된 subplot들을 조작하여 원하는 그래프를 그릴 수 있다. 간단한 사용법은 아래와 같다.

plt.subplot(nrows,ncols)

nrowsncols는 각각 의 수이다. 추가로, sharex, sharey를 지정하여 각각의 subplot들의 각각의 x, y 축들이 같게 할 수 있다.

fig, axs = plt.subplots(3,1)

axs[0].plot(x,y1)
axs[0].set_title('cos(x)')

axs[1].plot(x,y2)
axs[1].set_title('exp(x)')

axs[2].plot(x,y3)
axs[2].set_title('exp(x) * cos(x)')

plt.show()

sharey를 지정하여 각 subplot들의 보이는 y축의 범위가 같도록 하였다. y1의 최댓값은 1이었으므로 거의 직선처럼 보이게 된다.

fig, axs = plt.subplots(3,1,sharey=True)

axs[0].plot(x,y1)
axs[1].plot(x,y2)
axs[2].plot(x,y3)

plt.show()

이미 subplot들이 다 만들어진 상태임으로 좌측 상단에 하나, 우측에 상단, 하단에 각각 하나씩 3개의 그래프만 조작하였더라도 좌측 하단에 아무것도 그려지지 않은 subplot이 보이게 된다.

fig, axes = plt.subplots(2,2)

axes[0,0].plot(x,y1)
axes[0,1].plot(x,y2)
axes[1,1].plot(x,y3)

plt.show()

Subplot들의 간격 조정하기

subplots_adjust()

subplot을 만들었는데, title이 다른 subplot에 겹쳐 보이는 현상이 존재한다. title 외에도 xlabel, ylabel도 겹쳐 보일 수 있다. 이경우 우리는 subplot들 간의 간격을 조정하여 해결할 수 있다. 이 간격을 조정하기 위하여 subplots_adjust()를 사용하였다. state-machine 인터페이스와 Object-oriented 인터페이스 둘 다 사용이 가능하다.

# state-machine based interface
plt.subplots_adjust(~~)

# object oriented based interface
fig.subplots_adjust(~~)

조정 가능한 항목은 다음과 같다.

  • wspace, hspace : subplot의 가로, 세로 간격을 조정한다.
  • top, bottom, left, right : 위치를 조정할 때 쓰인다.
# subplot_adjust with plt.subplots
fig, axs = plt.subplots(3,1)
fig.subplots_adjust(hspace=1)

axs[0].plot(x,y1)
axs[0].set_title('cos(x)')

axs[1].plot(x,y2)
axs[1].set_title('exp(x)')

axs[2].plot(x,y3)
axs[2].set_title('exp(x) * cos(x)')

plt.show()

tight_layout()

자동으로 subplot들을 둘러싸고 있는 여백을 조정한다. 특별한 값을 지정해주지 않고도 간편하게 사용 가능하다.

fig, axs = plt.subplots(3,1)
fig.tight_layout()

axs[0].plot(x,y1)
axs[0].set_title('cos(x)')

axs[1].plot(x,y2)
axs[1].set_title('exp(x)')

axs[2].plot(x,y3)
axs[2].set_title('exp(x) * cos(x)')

plt.show()

constrained_layout = True

tight_layout과 비슷하다. plt.subplots() 혹은 plt.figure()에 인자로 넣어주면 된다. 다음과 같이 사용해줄 수 있다.

# plt.subplots
fig, axs = plt.subplots(3,1,constrained_layout=True)

# plt.figure
fig = plt.figure(constrained_layout = True)
fig, axs = plt.subplots(3,1,constrained_layout=True)

axs[0].plot(x,y1)
axs[0].set_title('cos(x)')

axs[1].plot(x,y2)
axs[1].set_title('exp(x)')

axs[2].plot(x,y3)
axs[2].set_title('exp(x) * cos(x)')

plt.show()

반응형

댓글