...

How to Train ConvNeXt in PyTorch on a Custom Dataset

ConvNext

Last Updated on 18/02/2026 by Eran Feit

ConvNeXt has become one of the most practical “modern CNN” choices when you want strong accuracy without giving up the speed and simplicity that make convolutional networks so useful in real projects. This article is about training ConvNeXt in PyTorch on a custom dataset—the kind you actually have in day-to-day work: folders of images organized by class names.

If you’ve ever fine-tuned a pre-trained model and felt unsure about the right pipeline—how to load data cleanly, how to apply augmentations without breaking labels, how to map class names to IDs, and how to know whether the model is truly learning—this guide is built for that exact moment. You’ll walk away with a workflow you can reuse for any classification problem, not just musical instruments.

The article shows how to take a ConvNeXt checkpoint from Hugging Face, connect it to a folder-based dataset with load_dataset("imagefolder"), and prepare images with practical transforms like random resized crops, flips, and normalization. You’ll also build PyTorch DataLoaders with a proper collate_fn, so batching works smoothly and training stays stable.

By the end, you’ll have a complete training loop using AdamW, plus tools that make the results meaningful: saving the best checkpoint, reloading it for inference on a single image, and evaluating performance with a confusion matrix to spot which classes are getting mixed up. That combination turns “I trained a model” into “I trust this model, and I can improve it.”

Training ConvNeXt on a Custom Dataset in PyTorch

Training ConvNeXt on a custom dataset is really about building a repeatable pipeline that turns raw images into a model you can evaluate, save, and use again later. The goal isn’t only to make the loss go down—it’s to create a training setup that remains reliable when your dataset changes, when you add new classes, or when you move from a small experiment to a bigger project.

A big part of the target is data consistency. Real-world datasets come in different resolutions, lighting conditions, and backgrounds, and that variability can confuse a model if you don’t handle it carefully. That’s why the preprocessing step matters: resizing/cropping to the input size ConvNeXt expects, normalizing with the same mean and standard deviation the model was trained with, and applying simple augmentations that help the network generalize instead of memorizing.

At a high level, the training flow is straightforward: load images from folders, convert class names into a stable label mapping, transform each image into pixel_values, and feed batches into ConvNeXt. From there, you fine-tune the classification head for your number of classes, optimize with AdamW, and track accuracy so you can tell whether improvements are real. Early stopping and checkpoint saving act like guardrails—preventing wasted epochs when learning stalls and preserving the best version of your model.

Finally, the payoff comes from evaluation that matches how you’ll use the model. Single-image inference gives you a realistic check of predictions, while a confusion matrix reveals patterns you can’t see from accuracy alone—like two visually similar instruments that ConvNeXt often confuses. That insight helps you make smarter next steps, whether it’s adjusting augmentations, balancing your dataset, or collecting a few more examples for the toughest classes.

ConvNeXt image classification
ConvNeXt image classification

Building a Complete ConvNeXt Training Pipeline on Your Own Dataset

This tutorial is focused on the code-first goal: taking a folder-based custom dataset and turning it into a working ConvNeXt image classifier you can train, validate, save, reload, and evaluate in a repeatable way. Instead of skipping to a “magic trainer,” the code shows the full pipeline step by step so you understand what each block is doing and how to tweak it for your own projects.

At a high level, the target of the code is simple: fine-tune a pre-trained facebook/convnext-base-224-22k model so it can recognize your custom classes (in your case, 30 musical instruments). That means replacing the original ImageNet-style classification head with a new head that matches your label set, then training on your own images until the model adapts to your domain.

To make training stable and realistic, the code builds a clean preprocessing and batching setup. Images are resized and randomly cropped to the expected input size, augmented with flips, converted to tensors, and normalized using the same mean and standard deviation the ConvNeXt checkpoint was trained with. Then a custom collate_fn stacks images into batches and packs labels into tensors, which is essential for smooth DataLoader behavior and consistent GPU training.

The training loop is designed around practical results, not just running epochs. It uses AdamW to update the weights, tracks loss and accuracy, and adds early stopping so training stops when validation loss stops improving. On top of that, the code saves the best checkpoint, which makes your final model reproducible—if you rerun training or continue fine-tuning later, you can always start from the best version you already found.

Finally, the tutorial connects training to real usage. It reloads the saved checkpoint, runs inference on a single image, and displays the predicted class so you can sanity-check outcomes visually. Then it adds a confusion matrix step to evaluate the full test set in a way that reveals which classes are being confused—exactly the kind of insight you need to improve the dataset, adjust augmentations, or refine your training strategy.

