Skip to content

SageMaker Hyper-Parameter Optimization: classify heartbeat anomalies from stethoscope audio

Reading Time: 10 minutes

Note: the code is available in the form of a Jupyter notebook here on Github. SageMaker needs a separate-so-called entry point script to train an MXNet model. You can find mine here.

In this post, I will cover two topics which recently tickled my curiosity:

  • Audio Deep Learning classification
  • Amazon SageMaker’s Hyper-Parameter Optimization (HPO)

The context

Both topics are equally interesting to me for various reasons. As for the first, not long ago, I stumbled upon a couple of posts (here and here) which showed how to turn an audio classification task into an image modeling challenge. The trick consists in calculating the spectrogram of the WAV file, i.e. a 2D visualization of how frequencies and amplitudes of the audio signal evolve over time, plot it and then run it through a standard CNN. I was mind-blown by how neat and effective the solution was. This needed a try.

As for SageMaker’s HPO functionality, my interest came from the fact that in my previous life as an Amazonian, I had the pleasure to use the underlying library at work when it was not part of AWS yet. Our Core ML team developed it for internal use, before realizing it worked so well that external customers could benefit from it too. Given my past positive experience, I couldn’t wait to give it a shot again.

Given these premises, I figured it could be a good idea to build an image classifier to address an audio classification task and to tune it with SageMaker’s HPO. Two birds with one stone.

Good. Let’s head to the data now. Two Kaggle challenges immediately grabbed my attention: cats VS dogs and heartbeats anomalies. Despite the doubtless cuteness of the former (who does not like pets?), I finally opted for the latter. The reason is simple: I am convinced that the domain AI can disrupt the most, providing by far the biggest impact to society, is healthcare. Concretely, this means simplifying the daily job of all kinds of doctor, by submitting patients’ data to computers instead of humans. With the automation of tedious tasks, we can have less tired physicians, focused on the cases which really matter, ultimately saving more lives. On top of this, the heart-related domain is of paramount importance. Cardiovascular diseases (CVDs) top the list of human killers almost anywhere across the world. According to the WHO, CVDs claim 17.9 M lives each year. That’s an estimated 31% of global deaths. Astounding.

Considering all of the previous, the choice against pets was easily made. The Heartbeat Sounds dataset consists of stethoscope recordings of human heartbeats. Those are WAV files labeled each with one among several tags. I decided to simplify my life and focus on two categories only: normal VS murmur. I won’t pretend to know how a heartbeat is supposed to look like. There are a number of resources (here and here just to scratch the surface) which you can go through to get some domain knowledge. In a nutshell, heart sounds are associated with the noise produced by blood flowing through the heart. Specifically, by the valves opening up and shutting down to let the blood in and out. In a normal patient, those tones are regular and clearly distinguishable.

Example of a normal heartbeat from the dataset

When it comes to murmurs instead, as MedicTests puts it


A murmur is simply the sound of turbulent blood flowing through an incompetent valve. Sometimes hardening of the valve (stenosis) causes it to be unable to fully open or close, so blood is able to backflow against it. This is called regurgitation. It sounds like a miniature version of putting your thumb over the water hose.

https://medictests.com/heart-tones/
Example of a murmur from the dataset

From audio to images

To recap, we are facing a binary classification task. We are asked to build a model to differentiate a normal heartbeat from a murmur, simply by listening to the stethoscope recording. How are we supposed to do that? I had my doses of unstructured data in the past, from images to text, but audio is a little new to me. Technically it is nothing more than a time series, so probably I can just feed the signal to a RNN. There is actually a better way. What if we could extract features in the form of an image from the audio file? That would mean turning the audio classification challenge into a standard image one. With all the benefits coming with it: robust and well-known architectures, plus the golden pre-trained networks for transfer learning. The idea seems quite odd but, as we will see later, it turns out to be a winning shot.

The most common strategy to get a picture out of an audio track is to calculate and plot its spectrogram. There are several flavors of those, depending on how they are computed, even though, at the end of the day, the core concept is always the same. A spectrogram is a plot of the audio wave frequencies across time. The color of the plot shows the power of the signal. The closer to red, the higher the power. So basically it is a way to summarize the info contained in the audio wave in a single, condensed plot.

Below you can find two clips from the dataset, the first being a murmur, the second a normal heartbeat. Both audio files are presented alongside with their waveplot representation (just a plot of the array over time) and their spectrogram (log scale on the y-frequency axis). As you can see, all the action occurs below 2 kHz, so the pictures for the Deep Learning pipeline will be cut at that threshold.

M. Audio clip of a murmur from the dataset
Waveplot of the murmur in audio clip M
Spectrogram of the murmur in audio clip M
N. Audio clip of a normal heartbeat from the dataset
Waveplot of the normal heartbeat in audio clip N
Spectrogram of the normal heartbeat in audio clip N

Therefore, the idea is to compute the log-scaled spectrogram for each audio recording, filter it at 2 kHz and then save it as JPEG. Those are the images we will feed to the CNN.

