Image Classification#

ในตัวอย่างนี้จะเป็นการแสดงตัวอย่างของการใช้ Transformer ซึ่งเป็นรูปแบบหนึ่งของ deeplearning นอกเหนือจาก CNN ที่สามารถใช้ในการทำ Computer vision ได้เช่นกันโดยจะเป็นการทำนายชื่อของเมนูอาหารจากรูปภาพ

Overview of Transformer#

Transformers เป็นโมเดล deeplearning ที่ถูกคิดค้นขึ้นมาเพื่อใช้ในงาน Natural Language Processing เช่น แปลภาษา หรือ Large language model อย่าง Chat GPT.โดยในปัจจุบัน transformers ก็ได้ถูกนำมาประยุกต์ใช้ในส่วนของ computer vision เช่นกันโดยการทำงานของ Transformers จะแบ่งรูปภาพออกมาเป็นส่วนๆก่อนจะแปลงภาพย่อยเหล่านั้นออกมาเป็น vector ที่แสดงถึงคุณลักษณะและตำแหน่ง ของภาพนั้นๆ ต่อมาvectorจะถูกนำไปแปรผลด้วยโมเดลที่เรียกว่า encoder ซึ่งใน encoder นั้นจะมีส่วนประกอบที่เรียกว่า Attention ซึ่งจะช่วยให้โมเดลเรียนรู้ได้ว่าส่วนไหนของภาพที่ควรให้ความสำคัญ ก่อนจะเข้าสู่ Neural Network ในขั้นต่อๆเพื่อเรียนรู้ที่จะแยกแยะภาพจาก feature ที่ได้รับจนออกมาเป็น class ของภาพที่ใส่เข้าไป

อ่านเพิ่มเติม: Vision Transformer

image.png

Thai Food Classification with Huggingface’s Transformers#

Open in Colab

ใน Notebook นี้เราจะโหลดชุดข้อมูลอาหารไทย 50 ชนิด สร้าง datasets และจะใช้วิธีการ fine-tune โมเดล Swin transformer tiny เพื่อแบ่งประเภทภาพอาหารไทย 50 ชนิด

อ่านเพิ่มเติม: huggingface datasets