Link to the video tutorial here

You can download the code for the tutorial here or here

My Blog

You can follow my blog here .

Link for Medium users here

Want to get started with Computer Vision or take your skills to the next level ?

Great Interactive Course : “Deep Learning for Images with PyTorch” here

If you’re just beginning, I recommend this step-by-step course designed to introduce you to the foundations of Computer Vision – Complete Computer Vision Bootcamp With PyTorch & TensorFlow

If you’re already experienced and looking for more advanced techniques, check out this deep-dive course – Modern Computer Vision GPT, PyTorch, Keras, OpenCV4


How to Train ConvNeXt in PyTorch on a Custom Dataset

This tutorial is about fine-tuning ConvNeXt in PyTorch on a custom dataset that lives in simple folders, so you can train a strong image classifier without designing a model from scratch.
You’ll take a pretrained facebook/convnext-base-224-22k checkpoint and adapt it to 30 musical instrument classes, using a clean Hugging Face + PyTorch workflow that scales to other datasets later.

If you’ve ever trained a model that looks “good” during training but fails on new images, this guide focuses on the practical pieces that usually make the difference: consistent preprocessing, realistic augmentations, correct label mapping, and saving the best checkpoint instead of “the last epoch.”
ConvNeXt is a great fit for this because it keeps CNN efficiency while borrowing modern design ideas that improved performance across many vision tasks.

You’ll also build the kind of pipeline you can reuse: load image folders with load_dataset, transform images with torchvision, feed batches with a DataLoader, fine-tune with AdamW, and stop early when validation stops improving.
By the end, you’ll be able to reload your best ConvNeXt checkpoint, run single-image inference, and interpret results with a confusion matrix.


Set up a clean ConvNeXt workspace (so training stays stable)

A ConvNeXt fine-tune usually fails for boring reasons: mismatched CUDA builds, mixed library versions, or installing packages into the wrong environment.
This section keeps things predictable by creating a dedicated Conda environment and pinning versions that work well together for PyTorch + Transformers training.

You’re also setting yourself up for repeatability.
When you revisit the project later to train on a new custom dataset, you want the same code to behave the same way, without chasing dependency issues.

Once this is done, everything else becomes a “model and data” problem instead of a “my machine is broken” problem.
That’s the fastest path to learning, debugging, and publishing reliable results.

### Create a new Conda environment for this project. conda create -n ConvNeXt python=3.11 ### Activate the environment so all installs happen inside it. conda activate ConvNeXt  ### Check your CUDA compiler version to install a compatible PyTorch build. nvcc --version  ### Install PyTorch 2.5.0 with CUDA 12.4 support. conda install pytorch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.4 -c pytorch -c nvidia  ### Install Hugging Face Transformers with Torch extras for training vision models. pip install transformers[torch]==4.46.1 ### Install Hugging Face Datasets for loading folder datasets easily. pip install datasets==3.2.0  ### Install OpenCV for image handling when you want it. pip install opencv-python==4.10.0.84 ### Install Matplotlib for visualization during sanity checks and inference. pip install matplotlib==3.10.0 ### Install scikit-learn for evaluation tools like the confusion matrix. pip install scikit-learn==1.6.0 

Want the exact dataset so your results match mine?

If you want to reproduce the same training flow and compare your results to mine, I can share the dataset structure and what I used in this tutorial.
Send me an email and mention “30 Musical Instruments CNN dataset” so I know what you’re requesting.

🖥️ Email: feitgemel@gmail.com


Short summary: You now have a clean environment that can fine-tune ConvNeXt in PyTorch without dependency chaos.


Turn your image folders into a real dataset with labels you can trust

ConvNeXt training is only as good as the dataset structure feeding it.
In this step, you load images from train, valid, and test folders using load_dataset("imagefolder"), which automatically assigns integer labels based on folder names.

You also do a small but important sanity check: display one image, inspect the dataset features, and print label names.
That prevents silent mistakes like swapped folders, missing classes, or a dataset that loads but doesn’t match what you think it contains.

Finally, you build id2label and label2id mappings.
This is what makes your ConvNeXt predictions readable and makes saved checkpoints consistent when you reload the model later.

