Skip to content
Go back

Skin Lesion Image Classification with Deep Convolutional Neural Networks

Updated:

Table of contents

Open Table of contents

Introduction

In this post we will show how to do skin lesion image classification with deep neural networks. It is an image classifier trained on the HAM10000 dataset, the same problem in the International Skin Imaging Collaboration (ISIC) 2018 challenge task3.

The solution in this post is mainly based on some web posts and methods from the ISIC2018 leadboard.

The classification neural network model is tested with pretrained ResNet and DenseNet and implemented with PyTOrch. The model with the highest mean of recalls (0.9369 on 20% test set) is a ensemble of ImageNet pretrained and fine-tuned DenseNet161 + ResNet152.

# Confusion matrix of the mdoel with the best recall
from IPython.display import Image

Image('test_results/ensemble_dense161_6_res152_4_lesion/confusion_matrix.png', width=900)

png

Here below we go through the process how I worked on this problem.

A look at the data

Before diving into the models and metrics, we need to firstly have a look at the dataset.

import pandas as pd

df = pd.read_csv('data/HAM10000_metadata.csv', index_col='image_id')
df.head()
lesion_id dx dx_type age sex localization
image_id
ISIC_0027419 HAM_0000118 bkl histo 80.0 male scalp
ISIC_0025030 HAM_0000118 bkl histo 80.0 male scalp
ISIC_0026769 HAM_0002730 bkl histo 80.0 male scalp
ISIC_0025661 HAM_0002730 bkl histo 80.0 male scalp
ISIC_0031633 HAM_0001466 bkl histo 75.0 male ear
import seaborn as sns

sns.countplot(df['dx'])
<matplotlib.axes._subplots.AxesSubplot at 0x126f39eb8>

png

pd.DataFrame({'counts':df['dx'].value_counts(), 'percent': df['dx'].value_counts() / len(df)})
counts percent
nv 6705 0.669496
mel 1113 0.111133
bkl 1099 0.109735
bcc 514 0.051323
akiec 327 0.032651
vasc 142 0.014179
df 115 0.011483

We can see that the samples for each class are very imbalanced. The class melanocytic nevi (nv) has about 67% of the dataset. The most minority class has only about 1% of the dataset.

When we organize the rows by lesion_id, we can see that many lesions have more than 1 images. The description of ham10000 says the dataset includes lesions with multiple images, which can be tracked by the lesion_id column.

dfr=df.reset_index(col_level='lesion_id').set_index(['lesion_id','image_id'])
dfr.head(10)
dx dx_type age sex localization
lesion_id image_id
HAM_0000118 ISIC_0027419 bkl histo 80.0 male scalp
ISIC_0025030 bkl histo 80.0 male scalp
HAM_0002730 ISIC_0026769 bkl histo 80.0 male scalp
ISIC_0025661 bkl histo 80.0 male scalp
HAM_0001466 ISIC_0031633 bkl histo 75.0 male ear
ISIC_0027850 bkl histo 75.0 male ear
HAM_0002761 ISIC_0029176 bkl histo 60.0 male face
ISIC_0029068 bkl histo 60.0 male face
HAM_0005132 ISIC_0025837 bkl histo 70.0 female back
ISIC_0025209 bkl histo 70.0 female back
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

%matplotlib inline

# figure size in inches optional
rcParams['figure.figsize'] = 10 ,5

def plot_by_lesion():
    grouped = df.groupby(['lesion_id'])
    lesions = []
    for i, lesion in enumerate(grouped):
        cnt = len(lesion[1].index)
        if cnt > 1:
            fig, axes = plt.subplots(1, cnt)
            for ax, name in zip(axes, lesion[1].index):
                img = mpimg.imread(f'data/{name}.jpg')
                ax.imshow(img)
        if i > 4:
            break

plot_by_lesion()

png

png

png

We can seee that the multiple images capture the same lesion with differences in color, scaling, orientation.

Now let’s count the images by lesion_id.

df['lesion_id'].nunique()
7470
cnt = df.groupby('dx')['lesion_id'].nunique()
per = df.groupby('dx')['lesion_id'].nunique().div(df['lesion_id'].nunique())
pd.DataFrame({'counts':cnt, 'percent': per})
counts percent
dx
akiec 228 0.030522
bcc 327 0.043775
bkl 727 0.097323
df 73 0.009772
mel 614 0.082195
nv 5403 0.723293
vasc 98 0.013119

