Image Segmentation Using TensorFlow
Last Updated :
11 Aug, 2025
Image segmentation is a computer method that breaks up a picture into different parts. It looks at the small details of each pixel (the tiny dots that make up the image) and decides what kind of thing it is like a pet, the pet’s outline or the background. The main goal is to give every pixel in a picture a label, so pixels that look alike are grouped together. This way, a computer can know exactly what is in the image and where things are.
- In regular classification, a computer just says what the whole picture is (like “cat” or “dog”).
- In object detection, the computer draws boxes around things it finds.
- Segmentation shows the exact shape of objects.
Step-by-Step Image Segmentation
Let's see the image segmentation using TensorFlow,
Step 1: Import Libraries
We will import the required libraries,
- numpy: For fast array calculations.
- matplotlib.pyplot: To visualize images and masks.
- tensorflow: Main deep learning framework.
- keras: Streamlines model building.
Python
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
Step 2: Load the Dataset
We load and split the Oxford-IIIT Pet data.
Python
dataset, info = tfds.load('oxford_iiit_pet:4.*.*', with_info=True)
Output:
Loading the Oxford-IIIT dataset Step 3: Set Constants
We set the constants that will be used,
- Batch size and buffer control training efficiency and randomization.
- Width/height standardize images for VGG16.
Python
BATCH_SIZE = 64
BUFFER_SIZE = 1000
width, height = 224, 224
TRAIN_LENGTH = info.splits['train'].num_examples
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
Step 4: Data Preprocessing and Augmentation
We perform the data preprocessing,
- Converts image pixels to float and scales between 0–1.
- Masks start from zero for correct class indexing.
- Resizes images and masks.
- Random flip adds variety for robust training.
Python
def normalize(input_image, input_mask):
img = tf.cast(input_image, dtype=tf.float32) / 255.0
input_mask -= 1
return img, input_mask
@tf.function
def load_train_ds(example):
img = tf.image.resize(example['image'], (width, height))
mask = tf.image.resize(example['segmentation_mask'], (width, height))
if tf.random.uniform(()) > 0.5:
img = tf.image.flip_left_right(img)
mask = tf.image.flip_left_right(mask)
img, mask = normalize(img, mask)
return img, mask
Step 5: Build Data Pipelines
We prepare the data pipelines,
- map: Applies preprocessing to each sample.
- cache, shuffle, batch, repeat, prefetch: Optimize data loading and training throughput.
Python
train = dataset['train'].map(
load_train_ds, num_parallel_calls=tf.data.AUTOTUNE)
test = dataset['test'].map(load_train_ds)
train_ds = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test.batch(BATCH_SIZE)
Step 6: Visualize the Data
We visualize the input, ground-truth mask and prediction side by side for easy comparison
Python
def display_images(display_list):
plt.figure(figsize=(15, 15))
titles = ['Input Image', 'True Mask', 'Predicted Mask']
for i, image in enumerate(display_list):
plt.subplot(1, len(display_list), i + 1)
plt.title(titles[i])
plt.imshow(keras.preprocessing.image.array_to_img(image))
plt.axis('off')
plt.show()
for img, mask in train.take(1):
display_images([img, mask])
Output:
Input Image Step 7: Model Construction
We build a model with VGG16+ FCN-like Decoder,
- Uses pre-trained VGG16 for feature extraction.
- Only extracts essential intermediate layers.
- Frozen weights ensure transfer learning stability.
- Decoder upsamples deep features, merges skip connections and produces pixel-wise class probabilities.
Python
base_model = keras.applications.vgg16.VGG16(
include_top=False, input_shape=(width, height, 3))
layer_names = ['block1_pool', 'block2_pool',
'block3_pool', 'block4_pool', 'block5_pool']
base_model_outputs = [base_model.get_layer(
name).output for name in layer_names]
base_model.trainable = False
VGG_16 = keras.Model(inputs=base_model.input, outputs=base_model_outputs)
def fcn8_decoder(convs, n_classes):
f1, f2, f3, f4, p5 = convs
n = 4096
c6 = keras.layers.Conv2D(n, (7, 7), activation='relu', padding='same')(p5)
c7 = keras.layers.Conv2D(n, (1, 1), activation='relu', padding='same')(c6)
f5 = c7
o = keras.layers.Conv2DTranspose(
n_classes, (4, 4), strides=(2, 2), use_bias=False)(f5)
o = keras.layers.Cropping2D((1, 1))(o)
o2 = keras.layers.Conv2D(
n_classes, (1, 1), activation='relu', padding='same')(f4)
o = keras.layers.Add()([o, o2])
o = keras.layers.Conv2DTranspose(
n_classes, (4, 4), strides=(2, 2), use_bias=False)(o)
o = keras.layers.Cropping2D((1, 1))(o)
o2 = keras.layers.Conv2D(
n_classes, (1, 1), activation='relu', padding='same')(f3)
o = keras.layers.Add()([o, o2])
o = keras.layers.Conv2DTranspose(
n_classes, (8, 8), strides=(8, 8), use_bias=False)(o)
o = keras.layers.Activation('softmax')(o)
return o
Output:
Building the ModelStep 8: Build and Compile Segmentation Model
We build the segmentation model which defines, connects and compiles the full pipeline into a trainable segmentation network.
Python
def segmentation_model():
inputs = keras.layers.Input(shape=(width, height, 3))
convs = VGG_16(inputs)
outputs = fcn8_decoder(convs, 3)
return keras.Model(inputs, outputs)
model = segmentation_model()
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
Step 9: Train the Model
We train the model for 15 epochs, reporting validation results at intervals.
Python
EPOCHS = 15
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples // BATCH_SIZE // VAL_SUBSPLITS
model_history = model.fit(
train_ds, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=test_ds,
validation_steps=VALIDATION_STEPS
)
Output:
Training the Model Step 10: Predict and visualize the Results
Model makes the predictions and we visualize it,
- Converts model output to a simple mask for visualization.
- Displays results for sample images to verify segmentation performance.
Python
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1):
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display_images([image[0], mask[0], create_mask(pred_mask)])
Output:
Prediction Step 11: Compute Segmentation Metrics
We compute the segmentation metrics which measures performance using overlap (IoU) and consolidation (Dice Score)metrics which are critical for segmentation success.
Python
def compute_metrics(y_true, y_pred):
class_wise_iou, class_wise_dice_score = [], []
smooth = 1e-5
for i in range(3):
intersection = np.sum((y_pred == i) & (y_true == i))
y_true_area = np.sum(y_true == i)
y_pred_area = np.sum(y_pred == i)
combined_area = y_true_area + y_pred_area
iou = (intersection + smooth) / (combined_area - intersection + smooth)
dice = 2 * (intersection + smooth) / (combined_area + smooth)
class_wise_iou.append(iou)
class_wise_dice_score.append(dice)
return class_wise_iou, class_wise_dice_score
Output:
IoU and Dice ScoreWe used TensorFlow and the Oxford-IIIT Pet Dataset to build a deep learning image segmentation model that assigns class labels to every pixel, allowing us to accurately separate pet images into distinct regions. Through a step-by-step pipeline, covering data preparation, model design using a VGG16 encoder and FCN-style decoder, training and evaluation, we demonstrated how raw image data can be turned into detailed, pixel-level segmentations, providing both clear visual results and reliable quantitative metrics for assessing model performance.
Image Segmentation Using TensorFlow
Explore
Deep Learning Basics
Neural Networks Basics
Deep Learning Models
Deep Learning Frameworks
Model Evaluation
Deep Learning Projects