!pip install datasets
!pip install git+https://github.com/huggingface/transformers
!pip install gradio
!pip install transformers[torch]
Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.13.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.22.4)
Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)
Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.6)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.27.1)
Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.65.0)
Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.2.0)
Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.14)
Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.4)
Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.15.1)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.0.12)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.2)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.3)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.12.2)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.6.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.5.7)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2022.7.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)
Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-1nsnvxd0
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-1nsnvxd0
  Resolved https://github.com/huggingface/transformers to commit cd4584e3c809bb9e1392ccd3fe38b40daba5519a
  Installing build dependencies ... ?25l?25hdone
  Getting requirements to build wheel ... ?25l?25hdone
  Preparing metadata (pyproject.toml) ... ?25l?25hdone
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (0.15.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (2.27.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.31.0.dev0) (4.65.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.31.0.dev0) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.31.0.dev0) (4.6.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.31.0.dev0) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.31.0.dev0) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.31.0.dev0) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.31.0.dev0) (3.4)
Requirement already satisfied: gradio in /usr/local/lib/python3.10/dist-packages (3.35.2)
Requirement already satisfied: aiofiles in /usr/local/lib/python3.10/dist-packages (from gradio) (23.1.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from gradio) (3.8.4)
Requirement already satisfied: altair>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2)
Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from gradio) (0.99.1)
Requirement already satisfied: ffmpy in /usr/local/lib/python3.10/dist-packages (from gradio) (0.3.0)
Requirement already satisfied: gradio-client>=0.2.7 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.2.7)
Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from gradio) (0.24.1)
Requirement already satisfied: huggingface-hub>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.15.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2)
Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.2.0)
Requirement already satisfied: markupsafe in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)
Requirement already satisfied: mdit-py-plugins<=0.3.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.3.3)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from gradio) (1.22.4)
Requirement already satisfied: orjson in /usr/local/lib/python3.10/dist-packages (from gradio) (3.9.1)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3)
Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (from gradio) (8.4.0)
Requirement already satisfied: pydantic in /usr/local/lib/python3.10/dist-packages (from gradio) (1.10.9)
Requirement already satisfied: pydub in /usr/local/lib/python3.10/dist-packages (from gradio) (0.25.1)
Requirement already satisfied: pygments>=2.12.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.14.0)
Requirement already satisfied: python-multipart in /usr/local/lib/python3.10/dist-packages (from gradio) (0.0.6)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from gradio) (2.27.1)
Requirement already satisfied: semantic-version in /usr/local/lib/python3.10/dist-packages (from gradio) (2.10.0)
Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.22.0)
Requirement already satisfied: websockets>=10.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (11.0.3)
Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair>=4.2.0->gradio) (0.4)
Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair>=4.2.0->gradio) (4.3.3)
Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair>=4.2.0->gradio) (0.12.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client>=0.2.7->gradio) (2023.6.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio-client>=0.2.7->gradio) (23.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from gradio-client>=0.2.7->gradio) (4.6.3)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (3.12.2)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (4.65.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2)
Requirement already satisfied: linkify-it-py<3,>=1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (2.0.2)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->gradio) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->gradio) (2022.7.1)
Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (8.1.3)
Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (0.14.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (23.1.0)
Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (2.0.12)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (6.0.4)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (4.0.2)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (1.9.2)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (1.3.3)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio) (1.3.1)
Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio) (0.27.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (2023.5.7)
Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (0.17.2)
Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (3.4)
Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio) (4.40.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio) (1.4.4)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio) (3.1.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->gradio) (1.26.16)
Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx->gradio) (3.7.0)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (0.19.3)
Requirement already satisfied: uc-micro-py in /usr/local/lib/python3.10/dist-packages (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio) (1.0.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->gradio) (1.16.0)
Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx->gradio) (1.1.1)
Requirement already satisfied: transformers[torch] in /usr/local/lib/python3.10/dist-packages (4.31.0.dev0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (3.12.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.15.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.27.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (4.65.0)
Requirement already satisfied: torch!=1.12.0,>=1.9 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.0.1+cu118)
Collecting accelerate>=0.20.3 (from transformers[torch])
  Downloading accelerate-0.20.3-py3-none-any.whl (227 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 227.6/227.6 kB 6.6 MB/s eta 0:00:00
?25hRequirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[torch]) (5.9.5)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers[torch]) (2023.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers[torch]) (4.6.3)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch!=1.12.0,>=1.9->transformers[torch]) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch!=1.12.0,>=1.9->transformers[torch]) (16.0.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.4)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.9->transformers[torch]) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.9->transformers[torch]) (1.3.0)
Installing collected packages: accelerate
Successfully installed accelerate-0.20.3

Download Datasets#

  • ดาวน์โหลด FoodyDudy dataset จาก github

  • จากนั้นใช้ load_dataset("imagefolder", ...) เพื่ออ่านข้อมูลมาใน class Dataset

  • ทำการโหลด feature extractor ของโมเดลที่จะใช้ fine-tune เพื่อใช้ในการปรับขนาดของภาพ

!git clone https://github.com/GemmyTheGeek/FoodyDudy.git
fatal: destination path 'FoodyDudy' already exists and is not an empty directory.

list รายชื่ออาหารทั้งหมดออกมา และแสดง id

food_list = [
    'green_curry', 'tepo_curry', 'liang_curry', 'taohoo_moosup', 'mara_yadsai',
    'masaman', 'orange_curry', 'cashew_chicken', 'omelette', 'sunny_side_up',
    'palo_egg', 'sil_egg', 'nun_banana', 'kua_gai', 'cabbage_fish_sauce',
    'river_prawn', 'shrimp_ob_woonsen', 'kanom_krok', 'mango_sticky_rice', 'kao_kamoo',
    'kao_klook_kapi', 'kaosoi', 'kao_pad', 'kao_pad_shrimp', 'chicken_rice',
    'kao_mok_gai', 'tom_ka_gai', 'tom_yum_kung', 'tod_mun', 'poh_pia',
    'pak_boong_fai_daeng', 'padthai', 'pad_krapao', 'pad_si_ew', 'pad_fakthong',
    'eggplant_stirfry', 'pad_hoi_lai', 'foithong', 'panaeng', 'yum_tua_ploo',
    'yum_woonsen', 'larb_moo', 'pumpkin_custard', 'sakoo_sai_moo', 'somtam',
    'moopoing','satay', 'hor_mok'
]

