데이터분석/통계

파이썬 다항회귀분석

씩씩한 IT블로그 2022. 9. 11. 18:28
반응형

다항회귀분석

차수가 2차수 이상인 다항회귀분석 모델을 만든다.

 

코드 설명

1. 학습데이터 준비

# 데이터 생성
x_train = sorted(6*rd.rand(100,1)-3)
y_train = 0.5 * x**2 + x + 2 + rd.randn(100,1)

 

2. 학습데이터 다차항으로 transform하기

파이썬에서 다항회귀를 하기 위해선 다항데이터를 만들어 주어야 한다.

즉 aX => bX+cX^2으로 만들어야 한다.

degree 파라미터를 이용해 몇차항으로 만들것인지 판단한며 include_bias를 이용하여 bias를 만들껀지 판단한다.

코드는 다음과 같다.

# 데이터 만들기
quadratic = PolynomialFeatures(degree=2, include_bias=False) # 다항회귀용 객체 생성(차수 결정)
x_quad = quadratic.fit_transform(x_train) # x -> x+x^2
print(x_train)
print(x_quad)
[array([-2.94137492]), array([-2.93027676]), array([-2.8131955]), array([-2.81048313]), array([-2.80484292]), array([-2.80078135]), array([-2.76228705]), array([-2.6881723]), array([-2.6378875]), array([-2.61474638]), array([-2.56531867]), array([-2.52852902]), array([-2.51931092]), array([-2.51022011]), array([-2.47280101]), array([-2.37259296]), array([-2.37041893]), array([-2.29615087]), array([-2.24888525]), array([-2.21253073]), array([-2.20370548]), array([-2.14502674]), array([-2.04273192]), array([-1.99510451]), array([-1.90729251]), array([-1.88684691]), array([-1.79909853]), array([-1.79685164]), array([-1.72137178]), array([-1.68471608]), array([-1.62027236]), array([-1.5477299]), array([-1.5173806]), array([-1.38603957]), array([-1.25002453]), array([-1.13782825]), array([-1.04123215]), array([-0.97862525]), array([-0.96329884]), array([-0.89623841]), array([-0.78884194]), array([-0.64503866]), array([-0.61846034]), array([-0.41816463]), array([-0.28609029]), array([-0.26867911]), array([-0.16074056]), array([-0.12036699]), array([-0.11695832]), array([-0.11086589]), array([-0.05043272]), array([-0.0108257]), array([-0.00667108]), array([0.10626986]), array([0.15841597]), array([0.24777801]), array([0.28555426]), array([0.3578921]), array([0.40524509]), array([0.49182841]), array([0.60012329]), array([0.69990189]), array([0.70893441]), array([0.75732193]), array([0.8185815]), array([0.86146617]), array([0.91000095]), array([0.95224665]), array([1.11339172]), array([1.13678981]), array([1.16989132]), array([1.21478982]), array([1.23885436]), array([1.25436312]), array([1.28452159]), array([1.38953713]), array([1.56614395]), array([1.62553107]), array([1.65958112]), array([1.67910593]), array([1.73819104]), array([1.77922375]), array([1.7853546]), array([1.80352019]), array([1.8103685]), array([1.88895068]), array([1.90378662]), array([2.06869937]), array([2.16073438]), array([2.1861495]), array([2.21173901]), array([2.22581474]), array([2.2307014]), array([2.36459522]), array([2.36889994]), array([2.52508709]), array([2.61727348]), array([2.80916057]), array([2.85519153]), array([2.91997556])]
[[-2.94137492e+00  8.65168639e+00]
 [-2.93027676e+00  8.58652189e+00]
 [-2.81319550e+00  7.91406894e+00]
 .
 .
 .
 ]

 

3. 학습 및 bias와 계수 확인하기

# 학습
pr = LinearRegression()
pr.fit(x_quad,y_train)
print("bias :",pr.intercept_, "계수 :",pr.coef_)
bias : [2.57786245] 계수 : [[0.09402803 0.07152709]]

 

4. 시각화

print(f"poly model : y = {pr.intercept_[0]} + {pr.coef_[0][0]}x + {pr.coef_[0][1]}x**2")
plt.plot(x_train,y_train,"b.",label="raw data")
plt.plot(x_train,pr.predict(x_quad),"r-",linewidth=2,label="poly")

plt.legend()
plt.axis([-3,3,0,10])
plt.show()

 

전체 코드(비교를 위해 선형회귀 추가)

# 데이터 생성
x_train = sorted(6*rd.rand(100,1)-3)
y_train = 0.5 * x**2 + x + 2 + rd.randn(100,1)

# [다항회귀]
# 데이터 만들기
quadratic = PolynomialFeatures(degree=2, include_bias=False) # 다항회귀용 객체 생성(차수 결정)
x_quad = quadratic.fit_transform(x_train) # x -> x+x^2
# print(x_train)
# print(x_quad)
# 학습
pr = LinearRegression()
pr.fit(x_quad,y_train)
print("bias :",pr.intercept_, "계수 :",pr.coef_)

# [단순회귀]
lr = LinearRegression()
lr.fit(x_train,y_train)
print("bias :",lr.intercept_, "계수 :",lr.coef_)

# [그래프 시각화]
print(f"poly model : y = {pr.intercept_[0]} + {pr.coef_[0][0]}x + {pr.coef_[0][1]}x**2")
print(f"linear model : y = {lr.intercept_[0]} + {lr.coef_[0][0]}x")
plt.plot(x_train,y_train,"b.",label="raw data")
plt.plot(x_train,pr.predict(x_quad),"r-",linewidth=2,label="poly")
plt.plot(x_train,lr.predict(x_train),"g-",linewidth=2,label="linear")

plt.legend()
plt.axis([-3,3,0,10])
plt.show()

반응형