{"cells":[{"cell_type":"markdown","metadata":{"id":"6VRMQUoFAhpm"},"source":["# Concentration Prediction with EEG\n","\n","ตัวอย่างการใช้งาน Neural network\n","จากบทที่แล้วเราได้ทำความรู้จักกับ Neural network(NN)ในเบื้องต้นไปแล้วใน notebook นี้เราจะลองสร้าง NN เพื่อดูว่าผู้เข้าทดสอบคนไหนสับสนกับบทเรียนจากสัญญาณ EEG และข้อมูลอื่นๆของผู้เข้าทดสอบ"]},{"cell_type":"markdown","metadata":{"id":"s3II4uXpjqtf"},"source":["## Install required library and download dataset\n","ใช้ API ของkaggleเพื่อ download dataset"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4519,"status":"ok","timestamp":1688405975527,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"7pARR1pxBNxP","outputId":"9203d0bf-b059-4fb3-869f-5644806fae79"},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.13)\n","Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2023.5.7)\n","Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.27.1)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.65.0)\n","Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.1)\n","Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.26.16)\n","Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.4)\n"]}],"source":["!pip install kaggle"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1558,"status":"ok","timestamp":1688405981867,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"re79gQL_BRjj","outputId":"9e6b401e-2865-4918-a9ef-3d609a7e79d1"},"outputs":[{"name":"stdout","output_type":"stream","text":["cp: cannot stat 'kaggle.json': No such file or directory\n","chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory\n","Traceback (most recent call last):\n"," File \"/usr/local/bin/kaggle\", line 5, in \n"," from kaggle.cli import main\n"," File \"/usr/local/lib/python3.10/dist-packages/kaggle/__init__.py\", line 23, in \n"," api.authenticate()\n"," File \"/usr/local/lib/python3.10/dist-packages/kaggle/api/kaggle_api_extended.py\", line 164, in authenticate\n"," raise IOError('Could not find {}. Make sure it\\'s located in'\n","OSError: Could not find kaggle.json. Make sure it's located in /root/.kaggle. Or use the environment method.\n","unzip: cannot find or open confused-eeg.zip, confused-eeg.zip.zip or confused-eeg.zip.ZIP.\n"]}],"source":["# อัพโหลด kaggle.json ที่หาได้จาก https://www.kaggle.com/settings (ไปที่หน้านี้แล้วกด \"Create New Token\" ใน section API)\n","!cp kaggle.json /root/.kaggle/\n","!chmod 600 /root/.kaggle/kaggle.json\n","!kaggle datasets download -d wanghaohan/confused-eeg\n","!unzip confused-eeg.zip"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"G7Z_QHJaAhpq"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import pandas as pd\n","import seaborn as sns\n","import matplotlib.pyplot as plt\n","from torch.utils.data import Dataset, DataLoader"]},{"cell_type":"markdown","metadata":{"id":"lIGNup3vkHjB"},"source":["อ่านข้อมูลในส่วนของ EEG ก่อนโดยในไฟล์นี้จะประกอบไปด้วย\n"," - SubjectID, VideoID: ID ของวิชาและวิดีโอ\n"," - Attention: ระดับความใส่ใจ\n"," - Mediation: ระดับสมาธิ\n"," - Raw: สัญญาณดิบ\n"," - EEG ในแต่ละคลื่นความถี่\n"," - predefinedlabel: ระดับความสับสนที่คาดเดา (ไม่ได้ใช้ในบทนี้) 0 คือไม่เข้าใจ 1 คือเข้าใจ\n"," - user-definedlabeln: ของระดับความสับสนของนักเรียนหลังจากเรียน 0 คือไม่เข้าใจ 1 คือเข้าใจ\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":287},"executionInfo":{"elapsed":10,"status":"ok","timestamp":1688406484473,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"E4N-AGMhAhps","outputId":"0d65ce0b-6552-4e4e-a6f7-a675bf5df61f"},"outputs":[{"data":{"text/html":["\n","
\n","
\n","
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SubjectIDVideoIDAttentionMediationRawDeltaThetaAlpha1Alpha2Beta1Beta2Gamma1Gamma2predefinedlabeluser-definedlabeln
00.00.056.043.0278.0301963.090612.033735.023991.027946.045097.033228.08293.00.00.0
10.00.040.035.0-50.073787.028083.01439.02240.02746.03687.05293.02740.00.00.0
20.00.047.048.0101.0758353.0383745.0201999.062107.036293.0130536.057243.025354.00.00.0
30.00.047.057.0-5.02012240.0129350.061236.017084.011488.062462.049960.033932.00.00.0
40.00.044.053.0-8.01005145.0354328.037102.088881.045307.099603.044790.029749.00.00.0
\n","
\n"," \n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" SubjectID VideoID Attention Mediation Raw Delta Theta \\\n","0 0.0 0.0 56.0 43.0 278.0 301963.0 90612.0 \n","1 0.0 0.0 40.0 35.0 -50.0 73787.0 28083.0 \n","2 0.0 0.0 47.0 48.0 101.0 758353.0 383745.0 \n","3 0.0 0.0 47.0 57.0 -5.0 2012240.0 129350.0 \n","4 0.0 0.0 44.0 53.0 -8.0 1005145.0 354328.0 \n","\n"," Alpha1 Alpha2 Beta1 Beta2 Gamma1 Gamma2 predefinedlabel \\\n","0 33735.0 23991.0 27946.0 45097.0 33228.0 8293.0 0.0 \n","1 1439.0 2240.0 2746.0 3687.0 5293.0 2740.0 0.0 \n","2 201999.0 62107.0 36293.0 130536.0 57243.0 25354.0 0.0 \n","3 61236.0 17084.0 11488.0 62462.0 49960.0 33932.0 0.0 \n","4 37102.0 88881.0 45307.0 99603.0 44790.0 29749.0 0.0 \n","\n"," user-definedlabeln \n","0 0.0 \n","1 0.0 \n","2 0.0 \n","3 0.0 \n","4 0.0 "]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["eeg_df = pd.read_csv('/EEG_data.csv')\n","eeg_df.head()"]},{"cell_type":"markdown","metadata":{"id":"NKS4jG6FClIb"},"source":["## Add demographic data"]},{"cell_type":"markdown","metadata":{"id":"ECu6sAdqlNIx"},"source":["อีกไฟล์จะเป็นไฟล์ที่เป็นข้อมูลของผู้เข้ารับการทดสอบโดยจะประกอบไปด้วย\n","- ID ของนักเรียน\n","- เพศ\n","- เชื้อชาติ\n","- อายุ"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fwXKbqxFCcq0"},"outputs":[],"source":["dem_df = pd.read_csv('/demographic_info.csv')"]},{"cell_type":"markdown","metadata":{"id":"o7PwMohNCzX3"},"source":["## Preprocess data\n","โดยภาพรวมเราจะนำสองตารางนี้มารวมกันก่อนจะทำการจัดเตรียมในเบื้องต้นเพื่อนำมาลองใช้กับ Neural Network"]},{"cell_type":"markdown","metadata":{"id":"52c-1yLfmBVm"},"source":["เปลี่ยนชื่อเพื่อความง่ายในการใช้งานโดยใช้ `.rename()`"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"G77HGM7XCkar"},"outputs":[],"source":["dem_df = dem_df.rename(columns = {'subject ID' : 'SubjectID'})"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":363},"executionInfo":{"elapsed":737,"status":"ok","timestamp":1688406509464,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"7UmGgglJCrsm","outputId":"538fce3c-cba2-4db0-b4f1-5b760487b83e"},"outputs":[{"data":{"text/html":["\n","
\n","
\n","
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SubjectIDageethnicitygender
00.025Han ChineseM
11.024Han ChineseM
22.031EnglishM
33.028Han ChineseF
44.024BengaliM
55.024Han ChineseM
66.024Han ChineseM
77.025Han ChineseM
88.025Han ChineseM
99.024Han ChineseF
\n","
\n"," \n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" SubjectID age ethnicity gender\n","0 0.0 25 Han Chinese M\n","1 1.0 24 Han Chinese M\n","2 2.0 31 English M\n","3 3.0 28 Han Chinese F\n","4 4.0 24 Bengali M\n","5 5.0 24 Han Chinese M\n","6 6.0 24 Han Chinese M\n","7 7.0 25 Han Chinese M\n","8 8.0 25 Han Chinese M\n","9 9.0 24 Han Chinese F"]},"execution_count":8,"metadata":{},"output_type":"execute_result"}],"source":["dem_df['SubjectID'] = dem_df['SubjectID'].astype(np.float64)\n","dem_df"]},{"cell_type":"markdown","metadata":{"id":"vB48m888flCb"},"source":["## รวม EEG กับ demography\n","โดยใช้ `.merge(dem_df, how='inner', on='SubjectID`ซึ่งการ merge แบบ inner จะเลือกเฉพาะ row มี่มีค่าที่เราต้องการ(`SubjectID`)ทั้งคู่เท่านั้น"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"YbU69OxRCvIv"},"outputs":[],"source":["eeg_df = eeg_df.merge(dem_df, how = 'inner', on = 'SubjectID')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":678},"executionInfo":{"elapsed":537,"status":"ok","timestamp":1688406528853,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"kGVvMYhFa0r3","outputId":"35de6e6f-c149-45c6-b112-b421c242b785"},"outputs":[{"data":{"text/html":["\n","
\n","
\n","
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SubjectIDVideoIDAttentionMediationRawDeltaThetaAlpha1Alpha2Beta1Beta2Gamma1Gamma2predefinedlabeluser-definedlabelnageethnicitygender
00.00.056.043.0278.0301963.090612.033735.023991.027946.045097.033228.08293.00.00.025Han ChineseM
10.00.040.035.0-50.073787.028083.01439.02240.02746.03687.05293.02740.00.00.025Han ChineseM
20.00.047.048.0101.0758353.0383745.0201999.062107.036293.0130536.057243.025354.00.00.025Han ChineseM
30.00.047.057.0-5.02012240.0129350.061236.017084.011488.062462.049960.033932.00.00.025Han ChineseM
40.00.044.053.0-8.01005145.0354328.037102.088881.045307.099603.044790.029749.00.00.025Han ChineseM
.........................................................
128069.09.064.038.0-39.0127574.09951.0709.021732.03872.039728.02598.0960.01.00.024Han ChineseF
128079.09.061.035.0-275.0323061.0797464.0153171.0145805.039829.0571280.036574.010010.01.00.024Han ChineseF
128089.09.060.029.0-426.0680989.0154296.040068.039122.010966.026975.020427.02024.01.00.024Han ChineseF
128099.09.060.029.0-84.0366269.027346.011444.09932.01939.03283.012323.01764.01.00.024Han ChineseF
128109.09.064.029.0-49.01164555.01184366.050014.0124208.010634.0445383.022133.04482.01.00.024Han ChineseF
\n","

12811 rows × 18 columns

\n","
\n"," \n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" SubjectID VideoID Attention Mediation Raw Delta Theta \\\n","0 0.0 0.0 56.0 43.0 278.0 301963.0 90612.0 \n","1 0.0 0.0 40.0 35.0 -50.0 73787.0 28083.0 \n","2 0.0 0.0 47.0 48.0 101.0 758353.0 383745.0 \n","3 0.0 0.0 47.0 57.0 -5.0 2012240.0 129350.0 \n","4 0.0 0.0 44.0 53.0 -8.0 1005145.0 354328.0 \n","... ... ... ... ... ... ... ... \n","12806 9.0 9.0 64.0 38.0 -39.0 127574.0 9951.0 \n","12807 9.0 9.0 61.0 35.0 -275.0 323061.0 797464.0 \n","12808 9.0 9.0 60.0 29.0 -426.0 680989.0 154296.0 \n","12809 9.0 9.0 60.0 29.0 -84.0 366269.0 27346.0 \n","12810 9.0 9.0 64.0 29.0 -49.0 1164555.0 1184366.0 \n","\n"," Alpha1 Alpha2 Beta1 Beta2 Gamma1 Gamma2 \\\n","0 33735.0 23991.0 27946.0 45097.0 33228.0 8293.0 \n","1 1439.0 2240.0 2746.0 3687.0 5293.0 2740.0 \n","2 201999.0 62107.0 36293.0 130536.0 57243.0 25354.0 \n","3 61236.0 17084.0 11488.0 62462.0 49960.0 33932.0 \n","4 37102.0 88881.0 45307.0 99603.0 44790.0 29749.0 \n","... ... ... ... ... ... ... \n","12806 709.0 21732.0 3872.0 39728.0 2598.0 960.0 \n","12807 153171.0 145805.0 39829.0 571280.0 36574.0 10010.0 \n","12808 40068.0 39122.0 10966.0 26975.0 20427.0 2024.0 \n","12809 11444.0 9932.0 1939.0 3283.0 12323.0 1764.0 \n","12810 50014.0 124208.0 10634.0 445383.0 22133.0 4482.0 \n","\n"," predefinedlabel user-definedlabeln age ethnicity gender \n","0 0.0 0.0 25 Han Chinese M \n","1 0.0 0.0 25 Han Chinese M \n","2 0.0 0.0 25 Han Chinese M \n","3 0.0 0.0 25 Han Chinese M \n","4 0.0 0.0 25 Han Chinese M \n","... ... ... ... ... ... \n","12806 1.0 0.0 24 Han Chinese F \n","12807 1.0 0.0 24 Han Chinese F \n","12808 1.0 0.0 24 Han Chinese F \n","12809 1.0 0.0 24 Han Chinese F \n","12810 1.0 0.0 24 Han Chinese F \n","\n","[12811 rows x 18 columns]"]},"execution_count":10,"metadata":{},"output_type":"execute_result"}],"source":["eeg_df"]},{"cell_type":"markdown","metadata":{"id":"BIIJGG6dC19Q"},"source":["## Convert to one-hot encoding\n","จะเปลี่ยนจากข้อมูลที่เป็น \"ประเภท\" เป็น ตัวเลข โดยใช้ `.get_dummies()`\\\n","EX: English, Chinese, Other -> (0,0) (0,1), (0,1) \\\n","Male/Female -> 1 & 0\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"o1siKFnMC9a4"},"outputs":[],"source":["eeg_df = pd.get_dummies(eeg_df)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":317},"executionInfo":{"elapsed":738,"status":"ok","timestamp":1688406543864,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"gdyPCbO7a75P","outputId":"507d138d-4d7b-47c7-e912-601bfd6ca890"},"outputs":[{"data":{"text/html":["\n","
\n","
\n","
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SubjectIDVideoIDAttentionMediationRawDeltaThetaAlpha1Alpha2Beta1...Gamma1Gamma2predefinedlabeluser-definedlabelnageethnicity_Bengaliethnicity_Englishethnicity_Han Chinesegender_Fgender_M
00.00.056.043.0278.0301963.090612.033735.023991.027946.0...33228.08293.00.00.02500101
10.00.040.035.0-50.073787.028083.01439.02240.02746.0...5293.02740.00.00.02500101
20.00.047.048.0101.0758353.0383745.0201999.062107.036293.0...57243.025354.00.00.02500101
30.00.047.057.0-5.02012240.0129350.061236.017084.011488.0...49960.033932.00.00.02500101
40.00.044.053.0-8.01005145.0354328.037102.088881.045307.0...44790.029749.00.00.02500101
\n","

5 rows × 21 columns

\n","
\n"," \n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" SubjectID VideoID Attention Mediation Raw Delta Theta \\\n","0 0.0 0.0 56.0 43.0 278.0 301963.0 90612.0 \n","1 0.0 0.0 40.0 35.0 -50.0 73787.0 28083.0 \n","2 0.0 0.0 47.0 48.0 101.0 758353.0 383745.0 \n","3 0.0 0.0 47.0 57.0 -5.0 2012240.0 129350.0 \n","4 0.0 0.0 44.0 53.0 -8.0 1005145.0 354328.0 \n","\n"," Alpha1 Alpha2 Beta1 ... Gamma1 Gamma2 predefinedlabel \\\n","0 33735.0 23991.0 27946.0 ... 33228.0 8293.0 0.0 \n","1 1439.0 2240.0 2746.0 ... 5293.0 2740.0 0.0 \n","2 201999.0 62107.0 36293.0 ... 57243.0 25354.0 0.0 \n","3 61236.0 17084.0 11488.0 ... 49960.0 33932.0 0.0 \n","4 37102.0 88881.0 45307.0 ... 44790.0 29749.0 0.0 \n","\n"," user-definedlabeln age ethnicity_Bengali ethnicity_English \\\n","0 0.0 25 0 0 \n","1 0.0 25 0 0 \n","2 0.0 25 0 0 \n","3 0.0 25 0 0 \n","4 0.0 25 0 0 \n","\n"," ethnicity_Han Chinese gender_F gender_M \n","0 1 0 1 \n","1 1 0 1 \n","2 1 0 1 \n","3 1 0 1 \n","4 1 0 1 \n","\n","[5 rows x 21 columns]"]},"execution_count":12,"metadata":{},"output_type":"execute_result"}],"source":["eeg_df.head()"]},{"cell_type":"markdown","metadata":{"id":"VKIY_B6BDdHG"},"source":["## Data cleaning\n","ทำการนำcolumn ที่ไม่ต้องการออกไปเช่น `SubjectID`, `VideoID` เพราะเราต้องการจะวัดความเข้าใจจากนักเรียนดังนั้นการที่มี วิชา และ video ที่เป็นสิ่งกระตุ้นให้เกิดความไม่เข้าใจนั้นอาจจะทำให้โมเดลของเราคาดเดาผลลัพธ์จากทั้งสอง features แทนที่จะใช้ข้อมูลของตัวนักเรียนเอง\n","\n","รวมถึง `predefinedlabel` ที่ไม่จำเป็นและ `gender_F` ที่เป็น columnsที่เกินมาจากการทำ one hot encoding\n","โดยใช่ `.drop()`"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kWmUYYzGDgnY"},"outputs":[],"source":["eeg_df = eeg_df.drop(['SubjectID', 'VideoID', 'predefinedlabel', ' gender_F'], axis = 1)"]},{"cell_type":"markdown","metadata":{"id":"8HEdtPqtEarW"},"source":["`Mediation` and `Attention` มีค่าเป็น 0 ซึ่งเป็นข้อผิดพลาดตามที่ผู้เขียนกล่าวในการอภิปรายดังนั้นเราจะเลือกข้อมูลที่ > 0 เท่านั้น"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"jQbw6PRzDjfI"},"outputs":[],"source":["eeg_df = eeg_df[eeg_df['Attention'] > 0.0]\n","eeg_df = eeg_df[eeg_df['Mediation'] > 0.0]"]},{"cell_type":"markdown","metadata":{"id":"Gkmu9rF9wlwr"},"source":["label มีเพียง 0,1"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":527,"status":"ok","timestamp":1688406572027,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"rsSPiD8_EM2w","outputId":"c9402298-1079-42b5-b9f3-28b36d2120b3"},"outputs":[{"data":{"text/plain":["array([0., 1.])"]},"execution_count":15,"metadata":{},"output_type":"execute_result"}],"source":["eeg_df['user-definedlabeln'].unique()"]},{"cell_type":"markdown","metadata":{"id":"ORWiuBiuEekn"},"source":["## Get the arrays from dataset"]},{"cell_type":"markdown","metadata":{"id":"QZld3sW2wp_X"},"source":["แยก column ที่เราต้องการทำนายออกมา"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"H-vHXno2EhRW"},"outputs":[],"source":["X = np.array(eeg_df.drop(['user-definedlabeln'], axis = 1))\n","y = np.array(eeg_df['user-definedlabeln'])"]},{"cell_type":"markdown","metadata":{"id":"gFpGT6FqEpdP"},"source":["## Data preprocessing"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":551,"status":"ok","timestamp":1688406579829,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"7wUo2i6ZEvLa","outputId":"11212379-a1d8-419d-97a9-6a4caf70f632"},"outputs":[{"name":"stdout","output_type":"stream","text":["-2048.0\n","3964663.0\n"]}],"source":["print(X.min())\n","print(X.max())"]},{"cell_type":"markdown","metadata":{"id":"WoGkB_JZEwwj"},"source":["เราจะเห็นว่าค่าแต่ละค่าของ feature นั้นแตกต่างกันมากเนื่องจากแต่ละ featureใช่คนละ scale ดังนั้นเราจึงต้องใช้ `StandardScaler`เข้ามาช่วยให้ข้อมูลนั้นอยู่ใน scale เดียวกัน\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ZkvCnJgcEq_B"},"outputs":[],"source":["from sklearn.preprocessing import StandardScaler\n","X = StandardScaler().fit_transform(X)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2,"status":"ok","timestamp":1688406585502,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"vbaRIxj4KiV7","outputId":"eac917f4-3f86-43c5-c90e-20130ce1b04c"},"outputs":[{"name":"stdout","output_type":"stream","text":["-15.829161155386538\n","29.216116594451524\n"]}],"source":["print(X.min())\n","print(X.max())"]},{"cell_type":"markdown","metadata":{"id":"l1U83FWtKnjQ"},"source":["## Split data"]},{"cell_type":"markdown","metadata":{"id":"jDwbdHznyxVG"},"source":["แยกข้อมูลเป็น train-tests โดย `train_test_split`"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"O-5-zEdXKsad"},"outputs":[],"source":["from sklearn.model_selection import train_test_split\n","X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 1)"]},{"cell_type":"markdown","metadata":{"id":"ibB42cX2y3-I"},"source":["เช็คขนาดนของ X, y"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2,"status":"ok","timestamp":1688406591053,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"BQI76JmpKvC5","outputId":"2deb2fb2-3353-4ca9-97d9-f36669b91403"},"outputs":[{"name":"stdout","output_type":"stream","text":["(9110, 16)\n","(9110,)\n"]}],"source":["print(X_train.shape)\n","print(y_train.shape)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2,"status":"ok","timestamp":1688406593370,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"NEYjavqxKxDv","outputId":"3e78ec05-6e91-4e83-8a65-70eb311bfd10"},"outputs":[{"name":"stdout","output_type":"stream","text":["(2278, 16)\n","(2278,)\n"]}],"source":["print(X_test.shape)\n","print(y_test.shape)"]},{"cell_type":"markdown","metadata":{"id":"VyU01Wj8y7XP"},"source":["สร้าง NN ของเราขึ้นมาตาที่เคยเรียนในบทก่อนหน้า"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FVn4j4tlPzA8"},"outputs":[],"source":["class StudentDataset(Dataset):\n"," def __init__(self, X, y):\n"," # เปลี่ยนให้อยู่ในรูป tensor\n"," self.X = torch.tensor(X, dtype=torch.float32)\n"," self.y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)\n","\n"," def __len__(self):\n"," return self.y.shape[0]\n","\n"," def __getitem__(self, index):\n"," features = self.X[index]\n"," label = self.y[index]\n","\n"," return features, label"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"uvquvOBLRBg4"},"outputs":[],"source":["# สร้าง dataloader สำหรับเทรน\n","train_dataset = StudentDataset(X_train, y_train)\n","train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)"]},{"cell_type":"markdown","metadata":{"id":"3ERKyj7HK2N9"},"source":["## Create a model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vHTAlisfK7Od"},"outputs":[],"source":["class EEGNet(nn.Module):\n"," # สร้าง Neural Network ที่มี 2 Linear layer\n"," def __init__(self, input_size):\n"," super().__init__()\n"," self.fc1 = nn.Linear(input_size, 4)\n"," self.fc2 = nn.Linear(4, 1)\n","\n"," # สร้าง forward porpagation\n"," def forward(self, x):\n"," x = F.relu(self.fc1(x))\n"," x = self.fc2(x)\n"," x = F.sigmoid(x)\n"," return x"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":705,"status":"ok","timestamp":1688406623211,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"GbHdMiIQLjsq","outputId":"59c95b58-f533-45ab-d5be-5d6076c8fb41"},"outputs":[{"data":{"text/plain":["EEGNet(\n"," (fc1): Linear(in_features=16, out_features=4, bias=True)\n"," (fc2): Linear(in_features=4, out_features=1, bias=True)\n",")"]},"execution_count":26,"metadata":{},"output_type":"execute_result"}],"source":["model = EEGNet(16)\n","model"]},{"cell_type":"markdown","metadata":{"id":"WDr5RF_Rzo3z"},"source":["เรากำหนดให้โมเดลเทรนข้อมูลไป 10 epoch โดยใช้ for loop\\\n","จากนั้นในแต่ละ batch จะทำการ\n","1. ล้าง gradient ของ optimizer ใน iteration ก่อนหน้าด้วย (`optimizer.zero_grad()`)\n","2. ผ่านข้อมูลเข้าไปในโมเดล\n","3. คำนวน loss โดย (`criterion(outputs, labels)`) จะเป็นการเทียบระหว่าง output & labels\n","4. และจำหา gradient และ ปรับ parameter โดย (`loss.backward()`) และ (`optimizer.step()`) ตามลำดับ"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"WszzAZboNedi"},"outputs":[],"source":["criterion = nn.BCELoss() # binary cross entropy\n","optimizer = optim.Adam(model.parameters(), lr=0.001)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2091,"status":"ok","timestamp":1688406634138,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"MWwiKqosOEJr","outputId":"5fbd7f1f-9076-47ff-c08e-96931e573470"},"outputs":[{"name":"stdout","output_type":"stream","text":["[epoch: 1 ] loss: 0.720\n","[epoch: 2 ] loss: 0.688\n","[epoch: 3 ] loss: 0.674\n","[epoch: 4 ] loss: 0.687\n","[epoch: 5 ] loss: 0.654\n","[epoch: 6 ] loss: 0.698\n","[epoch: 7 ] loss: 0.648\n","[epoch: 8 ] loss: 0.665\n","[epoch: 9 ] loss: 0.656\n","[epoch: 10 ] loss: 0.622\n","Finished Training\n"]}],"source":["for epoch in range(10):\n"," running_loss = 0.0\n"," for i, data in enumerate(train_loader, 0):\n"," inputs, labels = data\n"," optimizer.zero_grad()\n"," outputs = model(inputs)\n"," loss = criterion(outputs, labels)\n"," loss.backward()\n"," optimizer.step()\n"," running_loss += loss.item()\n"," if i % 1000 == 0:\n"," print(f\"[epoch: {epoch + 1} ] loss: {running_loss / 1:.3f}\")\n"," running_loss = 0.0\n","\n","print(\"Finished Training\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ckdFDGigWjWu"},"outputs":[],"source":["from sklearn.metrics import accuracy_score"]},{"cell_type":"markdown","metadata":{"id":"GShOJpldzsZ7"},"source":["เช็ค accuracy ของผลลัพธ์ที่โมเดลทำนายได้ `torch.no_grad()` จะเป็นการบอกว่าไม่ต้องเก็บ gradient ระหว่างทำงาน ก่อนที่จะนำมาเปรียบเทียบจะต้องนำมา `round()` เสียก่อนเพื่อให้ค่าความน่าจะเป็นที่ออกมาเป็นค่า 0,1"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2,"status":"ok","timestamp":1688406644363,"user":{"displayName":"NOPARIN SMERWONG","userId":"16832560020600244695"},"user_tz":-420},"id":"XDZ6CZuNUH00","outputId":"c86837f9-c7c1-47d3-efe8-0836212e06cb"},"outputs":[{"name":"stdout","output_type":"stream","text":["Accuracy 0.6198419666374012\n"]}],"source":["with torch.no_grad():\n"," y_pred = model(torch.tensor(X_test, dtype=torch.float32))\n"," accuracy = accuracy_score(y_test, y_pred.round())\n"," print(f\"Accuracy {accuracy}\")"]},{"cell_type":"markdown","metadata":{"id":"nWj8y9hIU1NN"},"source":["ในตัวอย่างนี้เราค้องการที่จะแสดงให้เห็นที่ process ในการเขียนและใช้งาน Neural networkในเบื้องต้น แต่ในการใช้งานจริงนั้นจะต้องมีการ design และปรับปรุง hyperparameters ต่างๆเพื่อให้เหมาะสมกับการใช่งานรวมถึงการจัดการกับข้อมูลก่อนนำมาเทรน (preprocessing) ก็เป็นขั้นตอนที่จำเป็นเช่นกัน\n","\n","

\n","**ผู้จัดเตรียม code ใน tutorial**: นาย กรวิชญ์ โชตยาภา"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[{"file_id":"1T8QDR3EJDDPLC0iu-6AdKjTmHhjzCpl0","timestamp":1688399462796}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":0}