# สร้าง dictionary
id2food = {str(i).zfill(2): f for i, f in enumerate(food_list)}
id2food
{'00': 'green_curry',
 '01': 'tepo_curry',
 '02': 'liang_curry',
 '03': 'taohoo_moosup',
 '04': 'mara_yadsai',
 '05': 'masaman',
 '06': 'orange_curry',
 '07': 'cashew_chicken',
 '08': 'omelette',
 '09': 'sunny_side_up',
 '10': 'palo_egg',
 '11': 'sil_egg',
 '12': 'nun_banana',
 '13': 'kua_gai',
 '14': 'cabbage_fish_sauce',
 '15': 'river_prawn',
 '16': 'shrimp_ob_woonsen',
 '17': 'kanom_krok',
 '18': 'mango_sticky_rice',
 '19': 'kao_kamoo',
 '20': 'kao_klook_kapi',
 '21': 'kaosoi',
 '22': 'kao_pad',
 '23': 'kao_pad_shrimp',
 '24': 'chicken_rice',
 '25': 'kao_mok_gai',
 '26': 'tom_ka_gai',
 '27': 'tom_yum_kung',
 '28': 'tod_mun',
 '29': 'poh_pia',
 '30': 'pak_boong_fai_daeng',
 '31': 'padthai',
 '32': 'pad_krapao',
 '33': 'pad_si_ew',
 '34': 'pad_fakthong',
 '35': 'eggplant_stirfry',
 '36': 'pad_hoi_lai',
 '37': 'foithong',
 '38': 'panaeng',
 '39': 'yum_tua_ploo',
 '40': 'yum_woonsen',
 '41': 'larb_moo',
 '42': 'pumpkin_custard',
 '43': 'sakoo_sai_moo',
 '44': 'somtam',
 '45': 'moopoing',
 '46': 'satay',
 '47': 'hor_mok'}
from datasets import load_dataset, load_metric

download dataset และ metric ที่ใช้ในการเทรนโมเดลโดยใช้ load_dataset และ load_metric

dataset = load_dataset("imagefolder", data_dir="FoodyDudy/images")
accuracy = load_metric("accuracy")
WARNING:datasets.builder:Found cached dataset imagefolder (/root/.cache/huggingface/datasets/imagefolder/default-2d8f66fe2f49d199/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
<ipython-input-6-87cdaa45b993>:2: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  accuracy = load_metric("accuracy")

dataset ของเราประกอบไปด้วย \

train 11520 ภาพ test & validation 1440 ภาพ

dataset
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 11520
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1440
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1440
    })
})
from transformers import AutoFeatureExtractor