Now we see there’re 7470 unique lesions there, and the class imbalance got evern worse when we group them by lesion_id. We need to keep this in mind for the later actions like sampling and choosing evaluation metrics.

Metrics

For classfication problem, the commonly used metrics are Precision/Recall/F-measures, ROC_AUC, Accuracy Score (ACC) and so on. But for this imbalanced dataset, we need to think more on the choice of metrics.

In the ISIC 2018 challenge report, it mentioned that

Use of balanced accuracy is critical to select the best unbiased classifier, rather than one that overfits to arbitrary dataset prevalence, as is the case with accuracy.

Based on the description, it’s the class-wise mean of recall. The recall_score(average='macro') in scikit-learn just calculates this score:

$$ \frac{1}{|L|} \sum_{l \in L} R \left( y_{l}, \hat{y}_{l} \right) $$

The more details of the Balanced Multiclass Accuracy can refer to

So we’ll use the balanced accuracy (BACC) or mean of recalls of the 7 classes as the main metric for this assignment.

The mean reason is that this is a very imbalanced dataset, it is a big problem we need to handel carefully. For this multiclass classification with very imbalanced dataset:

And this is a medical diagnosis, it’s important to have a high true positive rate (to minimize the false negatives), so it’s better to focus more on recall over precision.

But we’ll also use other metrics togher to have more insights. A confusion matrix plot is also a good way to present how does the model performed for each class. One of the metrics that is also good for a imbalanced classification is Matthews correlation coefficient (MCC), it ranges between −1 to 1

$$ m c c=\frac{t p \cdot t n-f p \cdot f n}{\sqrt{(t p+f p) \cdot(t p+f n) \cdot(t n+f p) \cdot(t n+f n)}} $$

Preprocess dataset

Sampling

Since the dataset is very imbalanced, so even though we could use the mean recall and loss function with class weights, it would be still troublesome to train the model for the under-represented minority classes. And the under-represented classes are likely to be missing or very few samples in a subsample or split, especially when the fraction is small. So we need to do something for the train-validation-test set sampling and split.

2 methods were applied to deal with the problem, with the assumption that new data follow a close imbalanced distribution as the labelled dataset.

For simplicity, I’ll just use the first image of each lesion_id. The code snippet below processes the dataset with oversampling. The parameter over_rate controls how much to over sample the minority classes.

import functools

exclude_set = []
weighted = True
imbalance_eq = True
remove_dup_img = True
over_rate = 4
train_fraction = 0.8
val_fraction = 0.2

meta_data = pd.read_csv('data/HAM10000_metadata.csv', index_col='image_id')

# for reproducibility, just keep 1st image of each lesion_id
if remove_dup_img:
    lesion_ids = []
    for index, row in meta_data.iterrows():
        if row['lesion_id'] not in lesion_ids:
            lesion_ids.append(row['lesion_id'])
        else:
            meta_data = meta_data.drop(index=index)

if len(exclude_set) > 0:
    meta_data = meta_data.drop(index=exclude_set)

image_ids = meta_data.index.tolist()
num_images = len(image_ids)
num_train_ids = int(num_images * train_fraction)
num_val_ids = int(num_images * val_fraction)

# sampling based on the distribution of classees
if weighted:
    size_total = num_train_ids + num_val_ids
    df_c = meta_data['dx'].astype('category').value_counts()
    weights = df_c / len(meta_data)

    def sampling(df, replace=False, total=size_total):
        return df.sample(n=int(weights[df.name] * total), replace=replace)

    train_val = meta_data.groupby('dx', as_index=False).apply(
        sampling).reset_index(0, drop=True)

    train_sampling = functools.partial(sampling, total=num_train_ids)

    train = train_val.groupby('dx', as_index=False).apply(
        train_sampling).reset_index(0, drop=True)
    val = train_val.drop(index=train.index)

    if imbalance_eq:
        bal_rate = 1 / weights / over_rate
        for k, v in bal_rate.to_dict().items():
            if v > 2:
                train = train.append(
                    [train.loc[train['dx'] == k, :]] * int(v),
                    ignore_index=False)


sns.countplot(train['dx'])
<matplotlib.axes._subplots.AxesSubplot at 0x12738aa58>

