151 lines
4.5 KiB
Plaintext
151 lines
4.5 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# `torch.nn.Embedding`\n",
|
||
"\n",
|
||
"> ‘我一直对Embedding有一层抽象的,模糊的认识’\n",
|
||
"\n",
|
||
"参考[我最爱的b站up主的内容](https://www.bilibili.com/video/BV1wm4y187Cr/?spm_id_from=333.337.search-card.all.click&vd_source=32f9de072b771f1cd307ca15ecf84087)\n",
|
||
"\n",
|
||
"## embedding的基础概念\n",
|
||
"\n",
|
||
"`embedding`是将词向量中的词映射为固定长度的词向量的技术,可以将one_hot出来的高维度的稀疏的向量转化成低维的连续的向量\n",
|
||
"\n",
|
||
"![直观显示词与词之间的关系](../assets/lecture-pic/Embedding1.png)\n",
|
||
"\n",
|
||
"## 首先明白embedding的计算过程\n",
|
||
"\n",
|
||
"- embedding module 的前向过程是一个索引(查表)的过程\n",
|
||
" - 表的形式是一个matrix (也即 embedding.weight,learnabel parameters)\n",
|
||
" - matrix.shape:(v,h)\n",
|
||
" - v:vocabulary size\n",
|
||
" - h:hidden dimension\n",
|
||
"\n",
|
||
" - 具体的索引的过程,是通过onehot+矩阵乘法的形式实现的\n",
|
||
" - input.shape:(b,s)\n",
|
||
" - b: batch size\n",
|
||
" - s: seq len \n",
|
||
" - embedding(input)=>(b,s,h)\n",
|
||
" - **这其中关键的问题就是(b,s)和(v,h)怎么变成了(b,s,h)**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[[-0.4682, -1.2143, 0.1752],\n",
|
||
" [ 1.5085, 0.4936, 0.3845],\n",
|
||
" [-1.1064, 1.0143, 0.4442],\n",
|
||
" [ 0.6037, 0.6854, 0.3562]],\n",
|
||
"\n",
|
||
" [[-1.1064, 1.0143, 0.4442],\n",
|
||
" [ 1.0134, 0.2836, -0.6358],\n",
|
||
" [ 1.5085, 0.4936, 0.3845],\n",
|
||
" [ 1.5759, -0.5384, -0.0649]]], grad_fn=<EmbeddingBackward0>)"
|
||
]
|
||
},
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"embedding = nn.Embedding(10, 3)\n",
|
||
"# a batch of 2 samples of 4 indices each\n",
|
||
"input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])\n",
|
||
"embedding(input)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## One-Hot 矩阵乘法\n",
|
||
"\n",
|
||
"目前`one_hot`可以很方便地在`torch.nn.functional`中进行调用,对于一个[batchsize,seqlength]的tensor,one_hot向量可以十分方便的将其转化为[batchsize,seqlength,numclasses],此时,再与[numclasses,h]进行相乘,从而得到最终的[b,s,v]\n",
|
||
"\n",
|
||
"## 参数padding_idx\n",
|
||
"\n",
|
||
"这个参数的作用是指定某个位置的梯度不进行更新,但是为什么不进行更新,以及在哪个位置不进行更新我还没搞明白...."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Parameter containing:\n",
|
||
"tensor([[ 1.0000, 1.0000, 1.0000],\n",
|
||
" [-0.0073, 0.8613, -0.4185],\n",
|
||
" [-0.1206, -0.8382, 0.4391]], requires_grad=True)"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# example with padding_idx\n",
|
||
"embedding = nn.Embedding(10, 3, padding_idx=0)\n",
|
||
"input = torch.LongTensor([[0, 2, 0, 5]])\n",
|
||
"embedding(input)\n",
|
||
"# example of changing `pad` vector\n",
|
||
"padding_idx = 0\n",
|
||
"embedding = nn.Embedding(3, 3, padding_idx=padding_idx)\n",
|
||
"embedding.weight\n",
|
||
"with torch.no_grad():\n",
|
||
" embedding.weight[padding_idx] = torch.ones(3)\n",
|
||
"embedding.weight"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 关于max_norm\n",
|
||
"\n",
|
||
"这个参数用于设置输出和权重参数是否经过了正则化。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "pytorch",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|