本文共 3447 字,大约阅读时间需要 11 分钟。
import torchimport matplotlib.pyplot as plt
import numpy as npxy = np.loadtxt('./资料/data/diabetes.csv.gz', delimiter=',', dtype=np.float32)# 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要x_data = torch.from_numpy(xy[:,:-1])# [-1] 最后得到的是个矩阵y_data = torch.from_numpy(xy[:, [-1]])
class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear1 = torch.nn.Linear(8, 6) # 输入数据x的特征是8维,x有8个特征 self.linear2 = torch.nn.Linear(6, 4) # 6个输入,4个输出 self.linear3 = torch.nn.Linear(4, 1) # 4个输入,1个输出 self.sigmoid = torch.nn.Sigmoid() # 最后一层,将输出结构映射到sigmoid函数中 def forward(self, x): x = self.sigmoid(self.linear1(x)) x = self.sigmoid(self.linear2(x)) x = self.sigmoid(self.linear3(x)) # y hat return x model = Model()
# construct loss and optimizer# criterion = torch.nn.BCELoss(size_average = True)criterion = torch.nn.BCELoss(reduction='mean') optimizer = torch.optim.SGD(model.parameters(), lr=0.1) epoch_list = []loss_list = []# training cycle forward, backward, updatefor epoch in range(100): # 1. Forward y_pred = model(x_data) loss = criterion(y_pred, y_data) print(epoch, loss.item()) epoch_list.append(epoch) loss_list.append(loss.item()) # 2. Backward optimizer.zero_grad() loss.backward() # 3. Update optimizer.step()
0 0.68533229827880861 0.68120205402374272 0.67751443386077883 0.67422044277191164 0.67127674818038945 0.6686449050903326 0.66629064083099377 0.66418379545211798 0.66229748725891119 0.660607874393463110 0.659093797206878711 0.657736361026763912 0.656518995761871313 0.655426681041717514 0.654446125030517615 0.653565704822540316 0.652774751186370817 0.652064085006713918 0.651425182819366519 0.650850653648376520 0.650333881378173821 0.64986890554428122 0.649450302124023423 0.649073481559753424 0.648734092712402325 0.648428380489349426 0.648152828216552727 0.647904515266418528 0.647680699825286929 0.647478818893432630 0.647296726703643831 0.647132456302642832 0.646984279155731233 0.646850526332855234 0.646729707717895535 0.646620631217956536 0.646522223949432437 0.646433234214782738 0.646352827548980739 0.646280229091644340 0.646214604377746641 0.646155238151550342 0.646101534366607743 0.646053016185760544 0.64600908756256145 0.645969390869140646 0.645933389663696347 0.645900845527648948 0.645871341228485149 0.645844697952270550 0.645820498466491751 0.645798563957214452 0.645778715610504253 0.645760655403137254 0.645744323730468855 0.645729482173919756 0.645716011524200457 0.645703852176666358 0.645692706108093359 0.645682573318481460 0.64567339420318661 0.645665049552917562 0.645657360553741563 0.645650446414947564 0.645644128322601365 0.645638346672058166 0.645633041858673167 0.645628213882446368 0.645623743534088169 0.645619750022888270 0.645615994930267371 0.645612597465515172 0.64560943841934273 0.645606577396392874 0.645603895187377975 0.645601391792297476 0.645599186420440777 0.645597100257873578 0.645595073699951279 0.645593225955963180 0.645591616630554281 0.645590007305145382 0.645588457584381183 0.645587027072906584 0.645585715770721485 0.645584464073181286 0.645583331584930487 0.645582199096679788 0.645581126213073789 0.645580172538757390 0.645579159259796191 0.645578265190124592 0.645577430725097793 0.64557653665542694 0.64557576179504495 0.645574867725372396 0.64557415246963597 0.645573437213897798 0.645572662353515699 0.6455719470977783
转载地址:http://wfali.baihongyu.com/