论文解读(MAML)《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》

news/发布时间2024/5/9 12:07:43

Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]

论文信息

论文标题:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
论文作者:Chelsea Finn、Pieter Abbeel、Sergey Levine
论文来源:2017 
论文地址:download 
论文代码:download
视屏讲解:click

1-摘要

  我们提出了一种与模型无关的元学习算法,在这个意义上,它与任何经过梯度下降训练的模型兼容,并适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练一个模型,这样它就可以只使用少量的训练样本来解决新的学习任务。在我们的方法中,模型的参数被明确地训练,这样少量的梯度步长和来自新任务的少量训练数据将在该任务上产生良好的泛化性能。实际上,我们的方法训练的模型易于微调。我们证明了这种方法在两个低镜头图像分类基准上取得了最先进的性能,在少镜头回归上产生了良好的结果,并加速了使用神经网络策略对策略梯度强化学习的微调。

2-方法

  

  代码:

def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True):meta_loss = []meta_acc = []for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):fast_weights = collections.OrderedDict(model.named_parameters())for _ in range(inner_step):  #inner_step = 1# Update weightsupport_logit = model.functional_forward(support_image, fast_weights)support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)grads = torch.autograd.grad(support_loss, fast_weights.values(), create_graph=True)fast_weights = collections.OrderedDict((name, param - args.inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))# Use trained weight to get query lossquery_logit = model.functional_forward(query_image, fast_weights)query_prediction = torch.max(query_logit, dim=1)[1]query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)meta_loss.append(query_loss)meta_acc.append(query_acc.data.cpu().numpy())# Zero the gradient
    optimizer.zero_grad()meta_loss = torch.stack(meta_loss).mean()meta_acc = np.mean(meta_acc)if is_train:meta_loss.backward()optimizer.step()return meta_loss, meta_acc

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ulsteruni.cn/article/21822403.html

如若内容造成侵权/违法违规/事实不符,请联系编程大学网进行投诉反馈email:xxxxxxxx@qq.com,一经查实,立即删除!

相关文章

Fastbin attackDouble free和Unsortbin leak的综合使用

Fastbin attack&&Double free和Unsortbin leak的综合使用✅ 今天做一个综合题目,包括利用Fastbin attack实现多指针指向一个地址,以及利用Unsortbin leak泄露libc基地址和修改__malloc_hook地址为one_gadget 题目是buuctf上面的一道题目,题目链接 https://buuoj.cn/…

python学习思维导图分享

python 本文包含了我的一些python学习的笔记和思维导图 第一部分:python基础导图下载链接 第二部分:函数及其他文件操作导图下载链接 第三部分:类及网络编程导图下载链接 第四部分:mysql导图下载链接

微机结构

微型计算机结构 总体来说,微型计算机的结构是采用总线结构实现相互之间的信息传递。CPU和存储器通过总线相互连接,I/O设备通过I/O接口连接在总线上。 总线是计算机各部件之间传输数据的通道,有三类总线分别是:数据总线、地址总线和控制总线(反馈)。主要特性有:公共性、分…

京东web端h5st—4.7逆向分析

声明 本文章中所有内容仅供学习交流,抓包内容、敏感网址、数据接口均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关,若有侵权,请联系我立即删除! 目标网站 aHR0cHM6Ly93d3cuamQuY29tLw== 分析流程了解h5st 看了sha256相关加密算法逻辑b…

Games 101: 旋转矩阵

旋转矩阵 本文主要介绍了旋转矩阵的推导,分为两种方式:旋转坐标 旋转坐标轴 以下坐标系都是右手坐标系旋转坐标 已知坐标点\(A(x_a,y_a)\), 旋转\(\theta\)角后变为坐标点\(B(x_b,y_b)\),求解旋转矩阵.\[{\large \begin{align*} \begin{split} x_a &=r_a \cdot cos(\alp…

2024-04-27:用go语言,在一个下标从 1 开始的 8 x 8 棋盘上,有三个棋子,分别是白色车、白色象和黑色皇后。 给定这三个棋子的位置,请计算出要捕获黑色皇后所需的最少移动次数。 需要注意

2024-04-27:用go语言,在一个下标从 1 开始的 8 x 8 棋盘上,有三个棋子,分别是白色车、白色象和黑色皇后。 给定这三个棋子的位置,请计算出要捕获黑色皇后所需的最少移动次数。 需要注意的是,白色车可以垂直或水平移动,而白色象可以沿对角线移动,它们不能跳过其他棋子。…