...

How To Actually Use MobileNetV3 for Fish Classifier

MobileNetV3 classify images transfer learning

This is a transfer learning tutorial for image classification using TensorFlow involves leveraging pre-trained model MobileNet-V3 to enhance the accuracy of image classification tasks.

By employing transfer learning with MobileNet-V3 in TensorFlow, image classification models can achieve improved performance with reduced training time and computational resources.

We’ll go step-by-step through:

  • Splitting a fish dataset for training & validation 
  • Applying transfer learning with MobileNetV3-Large 
  • Training a custom image classifier using TensorFlow
  • Predicting new fish images using OpenCV 
  • Visualizing results with confidence scores

👉 Watch the full tutorial here : https://youtu.be/12GvOHNc5DI

You can download the code here : https://ko-fi.com/s/4d8e51bca8

You can download the dataset here : https://www.kaggle.com/datasets/crowww/a-large-scale-fish-dataset

You can find more tutorials, and join my newsletter here : https://eranfeit.net/


Installation

The code is based on this components :

# ───────────────────────────── # Environment requirements  # ───────────────────────────── # pip install tensorflow==2.10         # deep-learning framework # pip install numpy                    # numerical arrays # pip install opencv-python            # image I/O + basic CV # Python version : 3.9.16


Part 1 : Prepare the data

This Python script prepares a fish image dataset for training and validation, suitable for use with TensorFlow or any deep learning model. It works with the Kaggle dataset “A Large Scale Fish Dataset,” where each fish species is stored in its own subfolder.

The script performs the following steps:

  1. Scans all class folders (each representing a fish species).
    2. Randomly splits the images in each class into a training set (default 85%) and a validation set (15%).
  2. Creates a new folder structure with `/train` and `/validate` subdirectories, each containing one folder per class.
  3. Copies the respective image files into their destination folders, skipping files with 0 size.

This structure is ideal for training a fish classification model using libraries like TensorFlow’s `ImageDataGenerator` or Keras’ `flow_from_directory()`.

It ensures your dataset is clean, balanced, and organized before starting the model training phase.

# Import required libraries import os          # For file system operations import random      # For random sampling of files import shutil      # For copying files between folders  # Define the percentage of data to be used for training splitsize = .85  # Initialize list to store class/category names categories = []  # Path to the original dataset where each subfolder is a fish class source_folder = "E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/Fish_Dataset"  # List all the subfolders (fish species names) folders = os.listdir(source_folder) print(folders)  # Filter only directories and store as categories for subfolder in folders:     if os.path.isdir(source_folder + "/" + subfolder):         categories.append(subfolder)  # Sort category names alphabetically for consistency categories.sort() print(categories)  # Define the path to store the processed dataset target_folder = "E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/dataset_for_model"  # Check if the output folder exists, and create it if not existDataSetPath = os.path.exists(target_folder) if existDataSetPath==False:     os.mkdir(target_folder)  # ------------------------------ # Function to split data into train and validation sets # ------------------------------ def split_data(SOURCE , TRAINING, VALIDATION, SPLIT_SIZE):     files=[]  # List to store non-empty image files      # Loop through all files in the source folder     for filename in os.listdir(SOURCE):         file = SOURCE + filename         print(file)          # Filter out empty files         if os.path.getsize(file) > 0 :             files.append(filename)         else:             print(filename + " is 0 length , ignot it ....")          print(len(files))  # Print number of valid images      # Shuffle and split the files     trainingLength = int(len(files) * SPLIT_SIZE )     shuffleSet = random.sample(files , len(files))     trainingSet = shuffleSet[0:trainingLength]     validSet = shuffleSet[trainingLength:]      # Copy training files     for filename in trainingSet :         thisFile = SOURCE + filename         destination = TRAINING + filename         shutil.copyfile(thisFile , destination)      # Copy validation files     for filename in validSet :         thisFile = SOURCE + filename         destination = VALIDATION + filename         shutil.copyfile(thisFile , destination)  # Define the training folder path trainPath = target_folder + "/train" print(trainPath)  # Define the validation folder path validatePath = target_folder + "/validate"  # Create training folder if it doesn't exist exitsDataSetPth = os.path.exists(trainPath) print(exitsDataSetPth) if not(exitsDataSetPth):     os.mkdir(trainPath)  # Create validation folder if it doesn't exist exitsDataSetPth = os.path.exists(validatePath) if exitsDataSetPth==False:     os.mkdir(validatePath)  # ------------------------------ # Loop through all categories and split their data # ------------------------------ for category in categories:     # Define paths for this category in train and validate folders     trainDestPath = trainPath + "/" + category     validateDestPath = validatePath + "/" + category      print(trainDestPath)      # Create class folder inside train if it doesn't exist     if os.path.exists(trainDestPath)==False :         os.mkdir(trainDestPath)      # Create class folder inside validate if it doesn't exist     if os.path.exists(validateDestPath)==False :         os.mkdir(validateDestPath)      # Define source path for current class     sourePath = source_folder + "/" + category + "/"     trainDestPath = trainDestPath + "/"     validateDestPath = validateDestPath + "/"      # Print copy operation log     print("Copy from : "+sourePath + " to : " + trainDestPath + " and " +validateDestPath)      # Split and copy the data     split_data(sourePath , trainDestPath , validateDestPath , splitsize)