The following is the key piece of code to achieve what just explained. The main function is `save_spectrograms`, which takes a `pd.DataFrame` with the dataset in the form of filenames plus labels and some metadata (and yes, of course, we use librosa to handle audio files in python!).

def get_melspectrogram(y, sr):
    mel_spec = librosa.feature.melspectrogram(y, sr=sr, fmin=10, fmax=2000, power=1.0)
    mel_spec_db = librosa.amplitude_to_db(mel_spec, ref=np.max)  
    return mel_spec_db
def produce_spect_image(spec, sr, name):
    fig, ax = plt.subplots(figsize=(15, 5))
    librosa.display.specshow(spec, sr=sr, x_axis='time', y_axis='log')    
    ax.axis('off')
    plt.savefig(name, bbox_inches=None, pad_inches=0)
    plt.close(fig)
    
########################################
# EACH AUDIO ARRAY IS PADDED WITH ITSELF
# UNTIL IT REACHES THE LENGTH OF THE LONGEST
# TO HAVE ALL FILES WITH THE SAME LENGTH
########################################
def load_and_pad(fname):
    y, sr = librosa.load(fname)
    y = repeat_to_length(y, MAX_LEN)
    return y, sr
def save_spectrograms(df):
    
    directory = os.path.join(PATH, "spects")
    if os.path.exists(directory): shutil.rmtree(directory)
    os.makedirs(directory)
    
    for i, row in df.iterrows():
        fname = os.path.join(PATH, row.fname)
        image_name = os.path.join(directory, row.image_names)
        
        y, sr = load_and_pad(fname)
        spec = get_melspectrogram(y, sr)
        produce_spect_image(spec, sr, image_name)
        
    mxn = df.copy().reset_index()
    mxn["is_murmur"] = mxn.label.apply(lambda x: int(x == "murmur"))
    
    return mxn[["index", "is_murmur", "image_names"]]

`save_spectrograms` saves the images to disk and returns a `pd.DataFrame` with 3 columns: index, audio_label and path_to_spectrogram_jpeg. This result needs to be appropriately transformed in a shape an MXNet pipeline can ingest, i.e. we have to package the images in .rec files. For more info on how to do that, refer to this previous post of mine, section Creating the dataset to feed to Gluon.

After following the necessary steps, we create 6 files, a .rec, .lst and .idx for both the training and validation sets (obtained by an 80/20 split of the dataset). Here 12 spectrograms from a random training batch. This is the kind of images the CNN is going to look at.

A note on the dataset: the original collection was composed of 351 normal heartbeat recordings plus 129 murmurs. To solve the issue of class unbalance I duplicated the latter to obtain 258 anomalous clips. Up-sampling the minority class works generally a lot better than down-sampling the majority one.

Deep Learning in SageMaker

We have set the context and figured out the data. Time to dive into some Neural Networks. To do that, we’ll be using Amazon SageMaker. I already played around with the service previously here. At the time I used the image classification built-in algorithm, which basically translates into calling the SageMaker SDK and passing it the S3 location of the data. That’s about it. You don’t get to write any code. The API is quite powerful, as there are lots of hyper-parameters to potentially tune. Nevertheless, everything boils down to an API call and AWS takes care of the rest in a simple and effective way.

This time I wanted to move a step forward and decided to provide an actual training script (entry point), which allows for more flexibility than the built-in solutions. Let’s see how this looks like in practice.

The training script

As stated above, if you need an intermediate level of flexibility between built-in algos and fully customizable Bring-Your-Own (BYO) solutions, SageMaker’s wrappers around a training script is the way to go. This piece of code is invoked internally by AWS on the ML EC2 instance we select in the wrapper’s constructor. As such, SageMaker expects it to contain some key elements. Specifically, as the original tutorial suggests, two pieces are mandatory (at least for the training part):

  • train() function that
    1. takes in hyperparameters. Those are the ones we will fine-tune with an HPO job later on. Lines 19-27
    2. defines our neural net architecture. Line 44
    3. builds the data iterators, taking care of image preprocessing and augmentation. Lines 47-49
    4. trains our network. Lines 70-111
  • save() function that saves our trained network as an MXNet model.

The rest of the script are helper functions, defined to make the whole thing more readable. You can look mine up here.

In the following sections, I will deep dive into the 2 main steps I followed: train a baseline ResNet and then use SageMaker’s HPO to improve the model’s performance. The code is largely inspired by this notebook, part of the official AWS SageMaker’s examples.

Training a CNN

The first thing to do is to get a baseline. How good is a model trained using an average learning rate, momentum, weight decay, etc? To answer such a question, I kicked-off a SageMaker training job with the default entry point’s hyper-parameters.

All the hard work is hidden within the training script, therefore this can be achieved with as few lines of code as the following:

import sagemaker
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
from sagemaker.mxnet import MXNet
m = MXNet('heartsound_mxnet.py', # THIS IS THE TRAINING (ENTRY-POINT) SCRIPT
          role=role, 
          train_instance_count=1, 
          py_version='py3',
          train_instance_type='ml.p2.xlarge',
          hyperparameters={'batch_size': 16, 
                           'epochs': 5})