### Import torch for tensor operations and GPU acceleration. import torch ### Import load_dataset to build datasets directly from folder structures. from datasets import load_dataset ### Import Matplotlib for quick visual checks of images. import matplotlib.pyplot as plt ### Import os for building file paths safely. import os    ### Define the data files dictionary with multiple image formats. data_files = {     "train": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/train", "**", "*.*"),     "validation": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/valid", "**", "*.*"),  # Changed to 'valid'     "test": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/test", "**", "*.*") }   ### Load the dataset from image folders into train, validation, and test splits. dataset = load_dataset(     "imagefolder",     data_files=data_files,     split={         "train": "train",         "validation": "validation",  # The split name remains "validation" as it's a standard term         "test": "test"     } ) ### Print the dataset object so you can confirm it loaded correctly. print("dataet : ") print(dataset)   ### Print dataset details so you can confirm counts, features, and labels. print("\nDataset information:") for split_name, split_dataset in dataset.items():     print(f"\n{split_name} split:")     print(f"Number of samples: {len(split_dataset)}")     print(f"Features: {split_dataset.features}")     if 'label' in split_dataset.features:         print(f"Labels: {split_dataset.features['label'].names}")  print("=================================================")  ### Extract the first sample so you can sanity-check image and label values. first_sample_image = dataset["train"][0] print("Keys in first sample:", first_sample_image.keys()) ### Read image and label from the first sample. first_image = first_sample_image["image"] first_label = first_sample_image["label"] print("First image type:", type(first_image))  ### Display the first image to verify the dataset is correct visually. plt.imshow(first_image) plt.axis("off")  # Turn off axis labels for better visualization plt.title("First Image from Dataset : " + str(first_label)) plt.show()  print("=================================================")  ### Extract label names so your model can map integers back to class names. labels = dataset["train"].features["label"].names print("--->>> Labels - list of the 30 classes :") print(labels)  print("=======")  ### Build id2label so predictions can be displayed as readable class names. id2label = {k:v for k,v in enumerate(labels)} print("id2label:") print(id2label) print("=======") ### Build label2id so the model configuration stays consistent. label2id = {v:k for k,v in enumerate(labels)} print("label2id:") print(label2id)   ### Print the label name of the first image for a quick sanity check. print("Label name of the first image : " + id2label[first_label]) print("=================================================") 

Short summary: You validated the dataset, confirmed labels, and created mappings that make ConvNeXt training and inference readable.


Make ConvNeXt-ready inputs with augmentations and DataLoaders

ConvNeXt expects images to be processed consistently, especially normalization and size.
This section loads the correct image processor for the pretrained checkpoint and uses it to build a transform pipeline that matches the model’s training expectations.

You also add practical augmentations like random resized crops and horizontal flips.
These augmentations matter because they teach ConvNeXt to handle small viewpoint changes instead of memorizing a fixed framing.

Finally, you wrap everything in PyTorch DataLoaders so training happens in batches.
That’s what makes training efficient, stable, and compatible with GPU acceleration.

### Import AutoImageProcessor so preprocessing matches the pretrained ConvNeXt checkpoint. from transformers import AutoImageProcessor  ### Load the image processor for the ConvNeXt model you plan to fine-tune. image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k")   ### Print the processor so you can see expected size, mean, and std. print("==============  image_processor :" ) print(image_processor)   ### Import torchvision transforms to build an augmentation pipeline. from torchvision.transforms import (     Compose,     Normalize,     RandomHorizontalFlip,     RandomResizedCrop,     ToTensor, )  ### Create a normalization layer using the processor’s mean and std. normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)  ### Build the training transform pipeline to resize, augment, tensorize, and normalize. transform = Compose(     [      RandomResizedCrop(image_processor.size["shortest_edge"]),      RandomHorizontalFlip(),      ToTensor(), # convert form Pillow to Pytorch Tensor      normalize # use the normalize we define early      ] )  ### Define a transform function that converts images to model-ready pixel_values. def train_transforms(examples):   examples["pixel_values"] = [transform(image.convert("RGB")) for image in examples["image"]]    return examples   ### Print the dataset again so you can compare before and after transform. print ("*********** Dataset : ") print(dataset)   ### Attach transforms to the dataset so items yield pixel_values during loading. processed_dataset = dataset.with_transform(train_transforms)  ### Inspect one transformed item to confirm pixel_values exist. print("*************** processed_dataset[train][0] :") print(processed_dataset["train"][0])  ### Confirm the tensor shape matches ConvNeXt expectations. print("*************** processed_dataset[train][0][pixel_values].shape :") print(processed_dataset["train"][0]["pixel_values"].shape)  ### Print keys to confirm pixel_values and label are present. print("*************** processed_dataset[train][0].keys :") print(processed_dataset["train"][0].keys())   ### Import DataLoader to batch data efficiently during training. from torch.utils.data import DataLoader  ### Define a collate function to stack tensors and build label batches. def collate_fn(examples):   pixel_values = torch.stack([example["pixel_values"] for example in examples])   labels = torch.tensor([example["label"] for example in examples])    return {"pixel_values": pixel_values, "labels": labels}   ### Create a shuffled train DataLoader for stochastic gradient training. train_dataloader = DataLoader(processed_dataset["train"], collate_fn=collate_fn, batch_size=8, shuffle=True)  ### Create a validation DataLoader for evaluation runs. val_dataloader = DataLoader(processed_dataset["train"], collate_fn=collate_fn, batch_size=8, shuffle=False)  ### Pull one batch to confirm shapes match what the model expects. batch = next(iter(train_dataloader)) for k,v in batch.items():   print(k,v.shape) 

