使用libtorch训练一个异或逻辑门

news/发布时间2024/4/28 5:29:36

本文以一个例子介绍如何使用libtorch创建一个包含多层神经元的感知机,训练识别异或逻辑。即${ z = x \text{^} y }$。本例的测试环境是VS2017和libtorch1.13.1。从本例可以学到如何复用网络结构,如下方的LinearSigImpl类的写法。该测试网络结构如下图。一个线性层2输入3输出,一个Sigmoid激活函数3输入3输出,一个线性输出层:

头文件代码如下:

class LinearSigImpl : public torch::nn::Module
{
public:LinearSigImpl(int intput_features, int output_features);torch::Tensor forward(torch::Tensor x);private:torch::nn::Linear ln;torch::nn::Sigmoid bn;
};TORCH_MODULE(LinearSig);class Mlp : public torch::nn::Module
{
public:Mlp(int in_features, int out_features);torch::Tensor forward(torch::Tensor x);private:LinearSig ln1;torch::nn::Linear output;
};

CPP文件:

LinearSigImpl::LinearSigImpl(int in_features, int out_features) : ln(nullptr), bn(nullptr)
{ln = register_module("ln", torch::nn::Linear(in_features, out_features));bn = register_module("bn", torch::nn::Sigmoid());
}torch::Tensor LinearSigImpl::forward(torch::Tensor x)
{x = ln->forward(x);x = bn->forward(x);return x;
}Mlp::Mlp(int in_features, int out_features) : ln1(nullptr), output(nullptr)
{ln1 = register_module("ln1", LinearSig(in_features, 3));output = register_module("output", torch::nn::Linear(3, out_features));
}torch::Tensor Mlp::forward(torch::Tensor x)
{x = ln1->forward(x);x = output->forward(x);return x;
}int main()
{Mlp linear(2, 1);/* 30个样本。在这里是一行一个样本 */at::Tensor b = torch::rand({ 30, 2 });at::Tensor c = torch::zeros({ 30, 1 });for (int i = 0; i < 30; i++){b[i][0] = (b[i][0] >= 0.5f);b[i][1] = (b[i][1] >= 0.5f);c[i] = b[i][0].item().toBool() ^ b[i][1].item().toBool();}//cout << b << endl;//cout << c << endl;/* 训练过程 */torch::optim::SGD optim(linear.parameters(), torch::optim::SGDOptions(0.01));torch::nn::MSELoss lossFunc;linear.train();for (int i = 0; i < 50000; i++){torch::Tensor predict = linear.forward(b);torch::Tensor loss = lossFunc(predict, c);optim.zero_grad();loss.backward();optim.step();if (i % 2000 == 0){/* 每2000次循环输出一次损失函数值 */cout << "LOOP:" << i << ",LOSS=" << loss.item() << endl;}}/* 非线性的网络就不输出网络参数了 *//* 太过玄学,输出也看不懂 *//* 做个测试 */at::Tensor x = torch::tensor({ { 1.0f, 0.0f }, { 0.0f, 1.0f }, { 1.0f, 1.0f }, { 0.0f, 0.0f} });at::Tensor y = linear.forward(x);cout << "输出为[1100]=" << y;/* 看看能不能泛化 */x = torch::tensor({ { 0.9f, 0.1f }, { 0.01f, 0.2f } });y = linear.forward(x);cout << "输出为[10]=" << y;return 0;
}

控制台输出如下。如果把0.5作为01分界线,从输出上看网络是有一定的泛化能力的。当然每次运行输出数字都不同,绝大多数泛化结果都正确:

LOOP:0,LOSS=1.56625
LOOP:2000,LOSS=0.222816
LOOP:4000,LOSS=0.220547
LOOP:6000,LOSS=0.218447
LOOP:8000,LOSS=0.215877
LOOP:10000,LOSS=0.212481
LOOP:12000,LOSS=0.207645
LOOP:14000,LOSS=0.199905
LOOP:16000,LOSS=0.187244
LOOP:18000,LOSS=0.168875
LOOP:20000,LOSS=0.145476
LOOP:22000,LOSS=0.118073
LOOP:24000,LOSS=0.087523
LOOP:26000,LOSS=0.0554768
LOOP:28000,LOSS=0.0280211
LOOP:30000,LOSS=0.0109953
LOOP:32000,LOSS=0.00348786
LOOP:34000,LOSS=0.000959343
LOOP:36000,LOSS=0.000243072
LOOP:38000,LOSS=5.89887e-05
LOOP:40000,LOSS=1.40228e-05
LOOP:42000,LOSS=3.3041e-06
LOOP:44000,LOSS=7.82167e-07
LOOP:46000,LOSS=1.85229e-07
LOOP:48000,LOSS=4.43763e-08
输出为[1100]= 0.99991.00000.00020.0001
[ CPUFloatType{4,1} ]输出为[10]= 0.99990.4588
[ CPUFloatType{2,1} ]

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ulsteruni.cn/article/34716184.html

如若内容造成侵权/违法违规/事实不符,请联系编程大学网进行投诉反馈email:xxxxxxxx@qq.com,一经查实,立即删除!

相关文章

Python爬取免费IP代理时,无法解析到数据

大家好,我是Python进阶者。 一、前言 前几天在Python最强王者交流群【ZXS】问了一个Python网络爬虫实战问题。问题如下: 我这里遇到一个问题:【爬取免费IP代理时,无法解析到数据】, 我通过 xpath,css定位到了元素,但是在运行时返回空列表,请问我该怎么解决呀 以下是解析数…

04_C++字符串

标准库类型string 1.定义和初始化 初始化:拷贝初始化和直接初始化 2.string对象上的操作 3.读写string 使用getline#include<iostream> #include<string> using namespace std;int main() {string s1;//读取一整行while (getline(cin,s1)){if (!s1.empty()) {c…

数据库大型应用——笔记2 50道mysql练习题

复健了一下mysql,练习内容是mysql50题目。(算法也有在写啦,前几天还被数论折磨)一.开始前数据库中的表的各种信息1.1表名与字段–1.学生表      Student(s_id,s_name,s_birth,s_sex) –学生编号,学生姓名, 出生年月,学生性别      –2.课程表      Cours…

关于钉钉直播回放视频下载若干方法的总结

钉钉直播回放视频下载的基本步骤分为两步,第一步获取m3u8链接或文件,第二步使用m3u8链接或文件下载合并钉钉视频。根据钉钉客户端、版本的不同,以及使用获取m3u8方式的而不同,我总结了三种下载钉钉直播回放视频的方法,具体如下: 获取m3u8链接的几种方式 Fiddler+vconsle抓…

Ubuntu源哪个速度快?镜像站速度比拼!

Ubuntu镜像站网速比拼 先放结论:科大>腾讯云>清华 实验环境 宽带规格:广州联通,带宽1000M。 测试方法:使用vmware workstation 17安装Ubuntu 23.10虚拟机,打上快照。依次切换用不同的镜像源进行更新(sudo apt update && sudo apt upgrade),单个源更新过程中…

实验1 C语言开发环境使用和数据类型、运算符、表达式

task1点击查看代码 #include <stdio.h>int main() {printf(" o\n");printf("<H>\n");printf("I I\n");printf(" o\n");printf("<H>\n");printf("I I\n");system("pause");return 0; …