Skip to content

Linear Model

Giovanna

About 398 wordsAbout 1 min

2024-07-08

xy
12
24
36
4?

What would be the best model for the data?

尝试最简单的模型:线性模型

Linear Model:

y^=x×ω+b \hat y=x\times\omega+b

可以简化为:

y^=x×ω \hat y=x\times\omega

用评估模型判断预测值与真实值的差距。

Training Loss(Error):

loss=(y^y)2=(x×ωy)2 loss=(\hat y-y)^2=(x\times\omega-y)^2

使用穷举法,找到一个ω\omega使得平均损失最低。

Mean Square Error(MSE):

cost=1Nn=1N(y^nyn)2 cost=\frac{1}{N}\sum\limits_{n=1}^{N}(\hat y_n-y_n)^2

完整代码:

import numpy as np  
import matplotlib.pyplot as plt  

# 训练集
x_data={1.0,2.0,3.0}  
y_data={2.0,4.0,6.0}  

# 计算y_pred
def forward(x): return x*w  

# 计算损失
def loss(x,y):  
    y_pred=forward(x)  
    return (y_pred-y)**2  
  
w_list=[]  
mse_list=[]  

# 遍历可能的w
for w in np.arrange(0.0,4.1,0.1):  
    print('w=',w)  
    l_sum=0  
    for x_val,y_val in zip(x_data,y_data):  
        y_pred_val=forward(x_val)  
        loss_val=loss(x_val,y_val)  
        l_sum+=loss_val  
        print('\t',x_val,y_val,y_pred_val,loss_val)  
    print('MSE=',l_sum/3)  
    w_list.append(w)  
    mse_list.append(l_sum/3)  

# 可视化
plt.plot(w_list,mse_list)  
plt.ylabel('Loss')  
plt.xlabel('w')  
plt.show()

运行结果:

tmpCB0C.png

tmp60B9.png

课后作业:考虑bb

代码如下:

import numpy as np  
import matplotlib.pyplot as plt  
  
x_data = [1.0, 2.0, 3.0]  
y_data = [2.0, 4.0, 6.0]  
  
  
def forward(x, w, b):  
    return x * w + b  
  
  
def loss(w, b, x, y):  
    y_pred = forward(x, w, b)  
    return (y_pred - y) ** 2  
  
  
w_list = []  
b_list = []  
mse_list = []  
  
for w in np.arange(0.0, 4.1, 0.1):  
    for b in np.arange(-2.0, 2.1, 0.1):  
        print('w=', w, 'b=', b)  
        l_sum = 0  
        for x_val, y_val in zip(x_data, y_data):  
            y_pred_val = forward(x_val, w, b)  
            loss_val = loss(w, b, x_val, y_val)  
            l_sum += loss_val  
            print('\t', x_val, y_val, y_pred_val, loss_val)  
        mse = l_sum / len(x_data)  
        print('MSE=', mse)  
        w_list.append(w)  
        b_list.append(b)  
        mse_list.append(mse)  
  
w_list = np.array(w_list)  
b_list = np.array(b_list)  
mse_list = np.array(mse_list)  
  
ax = plt.figure().add_subplot(projection='3d')  
ax.plot_trisurf(w_list, b_list, mse_list, cmap='viridis')  
  
ax.set_xlabel('Weight')  
ax.set_ylabel('Bias')  
ax.set_zlabel('MSE')  
  
plt.show()

运行结果:

tmp85DE.png

tmp9D65.png