...

Build an Image Classifier with Vision Transformer

Build an Image Classifier with Vision Transformer

Last Updated on 11/10/2025 by Eran Feit

🧩 Introduction

Understanding How Vision Transformers Work in Image Classification

In this tutorial, we’ll dive into how to use the Vision Transformer (ViT) — a model that has changed how computers “see” images.
We’ll not only walk through a working Python example step-by-step, but also explain what makes the Vision Transformer image classification approach so effective.

What is a Vision Transformer (ViT)?

The Vision Transformer (ViT) is a model introduced by Google Research that applies the same transformer architecture used in Natural Language Processing (NLP) to computer vision.
Instead of processing the entire image at once (like CNNs do), ViT splits the image into small fixed-size patches — think of them like “tokens” in a sentence.
Each patch is then converted into an embedding, and the model learns to understand global relationships between all patches using self-attention mechanisms.

Below is a placeholder for the ViT architecture diagram, showing how an image is divided into patches and processed through transformer layers:

image
Build an Image Classifier with Vision Transformer 5

The Vision Transformer is particularly powerful for image classification tasks because it can capture both local and global context in an image without relying on convolutional filters.
When trained or fine-tuned properly, ViTs often outperform CNNs on large datasets, especially when using transfer learning via pretrained models from Hugging Face Transformers.

The main steps:

  • Read an image and convert to correct color space
  • Use ViTFeatureExtractor (for resizing, normalizing)
  • Load a pretrained ViTForImageClassification
  • Pass the image through the model to produce logits
  • Argmax logits to pick the predicted class index
  • Look up the class name, overlay it on the image, and display

We’ll divide the code into two parts: preprocessing + model setup, and inference + display.


Let’s dive into the VIT tutorial

/

You can watch the tutorial here : https://youtu.be/zGydLt2-ubQ

You can find the full code here : https://ko-fi.com/s/00de8ae6a8

You can follow my blog here : https://eranfeit.net/blog/


Preprocessing, Model Load & Setup

Here we prepare the image and get the model and feature extractor ready for inference.
This section ensures your input is in the correct format for ViT.

We will :

  • Import all required libraries
  • Load the input image using OpenCV
  • Convert it to the correct color format (RGB)
  • Initialize the ViT feature extractor to handle resizing, normalization, and tensor conversion
  • Load the pretrained ViT model for image classification

This prepares everything the model needs before we run inference.

This will be our test image :

Test image
Test image
from transformers import ViTFeatureExtractor , ViTForImageClassification from PIL import Image as img  import cv2  originalImage = cv2.imread("Best-image-classification-models/Visual-Langauge-Models/Vision Transformer-Vit/Dori.jpg") ### convert image from BGR (OpenCV) to RGB, because the model expects RGB img = cv2.cvtColor(originalImage, cv2.COLOR_BGR2RGB)  # Create Feature extractor (for tasks like resize, normalize pixels, and prepare the image for the model) feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')  # load the pretrained ViT model for classification model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') 

Summary:

  • We import necessary classes
  • We read the image with OpenCV, then convert to RGB
  • We instantiate the ViTFeatureExtractor which handles resizing, normalization, etc. huggingface.co+1
  • We load the pretrained ViT classification model

Running Inference and Displaying Results

In this part, we actually run the image through the model, interpret the output, and display the result.

We will :

  • Feed the processed image into the model
  • Extract the logits (raw prediction scores)
  • Use argmax to select the most probable class index
  • Retrieve the class name using model.config.id2label
  • Overlay the prediction on the original image and display the result with OpenCV
# extract the features   ### prepare the image tensor (PyTorch) from raw image via extractor inputs = feature_extractor(images = img , return_tensors="pt")  # pass inputs through the model to get outputs outputs = model(**inputs)  # the logits hold the raw scores for each class logits = outputs.logits   # pick the class index with highest logit predicted_class_idx = logits.argmax(-1).item()  print(predicted_class_idx)  # e.g. 155  className = model.config.id2label[predicted_class_idx]  print("Predicted class : " + className)  # overlay the predicted class name onto the original image originalImage = cv2.putText(originalImage, className, (50,100), cv2.FONT_HERSHEY_SIMPLEX, 2, (0,0,0), 3)  cv2.imshow("img", originalImage) cv2.waitKey(0) 

Summary:

  • We call the feature extractor to get a PyTorch tensor input
  • Passing **inputs into the model returns an output object
  • We extract logits (raw class scores)
  • We pick the index of the highest logit (argmax)
  • We convert the index to a human-readable class name via model.config.id2label
  • We print the predicted index and class name
  • We overlay the label on the image and display it using OpenCV

Conclusion

In this blog post we walked through a real Python script for Vision Transformer image classification using Hugging Face.
We saw how to load an image, transform it properly, feed it into a pretrained ViT model, get logits, choose a class, and display the result.
You now understand each step, and can adapt this pipeline for your own dataset and images.

FAQ: Key Concepts, Tips & Common Issues

What is a Vision Transformer (ViT)?

A Vision Transformer (ViT) applies a transformer architecture to images by splitting them into patches treated like tokens for classification.

Why use ViT instead of CNNs?

ViTs are good at capturing long-range relationships between image patches and often outperform CNNs when pretrained on large datasets.

What does ViTFeatureExtractor do?

`ViTFeatureExtractor` resizes, normalizes, and converts an image into the correct tensor format expected by the model.

How should the input format look?

The input should be a PyTorch tensor, correctly normalized and shaped as expected by the ViT model.

How do you convert index to class name?

Use `model.config.id2label[predicted_index]` to map the highest-scoring index to its class label.

Why use `logits.argmax(-1)`?

Logits are raw class scores; `argmax(-1)` picks the index with the highest score as the predicted class.

What if the prediction is incorrect?

The model is pretrained on general datasets; for domain-specific images, consider fine-tuning the ViT model.

Can I process multiple images at once?

Yes — pass a list of images into the extractor and model to run batch inference.

Is GPU necessary?

GPU speeds up inference, but small-scale tasks like this can run on CPU (though slower).

What errors often arise?

Common mistakes include wrong image shape, BGR vs RGB mismatch, or using inconsistent model/feature extractor versions.


Connect :

☕ Buy me a coffee — https://ko-fi.com/eranfeit

🖥️ Email : feitgemel@gmail.com

🌐 https://eranfeit.net

🤝 Fiverr : https://www.fiverr.com/s/mB3Pbb

Enjoy,

Eran


Leave a Comment

Your email address will not be published. Required fields are marked *

Eran Feit