Source code for libauc.datasets.chexpert

import numpy as np
import torch 
from torch.utils.data import Dataset
import torchvision.transforms as tfs
import cv2
from PIL import Image
import pandas as pd

[docs] class CheXpert(Dataset): r""" Reference: .. [1] Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao. "Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021. https://arxiv.org/abs/2012.03173 """ def __init__(self, csv_path, image_root_path='', image_size=320, class_index=0, use_frontal=True, use_upsampling=True, flip_label=False, shuffle=True, seed=123, verbose=False, transforms=None, upsampling_cols=['Cardiomegaly', 'Consolidation'], train_cols=['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis', 'Pleural Effusion'], return_index=False, mode='train'): # load data from csv self.df = pd.read_csv(csv_path) self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0-small/', '', regex=True) self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0/', '', regex=True) if use_frontal: self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'] # upsample selected cols if use_upsampling: assert isinstance(upsampling_cols, list), 'Input should be list!' sampled_df_list = [] for col in upsampling_cols: print ('Upsampling %s...'%col) sampled_df_list.append(self.df[self.df[col] == 1]) self.df = pd.concat([self.df] + sampled_df_list, axis=0) # impute missing values for col in train_cols: if col in ['Edema', 'Atelectasis']: self.df[col].replace(-1, 1, inplace=True) self.df[col].fillna(0, inplace=True) elif col in ['Cardiomegaly','Consolidation', 'Pleural Effusion']: self.df[col].replace(-1, 0, inplace=True) self.df[col].fillna(0, inplace=True) elif col in ['No Finding', 'Enlarged Cardiomediastinum', 'Lung Opacity', 'Lung Lesion', 'Pneumonia', 'Pneumothorax', 'Pleural Other','Fracture','Support Devices']: # other labels self.df[col].replace(-1, 0, inplace=True) self.df[col].fillna(0, inplace=True) else: self.df[col].fillna(0, inplace=True) self._num_images = len(self.df) # 0 --> -1 if flip_label and class_index != -1: # In multi-class mode we disable this option! self.df.replace(0, -1, inplace=True) # shuffle data if shuffle: data_index = list(range(self._num_images)) np.random.seed(seed) np.random.shuffle(data_index) self.df = self.df.iloc[data_index] #assert class_index in [-1, 0, 1, 2, 3, 4], 'Out of selection!' assert image_root_path != '', 'You need to pass the correct location for the dataset!' if class_index == -1: # 5 classes if verbose: print ('Multi-label mode: True, Number of classes: [%d]'%len(train_cols)) print ('-'*30) self.select_cols = train_cols self.value_counts_dict = {} for class_key, select_col in enumerate(train_cols): class_value_counts_dict = self.df[select_col].value_counts().to_dict() self.value_counts_dict[class_key] = class_value_counts_dict else: self.select_cols = [train_cols[class_index]] # this var determines the number of classes self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict() self.mode = mode self.class_index = class_index self.image_size = image_size self.transforms = transforms self.return_index = return_index self._images_list = [image_root_path+path for path in self.df['Path'].tolist()] if class_index != -1: self.targets = self.df[train_cols].values[:, class_index].tolist() else: self.targets = self.df[train_cols].values.tolist() if True: if class_index != -1: if flip_label: self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[-1]+self.value_counts_dict[1]) if verbose: print ('-'*30) print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[-1] )) print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio )) print ('-'*30) else: self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[0]+self.value_counts_dict[1]) if verbose: print ('-'*30) print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0] )) print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio )) print ('-'*30) else: imratio_list = [] for class_key, select_col in enumerate(train_cols): try: imratio = self.value_counts_dict[class_key][1]/(self.value_counts_dict[class_key][0]+self.value_counts_dict[class_key][1]) except: if len(self.value_counts_dict[class_key]) == 1 : only_key = list(self.value_counts_dict[class_key].keys())[0] if only_key == 0: self.value_counts_dict[class_key][1] = 0 imratio = 0 # no postive samples else: self.value_counts_dict[class_key][1] = 0 imratio = 1 # no negative samples imratio_list.append(imratio) if verbose: #print ('-'*30) print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[class_key][1], self.value_counts_dict[class_key][0] )) print ('%s(C%s): imbalance ratio is %.4f'%(select_col, class_key, imratio )) print () #print ('-'*30) self.imratio = np.mean(imratio_list) self.imratio_list = imratio_list @property def class_counts(self): return self.value_counts_dict @property def imbalance_ratio(self): return self.imratio @property def num_classes(self): return len(self.select_cols) @property def data_size(self): return self._num_images def image_augmentation(self, image): img_aug = tfs.Compose([tfs.RandomAffine(degrees=(-15, 15), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=128)]) # pytorch 3.7: fillcolor --> fill image = img_aug(image) return image def __len__(self): return self._num_images def __getitem__(self, idx): image = cv2.imread(self._images_list[idx], 0) image = Image.fromarray(image) if self.mode == 'train' : if self.transforms is None: image = self.image_augmentation(image) else: image = self.transforms(image) image = np.array(image) image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # resize and normalize; e.g., ToTensor() image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) image = image/255.0 __mean__ = np.array([[[0.485, 0.456, 0.406]]]) __std__ = np.array([[[0.229, 0.224, 0.225] ]]) image = (image-__mean__)/__std__ image = image.transpose((2, 0, 1)).astype(np.float32) if self.class_index != -1: # multi-class mode label = np.array(self.targets[idx]).reshape(-1).astype(np.float32) else: label = np.array(self.targets[idx]).reshape(-1).astype(np.float32) if self.return_index: return image, label, idx return image, label
if __name__ == '__main__': root = '../chexpert/dataset/CheXpert-v1.0-small/' traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=True, use_frontal=True, image_size=320, mode='train', class_index=0) testSet = CheXpert(csv_path=root+'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=320, mode='valid', class_index=0) trainloader = torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=True, shuffle=True) testloader = torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False)