You can download the code here : https://ko-fi.com/s/4d8e51bca8


Part 2 : Build The Model

This script builds a fish species classifier using Transfer Learning with the MobileNetV3-Large architecture pre-trained on ImageNet.

The pipeline includes:

  1. Data Augmentation using ImageDataGenerator for both training and validation sets, with random rotation, shifting, and brightness changes to improve generalization.
  2. Feature Extraction using MobileNetV3Large as the base model, excluding its top classification layers.
  3. Custom Classification Head with several fully connected layers and a softmax output for multi-class prediction across 9 fish species.
  4. Freezing Pretrained Layers to retain knowledge from ImageNet and train only the final layers.
  5. Compilation & Training with the Adam optimizer and categorical crossentropy loss.
  6. Model Saving to a .h5 file for future inference or fine-tuning.

This model is lightweight, efficient, and optimized for deployment on resource-constrained devices while achieving high accuracy.

# Import necessary modules for building the model from tensorflow.keras import Model  from tensorflow.keras.applications import MobileNetV3Large  from tensorflow.keras.preprocessing.image import ImageDataGenerator  from tensorflow.keras.layers import Dense, GlobalAveragePooling2D from tensorflow.keras.optimizers import Adam  # Define paths to the train and validation datasets trainPath = "E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/dataset_for_model/train" ValidPath = "E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/dataset_for_model/validate"  # Create data generators with augmentation for the training set trainGenerator = ImageDataGenerator(     rotation_range=15 ,                   # Random rotation     width_shift_range=0.1,                # Horizontal shift     height_shift_range=0.1,               # Vertical shift     brightness_range=(0, 0.2)             # Random brightness ).flow_from_directory(     trainPath, target_size=(320,320),     # Resize images to 320x320     batch_size=32                         # Batch size for training )  # Create validation generator with the same augmentations (optional for validation) ValidGenerator = ImageDataGenerator(     rotation_range=15 ,      width_shift_range=0.1,     height_shift_range=0.1,     brightness_range=(0, 0.2) ).flow_from_directory(     ValidPath, target_size=(320,320),     batch_size=32 )  # ------------------------------ # Load the MobileNetV3Large model as a base # ------------------------------ baseModel = MobileNetV3Large(     weights= "imagenet",           # Use pretrained ImageNet weights     include_top=False              # Exclude the default classifier head )  # ------------------------------ # Add custom classification layers on top of the base model # ------------------------------ x = baseModel.output x = GlobalAveragePooling2D()(x)       # Convert feature maps to a flat vector x = Dense(512, activation='relu')(x)  # First dense layer x = Dense(256, activation='relu')(x)  # Second dense layer x = Dense(128, activation='relu')(x)  # Third dense layer  # Final classification layer - softmax for 9 classes (fish species) predictionLayer = Dense(9, activation='softmax')(x)  # Create the final model by connecting input and output model = Model(inputs=baseModel.input , outputs=predictionLayer)  # Print a summary of the model architecture print(model.summary())  # ------------------------------ # Freeze all layers in base model except the last few # ------------------------------ for layer in model.layers[:-5]:     layer.trainable = False  # Freeze early layers to keep pretrained weights  # ------------------------------ # Compile the model # ------------------------------ optimizer = Adam(learning_rate = 0.0001)  # Use Adam optimizer with low LR model.compile(     loss= "categorical_crossentropy",    # Suitable for multi-class classification     optimizer=optimizer,     metrics=['accuracy']                 # Track accuracy during training )  # ------------------------------ # Train the model using the data generators # ------------------------------ model.fit(     trainGenerator,     validation_data=ValidGenerator,     epochs=5                             # Train for 5 epochs )  # ------------------------------ # Save the trained model to disk # ------------------------------ modelSavedPath = "E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/dataset_for_model/FishV3.h5" model.save(modelSavedPath)