Short summary: You built ConvNeXt-compatible preprocessing, added augmentations, and created DataLoaders that feed batches into training.


Fine-tune ConvNeXt with AdamW, checkpoints, and early stopping

This is where the model starts adapting to your custom dataset.
You load the pretrained ConvNeXt checkpoint, replace the classification head to match your number of classes, and train with AdamW.

The training loop tracks both loss and accuracy, which helps you spot common issues early.
For example, if loss drops but accuracy stays flat, your labels, transforms, or mappings may not be correct.

You also save the best checkpoint based on validation loss and stop early when improvements stall.
That’s how you avoid “training forever” and end up with the best-performing version of ConvNeXt for your dataset.

### Import the ConvNeXt image classification model wrapper from Transformers. from transformers import AutoModelForImageClassification  ### Load the pretrained ConvNeXt model and adapt the head to your label count. model = AutoModelForImageClassification.from_pretrained("facebook/convnext-base-224-22k",                                                         id2label=id2label,                                                         label2id=label2id,                                                         ignore_mismatched_sizes=True) # this     ### Import tqdm for clean progress bars during training. from tqdm import tqdm ### Import os for checkpoint directory creation and path handling. import os  ### Define a directory to store model checkpoints. save_dir = "d:/temp/models/ConvNext-30 Musical Instruments/checkpoints/" ### Create the directory if it does not exist. os.makedirs(save_dir, exist_ok=True)  ### Create an AdamW optimizer for stable fine-tuning of transformer-style training setups. optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)  ### Choose GPU if available for faster training. device = "cuda" if torch.cuda.is_available() else "cpu" ### Move the model to the selected device. model.to(device)  ### Put the model into training mode before starting. model.train()   ### Track the best validation loss so you know when to save a checkpoint. best_loss = float("inf") ### Track epochs without improvement for early stopping. epochs_without_improvement = 0 ### Stop if validation does not improve for this many epochs. patience = 10 ### Set a maximum epoch limit so training cannot run forever. max_epochs = 100 ### Define where to save the best model weights. best_model_path = os.path.join(save_dir, "best_model.pth")  ### Run the training loop across epochs. for epoch in range(max_epochs):     print(f"Epoch {epoch + 1}/{max_epochs}")     train_loss = 0.0     train_correct = 0     train_total = 0      ### Reload the best weights each epoch so training continues from the best-known state.     if os.path.exists(best_model_path):         model.load_state_dict(torch.load(best_model_path))         print("Loaded best model weights for fine-tuning.")      ### Ensure the model is in train mode for gradient updates.     model.train()     for batch in tqdm(train_dataloader, desc="Training"):         ### Move the current batch to the same device as the model.         batch = {k: v.to(device) for k, v in batch.items()}          ### Clear old gradients so they do not accumulate across steps.         optimizer.zero_grad()         ### Run a forward pass and compute loss using labels.         outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])         loss, logits = outputs.loss, outputs.logits         ### Backpropagate loss to compute gradients.         loss.backward()         ### Update model weights with AdamW.         optimizer.step()          ### Accumulate loss so you can compute average loss later.         train_loss += loss.item()         ### Track how many samples were seen for accuracy.         train_total += batch["labels"].shape[0]         ### Count correct predictions for accuracy calculation.         train_correct += (logits.argmax(-1) == batch["labels"]).sum().item()      ### Compute epoch-level training accuracy.     train_accuracy = train_correct / train_total     ### Compute average training loss across batches.     avg_train_loss = train_loss / len(train_dataloader)     print(f"Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}")      ### Switch to evaluation mode for validation.     model.eval()     val_loss = 0.0     val_correct = 0     val_total = 0      ### Disable gradients during validation for speed and correctness.     with torch.no_grad():         for batch in tqdm(val_dataloader, desc="Validation"):             ### Move validation batch to the device.             batch = {k: v.to(device) for k, v in batch.items()}             ### Run the model on validation data to compute validation loss.             outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])             loss, logits = outputs.loss, outputs.logits              ### Accumulate validation loss across batches.             val_loss += loss.item()             ### Track sample count for validation accuracy.             val_total += batch["labels"].shape[0]             ### Count correct validation predictions.             val_correct += (logits.argmax(-1) == batch["labels"]).sum().item()      ### Compute validation accuracy for this epoch.     val_accuracy = val_correct / val_total     ### Compute average validation loss for checkpoint decisions.     avg_val_loss = val_loss / len(val_dataloader)     print(f"Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}")      ### Save the model if validation loss improved.     if avg_val_loss < best_loss:         best_loss = avg_val_loss         epochs_without_improvement = 0         torch.save(model.state_dict(), best_model_path)         print(f"New best model saved with Validation Loss: {best_loss:.4f}")     else:         epochs_without_improvement += 1         print(f"No improvement. Patience count: {epochs_without_improvement}/{patience}")      ### Stop training if validation has not improved for too long.     if epochs_without_improvement >= patience:         print(f"Early stopping triggered after {patience} epochs without improvement.")         break 