# extracting feature โดยใช้ โมเดลที่ถูกเทรนมาก่อนแล้ว
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
feat_size = tuple(feature_extractor.size.values())
feature_extractor
/usr/local/lib/python3.10/dist-packages/transformers/models/vit/feature_extraction_vit.py:28: FutureWarning: The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ViTImageProcessor instead.
  warnings.warn(
ViTFeatureExtractor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ViTFeatureExtractor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

จากนั้นสร้าง preprocess_train และ preprocess_val เพื่อ preprocess ข้อมูลภาพในแต่ละ batch ของเรา จะเห็นว่าภาพของเราอยู่ใน key ที่ชื่อว่า image และเมื่อเรา preprocess ภาพเรียบร้อยจะเก็บไว้ใน key ที่ชื่อว่า pixel_values และเพื่อไม่ให้ประสิทธิภาพของโมเดลถูกจำกัดจากเซ็ตภาพที่เราใช้ เราจึงทำการ transform รูปภาพให้มีคุณลักษณะหลากหลายยิ่งขึ้น เช่นในภาพตัวอย่างที่มีความต่างทั้งในแง่ของสี มุมมอง รวมถึง noise แต่ทุกภาพก็คือนกแก้วเหมือนกัน ซึ่งเราต้องการให้โมเดลไม่ถูก limit ด้วยปัจจัยเหล่านี้ (ศึกษาเพิ่มเติมที่ Illustration of Transform )

1_kqKzGryy0qYjEXbR3kyNOg.png

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    RandomAffine,
    ColorJitter,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

# ทำการ Augmentation โดยการ list วิธีการที่ต้องใช้ process ภาพลงไป เช่น
# resize, RandomResizeCrop (สุ่มย่อขยายภาพและ crop), normalize อื่นๆ
train_transforms = Compose([
    Resize(feat_size),
    RandomResizedCrop(feat_size, scale=(0.8, 1.2)),
    ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1),
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

val_transforms = Compose([
    Resize(feat_size),
    ToTensor(),
    normalize,
])


def preprocess_train(example_batch):
    example_batch["pixel_values"] = [
        # เปลี่ยน รูปภาพให้อยู่ในรูป RGB
        # เก็บไว้ใน key pixel_values
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

transform รูปภาพของเราโดย input preprocess ที่เราได้กำหนดไว้ในตอนแรกลงไป

dataset["train"].set_transform(preprocess_train)
dataset["validation"].set_transform(preprocess_val)
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
# สร้าง dictionary ระหว่างค่า และ id -> 1 : '01'
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label
id2label[2]
'02'

Fine-tune Our Model#

  • โหลดโมเดลจาก huggingface hub

  • สร้าง training arguments

  • จากนั้นเทรนและเซฟโมเดล

from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

# เลือกโมเดลจาก Huggingface
model = AutoModelForImageClassification.from_pretrained(
    # เทรนจาก pretrain model
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([48, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([48]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# ตั้งชื่อโมเดล และกำหนด batch_size
model_name = model_checkpoint.split("/")[-1]
batch_size = 32
# input arguments ที่จะเป็น
args = TrainingArguments(
    f"{model_name}-finetuned-eurosat",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)
import torch
import numpy as np

def collate_fn(examples):
    # function ที่ใช้ในการ จัดเรียงข้อมูลของเรา ในที่นี้จะทำการเรียงรูป/label และใส่ในtensor
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(eval_pred):
    # function ที่ใช้ในการ evaluation โมเดล
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)
trainer = Trainer(
    model,
    args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)
trainer.train()
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
[270/270 14:31, Epoch 3/3]
Epoch Training Loss Validation Loss Accuracy
1 0.976400 0.627662 0.838889
2 0.432000 0.277169 0.919444
3 0.312500 0.223152 0.940972

TrainOutput(global_step=270, training_loss=1.0941412766774496, metrics={'train_runtime': 886.0021, 'train_samples_per_second': 39.007, 'train_steps_per_second': 0.305, 'total_flos': 8.601271252077773e+17, 'train_loss': 1.0941412766774496, 'epoch': 3.0})
trainer.save_model(f"trained/{model_name}")
# alternatively use trainer.push_to_hub instead
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
[45/45 00:18]
***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =      0.941
  eval_loss               =     0.2232
  eval_runtime            = 0:00:19.29
  eval_samples_per_second =     74.646
  eval_steps_per_second   =      2.333

Prediction: Using Feature Extractor and Model#

การใช้โมเดลที่เทรนเรียบร้อยแล้วมาทำนายภาพประกอบด้วย AutoFeatureExtractor และ AutoModelForImageClassification ทั้งนี้สามารถโหลดโมเดลจากโฟลเดอร์ที่เทรนเสร็จเรียบร้อยแล้ว หรือโหลดจาก huggingface hub ก็ได้ ในตัวอย่างนี้เราจะโหลดจากโฟล์เดอร์ที่เซฟโมเดลไป

จากนั้นสามารถอ่านภาพ image และแปลงให้เป็นฟีเจอร์ที่เหมาะสม ก่อนที่จะใส่เข้าไปในโมเดล โดย output ที่ได้จากโมเดลสามารถนำไปใช้ต่อได้เหมือนกับการเขียนโมเดล Pytorch ทั่วไปเลย

import requests
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image

# read trained model from a folder (please double-check if you point to the correct path)
model_name = "./trained/swin-tiny-patch4-window7-224/"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
/usr/local/lib/python3.10/dist-packages/transformers/models/vit/feature_extraction_vit.py:28: FutureWarning: The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ViTImageProcessor instead.
  warnings.warn(
url = "https://github.com/GemmyTheGeek/FoodyDudy/raw/main/images/test/03/0289.jpg"
image = Image.open(requests.get(url, stream=True).raw)
encoding = feature_extractor(image.convert("RGB"), return_tensors="pt")
image
../../_images/e165d0d4d5fc47e688729280ab7647f2c1867b148a11e44e46fdeb67ebfad9fb.png
with torch.no_grad():
    outputs = model(**encoding)
    logits = outputs.logits
pred_idx = logits.argmax(-1).item()
print("Predicted class:", id2food[model.config.id2label[pred_idx]])
Predicted class: taohoo_moosup

Prediction: Using Pipeline API#

นอกจากนั้น transformers ยังมี pipeline ที่เราเลือกชนิดของ pipeline แบบต่างๆ เช่น image-classification ทำให้การทำนายทำได้สะดวกยิ่งขึ้น

from transformers import pipeline

pipe = pipeline("image-classification", model_name)
/usr/local/lib/python3.10/dist-packages/transformers/models/vit/feature_extraction_vit.py:28: FutureWarning: The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ViTImageProcessor instead.
  warnings.warn(
pipe(image)
[{'score': 0.9978020787239075, 'label': '03'},
 {'score': 0.0016931634163483977, 'label': '04'},
 {'score': 0.0001949473371496424, 'label': '10'},
 {'score': 8.342483488377184e-05, 'label': '02'},
 {'score': 7.1116337494459e-05, 'label': '26'}]
[{"score": l["score"], "label": id2food[l["label"]]} for l in pipe(image)]
[{'score': 0.9978020787239075, 'label': 'taohoo_moosup'},
 {'score': 0.0016931634163483977, 'label': 'mara_yadsai'},
 {'score': 0.0001949473371496424, 'label': 'palo_egg'},
 {'score': 8.342483488377184e-05, 'label': 'liang_curry'},
 {'score': 7.1116337494459e-05, 'label': 'tom_ka_gai'}]

Create Gradio Application for Prediction#

สุดท้ายแล้วเราสามารถนำโค้ดทั้งหมดมาจัดเรียง และ deploy ด้วย Gradio application ทั้งนี้เราเพียงต้องเขียน

  • ฟังก์ชั่นเพื่อ inference โดยมี id2food เพื่อเปลี่ยน class ที่ทำนายเป็นชื่ออาหาร

  • input ซึ่งเป็นชนิดภาพ gr.inputs.Image()

  • output เป็น label ที่ทำนายได้ gr.outputs.Label(num_top_classes=5)

  • ประกอบร่างกันเข้ามาด้วย gr.Interface

import gradio as gr


food_list = [
    'green_curry', 'tepo_curry', 'liang_curry', 'taohoo_moosup', 'mara_yadsai',
    'masaman', 'orange_curry', 'cashew_chicken', 'omelette', 'sunny_side_up',
    'palo_egg', 'sil_egg', 'nun_banana', 'kua_gai', 'cabbage_fish_sauce',
    'river_prawn', 'shrimp_ob_woonsen', 'kanom_krok', 'mango_sticky_rice', 'kao_kamoo',
    'kao_klook_kapi', 'kaosoi', 'kao_pad', 'kao_pad_shrimp', 'chicken_rice',
    'kao_mok_gai', 'tom_ka_gai', 'tom_yum_kung', 'tod_mun', 'poh_pia',
    'pak_boong_fai_daeng', 'padthai', 'pad_krapao', 'pad_si_ew', 'pad_fakthong',
    'eggplant_stirfry', 'pad_hoi_lai', 'foithong', 'panaeng', 'yum_tua_ploo',
    'yum_woonsen', 'larb_moo', 'pumpkin_custard', 'sakoo_sai_moo', 'somtam',
    'moopoing','satay', 'hor_mok'
]
id2food = {str(i).zfill(2): f for i, f in enumerate(food_list)}


def inference(gr_input):
    """Inference function from gradio input."""
    image = Image.fromarray(gr_input.astype("uint8"), "RGB")
    predictions = pipe(image)
    predictions = {id2food[l["label"]]: l["score"] for l in predictions}
    return predictions
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=5)

interface = gr.Interface(
    fn=inference, inputs=inputs, outputs=outputs, interpretation="default",
).launch(debug="True")
/usr/local/lib/python3.10/dist-packages/gradio/inputs.py:259: UserWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/gradio/inputs.py:262: UserWarning: `optional` parameter is deprecated, and it has no effect
  super().__init__(
/usr/local/lib/python3.10/dist-packages/gradio/outputs.py:197: UserWarning: Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/gradio/outputs.py:200: UserWarning: The 'type' parameter has been deprecated. Use the Number component instead.
  super().__init__(num_top_classes=num_top_classes, type=type, label=label)
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.
Keyboard interruption in main thread... closing server.

ผู้จัดเตรียม code ใน tutorial: ดร. ฐิติพัทธ อัชชะกุลวิสุทธิ์