You can download the code here : https://ko-fi.com/s/4d8e51bca8


Part 3 : Test The model

This is our test image :

image

This script loads a trained deep learning model based on MobileNetV3 and classifies a given fish image into one of the predefined fish species. It is part of a transfer learning pipeline for fish recognition.

Here’s what the script does:

  1. Reads the list of fish species from the training dataset folders to determine label names.
  2. Loads a previously saved `.h5` Keras model.
  3. Defines a `classify_image()` function that:
    — Opens and resizes the image.
    — Converts it into a TensorFlow-compatible array.
    — Predicts the class using the model.
    — Maps the prediction to the corresponding class name.
  4. Loads a test image and runs classification.
  5. Displays the image using OpenCV, overlaid with the predicted class name.

This allows for fast, real-time image classification using a lightweight model, and can be extended to video streams or deployment on mobile devices.

# Import required libraries import os  from tensorflow.keras.preprocessing import image     # For image preprocessing from PIL import Image                                 # For opening and resizing images import numpy as np  import tensorflow as tf                               # TensorFlow to load the model import cv2                                            # OpenCV to show image with prediction  # ------------------------------ # Load class labels from the training directory # ------------------------------ categories = os.listdir("E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/dataset_for_model/train") categories.sort()     # Ensure labels are in the same order as during training print(categories)  # ------------------------------ # Load the trained MobileNetV3 model # ------------------------------ modelSavedPath = "E:/Data-sets/A Large Scale Fish Dataset/Fish_Dataset/dataset_for_model/FishV3.h5" model = tf.keras.models.load_model(modelSavedPath)  # ------------------------------ # Define a function to classify a single image # ------------------------------ def classify_image(imageFile):     x = []      # Open the image using PIL     img = Image.open(imageFile)     img.load()      # Resize to match model input size (320x320)     img = img.resize((320,320), Image.ANTIALIAS)      # Convert the image to a NumPy array     x = image.img_to_array(img)      # Add batch dimension: (1, 320, 320, 3)     x = np.expand_dims(x, axis=0)      print(x.shape)  # Debug: check input shape      # Run the prediction     pred = model.predict(x)     print(pred)     # Debug: print raw output      # Get index of highest confidence class     categoryValue = np.argmax(pred, axis=1)[0]     print(categoryValue)  # Debug: predicted class index      # Map the index to the class name     result = categories[categoryValue]      return result  # ------------------------------ # Path to a test image for classification # ------------------------------ img_path = "Best-image-classification-models/Classify-Images-Transfer-Learning-MobileNet-V3/Sea-Bass-test.jpg"  # Classify the image and print the result resultText = classify_image(img_path) print(resultText)  # ------------------------------ # Display the image with prediction using OpenCV # ------------------------------ img = cv2.imread(img_path)  # Put the predicted label on the image img = cv2.putText(img , resultText, (50,50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 2)  # Show the image in a window cv2.imshow('img', img) cv2.waitKey(0) cv2.destroyAllWindows()

You can download the code here : https://ko-fi.com/s/4d8e51bca8


Connect :

☕ Buy me a coffee — https://ko-fi.com/eranfeit

🖥️ Email : feitgemel@gmail.com

🌐 https://eranfeit.net

🤝 Fiverr : https://www.fiverr.com/s/mB3Pbb

Enjoy,

Eran

error: Content is protected !!
Eran Feit