常见模型结构梳理

下面我将以[问题定义]_[数据集]_[模型]的格式来梳理常见的模型架构。重在梳理模型,不详细阐述任务过程

 

一、多分类_CIFAR10_CNN

数据集load我们会拿到:x=n*32*32*3     y=n*1   (x为32*32的图片,通道数为3表示为彩色图片,其中y是类型的索引id)

通常我们会先将y转成onehot编码,y=n*m (m来分类的总数)

 

 

 

二、多分类_baby_RNN

数据集load我们会拿到三部分:故事、问题、答案。(story, query, answer)

任务目标:我们需要根据故事+问题来进行答案的预测。

那么输入就有2个值,分别为故事和问题,输出则是答案 : x1=(n*552)    x2=(n*5)     y=(n*36)

第一步:首先我们需要从story, query, answer统计出词汇表索引字典:word_idx (key:词 value:索引)

第二步:根据word_idx 构造x1、x2 、y

每个样本的story, query长度不一样怎么处理:统一为最长的那个长度,(此处story为552, query为5,answer就是1),长度不足的用0来补全:masking

answer长度为1是对应词的索引,需要转换成onehot编码以便输出的时候用softmax激活即可。(这里词汇表的大小是35,所以answer进行onehot后是len+1,因为字典是从1开始的)

 

 

 

三、二分类_imdb

这是一个影评情感正负分类问题,数据集=影评文本,labels=pos/neg

对于这个二分类问题,我这边梳理3个框架模型:Bi+lstm、cnn、cnn+lstm

这个数据集load直接就是对应的词索引矩阵,只需要自行截取补0固定一个长度即可!二分类就无需onehot了哈!pig

我们这边截取文本长度为maxlen=100,max_features=20000(获取数据集时记得传参否则会报错)

x=(n*100)  y=(n*1)

①Bi+lstm

 

②cnn (不展示了)

③cnn+lstm

 

四、多分类_20NewsGroup

数据集18000篇新闻文章,一共涉及到20种话题。可以用sklearn获取数据集 data = fetch_20newsgroups(subset='train')

data.data是1W多个文本list
data.target是1W多个分类索引
data.target_names是20个分类名称

【这个任务主要聊引入外部训练好的词向量进行下一步训练模型】

没错,拿到文本数据
【第一步】将data转成词索引矩阵,将target转成onehot编码,我们这里将每篇文章截取1000长度
x=(n*1000)    y=(n*20)

【第二步】构造词向量矩阵(v是文章词汇表总长,d是选择的词向量维度,我这用的glove_50d,d=50)
embedding_matrix=(v*d)

 

引用外部词向量同样也是使用Embedding层实现,通过embeddings_initializer配置上矩阵,然后trainable=False

 

 

 

 

2

发表评论

电子邮件地址不会被公开。 必填项已用*标注

微信扫一扫

微信扫一扫

微信扫一扫,分享到朋友圈

常见模型结构梳理
嘿!有什么能帮到您的吗?
返回顶部

显示

忘记密码?

显示

显示

获取验证码

Close