m.fit("s3://heartbeat-sounds")

The above spins up a ml.p2.xlarge EC2 instance and runs the heartsound_mxnet.py script on the data stored in the s3://heartbeat-sounds S3 bucket. The bucket must contain the valid.rec and train.rec files generated before (there is no strict naming convention here; just make sure to call the right .rec files in the entry point script). As you can see we don’t pass any other hyper-parameter except for batch_size and epochs, which means all the others default to the values set in the training script,

learning_rate = hyperparameters.get('learning_rate', 0.01)
momentum = hyperparameters.get('momentum', 0.9)
wd = hyperparameters.get('wd', 0.0001)
arch = hyperparameters.get('arch', 'resnet34_v1')
freeze = hyperparameters.get('freeze', False)
opt = hyperparameters.get('optimizer', 'sgd')

i.e. a pre-trained ResNet34 which we re-train entirely using SGD, with a quite high learning rate and standard momentum and weight decay. This quick and dirty approach reaches very good results, with a validation accuracy of 78% after 4 epochs.


Screenshot from the output of the SageMaker’s training job in Jupyter

Can we do better than a baseline trained for 5 epochs?

Bayesian Hyper-Parameter Optimization

HPO is a well known problem. If we had infinite computing power and time we would re-train a network with every possible combination of the model’s hyper-parameters. We obviously cannot do that, though. Random search is generally the default strategy when it comes to HPO. We define ranges of interest and then probe our function on N random combinations of parameters. This is fine but we can do better.

Enter Bayesian optimization. If we know that in a specific region of the hyper-parameters space the scoring metric is bad, it makes no sense to waste time and resources testing anything close to that region. In a Bayesian framework, the values of the hyper-parameters change proactively according to the data gathered after every test. The algorithm smartly moves out of stagnant regions and focuses on more promising areas. So the inputs to a Bayesian HPO are the HP ranges and the number of available tries. The algorithm does the rest. There are a number of nice open-source python libraries doing this (Hyperopt is one of those). SageMaker offers this functionality as well. Let’s see if we can squeeze something more than 78% accuracy out of our heartbeat sound classifier.

Once again, the SageMaker python SDK makes all of that a piece of cake:

from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner
mt = MXNet('heartsound_mxnet.py', 
          role=role, 
          train_instance_count=1, 
          py_version='py3',
          train_instance_type='ml.p2.xlarge',
          hyperparameters={'batch_size': 16})
# WHICH HYPER-PARAMETERS TO TUNE AND WHICH RANGES TO CONSIDER
hyperparameter_ranges = {'learning_rate': ContinuousParameter(0.0001, 0.1),
                         'momentum': ContinuousParameter(0., 0.99),
                         'wd': ContinuousParameter(0., 0.001),
                         'epochs': IntegerParameter(5, 10),
                         'opt': CategoricalParameter(['sgd', 'adam'])}
# THE SCORING METRIC TO MAXIMIZE
objective_metric_name = 'Validation-accuracy'
metric_definitions = [{'Name': 'Validation-accuracy',
                       'Regex': 'validation: accuracy=([0-9\\.]+)'}]
tuner = HyperparameterTuner(mt,
                            objective_metric_name,
                            hyperparameter_ranges,
                            metric_definitions,
                            max_jobs=15) # OUR BUDGET: 15 SHOTS MAX!
tuner.fit("s3://heartbeat-sounds")

tuner.fit() triggers the tuning job. Nothing happens within the Jupyter instance. The AWS console is our friend here. Below you can see 2 screenshots from the same console page (the second is obtained scrolling down the first). As you can see (once completed) SageMaker ran 15 different training jobs (this was our budget) in a little bit more than a hour.

If we sort them by descending objective metric value (i.e. validation accuracy) we notice that one of the 15 hyper-parameters combinations managed to achieve an astounding 96.4% accuracy. This is quite an improvement compared to 75% of the baseline!

What is the winning hyper-parameters combination? We can click on the specific training job within the AWS console and look it up, or grab the results of all the experiments in Jupyter via the python SDK.

bayes_metrics = sagemaker.HyperparameterTuningJobAnalytics(tuner._current_job_name).dataframe()
bayes_metrics.sort_values(['FinalObjectiveValue'], ascending=False, inplace=True)
bayes_metrics.rename(index=str, columns={"FinalObjectiveValue": "validation_accuracy"}, inplace=True)
bayes_metrics[["validation_accuracy", "epochs", "learning_rate", "momentum", "opt", "wd"]]

which returns the following pd.DataFrame. The first row hosts the top performing model. Looks like Adam wins over SGD!

Conclusions

We managed to successfully train and tune a Deep Neural Network to recognize normal heartbeat sounds from murmurs. To do that:

  • We processed each audio recording and turned it into an image, plotting its spectrogram.
  • We then trained a baseline CNN which reached 78% accuracy on the validation set.
  • We eventually fed the ResNet to a SageMaker’s Hyper-Parameter Optimization job and managed to squeeze 96% accuracy out of our heartbeats’ classifier!

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading