博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
6.PyTorch实现逻辑回归(多分类)
阅读量:4203 次
发布时间:2019-05-26

本文共 3447 字,大约阅读时间需要 11 分钟。

1 准备数据

import torchimport matplotlib.pyplot as plt
import numpy as npxy = np.loadtxt('./资料/data/diabetes.csv.gz', delimiter=',', dtype=np.float32)# 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要x_data = torch.from_numpy(xy[:,:-1])# [-1] 最后得到的是个矩阵y_data = torch.from_numpy(xy[:, [-1]])

2 构建模型

class Model(torch.nn.Module):    def __init__(self):        super(Model, self).__init__()        self.linear1 = torch.nn.Linear(8, 6) # 输入数据x的特征是8维,x有8个特征        self.linear2 = torch.nn.Linear(6, 4) # 6个输入,4个输出        self.linear3 = torch.nn.Linear(4, 1) # 4个输入,1个输出        self.sigmoid = torch.nn.Sigmoid() # 最后一层,将输出结构映射到sigmoid函数中     def forward(self, x):        x = self.sigmoid(self.linear1(x))        x = self.sigmoid(self.linear2(x))        x = self.sigmoid(self.linear3(x)) # y hat        return x  model = Model()

3 模型训练

# construct loss and optimizer# criterion = torch.nn.BCELoss(size_average = True)criterion = torch.nn.BCELoss(reduction='mean')  optimizer = torch.optim.SGD(model.parameters(), lr=0.1) epoch_list = []loss_list = []# training cycle forward, backward, updatefor epoch in range(100):    # 1. Forward    y_pred = model(x_data)    loss = criterion(y_pred, y_data)    print(epoch, loss.item())    epoch_list.append(epoch)    loss_list.append(loss.item())        # 2. Backward    optimizer.zero_grad()    loss.backward()        # 3. Update    optimizer.step()
0 0.68533229827880861 0.68120205402374272 0.67751443386077883 0.67422044277191164 0.67127674818038945 0.6686449050903326 0.66629064083099377 0.66418379545211798 0.66229748725891119 0.660607874393463110 0.659093797206878711 0.657736361026763912 0.656518995761871313 0.655426681041717514 0.654446125030517615 0.653565704822540316 0.652774751186370817 0.652064085006713918 0.651425182819366519 0.650850653648376520 0.650333881378173821 0.64986890554428122 0.649450302124023423 0.649073481559753424 0.648734092712402325 0.648428380489349426 0.648152828216552727 0.647904515266418528 0.647680699825286929 0.647478818893432630 0.647296726703643831 0.647132456302642832 0.646984279155731233 0.646850526332855234 0.646729707717895535 0.646620631217956536 0.646522223949432437 0.646433234214782738 0.646352827548980739 0.646280229091644340 0.646214604377746641 0.646155238151550342 0.646101534366607743 0.646053016185760544 0.64600908756256145 0.645969390869140646 0.645933389663696347 0.645900845527648948 0.645871341228485149 0.645844697952270550 0.645820498466491751 0.645798563957214452 0.645778715610504253 0.645760655403137254 0.645744323730468855 0.645729482173919756 0.645716011524200457 0.645703852176666358 0.645692706108093359 0.645682573318481460 0.64567339420318661 0.645665049552917562 0.645657360553741563 0.645650446414947564 0.645644128322601365 0.645638346672058166 0.645633041858673167 0.645628213882446368 0.645623743534088169 0.645619750022888270 0.645615994930267371 0.645612597465515172 0.64560943841934273 0.645606577396392874 0.645603895187377975 0.645601391792297476 0.645599186420440777 0.645597100257873578 0.645595073699951279 0.645593225955963180 0.645591616630554281 0.645590007305145382 0.645588457584381183 0.645587027072906584 0.645585715770721485 0.645584464073181286 0.645583331584930487 0.645582199096679788 0.645581126213073789 0.645580172538757390 0.645579159259796191 0.645578265190124592 0.645577430725097793 0.64557653665542694 0.64557576179504495 0.645574867725372396 0.64557415246963597 0.645573437213897798 0.645572662353515699 0.6455719470977783

转载地址:http://wfali.baihongyu.com/

你可能感兴趣的文章
网站加载代码
查看>>
php图像处理函数大全(缩放、剪裁、缩放、翻转、旋转、透明、锐化的实例总结)
查看>>
magento url中 uenc 一坨编码 base64
查看>>
强大的jQuery焦点图无缝滚动走马灯特效插件cxScroll
查看>>
Yii2.0 数据库查询
查看>>
yii2 db 操作
查看>>
mongodb group 有条件的过滤组合个数。
查看>>
yii2 用命令行操作web下的controller
查看>>
yii2 console的使用
查看>>
关于mongodb的 数组分组 array group
查看>>
MongoDB新的数据统计框架介绍
查看>>
mongodb fulltextsearch 关于语言的设置选项
查看>>
mongodb 增加全文检索索引
查看>>
symfony
查看>>
yourls 短连接 安装
查看>>
yii2 php namespace 引入第三方非namespace库文件时候,报错:Class not found 的解决
查看>>
softlayer 端口开放
查看>>
操作1:mongodb安装
查看>>
操作2:mongodb使用语法
查看>>
如何给分类增加一个属性(后台)
查看>>