nanoGPT-Tutorial-CN/Lecture/l3.ipynb

557 lines
22 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 在上一节中我们简单地实现了一个十分十分基础甚至有些简陋的GPT同时起生成效果看起来也有很大的提升空间\n",
"- 这一节中,我们将对通过一系列的推导来向大家引入可以增强性能的自注意力机制\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 自注意力机制怎么增强性能\n",
"\n",
"在此之前nn.Embedding致力于将现有的编码转化为其对应的下一位的编码\n",
"\n",
"但是一个很重要的点是其忽略了现有的编码中彼此之间的联系,\n",
"\n",
"如果可以利用好这份联系,使得每个字的编码可以**相互通信**,是从能产生更好的性能呢"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 8, 2])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 此时我们以一个真实的情况为例,通过随机生成一些数据来代表当前的真实情况\n",
"torch.manual_seed(42) # 设置固定的种子,使得结果可以及时的复现\n",
"B,T,C = 4,8,2 # batch, time, channels\n",
"x = torch.randn(B,T,C)\n",
"x.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 最简单的通信方式,第五个编码可以很简单地与收到前面四个编码的平均影响,虽然这种通信的方式听起来也十分地弱"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1.9269, 1.4873]])\n",
"这是上面的均值累加:\n",
"tensor([1.9269, 1.4873])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055]])\n",
"这是上面的均值累加:\n",
"tensor([ 1.4138, -0.3091])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055],\n",
" [ 0.6784, -1.2345]])\n",
"这是上面的均值累加:\n",
"tensor([ 1.1687, -0.6176])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055],\n",
" [ 0.6784, -1.2345],\n",
" [-0.0431, -1.6047]])\n",
"这是上面的均值累加:\n",
"tensor([ 0.8657, -0.8644])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055],\n",
" [ 0.6784, -1.2345],\n",
" [-0.0431, -1.6047],\n",
" [-0.7521, 1.6487]])\n",
"这是上面的均值累加:\n",
"tensor([ 0.5422, -0.3617])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055],\n",
" [ 0.6784, -1.2345],\n",
" [-0.0431, -1.6047],\n",
" [-0.7521, 1.6487],\n",
" [-0.3925, -1.4036]])\n",
"这是上面的均值累加:\n",
"tensor([ 0.3864, -0.5354])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055],\n",
" [ 0.6784, -1.2345],\n",
" [-0.0431, -1.6047],\n",
" [-0.7521, 1.6487],\n",
" [-0.3925, -1.4036],\n",
" [-0.7279, -0.5594]])\n",
"这是上面的均值累加:\n",
"tensor([ 0.2272, -0.5388])\n",
"----------------------\n",
"tensor([[ 1.9269, 1.4873],\n",
" [ 0.9007, -2.1055],\n",
" [ 0.6784, -1.2345],\n",
" [-0.0431, -1.6047],\n",
" [-0.7521, 1.6487],\n",
" [-0.3925, -1.4036],\n",
" [-0.7279, -0.5594],\n",
" [-0.7688, 0.7624]])\n",
"这是上面的均值累加:\n",
"tensor([ 0.1027, -0.3762])\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n",
"这是上面的均值累加:\n",
"----------------------\n"
]
}
],
"source": [
"# 第一种方式为了循环和聚合我们使用torch.mean函数进行操作\n",
"\n",
"xbow = torch.zeros((B,T,C))\n",
"\n",
"for b in range(B): # 遍历所有的batch\n",
" for t in range(T):\n",
" # 使用切片操作x[b,:t+1]来获取x在第b个批次中前t+1个时间步的所有元素得到一个形状为(t,C)的张量xprev。\n",
" xprev = x[b,:t+1] # (t,C)\n",
" # 使用torch.mean(xprev, 0)来计算xprev在第一个维度dim=0上的平均值。这个操作会返回一个形状为(C,)的张量它的每个元素是xprev在对应列上的元素的平均值。\n",
" xbow[b,t] = torch.mean(xprev, 0) \n",
" \n",
" # print(b) if b==0 else None\n",
" # print(t) if b==0 else None\n",
" print(xprev) if b==0 else None # 前面的累加\n",
" print('这是上面的均值累加:')\n",
" print(xbow[b,t]) if b==0 else None\n",
" print(\"----------------------\")\n",
"\n",
"# 虽然这样是可以行的,但是这里的实现方式明显可以进一步改进"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a=\n",
"tensor([[1.0000, 0.0000, 0.0000],\n",
" [0.5000, 0.5000, 0.0000],\n",
" [0.3333, 0.3333, 0.3333]])\n",
"--\n",
"b=\n",
"tensor([[0., 1.],\n",
" [3., 0.],\n",
" [1., 1.]])\n",
"--\n",
"c=\n",
"tensor([[0.0000, 1.0000],\n",
" [1.5000, 0.5000],\n",
" [1.3333, 0.6667]])\n"
]
}
],
"source": [
"# 方法2使用矩阵乘法\n",
"\n",
"a = torch.tril(torch.ones(3, 3)) # 创建一个下三角的矩阵的函数\n",
"# torch.sum 计算张量a在每一横行上的值\n",
"a = a / torch.sum(a, 1, keepdim=True) # 这一步相当于是在\n",
"b = torch.randint(0,10,(3,2)).float()\n",
"c = a @ b\n",
"print('a=')\n",
"print(a)\n",
"print('--')\n",
"print('b=')\n",
"print(b)\n",
"print('--')\n",
"print('c=')\n",
"print(c)\n",
"\n",
"# 在这个过程中,实现了字符根据权重得到最终的结果"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* 这个时候可以揭晓,我们所想要平均的一直是**权重**"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weight = torch.tril(torch.ones(T, T))\n",
"weight = weight / weight.sum(1, keepdim=True) # 直接是在第一个维度上进行加和\n",
"weight\n",
"# 而在这个例子中的b其实是x\n",
"xbow2 = weight @ x # (B,T,T) @ (B,T,C) ------> (B,T,C)\n",
"torch.allclose(xbow,xbow2) # 这个是用于检测两个张量是否在一定的容忍度内是相等的"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"结果为`True`,所以说明这样几行就解决了上面这个循环要做的事情"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* 然而,这里还有一种更为巧妙的方式可以实现"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# 第三种方式使用softmax\n",
"trils = torch.tril(torch.ones(T,T))\n",
"\n",
"weight = torch.zeros((T,T)) # 构造一个全为0的向量\n",
"weight = weight.masked_fill(trils == 0,float('-inf')) # 使所有tril为0的位置都变为无穷大\n",
"# 然后我们选择在每行的维度上去使用sotfmax\n",
"weight = F.softmax(weight,dim=-1)\n",
"\n",
"xbow3 = weight @ x\n",
"\n",
"torch.allclose(xbow,xbow3) # 这个是用于检测两个张量是否在一定的容忍度内是相等的\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 8, 16])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 第4种方式 : 自注意力方式\n",
"torch.manual_seed(1337)\n",
"B,T,C = 4,8,32 # batch, time, channels\n",
"x = torch.randn(B,T,C)\n",
"\n",
"# 一个简单的单头注意力机制示范\n",
"head_size = 16\n",
"key = nn.Linear(C, head_size, bias=False)\n",
"query = nn.Linear(C, head_size, bias=False)\n",
"value = nn.Linear(C, head_size, bias=False)\n",
"k = key(x) # (B, T, 16)\n",
"q = query(x) # (B, T, 16)\n",
"wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
"\n",
"tril = torch.tril(torch.ones(T, T))\n",
"wei = wei.masked_fill(tril == 0, float('-inf'))\n",
"wei = F.softmax(wei, dim=-1)\n",
"\n",
"v = value(x)\n",
"out = wei @ v\n",
"\n",
"out.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-1.5713e-01, 8.8009e-01, 1.6152e-01, -7.8239e-01, -1.4289e-01,\n",
" 7.4676e-01, 1.0068e-01, -5.2395e-01, -8.8726e-01, 1.9068e-01,\n",
" 1.7616e-01, -5.9426e-01, -4.8124e-01, -4.8598e-01, 2.8623e-01,\n",
" 5.7099e-01],\n",
" [ 6.7643e-01, -5.4770e-01, -2.4780e-01, 3.1430e-01, -1.2799e-01,\n",
" -2.9521e-01, -4.2962e-01, -1.0891e-01, -4.9282e-02, 7.2679e-01,\n",
" 7.1296e-01, -1.1639e-01, 3.2665e-01, 3.4315e-01, -7.0975e-02,\n",
" 1.2716e+00],\n",
" [ 4.8227e-01, -1.0688e-01, -4.0555e-01, 1.7696e-01, 1.5811e-01,\n",
" -1.6967e-01, 1.6217e-02, 2.1509e-02, -2.4903e-01, -3.7725e-01,\n",
" 2.7867e-01, 1.6295e-01, -2.8951e-01, -6.7610e-02, -1.4162e-01,\n",
" 1.2194e+00],\n",
" [ 1.9708e-01, 2.8561e-01, -1.3028e-01, -2.6552e-01, 6.6781e-02,\n",
" 1.9535e-01, 2.8073e-02, -2.4511e-01, -4.6466e-01, 6.9287e-02,\n",
" 1.5284e-01, -2.0324e-01, -2.4789e-01, -1.6213e-01, 1.9474e-01,\n",
" 7.6778e-01],\n",
" [ 2.5104e-01, 7.3457e-01, 5.9385e-01, 2.5159e-01, 2.6064e-01,\n",
" 7.5820e-01, 5.5947e-01, 3.5387e-01, -5.9338e-01, -1.0807e+00,\n",
" -3.1110e-01, -2.7809e-01, -9.0541e-01, 1.3181e-01, -1.3818e-01,\n",
" 6.3715e-01],\n",
" [ 3.4277e-01, 4.9605e-01, 4.7248e-01, 3.0277e-01, 1.8440e-01,\n",
" 5.8144e-01, 3.8245e-01, 2.9521e-01, -4.8969e-01, -7.7051e-01,\n",
" -1.1721e-01, -2.5412e-01, -6.8921e-01, 1.9795e-01, -1.5135e-01,\n",
" 7.6659e-01],\n",
" [ 1.8658e-01, -9.6351e-02, -1.4300e-01, 3.0587e-01, 8.3441e-02,\n",
" -6.8646e-03, -2.0472e-01, -1.5350e-01, -7.6250e-02, 3.2689e-01,\n",
" 3.0896e-01, 7.6626e-02, 9.9243e-02, 1.6560e-01, 1.9745e-01,\n",
" 7.6248e-01],\n",
" [ 1.3013e-01, -3.2832e-02, -4.9645e-01, 2.8652e-01, 2.7042e-01,\n",
" -2.6357e-01, -7.3756e-02, 3.7857e-01, 7.4580e-02, 3.3827e-02,\n",
" 1.4695e-02, 3.1937e-01, 2.9926e-01, -1.6530e-01, -3.8630e-02,\n",
" 3.3748e-01]],\n",
"\n",
" [[-1.3254e+00, 1.1236e+00, 2.2927e-01, -2.9970e-01, -7.6267e-03,\n",
" 7.9364e-01, 8.9581e-01, 3.9650e-01, -6.6613e-01, -2.1844e-01,\n",
" -1.3539e+00, 4.1245e-01, 9.6011e-01, -1.0805e+00, -3.9751e-01,\n",
" -4.4439e-01],\n",
" [-3.8338e-01, -1.9659e-01, 8.8455e-02, 1.8560e-01, -8.7010e-02,\n",
" 1.3239e-01, 3.0841e-01, -2.4350e-01, -1.9396e-01, -1.7634e-02,\n",
" 4.8439e-01, 5.4210e-01, -2.0407e-02, -4.2467e-01, -2.3463e-01,\n",
" -4.6465e-01],\n",
" [-1.1100e+00, 3.2334e-01, 4.7054e-01, -6.3595e-02, 2.5443e-01,\n",
" 1.5352e-01, 2.5186e-01, 2.6286e-01, 2.7916e-01, -3.1662e-03,\n",
" -3.2881e-02, 4.8191e-01, 7.4431e-01, -1.9921e-01, 2.7134e-01,\n",
" -8.5871e-02],\n",
" [-9.7190e-01, 4.6124e-01, 4.2349e-01, -1.7230e-02, 1.5847e-01,\n",
" 4.1175e-01, 4.0764e-01, 2.4982e-01, -5.0322e-02, 4.1514e-03,\n",
" -3.9853e-01, 4.3551e-01, 7.0285e-01, -4.3081e-01, 2.6684e-02,\n",
" -2.0169e-01],\n",
" [ 3.3586e-01, -8.5915e-02, 9.3660e-01, 7.7311e-01, 1.8037e-01,\n",
" 8.2853e-01, -6.9183e-02, 2.8814e-01, 1.1734e-01, 6.8448e-01,\n",
" -5.8500e-02, 1.2726e-01, 2.9780e-01, 1.9324e-01, 1.5655e-01,\n",
" -9.3004e-03],\n",
" [ 1.6984e-01, 3.0993e-02, 8.1557e-01, 6.1679e-01, 1.0429e-01,\n",
" 7.4573e-01, 2.3072e-02, 3.0572e-01, 5.8163e-02, 5.7122e-01,\n",
" -4.5275e-02, 1.5051e-01, 3.2901e-01, 5.6984e-02, 1.0311e-01,\n",
" -9.9174e-02],\n",
" [ 4.6496e-02, 1.5765e-01, 3.9760e-01, 1.7619e-01, -2.1168e-01,\n",
" 2.3365e-01, -6.2083e-02, 2.1726e-01, -7.8725e-03, 4.5389e-01,\n",
" 3.4349e-01, -5.5631e-02, 3.3726e-01, -3.7591e-01, -1.0140e-02,\n",
" -4.5806e-01],\n",
" [-5.3896e-01, 7.5555e-01, 3.3034e-01, -1.5849e-01, -2.6740e-01,\n",
" 4.3495e-01, 3.7772e-01, 5.5794e-01, -1.8369e-01, 1.5938e-01,\n",
" -2.1042e-01, 5.5790e-02, 6.3184e-01, -6.4884e-01, -9.6084e-02,\n",
" -5.0751e-01]],\n",
"\n",
" [[ 6.8925e-02, 1.2248e+00, -4.1194e-01, -1.7046e-01, -6.9224e-01,\n",
" -2.9201e-01, 1.2704e+00, -6.8596e-01, 4.3798e-01, -2.6366e-01,\n",
" 1.1528e-01, 1.1676e+00, -7.2138e-01, -1.2308e+00, 8.3821e-01,\n",
" -5.5987e-01],\n",
" [-4.6375e-01, 6.3807e-01, -1.5842e-01, -1.3309e-01, -5.9402e-01,\n",
" -5.0374e-01, 2.3289e-01, -3.2126e-01, 4.5781e-01, -1.8590e-01,\n",
" 1.9215e-01, 3.7566e-01, -3.5905e-01, -7.7262e-01, 3.5036e-01,\n",
" 6.9694e-02],\n",
" [-6.4044e-01, 1.3831e-01, -6.1007e-02, -1.1112e-01, -4.5228e-01,\n",
" -6.2271e-01, -1.7030e-01, -2.4949e-01, 5.0670e-01, -9.6444e-02,\n",
" 4.8315e-01, 9.4986e-02, -2.9810e-01, -3.6538e-01, 3.9458e-01,\n",
" 4.1512e-01],\n",
" [-6.7193e-01, 1.2516e-01, 7.3386e-02, -1.3198e-01, -1.7880e-01,\n",
" -5.6740e-01, -6.8226e-01, 5.0844e-02, 3.3051e-01, 7.8242e-02,\n",
" 6.8022e-02, -2.4041e-01, -6.6864e-02, -1.8411e-01, -5.3514e-02,\n",
" 4.5113e-01],\n",
" [-1.4270e-02, 1.0195e+00, -3.4792e-01, -1.6421e-01, -5.5846e-01,\n",
" -3.2457e-01, 9.9404e-01, -5.6891e-01, 4.0097e-01, -1.8123e-01,\n",
" 1.1856e-01, 9.8704e-01, -6.4057e-01, -1.0320e+00, 7.3320e-01,\n",
" -4.3167e-01],\n",
" [-6.3858e-01, -7.6533e-02, -3.6510e-01, 1.7782e-01, -6.5426e-02,\n",
" -3.5158e-01, 7.9591e-02, 1.7384e-01, 3.6676e-01, -4.2302e-02,\n",
" 2.4923e-01, 4.8239e-01, -2.1295e-01, -2.9492e-01, 3.4749e-01,\n",
" -1.7111e-01],\n",
" [-2.2366e-01, -5.5317e-02, -1.8296e-01, 2.4258e-01, 2.5357e-01,\n",
" -1.6154e-01, -2.3908e-01, 3.3243e-01, 1.0304e-01, 2.6067e-01,\n",
" -5.0670e-02, 3.6947e-01, -4.9856e-02, 1.1197e-01, 1.1752e-01,\n",
" -2.5078e-01],\n",
" [-2.4821e-01, 1.4845e-01, -3.5033e-01, 1.7102e-01, 1.6613e-01,\n",
" -2.0643e-01, 8.6633e-02, 8.8414e-02, 2.1188e-01, 2.5805e-01,\n",
" 5.5145e-02, 4.2668e-01, -2.0443e-01, -1.7372e-01, 3.8899e-01,\n",
" 5.1725e-02]],\n",
"\n",
" [[ 9.7183e-02, 5.7301e-02, -1.0468e-01, -4.6654e-02, -1.4006e-01,\n",
" -8.4126e-01, -1.3625e-01, -6.7465e-01, -2.1541e-01, 1.0993e+00,\n",
" 2.3427e-01, 3.2605e-02, -1.8521e-01, 1.4780e-01, -6.1045e-01,\n",
" 1.5391e+00],\n",
" [ 1.9305e-01, -2.1031e-01, -3.4658e-01, 2.0567e-01, -1.7798e-01,\n",
" -7.4604e-01, -6.4427e-01, -6.9183e-01, -2.0558e-01, 7.0413e-01,\n",
" 2.3632e-01, 9.8797e-04, -1.7015e-01, 1.1203e-01, -7.1064e-01,\n",
" 1.2431e+00],\n",
" [ 2.9114e-01, -4.8343e-01, -5.9254e-01, 4.6477e-01, -2.1832e-01,\n",
" -6.4460e-01, -1.1627e+00, -7.0993e-01, -1.9703e-01, 2.9262e-01,\n",
" 2.3669e-01, -3.1050e-02, -1.5471e-01, 7.7153e-02, -8.1137e-01,\n",
" 9.3578e-01],\n",
" [ 1.7549e-01, -3.4260e-02, -2.0523e-01, 2.7644e-02, -2.1312e-01,\n",
" -5.6022e-01, -3.5273e-01, -6.2722e-01, -3.0037e-01, 4.6061e-01,\n",
" 1.5004e-01, 1.9040e-02, -1.4646e-01, 1.7220e-01, -6.2559e-01,\n",
" 1.0722e+00],\n",
" [ 1.7354e-01, -1.7962e-01, -2.7874e-01, -1.0590e-01, -1.2952e-01,\n",
" -3.5086e-01, -5.5830e-01, -3.8638e-01, -2.9719e-01, 3.3368e-02,\n",
" 1.7392e-01, 5.5898e-02, -7.2007e-02, 1.3182e-02, -6.6710e-01,\n",
" 5.4229e-01],\n",
" [ 2.4678e-01, -4.7274e-01, -5.2827e-01, 3.1212e-01, -1.7528e-01,\n",
" -4.8636e-01, -1.1223e+00, -5.4196e-01, -2.0142e-01, 4.0103e-02,\n",
" 2.2231e-01, -2.9381e-02, -9.4354e-02, 2.6374e-02, -7.8726e-01,\n",
" 6.2836e-01],\n",
" [-3.9784e-01, 2.5915e-01, 5.0358e-01, -4.6864e-01, -2.2024e-02,\n",
" -3.2242e-01, -1.2578e-01, 1.0634e-01, 1.3618e-01, 1.7780e-01,\n",
" 1.0391e-01, -6.2540e-01, 3.8904e-01, 3.3690e-01, -5.5140e-01,\n",
" 5.2246e-01],\n",
" [-3.5927e-01, 3.3935e-02, -2.9863e-02, -1.5019e-01, -6.0354e-03,\n",
" -6.5733e-02, -3.9659e-01, -6.0435e-02, -5.7551e-01, -2.9157e-01,\n",
" 1.4899e-01, -7.5002e-02, 7.3228e-02, -4.7413e-02, -6.4394e-01,\n",
" 2.8560e-01]]], grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}