nn.embedding()其实是NLP中常用的词嵌入层,在实现词嵌入的过程中embedding层的权重用于随机初始化词的向量,该embedding层的权重参数在后续训练时会不断更新调整,并被优化。
nn.embedding:这是一个矩阵类,该开始时里面初始化了一个随机矩阵,矩阵的长是字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。
因为输入的句子长度不一,有的长有的短。长了截断,不够长补齐(我文中用’'填充,然后在nn.embedding层将其补0,也就是用它来表示无意义的词,这样在后面的max-pooling层也就自然而然会把其过滤掉,这样就不用担心他会影响识别。)
这里说一下它的用法:
nn.embedding()的主要3个参数:
第一个参数num_embeddings是指词表大小
第二个参数embedding_dim是指你需要用多少维来表示一个符号
第三个参数pading_idx即需要用0填充的符号在词表中的位置,如下,输出中后面两个’'都有被填充为了0.
import torch import torch.nn as nn #词表 word_to_id = {'hello':0, '<PAD>':1,
'world':2} embeds = nn.Embedding(len(word_to_id), 4,padding_idx=word_to_id[
'<PAD>']) text = 'hello world <PAD> <PAD>' hello_idx = torch.LongTensor([
word_to_id[i] for i in text.split()]) #词嵌入得到词向量 hello_embed = embeds(hello_idx)
print(hello_embed)
从以下输出可以看到,每行代表句子中一个单词的词嵌入向量,句子中的每个单词都有4维度,最后两个0向量是时用来填充补齐的没意义。
所以embedding层其实相当于将前面用索引编码的句子表示乘上embedding层的可训练权重得到的就是词嵌入的结果
输出:
tensor([[-1.1436, 1.4588, -1.2755, 0.0077],
[-0.9600, -1.9986, -1.1087, -0.1520],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=)
你也可以使用nn.Embedding.from_pretrained()加载预训练好的模型,如word2vec,glove等,在训练的过程中也可以边训练,边更新词向量,加快模型的收敛。本文用的只是简单的nn.embedding()嘿嘿~
然后具体使用 nn.embedding() 时,写在初始化搭建网络里,如下:
class Network(nn.Module): def __init__(self): super(TextCNN, self).__init__(
nvocab,embed) self.filter_sizes = (2, 3, 4) self.embed = embed self.num_filters
= 256 self.dropout = 0.5 self.num_classes = num_classes self.n_vocab = nvocab
#通过padding_idx将<PAD>字符填充为0,因为他没意义哦,后面max-pooling自然而然会把他过滤掉哦 self.embedding = nn.
Embedding(self.n_vocab, self.embed, padding_idx=word2idx['<PAD>']) self.convs =
nn.ModuleList( [nn.Conv2d(1, self.num_filters, (k, self.embed)) for k in self.
filter_sizes]) self.dropout = nn.Dropout(self.dropout) self.fc = nn.Linear(self.
num_filters* len(self.filter_sizes), self.num_classes) def conv_and_pool(self, x
, conv): x = F.relu(conv(x)).squeeze(3) x = F.max_pool1d(x, x.size(2)).squeeze(2
) return x def forward(self, x): out = self.embedding(x) out = out.unsqueeze(1)
out= torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) out =
self.dropout(out) out = self.fc(out) return out