Short summary: You fine-tuned ConvNeXt with AdamW, saved the best checkpoint, and stopped early to avoid overfitting and wasted compute.


Reload your best ConvNeXt checkpoint and predict one image

Training is only useful if you can reliably reuse the model later.
This section reloads your saved checkpoint, moves the model to GPU or CPU, and runs inference on a single test image.

Single-image inference is the fastest way to sanity-check real behavior.
If the model predicts the correct class on several varied examples, it’s usually a good sign your pipeline is healthy.

You also visualize the prediction directly on the image title.
That makes it easy to capture screenshots for your blog post and confirm results without digging through logs.

### Import torch for model loading and device placement. import torch ### Import load_dataset to rebuild the dataset object and label mappings. from datasets import load_dataset ### Import Matplotlib for displaying the prediction result. import matplotlib.pyplot as plt ### Import PIL Image for reliable image conversion to RGB. from PIL import Image ### Import os for building file paths. import os    ### Define the data files dictionary with multiple image formats. data_files = {     "train": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/train", "**", "*.*"),     "validation": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/valid", "**", "*.*"),  # Changed to 'valid'     "test": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/test", "**", "*.*") }   ### Load the dataset again so labels match the original training setup. dataset = load_dataset(     "imagefolder",     data_files=data_files,     split={         "train": "train",         "validation": "validation",  # The split name remains "validation" as it's a standard term         "test": "test"     } ) ### Print the dataset to confirm it loaded successfully. print("dataet : ") print(dataset)   ### Extract label names so predictions map to readable class names. labels = dataset["train"].features["label"].names print("--->>> Labels - list of the 30 classes :") print(labels)  print("=======")  ### Build id2label to decode predictions. id2label = {k:v for k,v in enumerate(labels)} print("id2label:") print(id2label) print("=======") ### Build label2id to keep model config aligned. label2id = {v:k for k,v in enumerate(labels)} print("label2id:") print(label2id)    ### Choose GPU if available for faster inference. device = "cuda" if torch.cuda.is_available() else "cpu"   ### Load the image processor so preprocessing matches ConvNeXt expectations. from transformers import AutoImageProcessor image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k")  ### Import torchvision transforms used to preprocess the inference image. from torchvision.transforms import (     Compose,     Normalize,     RandomHorizontalFlip,     RandomResizedCrop,     ToTensor, )  ### Create normalization using the pretrained processor stats. normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)  ### Define preprocessing transforms for the input image. transform = Compose(     [      RandomResizedCrop(image_processor.size["shortest_edge"]),      RandomHorizontalFlip(),      ToTensor(), # convert form Pillow to Pytorch Tensor      normalize # use the normalize we define early      ] )   ### Import the model class so you can rebuild ConvNeXt for your label set. from transformers import AutoModelForImageClassification  ### Initialize ConvNeXt with your label mappings so outputs decode correctly. model = AutoModelForImageClassification.from_pretrained("facebook/convnext-base-224-22k",                                                         id2label=id2label,                                                         label2id=label2id,                                                         ignore_mismatched_sizes=True) # this    ### Define the checkpoint path that stores the best saved weights. checkpoint_path = "D:/Temp/Models/ConvNext-30 Musical Instruments/checkpoints/best_model.pth"  ### Load the saved weights from disk. state_dict = torch.load(checkpoint_path)  ### Apply the weights to the model. model.load_state_dict(state_dict)  ### Move the model to the inference device. model.to(device)  ### Define the path to a single test image. new_image_path = "D:/Data-Sets-Image-Classification/30 Musical Instruments/test/clarinet/4.jpg"   ### Read the image with Matplotlib so it is easy to display later. image = plt.imread(new_image_path) ### Convert the image to RGB, transform it, add a batch dimension, and move it to the device. input_image = transform(Image.fromarray(image).convert("RGB")).unsqueeze(0).to(device)  ### Run inference without gradients. with torch.no_grad():     outputs = model(pixel_values=input_image)     logits = outputs.logits     predicted_label_id = logits.argmax(-1).item()     predicted_label = id2label[predicted_label_id]  ### Display the image and show the predicted label as the title. plt.imshow(image) plt.title(f"Predicted Label: {predicted_label}") plt.axis("off") plt.show() 

