반응형
Pytorch로 선형회귀 구현하기¶
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import SGD
데이터 생성¶
y = 2*x + 10에서 약간의 오차를 준 데이터를 사용하였다.
# y = x*2 + 10
data_x = np.random.randint(1,20,20)
data_y = np.array([1*np.random.randn() + x*2 + 10 for x in data_x])
data_x = torch.FloatTensor(data_x).unsqueeze(dim=1)
data_y = torch.FloatTensor(data_y).unsqueeze(dim=1)
plt.plot(data_x,data_y,'bo')
[<matplotlib.lines.Line2D at 0x7fe85d089e90>]
모델 생성¶
nn.Linear을 사용하였다. 초기화시에 입력해야 하는 값은 입력의 수(in_features), 출력의 수(out_features),bias의 사용 여부이다. 이 코드에선 bias=True로 설정해주었는데, Linear에선 이 옵션의 기본값이 True이기 때문에 따로 옵션을 부여하지 않아도 bias=True가 적용된다.
커스텀 모듈 생성¶
학습모델¶
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression,self).__init__()
self.linear = nn.Linear(1,1,bias=True)
def forward(self,x):
return self.linear(x)
RMSE 오차¶
class RMSELoss(nn.Module):
def __init__(self):
super(RMSELoss,self).__init__()
self.mse = nn.MSELoss()
self.eps = 1e-7
def forward(self,y,y_hat):
return torch.sqrt(self.mse(y,y_hat) + self.eps)
학습¶
RMSE(Root Mean Square Error)오차와 확률적 경사강하(SGD)를 사용하여 학습을 진행하였다.
model = LinearRegression()
lossfn = RMSELoss()
lr = 0.01
optim = SGD(model.parameters(),lr=lr)
EPOCHS = 5000
model.train()
for epoch in range(EPOCHS):
y_hat = model(data_x)
loss = lossfn(data_y,y_hat)
optim.zero_grad()
loss.backward()
optim.step()
예측¶
pred = []
model.eval()
with torch.no_grad():
for i in range(21):
pred.append(model(torch.tensor([i],dtype=torch.float32)))
적색 실선이 모델이 예측한 값을 나타낸다.
plt.plot(data_x,data_y,'bo')
plt.plot(range(len(pred)),pred,'-r')
[<matplotlib.lines.Line2D at 0x7fe85c39b750>]
for params in model.parameters():
print(params)
Parameter containing: tensor([[1.9398]], requires_grad=True) Parameter containing: tensor([10.4748], requires_grad=True)
반응형
'Study > Python' 카테고리의 다른 글
Python에서 Glob으로 파일 혹은 폴더의 경로 불러오기 (0) | 2021.03.26 |
---|---|
Pytorch Tensor(텐서) 만들기 (0) | 2021.03.23 |
Kaggle에서 Pytorch로 간단한 Mnist 숫자 분류기 만들기 (0) | 2021.03.16 |
.py를 .ipynb으로, 또 그 반대로 (3) | 2020.10.17 |
Python에서 c,c++ 코드 사용하기 (0) | 2020.10.04 |
댓글