Custom model training

Below is an example of how to use BenchNIRS to train a custom convolutional neural network (CNN) on one of the datasets.

import datetime
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy import stats

from benchnirs.load import load_dataset
from benchnirs.process import process_epochs
from benchnirs.learn import deep_learn


DATA_PATH = '../../data/dataset_shin_2018/'  # path to the dataset
CLASSES = ['0-back', '2-back', '3-back']
CONFIDENCE = 0.05  # stat confidence at 95 %
ROIS = {'Right PFC HbO': [9, 10, 19, 20, 21, 22, 23],
        'Right PFC HbR': [45, 46, 55, 56, 57, 58, 59],
        'Left PFC HbO': [0, 1, 2, 3, 4, 5, 6],
        'Left PFC HbR': [36, 37, 38, 39, 40, 41, 42],
        'Central PFC HbO': [7, 8],
        'Central PFC HbR': [43, 44]}

start_time = datetime.datetime.now()
date = start_time.strftime('%Y_%m_%d_%H%M')
out_folder = f'./results/custom_shin_nb_{date}'


class CustomCNN(nn.Module):

    def __init__(self, n_classes):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv1d(32, 12, kernel_size=16, stride=4)  # tempo conv
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(12, 4, kernel_size=8, stride=3)  # tempo conv
        self.pool2 = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(28, 16)
        self.do1 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(16, 8)
        self.do2 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(8, n_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = self.do1(x)
        x = F.relu(self.fc2(x))
        x = self.do2(x)
        x = self.fc3(x)
        return x


if not os.path.isdir(out_folder):
    os.makedirs(out_folder)
print(f'Main output folder: {out_folder}/')

print(f'Number of GPUs: {torch.cuda.device_count()}')

print('=====\nshin_2018_nb\n=====')

# Load and preprocess data
epochs = load_dataset('shin_2018_nb', DATA_PATH, bandpass=[0.01, 0.5],
                      baseline=(-2, 0), tddr=True)
ch_picks = []
for group in ROIS.values():
    ch_picks += group
epochs.pick(ch_picks)
epochs_lab = epochs[CLASSES]

# Run models
nirs, labels, groups = process_epochs(epochs_lab, tmax=39.9, sort=True)
print(nirs.shape)
accuracies, hps, additional_metrics = deep_learn(
    CustomCNN, nirs, labels, groups, normalize=(0, 2),
    output_folder=f'{out_folder}')

# Write results
with open(f'{out_folder}/results.csv', 'w') as w:
    w.write('dataset;model;fold;accuracy;hyperparameters\n')
    for fold, accuracy in enumerate(accuracies):
        hp = hps[fold]
        w.write(f'shin_2018_nb;CNN;{fold+1};{accuracy};"{hp}"\n')

print(f'Average accuracy: {np.mean(accuracies)}')
_, p_shap = stats.shapiro(accuracies)
print(f'Shapiro p-value: {p_shap}')
if p_shap > CONFIDENCE:
    s_tt, p_tt = stats.ttest_1samp(accuracies, 1/3, alternative='greater')
    print(f't-test = {s_tt} (p-value = {p_tt})')
else:
    s_wilcox, p_wilcox = stats.wilcoxon(accuracies - np.array(1/3),
                                        alternative='greater')
    print(f'Wilcoxon = {s_wilcox} (p-value = {p_wilcox})')


end_time = datetime.datetime.now()
elapsed_time = end_time - start_time
print(f'===\nElapsed time: {elapsed_time}')