Short summary: You reloaded the best ConvNeXt checkpoint and verified real-world behavior with a single-image prediction.


Use a confusion matrix to see what ConvNeXt confuses

Accuracy alone hides the interesting failures.
A confusion matrix shows which instrument classes ConvNeXt mixes up, which is usually the fastest way to decide what to fix next.

This is especially useful with 30 classes, where “overall accuracy” can look fine while a few classes are consistently misclassified.
The confusion matrix makes those patterns visible immediately.

Once you know the confusion pairs, you can respond with targeted improvements.
That might mean adding more data for specific classes, adjusting augmentations, or checking if certain classes are visually too similar in your dataset.

### Import torch for model loading and inference on batches of images. import torch ### Import load_dataset to iterate through the test split cleanly. from datasets import load_dataset ### Import transforms to preprocess test images consistently. from torchvision.transforms import Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor ### Import image processor and model to rebuild ConvNeXt for evaluation. from transformers import AutoImageProcessor, AutoModelForImageClassification ### Import Matplotlib for plotting the confusion matrix. import matplotlib.pyplot as plt ### Import PIL Image for reading images from file paths. from PIL import Image ### Import os for file path handling. import os ### Import numpy for label and matrix utilities. import numpy as np ### Import confusion matrix tools from scikit-learn. from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  ### Define the data files dictionary with multiple image formats. data_files = {     "train": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/train", "**", "*.*"),     "validation": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/valid", "**", "*.*"),     "test": os.path.join("D:/Data-Sets-Image-Classification/30 Musical Instruments/test", "**", "*.*") }  ### Load the dataset splits so you can iterate through test images. dataset = load_dataset(     "imagefolder",     data_files=data_files,     split={         "train": "train",         "validation": "validation",         "test": "test"     } )  ### Extract the label list so confusion matrix axes are readable. labels = dataset["train"].features["label"].names ### Build id2label mapping for consistent decoding. id2label = {k: v for k, v in enumerate(labels)} ### Build label2id mapping for consistent configuration. label2id = {v: k for k, v in enumerate(labels)}  ### Choose GPU if available for faster evaluation. device = "cuda" if torch.cuda.is_available() else "cpu"  ### Load the image processor to match ConvNeXt preprocessing. image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k") ### Build normalization based on the processor statistics. normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) ### Build preprocessing transforms for evaluation images. transform = Compose([     RandomResizedCrop(image_processor.size["shortest_edge"]),     RandomHorizontalFlip(),     ToTensor(),     normalize ])  ### Initialize ConvNeXt with the correct label mappings. model = AutoModelForImageClassification.from_pretrained(     "facebook/convnext-base-224-22k",     id2label=id2label,     label2id=label2id,     ignore_mismatched_sizes=True ) ### Define the checkpoint path for the best saved model. checkpoint_path = "D:/Temp/Models/ConvNext-30 Musical Instruments/checkpoints/best_model.pth" ### Load weights from disk. state_dict = torch.load(checkpoint_path) ### Apply weights to the model. model.load_state_dict(state_dict) ### Move model to the selected device. model.to(device) ### Switch to eval mode for deterministic inference. model.eval()  ### Create lists to store true labels and predicted labels. true_labels = [] predicted_labels = []  ### Iterate through every test item to build confusion matrix inputs. for item in dataset["test"]:     image_path = item["image"].filename     true_label = item["label"]          ### Load the image and ensure RGB format.     image = Image.open(image_path).convert("RGB")     ### Preprocess the image and add batch dimension.     input_image = transform(image).unsqueeze(0).to(device)      ### Run inference without gradients.     with torch.no_grad():         outputs = model(pixel_values=input_image)         logits = outputs.logits         predicted_label_id = logits.argmax(-1).item()      ### Store labels for confusion matrix computation.     true_labels.append(true_label)     predicted_labels.append(predicted_label_id)  ### Compute the confusion matrix across all classes. cm = confusion_matrix(true_labels, predicted_labels, labels=list(range(len(labels)))) ### Create a display helper with class names on axes. disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)  ### Plot a larger confusion matrix so labels remain readable. fig, ax = plt.subplots(figsize=(12, 12)) disp.plot(cmap=plt.cm.Blues, xticks_rotation=45, ax=ax) plt.title("Confusion Matrix") plt.show() 