png

Fractions of the entire set is splitted as below:

Loss function

Cross Entropy Loss function will be used. As for this multiclass classification problem, I don’t have a good reason to use other loss functions over cross entropy.

On whether or not the loss criterion should also be weighted according to the imbalanced classes, I think it needs to be based on how we sample the training and validation set.

Data transformation

Some data transformations were performed to all the input images. It is performed according to the description in torchvision documentation for pre-trained models (as we will use these pre-trained models). It says Height and Width are expected to be at least 224, so we will resize all the input images into 224x224 to save some computation. We also normalize the iamges by the same mean and std as mentioned in the documentation.

# Transformation for validation and testing sets
transforms.Compose([
     transforms.Resize(224),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
])

As we oversampled minority classes in the training set, we should perform some random transformations, such as random horizontal-vertical flip, rotation and color jitter (saturation not used since I thought it might affect the preciseness of lesion area).

# Transformation for training set
transforms_train = transforms.Compose([
     transforms.Resize(300),
     transforms.RandomHorizontalFlip(0.5),
     transforms.RandomVerticalFlip(0.5),
     transforms.RandomRotation(20),
     transforms.RandomResizedCrop(224),
     transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
])

Models

When choosing a neural netowrk model for a certain task, we need to consider about several factors, e.g., the performance (accuracy) of a model, the appropriation of a model architecture for the task (in terms of pretrained model, the pretrained dataset should be more similar to the new task’s dataset), and computation efficiency.

For this assignment, we need to make a tradeoff between the performance and computation efficiency since I don’t have much of the computation resources.

Pre-trained models

By glancing through the methods in the ISIC2018 leadboard, most of the top methods used ensemble models of pre-trained models such as ResNet, Densenet, Inception and so on. There’re also methods used a single pre-trained model achieved a high rank.

And also as quoted from cs231n notes states that it’s now rarely people will train a network from scratch due to insufficient data size and expensiveness of training. It’s common to start from a model pre-trained on a very large dataset and use it as an initialization or a fixed feature extractor for a new task.

Therefore, we’ll start from a pre-trained model. As suggested in this post, ResNet-34 would be a good choice to start. So the initial plan was use the pre-trained ResNet model as a fixed feature extractor to see how it performs, and then try to “fine-tune” the weights of some layers.

However, after a few preliminary short-run experiments, I found it’s slow to train and the metrics didn’t show a potential improvement.

In Keres documentation there’s a table lists some stats like accuracy and number of parameters for some widely used models. As a tradeoff between accuracy and trainability (number of parameters), I started to focus more on the DenseNet161 and ResNet152.

After some preliminary experiments on using the networks as a feature extractor, which didn’t give a encouraging result, I decided to fine-tune the whole network.

Experiments

The experiments were early stopped when I think it might stop improving. Though it will possibly improve as the training continues, the time is precious.

Training and validation

DenseNet161

For the DenseNet161, the best validation mean of recalls is about 0.6845.

import torch

best_dense161_lesion = torch.load(
    'experiments/dense161_eq3_exclutest_lesion_v1/model_best.pth.tar',
    map_location=torch.device('cpu'))
recall_val_dense161_lesion = best_dense161_lesion['metrics']['recall']
recall_val_dense161_lesion
0.684509306993473
Image('experiments/dense161_eq3_exclutest_lesion_v1/confusion_matrix.png', width=900)

png

ResNet152

For the ResNet152, the best validation mean of recalls is about 0.7202.

best_res152_lesion = torch.load(
    'experiments/res152_eq3_exclutest_lesion_v1/model_best.pth.tar',
    map_location=torch.device('cpu'))
recall_val_res152_lesion = best_res152_lesion['metrics']['recall']
recall_val_res152_lesion
0.7202260074093291
Image('experiments/res152_eq3_exclutest_lesion_v1/recall.png', width=700)

png

import numpy as np
import matplotlib.pyplot as plt

def plot_metric(train_loss, test_loss, name, plot_type='loss'):
    epochs = range(len(train_losses))

    f = plt.figure()
    plt.title(f"{name} {plot_type} plot")
    plt.xlabel("epoch")
    plt.ylabel(f"{plot_type}")
    plt.grid(True)
    plt.plot(epochs, train_loss, 'b', marker='o', label=f'train {plot_type}')
    plt.plot(epochs, test_loss, 'r', marker='o', label=f'val {plot_type}')
    plt.legend()

