This commit is contained in:
明硕 2024-05-17 14:49:49 +08:00
parent 917bf2c2d6
commit 1a1482f0f8
1 changed files with 186 additions and 25 deletions

View File

@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -34,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@ -43,7 +43,7 @@
"torch.Size([4, 8, 2])"
]
},
"execution_count": 17,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -207,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -260,7 +260,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -269,7 +269,7 @@
"True"
]
},
"execution_count": 24,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -299,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -308,7 +308,7 @@
"True"
]
},
"execution_count": 33,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -330,38 +330,199 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],\n",
" [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],\n",
" [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],\n",
" [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])"
"torch.Size([4, 8, 16])"
]
},
"execution_count": 32,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
<<<<<<< HEAD
"source": []
=======
"source": [
"# version 4: 自注意力的实现\n",
"# 第4种方式 : 自注意力方式\n",
"torch.manual_seed(1337)\n",
"B,T,C = 4,8,32 # batch, time, channels\n",
"x = torch.randn(B,T,C)\n",
"\n"
"\n",
"# 一个简单的单头注意力机制示范\n",
"head_size = 16\n",
"key = nn.Linear(C, head_size, bias=False)\n",
"query = nn.Linear(C, head_size, bias=False)\n",
"value = nn.Linear(C, head_size, bias=False)\n",
"k = key(x) # (B, T, 16)\n",
"q = query(x) # (B, T, 16)\n",
"wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
"\n",
"tril = torch.tril(torch.ones(T, T))\n",
"wei = wei.masked_fill(tril == 0, float('-inf'))\n",
"wei = F.softmax(wei, dim=-1)\n",
"\n",
"v = value(x)\n",
"out = wei @ v\n",
"\n",
"out.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-1.5713e-01, 8.8009e-01, 1.6152e-01, -7.8239e-01, -1.4289e-01,\n",
" 7.4676e-01, 1.0068e-01, -5.2395e-01, -8.8726e-01, 1.9068e-01,\n",
" 1.7616e-01, -5.9426e-01, -4.8124e-01, -4.8598e-01, 2.8623e-01,\n",
" 5.7099e-01],\n",
" [ 6.7643e-01, -5.4770e-01, -2.4780e-01, 3.1430e-01, -1.2799e-01,\n",
" -2.9521e-01, -4.2962e-01, -1.0891e-01, -4.9282e-02, 7.2679e-01,\n",
" 7.1296e-01, -1.1639e-01, 3.2665e-01, 3.4315e-01, -7.0975e-02,\n",
" 1.2716e+00],\n",
" [ 4.8227e-01, -1.0688e-01, -4.0555e-01, 1.7696e-01, 1.5811e-01,\n",
" -1.6967e-01, 1.6217e-02, 2.1509e-02, -2.4903e-01, -3.7725e-01,\n",
" 2.7867e-01, 1.6295e-01, -2.8951e-01, -6.7610e-02, -1.4162e-01,\n",
" 1.2194e+00],\n",
" [ 1.9708e-01, 2.8561e-01, -1.3028e-01, -2.6552e-01, 6.6781e-02,\n",
" 1.9535e-01, 2.8073e-02, -2.4511e-01, -4.6466e-01, 6.9287e-02,\n",
" 1.5284e-01, -2.0324e-01, -2.4789e-01, -1.6213e-01, 1.9474e-01,\n",
" 7.6778e-01],\n",
" [ 2.5104e-01, 7.3457e-01, 5.9385e-01, 2.5159e-01, 2.6064e-01,\n",
" 7.5820e-01, 5.5947e-01, 3.5387e-01, -5.9338e-01, -1.0807e+00,\n",
" -3.1110e-01, -2.7809e-01, -9.0541e-01, 1.3181e-01, -1.3818e-01,\n",
" 6.3715e-01],\n",
" [ 3.4277e-01, 4.9605e-01, 4.7248e-01, 3.0277e-01, 1.8440e-01,\n",
" 5.8144e-01, 3.8245e-01, 2.9521e-01, -4.8969e-01, -7.7051e-01,\n",
" -1.1721e-01, -2.5412e-01, -6.8921e-01, 1.9795e-01, -1.5135e-01,\n",
" 7.6659e-01],\n",
" [ 1.8658e-01, -9.6351e-02, -1.4300e-01, 3.0587e-01, 8.3441e-02,\n",
" -6.8646e-03, -2.0472e-01, -1.5350e-01, -7.6250e-02, 3.2689e-01,\n",
" 3.0896e-01, 7.6626e-02, 9.9243e-02, 1.6560e-01, 1.9745e-01,\n",
" 7.6248e-01],\n",
" [ 1.3013e-01, -3.2832e-02, -4.9645e-01, 2.8652e-01, 2.7042e-01,\n",
" -2.6357e-01, -7.3756e-02, 3.7857e-01, 7.4580e-02, 3.3827e-02,\n",
" 1.4695e-02, 3.1937e-01, 2.9926e-01, -1.6530e-01, -3.8630e-02,\n",
" 3.3748e-01]],\n",
"\n",
" [[-1.3254e+00, 1.1236e+00, 2.2927e-01, -2.9970e-01, -7.6267e-03,\n",
" 7.9364e-01, 8.9581e-01, 3.9650e-01, -6.6613e-01, -2.1844e-01,\n",
" -1.3539e+00, 4.1245e-01, 9.6011e-01, -1.0805e+00, -3.9751e-01,\n",
" -4.4439e-01],\n",
" [-3.8338e-01, -1.9659e-01, 8.8455e-02, 1.8560e-01, -8.7010e-02,\n",
" 1.3239e-01, 3.0841e-01, -2.4350e-01, -1.9396e-01, -1.7634e-02,\n",
" 4.8439e-01, 5.4210e-01, -2.0407e-02, -4.2467e-01, -2.3463e-01,\n",
" -4.6465e-01],\n",
" [-1.1100e+00, 3.2334e-01, 4.7054e-01, -6.3595e-02, 2.5443e-01,\n",
" 1.5352e-01, 2.5186e-01, 2.6286e-01, 2.7916e-01, -3.1662e-03,\n",
" -3.2881e-02, 4.8191e-01, 7.4431e-01, -1.9921e-01, 2.7134e-01,\n",
" -8.5871e-02],\n",
" [-9.7190e-01, 4.6124e-01, 4.2349e-01, -1.7230e-02, 1.5847e-01,\n",
" 4.1175e-01, 4.0764e-01, 2.4982e-01, -5.0322e-02, 4.1514e-03,\n",
" -3.9853e-01, 4.3551e-01, 7.0285e-01, -4.3081e-01, 2.6684e-02,\n",
" -2.0169e-01],\n",
" [ 3.3586e-01, -8.5915e-02, 9.3660e-01, 7.7311e-01, 1.8037e-01,\n",
" 8.2853e-01, -6.9183e-02, 2.8814e-01, 1.1734e-01, 6.8448e-01,\n",
" -5.8500e-02, 1.2726e-01, 2.9780e-01, 1.9324e-01, 1.5655e-01,\n",
" -9.3004e-03],\n",
" [ 1.6984e-01, 3.0993e-02, 8.1557e-01, 6.1679e-01, 1.0429e-01,\n",
" 7.4573e-01, 2.3072e-02, 3.0572e-01, 5.8163e-02, 5.7122e-01,\n",
" -4.5275e-02, 1.5051e-01, 3.2901e-01, 5.6984e-02, 1.0311e-01,\n",
" -9.9174e-02],\n",
" [ 4.6496e-02, 1.5765e-01, 3.9760e-01, 1.7619e-01, -2.1168e-01,\n",
" 2.3365e-01, -6.2083e-02, 2.1726e-01, -7.8725e-03, 4.5389e-01,\n",
" 3.4349e-01, -5.5631e-02, 3.3726e-01, -3.7591e-01, -1.0140e-02,\n",
" -4.5806e-01],\n",
" [-5.3896e-01, 7.5555e-01, 3.3034e-01, -1.5849e-01, -2.6740e-01,\n",
" 4.3495e-01, 3.7772e-01, 5.5794e-01, -1.8369e-01, 1.5938e-01,\n",
" -2.1042e-01, 5.5790e-02, 6.3184e-01, -6.4884e-01, -9.6084e-02,\n",
" -5.0751e-01]],\n",
"\n",
" [[ 6.8925e-02, 1.2248e+00, -4.1194e-01, -1.7046e-01, -6.9224e-01,\n",
" -2.9201e-01, 1.2704e+00, -6.8596e-01, 4.3798e-01, -2.6366e-01,\n",
" 1.1528e-01, 1.1676e+00, -7.2138e-01, -1.2308e+00, 8.3821e-01,\n",
" -5.5987e-01],\n",
" [-4.6375e-01, 6.3807e-01, -1.5842e-01, -1.3309e-01, -5.9402e-01,\n",
" -5.0374e-01, 2.3289e-01, -3.2126e-01, 4.5781e-01, -1.8590e-01,\n",
" 1.9215e-01, 3.7566e-01, -3.5905e-01, -7.7262e-01, 3.5036e-01,\n",
" 6.9694e-02],\n",
" [-6.4044e-01, 1.3831e-01, -6.1007e-02, -1.1112e-01, -4.5228e-01,\n",
" -6.2271e-01, -1.7030e-01, -2.4949e-01, 5.0670e-01, -9.6444e-02,\n",
" 4.8315e-01, 9.4986e-02, -2.9810e-01, -3.6538e-01, 3.9458e-01,\n",
" 4.1512e-01],\n",
" [-6.7193e-01, 1.2516e-01, 7.3386e-02, -1.3198e-01, -1.7880e-01,\n",
" -5.6740e-01, -6.8226e-01, 5.0844e-02, 3.3051e-01, 7.8242e-02,\n",
" 6.8022e-02, -2.4041e-01, -6.6864e-02, -1.8411e-01, -5.3514e-02,\n",
" 4.5113e-01],\n",
" [-1.4270e-02, 1.0195e+00, -3.4792e-01, -1.6421e-01, -5.5846e-01,\n",
" -3.2457e-01, 9.9404e-01, -5.6891e-01, 4.0097e-01, -1.8123e-01,\n",
" 1.1856e-01, 9.8704e-01, -6.4057e-01, -1.0320e+00, 7.3320e-01,\n",
" -4.3167e-01],\n",
" [-6.3858e-01, -7.6533e-02, -3.6510e-01, 1.7782e-01, -6.5426e-02,\n",
" -3.5158e-01, 7.9591e-02, 1.7384e-01, 3.6676e-01, -4.2302e-02,\n",
" 2.4923e-01, 4.8239e-01, -2.1295e-01, -2.9492e-01, 3.4749e-01,\n",
" -1.7111e-01],\n",
" [-2.2366e-01, -5.5317e-02, -1.8296e-01, 2.4258e-01, 2.5357e-01,\n",
" -1.6154e-01, -2.3908e-01, 3.3243e-01, 1.0304e-01, 2.6067e-01,\n",
" -5.0670e-02, 3.6947e-01, -4.9856e-02, 1.1197e-01, 1.1752e-01,\n",
" -2.5078e-01],\n",
" [-2.4821e-01, 1.4845e-01, -3.5033e-01, 1.7102e-01, 1.6613e-01,\n",
" -2.0643e-01, 8.6633e-02, 8.8414e-02, 2.1188e-01, 2.5805e-01,\n",
" 5.5145e-02, 4.2668e-01, -2.0443e-01, -1.7372e-01, 3.8899e-01,\n",
" 5.1725e-02]],\n",
"\n",
" [[ 9.7183e-02, 5.7301e-02, -1.0468e-01, -4.6654e-02, -1.4006e-01,\n",
" -8.4126e-01, -1.3625e-01, -6.7465e-01, -2.1541e-01, 1.0993e+00,\n",
" 2.3427e-01, 3.2605e-02, -1.8521e-01, 1.4780e-01, -6.1045e-01,\n",
" 1.5391e+00],\n",
" [ 1.9305e-01, -2.1031e-01, -3.4658e-01, 2.0567e-01, -1.7798e-01,\n",
" -7.4604e-01, -6.4427e-01, -6.9183e-01, -2.0558e-01, 7.0413e-01,\n",
" 2.3632e-01, 9.8797e-04, -1.7015e-01, 1.1203e-01, -7.1064e-01,\n",
" 1.2431e+00],\n",
" [ 2.9114e-01, -4.8343e-01, -5.9254e-01, 4.6477e-01, -2.1832e-01,\n",
" -6.4460e-01, -1.1627e+00, -7.0993e-01, -1.9703e-01, 2.9262e-01,\n",
" 2.3669e-01, -3.1050e-02, -1.5471e-01, 7.7153e-02, -8.1137e-01,\n",
" 9.3578e-01],\n",
" [ 1.7549e-01, -3.4260e-02, -2.0523e-01, 2.7644e-02, -2.1312e-01,\n",
" -5.6022e-01, -3.5273e-01, -6.2722e-01, -3.0037e-01, 4.6061e-01,\n",
" 1.5004e-01, 1.9040e-02, -1.4646e-01, 1.7220e-01, -6.2559e-01,\n",
" 1.0722e+00],\n",
" [ 1.7354e-01, -1.7962e-01, -2.7874e-01, -1.0590e-01, -1.2952e-01,\n",
" -3.5086e-01, -5.5830e-01, -3.8638e-01, -2.9719e-01, 3.3368e-02,\n",
" 1.7392e-01, 5.5898e-02, -7.2007e-02, 1.3182e-02, -6.6710e-01,\n",
" 5.4229e-01],\n",
" [ 2.4678e-01, -4.7274e-01, -5.2827e-01, 3.1212e-01, -1.7528e-01,\n",
" -4.8636e-01, -1.1223e+00, -5.4196e-01, -2.0142e-01, 4.0103e-02,\n",
" 2.2231e-01, -2.9381e-02, -9.4354e-02, 2.6374e-02, -7.8726e-01,\n",
" 6.2836e-01],\n",
" [-3.9784e-01, 2.5915e-01, 5.0358e-01, -4.6864e-01, -2.2024e-02,\n",
" -3.2242e-01, -1.2578e-01, 1.0634e-01, 1.3618e-01, 1.7780e-01,\n",
" 1.0391e-01, -6.2540e-01, 3.8904e-01, 3.3690e-01, -5.5140e-01,\n",
" 5.2246e-01],\n",
" [-3.5927e-01, 3.3935e-02, -2.9863e-02, -1.5019e-01, -6.0354e-03,\n",
" -6.5733e-02, -3.9659e-01, -6.0435e-02, -5.7551e-01, -2.9157e-01,\n",
" 1.4899e-01, -7.5002e-02, 7.3228e-02, -4.7413e-02, -6.4394e-01,\n",
" 2.8560e-01]]], grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out[0]"
]
>>>>>>> 150e7f178e9b465dc452db5a6b0e87bedbaf896a
},
{
"cell_type": "code",