Short summary: You evaluated ConvNeXt beyond accuracy and used a confusion matrix to identify exactly which classes need attention.


FAQ

What is ConvNeXt, and why use it for a custom dataset?

ConvNeXt is a modern CNN that borrows strong design ideas from Vision Transformers while staying efficient. It fine-tunes well on custom datasets with a straightforward PyTorch pipeline.

Why use facebook/convnext-base-224-22k for fine-tuning?

It starts from pretrained visual features learned on large-scale data. Fine-tuning adapts that knowledge to your new labels faster than training from scratch.

Do I need a specific folder structure for load_dataset imagefolder?

Yes, each class should be a folder name with images inside it. Separate train, valid, and test folders help you evaluate ConvNeXt fairly.

Why normalize images using the ConvNeXt processor stats?

Normalization keeps input distributions aligned with pretraining. That usually improves stability, convergence speed, and final accuracy.

What does ignore_mismatched_sizes=True fix?

It allows the model to swap the pretrained classification head for a new head sized to your classes. This is the simplest way to fine-tune ConvNeXt for new labels.

Why use RandomResizedCrop and RandomHorizontalFlip?

They add realistic variation so ConvNeXt doesn’t memorize exact framing. This often boosts generalization on real-world test images.

Why is AdamW a good default for fine-tuning?

AdamW applies weight decay in a cleaner way than classic Adam. It’s widely used for modern fine-tuning because it often stabilizes training.

What’s the point of checkpoints and early stopping?

Checkpoints keep the best model instead of the last model. Early stopping prevents wasted epochs once validation performance stops improving.

How does a confusion matrix help with 30 classes?

It shows exactly which instruments ConvNeXt confuses. That makes it easier to target data fixes and training tweaks where they matter.

What should I check first if predictions look wrong?

Verify label mappings and confirm preprocessing matches the ConvNeXt processor stats. Then test several images manually to confirm the pipeline end-to-end.

Conclusion

Fine-tuning ConvNeXt in PyTorch on a custom dataset is one of the most practical ways to get strong image classification results without building a model from scratch.
Once your dataset is structured into folders and your preprocessing matches the pretrained checkpoint, ConvNeXt becomes a reliable base you can reuse across many projects.

The real win in this tutorial is the full workflow: load data cleanly, apply realistic augmentations, train with AdamW, save the best checkpoint, and stop early when validation stops improving.
That combination usually produces a model you can actually trust on new images, not just one that looks good in training logs.

Finally, the confusion matrix gives you a clear path for iteration.
Instead of guessing why accuracy stalls, you can see which classes are confused and decide whether you need more images, better class separation, or small changes in augmentation and training settings.

Connect :

☕ Buy me a coffee — https://ko-fi.com/eranfeit

🖥️ Email : feitgemel@gmail.com

🌐 https://eranfeit.net

🤝 Fiverr : https://www.fiverr.com/s/mB3Pbb

Enjoy,

Eran

Leave a Comment

Your email address will not be published. Required fields are marked *

Eran Feit