train_losses, test_losses = np.load(
    'experiments/res152_eq3_exclutest_lesion_v1/final_results.npy')
plot_metric(train_losses, test_losses, 'ResNet152_lesion', 'loss')

png

# ResNet152 val
Image('experiments/res152_eq3_exclutest_lesion_v1/confusion_matrix.png', width=900)

png

As we can see from the results, it’s not a satisfactory performance for both the DenseNet161 and ResNet152 as they have only around 0.7 mean of recalls. No matter what, let’s have a look at how they perform on the test set.

Metrics on the test set

import os
import json

import numpy as np
import pandas as pd
from sklearn import metrics

test_results_path = 'test_results'
# model_ids =['dense161','res101','res152']
result_paths = [d for d in os.listdir(
    test_results_path) if not d.startswith('.')]
result_paths = [d for d in result_paths if 'lesion' in d]
# print(result_paths)

model_metrics = {}
for i in result_paths:
    fp = os.path.join(test_results_path, i, 'metrics_results.json')
    y_true = np.load(os.path.join(test_results_path, i, 'val_true.npy'))
    y_pred = np.load(os.path.join(test_results_path, i, 'val_pred.npy'))
    with open(fp) as f:
        rec = json.load(f)
        rec['f1'] = metrics.f1_score(y_true, y_pred, average='macro')
        rec['mcc'] = metrics.matthews_corrcoef(y_true, y_pred)
    model_metrics[i] = rec

df_results_lesion = pd.read_json(json.dumps(model_metrics), orient='index').drop(
    columns=['bacc']).sort_values(by='recall', ascending=False)
df_results_lesion['acc'] = df_results_lesion['acc'] / 100
df_results_lesion = df_results_lesion[['recall', 'prec', 'f1', 'mcc', 'acc']]
df_results_lesion.columns = ['Recall', 'Precision', 'F1', 'MCC', 'ACC']
df_results_lesion.loc[['dense161_lesion','res152_lesion']].round(4)
Recall Precision F1 MCC ACC
dense161_lesion 0.9105 0.8085 0.8504 0.8210 0.9111
res152_lesion 0.8594 0.7542 0.7971 0.7072 0.8465

The surperising result shows a much higher mean of recalls for both of the models on the test dataset, from around 0.7 to 0.9105 and 0.8594.

I also tested ensembles of the the trained models with different weights on each (though without grid search).

Besides pick the model when with the highest mean of recalls, I also used the DenseNet161 model with the highest MCC score during validation.

The results are also surprisingly good.

df_results_lesion.round(4)
Recall Precision F1 MCC ACC
ensemble_dense161_6_res152_4_lesion 0.9369 0.7939 0.8558 0.8067 0.9018
ensemble_dense161_4_res152_6_lesion 0.9206 0.8132 0.8610 0.7808 0.8884
dense161_lesion 0.9105 0.8085 0.8504 0.8210 0.9111
dense161_lesion_mcc 0.9095 0.8540 0.8789 0.8236 0.9144
ensemble_dense161_res152_lesion 0.9055 0.8052 0.8491 0.7931 0.8960
res152_lesion 0.8594 0.7542 0.7971 0.7072 0.8465
Image('test_results/dense161_lesion/confusion_matrix.png', width=900)

png

Image('test_results/dense161_lesion_mcc/confusion_matrix.png', width=900)

png

Image('test_results/ensemble_dense161_6_res152_4_lesion/confusion_matrix.png', width=900)

png

Image('test_results/res152_lesion/confusion_matrix.png', width=900)

png

However, I still doubt why the test metrics are much higher than when in validation (even higher than those deeply hacked top ranks in the ISIC 2018 challenge).

I’ve thought about and also discussed with others on the possible flaws in my solution. However, I couldn’t find a likely problem caused the very high mean of recalls on the tes tset. There’s no leakage of information from training set to test set, as I groupped and splitted the datasets according to the lesion_id.

You’re very welcomed to contact me if you have any idea.

Discussion of limitations and possible improvements

References not mentioned yet


Share this post on:

Previous Post
A Walk Through of the IEEE-CIS Fraud Detection Challenge
Next Post
Explore the house prices kaggle competition