Upload HELPER_NEW.ipynb
Browse files- HELPER_NEW.ipynb +348 -0
HELPER_NEW.ipynb
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"#### pickle file checking for AUPRC random lead"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": 6,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [
|
| 15 |
+
{
|
| 16 |
+
"name": "stdout",
|
| 17 |
+
"output_type": "stream",
|
| 18 |
+
"text": [
|
| 19 |
+
"<class 'list'> 50\n",
|
| 20 |
+
"<class 'dict'>\n",
|
| 21 |
+
"----------------------\n",
|
| 22 |
+
"epoch: \n",
|
| 23 |
+
"model: \n",
|
| 24 |
+
"train_auprc: \n",
|
| 25 |
+
"valid_auprc: \n",
|
| 26 |
+
"valid_targets: \n",
|
| 27 |
+
"valid_outputs: \n",
|
| 28 |
+
"-----------------------\n",
|
| 29 |
+
"-----------------------\n",
|
| 30 |
+
"[0.20795198881124255, 0.2924131615408049, 0.31194815399388126, 0.357671229080611, 0.3907590012977773, 0.39197022751675975, 0.39688932315376796, 0.41098642756821824, 0.4280303875603716, 0.4251116328825386, 0.41492397254078656, 0.44119503399957305, 0.42866565608661766, 0.42155615910506705, 0.4352771610735857, 0.4355309812927433, 0.4575302940022513, 0.4621060999031488, 0.4615244295921646, 0.4347042141353311, 0.4843673460502776, 0.49216570578173724, 0.49284316077316226, 0.4976730562122618, 0.4981241668777771, 0.4985906269863735, 0.5023674118168958, 0.5039947051779108, 0.5025596400291938, 0.501332454384853, 0.5017141509761979, 0.5033696471830942, 0.5035807094153067, 0.5044712423289812, 0.49912591150498187, 0.5036493639939076, 0.5073756144905568, 0.5066738446153692, 0.5041024684427422, 0.5061074251973712, 0.5079663458037375, 0.5080434717076571, 0.5071731389137064, 0.5066158069067092, 0.5059333249321385, 0.5078252460128987, 0.5081895157894929, 0.5079278975582764, 0.5073543066159428, 0.5078677916025073]\n",
|
| 31 |
+
"0.5081895157894929 46\n"
|
| 32 |
+
]
|
| 33 |
+
}
|
| 34 |
+
],
|
| 35 |
+
"source": [
|
| 36 |
+
"import pickle\n",
|
| 37 |
+
"import torch\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"address = \"./model_output/model_group5/PROGRESS.pickle\"\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"with open(address, 'rb') as file:\n",
|
| 42 |
+
" data = pickle.load(file)\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"print(type(data), len(data))\n",
|
| 45 |
+
"# print(data[0])\n",
|
| 46 |
+
"print(type(data[1]))\n",
|
| 47 |
+
"print(\"----------------------\")\n",
|
| 48 |
+
"for key, _ in data[1].items():\n",
|
| 49 |
+
" print(f\"{key}: \")\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"print(\"-----------------------\")\n",
|
| 52 |
+
"AUPRC_list = []\n",
|
| 53 |
+
"for i in range(len(data)):\n",
|
| 54 |
+
" AUPRC_list.append(data[i][\"valid_auprc\"])\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"print(\"-----------------------\") \n",
|
| 57 |
+
"print(AUPRC_list)\n",
|
| 58 |
+
"(\"-----------------------\")\n",
|
| 59 |
+
"largest_number = max(AUPRC_list)\n",
|
| 60 |
+
"index = AUPRC_list.index(largest_number)\n",
|
| 61 |
+
"print(largest_number, index)"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "markdown",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"source": [
|
| 68 |
+
"group#1\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"[0.24092945522182005, 0.3139675367502194, 0.3062163369752217, 0.32297163568130305, 0.3672050308180419, 0.3801609216698969, 0.3915211363523951, 0.4034875773118736, 0.41721359538446234, 0.41755420607909477, 0.4101699028342543, 0.42683222688245664, 0.4338339272938271, 0.4432706404963518, 0.4451886249025738, 0.4436839678451211, 0.46470292201596697, 0.4619959382638624, 0.4389299870874322, 0.4537386141609928, 0.4880276013143086, 0.48964141469390005, 0.49214694908533474, 0.49336784163926267, 0.4978899412041259, 0.4960868620495151, 0.4949812567178974, 0.49875221067947606, 0.4959535547710648, 0.49723019893878023, 0.49849758106937503, 0.5005045769636993, 0.4968324354226746, 0.4985954057932132, 0.4985684464062525, 0.4948398218890804, 0.5003443438290083, 0.49804674478254773, 0.5015115944170082, 0.5043099513157541, 0.5022930844045073, 0.502102123403741, 0.5025587387783707, 0.5026322695878688, 0.5028108420912678, 0.501853319716798, 0.5044486284061104, 0.5043333679462079, 0.503047975296802, 0.5021477867974229]\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"0.5044486284061104, index: 46\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"group #2\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"[0.24668762844932296, 0.31123092790061574, 0.35728718371921886, 0.37858993755415526, 0.38325613445804607, 0.38183540019756823, 0.40688905625255206, 0.4050292403852287, 0.4103841963804383, 0.4288343036036706, 0.4293594683280219, 0.44373329349811874, 0.44694196761428867, 0.44516332505161516, 0.4570591656299683, 0.44925142278910385, 0.45783436251651694, 0.4512008966459152, 0.4628860929136446, 0.46190128250293605, 0.4891415053038087, 0.4933325648723347, 0.49795793473520533, 0.4989478549566136, 0.507199717375493, 0.5031777644234027, 0.5048360591023886, 0.5026344145441939, 0.5070084702134143, 0.50851780828997, 0.5013767142024679, 0.5077028354409389, 0.5073222030725629, 0.5103865617070087, 0.5070321372047399, 0.5069057373554984, 0.5054984338086199, 0.5052088211513525, 0.5085875776438461, 0.5015018579996042, 0.507983738986951, 0.506001318616706, 0.5078548999343991, 0.5084694227173217, 0.5081644743764611, 0.5070537320211395, 0.5072728550164887, 0.5084469401746737, 0.5081580384861908, 0.5092361778552277]\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"0.5103865617070087, index: 33\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"group #3\n",
|
| 81 |
+
"[0.20546938178065813, 0.31056285598824596, 0.3521164077944065, 0.36566363279169545, 0.3649970330628938, 0.3816742095036071, 0.408841252427171, 0.4192963362391232, 0.419725128897165, 0.4009845215509139, 0.4221866024862177, 0.4383579336817017, 0.41634488480301257, 0.4394011015343916, 0.42674958918677536, 0.4484833626141604, 0.43733868299572076, 0.42813204282903494, 0.44362467579095183, 0.4525213211300688, 0.47993303563958817, 0.48221178536835363, 0.4832912567732829, 0.485964752652683, 0.4894140885779246, 0.49081305081555826, 0.4835906970652839, 0.4881328848995447, 0.49108874994886303, 0.49205732309554323, 0.4918174541861535, 0.49104602501641953, 0.49033495002806987, 0.49255438103140303, 0.4982302563540638, 0.4919847023325378, 0.49138268849817107, 0.49216471663752714, 0.49367968532436873, 0.49558690171904884, 0.4952242601993453, 0.49709259551176815, 0.4969043181087201, 0.49722348299821856, 0.49599951407363857, 0.49572421827303714, 0.49551046935516674, 0.4969339282495756, 0.49522481850002315, 0.4956301125397299]\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"0.4982302563540638, index: 34\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"group #4\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"[0.16705442847351432, 0.2811237847091236, 0.3227277423619332, 0.3459164670019608, 0.3433205542817934, 0.38953865811323535, 0.40093825754134493, 0.4042482476980622, 0.4179255247142833, 0.42026119275049384, 0.415263850960453, 0.4326573070148512, 0.4284856196846552, 0.455811861263988, 0.44742754829379755, 0.4428520431746461, 0.4288860834282809, 0.43801462440444205, 0.441347802107846, 0.4560878428908129, 0.47952984096244766, 0.4859939647185739, 0.48291741623601653, 0.4863560035613435, 0.4879069301596515, 0.49283878286572264, 0.4925634321692941, 0.49296767067476266, 0.4925321693215088, 0.4930295366233496, 0.4927986378984127, 0.49612537918838245, 0.4992350455119594, 0.4951830005033058, 0.49014993853897326, 0.4924448141210762, 0.4945801109607605, 0.4971188401719394, 0.49753234729288465, 0.49315691206981155, 0.4963229926370793, 0.49660539254449804, 0.49752930191373473, 0.4983978705842285, 0.498218560630721, 0.49778016282127696, 0.4980937334749714, 0.4982398417549309, 0.49825272820647715, 0.4978916971990578]\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"0.4992350455119594, index: 32\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"group #5\n",
|
| 92 |
+
"[0.20795198881124255, 0.2924131615408049, 0.31194815399388126, 0.357671229080611, 0.3907590012977773, 0.39197022751675975, 0.39688932315376796, 0.41098642756821824, 0.4280303875603716, 0.4251116328825386, 0.41492397254078656, 0.44119503399957305, 0.42866565608661766, 0.42155615910506705, 0.4352771610735857, 0.4355309812927433, 0.4575302940022513, 0.4621060999031488, 0.4615244295921646, 0.4347042141353311, 0.4843673460502776, 0.49216570578173724, 0.49284316077316226, 0.4976730562122618, 0.4981241668777771, 0.4985906269863735, 0.5023674118168958, 0.5039947051779108, 0.5025596400291938, 0.501332454384853, 0.5017141509761979, 0.5033696471830942, 0.5035807094153067, 0.5044712423289812, 0.49912591150498187, 0.5036493639939076, 0.5073756144905568, 0.5066738446153692, 0.5041024684427422, 0.5061074251973712, 0.5079663458037375, 0.5080434717076571, 0.5071731389137064, 0.5066158069067092, 0.5059333249321385, 0.5078252460128987, 0.5081895157894929, 0.5079278975582764, 0.5073543066159428, 0.5078677916025073]\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"0.5081895157894929, index: 46\n",
|
| 95 |
+
"\n"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": 1,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [
|
| 103 |
+
{
|
| 104 |
+
"name": "stdout",
|
| 105 |
+
"output_type": "stream",
|
| 106 |
+
"text": [
|
| 107 |
+
"A0004.hea\n"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"name": "stderr",
|
| 112 |
+
"output_type": "stream",
|
| 113 |
+
"text": [
|
| 114 |
+
"100%|██████████| 17651/17651 [00:02<00:00, 6737.50it/s]\n"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"name": "stdout",
|
| 119 |
+
"output_type": "stream",
|
| 120 |
+
"text": [
|
| 121 |
+
"164889003 1051.0\n",
|
| 122 |
+
"164890007 1675.0\n",
|
| 123 |
+
"6374002 103.0\n",
|
| 124 |
+
"426627000 59.0\n",
|
| 125 |
+
"733534002 299.0\n",
|
| 126 |
+
"713427006 963.0\n",
|
| 127 |
+
"270492004 706.0\n",
|
| 128 |
+
"713426002 372.0\n",
|
| 129 |
+
"39732003 1526.0\n",
|
| 130 |
+
"445118002 437.0\n",
|
| 131 |
+
"164947007 74.0\n",
|
| 132 |
+
"251146004 320.0\n",
|
| 133 |
+
"111975006 382.0\n",
|
| 134 |
+
"698252002 354.0\n",
|
| 135 |
+
"426783006 5794.0\n",
|
| 136 |
+
"284470004 653.0\n",
|
| 137 |
+
"10370003 296.0\n",
|
| 138 |
+
"365413008 120.0\n",
|
| 139 |
+
"427172004 387.0\n",
|
| 140 |
+
"164917005 415.0\n",
|
| 141 |
+
"47665007 256.0\n",
|
| 142 |
+
"427393009 758.0\n",
|
| 143 |
+
"426177001 3784.0\n",
|
| 144 |
+
"427084000 1932.0\n",
|
| 145 |
+
"164934002 2343.0\n",
|
| 146 |
+
"59931005 798.0\n",
|
| 147 |
+
"dtype: float64\n"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"name": "stderr",
|
| 152 |
+
"output_type": "stream",
|
| 153 |
+
"text": [
|
| 154 |
+
" 0%| | 0/138 [00:00<?, ?it/s]"
|
| 155 |
+
]
|
| 156 |
+
}
|
| 157 |
+
],
|
| 158 |
+
"source": [
|
| 159 |
+
"import torch\n",
|
| 160 |
+
"import numpy as np\n",
|
| 161 |
+
"from tqdm import tqdm\n",
|
| 162 |
+
"from sklearn.metrics import average_precision_score, roc_auc_score, f1_score\n",
|
| 163 |
+
"import pandas as pd\n",
|
| 164 |
+
"from dataset import dataset\n",
|
| 165 |
+
"from torch.utils.data import DataLoader\n",
|
| 166 |
+
"from model import NN\n",
|
| 167 |
+
"import pickle\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
|
| 171 |
+
"DEVICE = 'cpu'\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"address = \"./model_output/model_group5/PROGRESS.pickle\"\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"with open(address, 'rb') as file:\n",
|
| 176 |
+
" data = pickle.load(file)\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"new_state_dict = data[46]['model']\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"def collate(batch):\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" ch = batch[0][0].shape[0]\n",
|
| 183 |
+
" maxL = 8192\n",
|
| 184 |
+
" X = np.zeros((len(batch), ch, maxL))\n",
|
| 185 |
+
" \n",
|
| 186 |
+
" for i in range(len(batch)):\n",
|
| 187 |
+
" X[i, :, -batch[i][0].shape[-1]:] = batch[i][0]\n",
|
| 188 |
+
" \n",
|
| 189 |
+
" t = np.array([b[1] for b in batch])\n",
|
| 190 |
+
" l = np.concatenate([b[2].reshape(1,12) for b in batch], axis=0)\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" X = torch.from_numpy(X)\n",
|
| 193 |
+
" t = torch.from_numpy(t)\n",
|
| 194 |
+
" l = torch.from_numpy(l)\n",
|
| 195 |
+
" return X, t, l\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"def valid_part(model, dataset):\n",
|
| 198 |
+
" targets = []\n",
|
| 199 |
+
" outputs = []\n",
|
| 200 |
+
" model.eval()\n",
|
| 201 |
+
" with torch.no_grad():\n",
|
| 202 |
+
" for i, (x, t, l) in enumerate(tqdm(dataset)):\n",
|
| 203 |
+
" x = x.unsqueeze(2).float().to(DEVICE)\n",
|
| 204 |
+
" t = t.to(DEVICE)\n",
|
| 205 |
+
" l = l.float().to(DEVICE)\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" y,p = model(x, l)\n",
|
| 208 |
+
" #p = torch.sigmoid(y)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
" targets.append(t.data.cpu().numpy())\n",
|
| 211 |
+
" outputs.append(p.data.cpu().numpy())\n",
|
| 212 |
+
" \n",
|
| 213 |
+
" targets = np.concatenate(targets, axis=0)\n",
|
| 214 |
+
" outputs = np.concatenate(outputs, axis=0)\n",
|
| 215 |
+
" auprc = average_precision_score(y_true=targets, y_score=outputs)\n",
|
| 216 |
+
" auroc = roc_auc_score(targets, outputs)\n",
|
| 217 |
+
"\n",
|
| 218 |
+
" outputs_f1 = np.array([[(1 if prob > 0.5 else 0) for prob in probs] for probs in np.array(outputs)])\n",
|
| 219 |
+
" f1 = f1_score(targets, outputs_f1, average='weighted')\n",
|
| 220 |
+
" print(\"This is the auroc of testing:\", auroc)\n",
|
| 221 |
+
" print(\"This is the f1 of testing:\", f1)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" return auprc, targets, outputs, auroc, f1\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"file_address = \"../collection_of_all_datasets/\"\n",
|
| 226 |
+
"data_directory = \"./csv-file/training_validation_testing/group5\"\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"############ testing area #########################\n",
|
| 229 |
+
"the_testing_address = data_directory + \"/testing_group\"+data_directory[-1]+\".csv\"\n",
|
| 230 |
+
"df = pd.read_csv(the_testing_address)\n",
|
| 231 |
+
"print(df['Name'][0])\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"testing_header_files=[]\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"for i in range(len(df['Name'])):\n",
|
| 236 |
+
" each_header_file = file_address + df['Name'][i]\n",
|
| 237 |
+
" testing_header_files.append(each_header_file)\n",
|
| 238 |
+
" \n",
|
| 239 |
+
"test_dataset = dataset(testing_header_files)\n",
|
| 240 |
+
"print(test_dataset.summary('pandas'))\n",
|
| 241 |
+
" \n",
|
| 242 |
+
"\n",
|
| 243 |
+
"test_dataset.num_leads = 12\n",
|
| 244 |
+
"test_dataset.sample = True\n",
|
| 245 |
+
"###################################################\n",
|
| 246 |
+
"valid = DataLoader(dataset=test_dataset,\n",
|
| 247 |
+
" batch_size=128,\n",
|
| 248 |
+
" shuffle=False,\n",
|
| 249 |
+
" num_workers=8,\n",
|
| 250 |
+
" collate_fn=collate,\n",
|
| 251 |
+
" pin_memory=True,\n",
|
| 252 |
+
" drop_last=False)\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"model = NN(nOUT=26).to(DEVICE)\n",
|
| 255 |
+
"model.load_state_dict(new_state_dict)\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"auprc, targets, outputs, auroc, f1 = valid_part(model, valid)\n",
|
| 258 |
+
"print(\"============================================\")\n",
|
| 259 |
+
"print(\"This is the auprc:\", auprc)\n",
|
| 260 |
+
"print(\"This is the auroc:\", auroc)\n",
|
| 261 |
+
"print(\"This is the f1: \", f1)\n"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"cell_type": "markdown",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"source": [
|
| 268 |
+
"#### AUPRC Checking for the 12-lead"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": 4,
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"outputs": [
|
| 276 |
+
{
|
| 277 |
+
"name": "stdout",
|
| 278 |
+
"output_type": "stream",
|
| 279 |
+
"text": [
|
| 280 |
+
"<class 'list'> 70\n",
|
| 281 |
+
"<class 'dict'>\n",
|
| 282 |
+
"----------------------\n",
|
| 283 |
+
"epoch: \n",
|
| 284 |
+
"model: \n",
|
| 285 |
+
"train_auprc: \n",
|
| 286 |
+
"valid_auprc: \n",
|
| 287 |
+
"valid_auroc: \n",
|
| 288 |
+
"valid_targets: \n",
|
| 289 |
+
"valid_outputs: \n",
|
| 290 |
+
"-----------------------\n",
|
| 291 |
+
"-----------------------\n",
|
| 292 |
+
"[0.22930097130178695, 0.32832962591991394, 0.35747821563458165, 0.40334068352552854, 0.4143247499932358, 0.4029527044033194, 0.43389495514664134, 0.43227590210447153, 0.42763094099224175, 0.44957633208074327, 0.4597494213037991, 0.4526247117906813, 0.43743989538814887, 0.45792376450331546, 0.451313159856729, 0.4639358454505829, 0.47232763045243853, 0.4606242649328537, 0.4639758325284892, 0.46652311062551555, 0.5042587688515305, 0.5066928376264266, 0.5095729769831645, 0.5080036767983908, 0.5127449760336349, 0.5125109269447715, 0.5122839697295523, 0.5136472386495577, 0.511262098976581, 0.5164922389762391, 0.5140972681257835, 0.5178319592806367, 0.5164327367393361, 0.5148920196758733, 0.5163729160832938, 0.5160267502885856, 0.5145344650062585, 0.5168218047546651, 0.5167044163613791, 0.516398763630853, 0.519743273529755, 0.5180094428789471, 0.5187746854102783, 0.5189446261317721, 0.5191071800725848, 0.5190530337013738, 0.5181188209013248, 0.5198348158610016, 0.5201651447934246, 0.5200725282638343, 0.5189072362353084, 0.5188762698996898, 0.5176415639383447, 0.5190125547904473, 0.5182098122362689, 0.5199563616134137, 0.5200210859305623, 0.5213200600204746, 0.5191351758073407, 0.518985947976628, 0.5195132598638488, 0.519068429782278, 0.5185359984682572, 0.5184963313155649, 0.5190214730287106, 0.5187674089918028, 0.5193985132112929, 0.5195558026974484, 0.5183350858698199, 0.5191250549273707]\n",
|
| 293 |
+
"0.5213200600204746 57\n"
|
| 294 |
+
]
|
| 295 |
+
}
|
| 296 |
+
],
|
| 297 |
+
"source": [
|
| 298 |
+
"import pickle\n",
|
| 299 |
+
"import torch\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"address = \"./model_output_12_lead/model_group4/PROGRESS.pickle\"\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"with open(address, 'rb') as file:\n",
|
| 304 |
+
" data = pickle.load(file)\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"print(type(data), len(data))\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"print(type(data[1]))\n",
|
| 309 |
+
"print(\"----------------------\")\n",
|
| 310 |
+
"for key, _ in data[1].items():\n",
|
| 311 |
+
" print(f\"{key}: \")\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"print(\"-----------------------\")\n",
|
| 314 |
+
"AUPRC_list = []\n",
|
| 315 |
+
"for i in range(len(data)):\n",
|
| 316 |
+
" AUPRC_list.append(data[i][\"valid_auprc\"])\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"print(\"-----------------------\") \n",
|
| 319 |
+
"print(AUPRC_list)\n",
|
| 320 |
+
"(\"-----------------------\")\n",
|
| 321 |
+
"largest_number = max(AUPRC_list)\n",
|
| 322 |
+
"index = AUPRC_list.index(largest_number)\n",
|
| 323 |
+
"print(largest_number, index)"
|
| 324 |
+
]
|
| 325 |
+
}
|
| 326 |
+
],
|
| 327 |
+
"metadata": {
|
| 328 |
+
"kernelspec": {
|
| 329 |
+
"display_name": "testing",
|
| 330 |
+
"language": "python",
|
| 331 |
+
"name": "python3"
|
| 332 |
+
},
|
| 333 |
+
"language_info": {
|
| 334 |
+
"codemirror_mode": {
|
| 335 |
+
"name": "ipython",
|
| 336 |
+
"version": 3
|
| 337 |
+
},
|
| 338 |
+
"file_extension": ".py",
|
| 339 |
+
"mimetype": "text/x-python",
|
| 340 |
+
"name": "python",
|
| 341 |
+
"nbconvert_exporter": "python",
|
| 342 |
+
"pygments_lexer": "ipython3",
|
| 343 |
+
"version": "3.10.14"
|
| 344 |
+
}
|
| 345 |
+
},
|
| 346 |
+
"nbformat": 4,
|
| 347 |
+
"nbformat_minor": 2
|
| 348 |
+
}
|