import copy
import sys
import pickle
import glob
import cv2
import numpy as np
from PIL import Image
import io
import torch
import torch.utils.data as data
from . import transforms as preprocess
[docs]class DetectionDataset(data.Dataset):
'''The base class for the detection 2d dataset.
It contains the data pipeline which is defined by :meth:`_init_transform`.
DetectionDataset is the base class and does not contain the actual data, the derivative class
need to fill the annotation to the self.db.
'''
def __init__(self, cfg, is_train, transform=None):
# super(DetectionDataset, self).__init__()
self.is_train = is_train
self.image_size = cfg.IMAGE_SIZE
# self.num_classes = cfg.NUM_CLASSES
# self.classes_names = cfg.CLASSES_NAME
self.preproc_param = cfg.PREPROC
self.using_pickle = cfg.PICKLE
self.transform = transform
self.db = []
self.img_db = []
self._init_transform()
def _init_transform(self):
if self.is_train:
self.transform = preprocess.Compose([
preprocess.ConvertFromInts(),
preprocess.ToAbsoluteCoords(),
preprocess.RandomSampleCrop(scale=self.preproc_param.CROP_SCALE,
num_attempts=self.preproc_param.CROP_ATTEMPTS),
preprocess.RandomMirror(),
# preprocess.PhotometricDistort(hue_delta=self.preproc_param.HUE_DELTA,
# bri_delta=self.preproc_param.BRI_DELTA,
# contrast_range=self.preproc_param.CONTRAST_RANGE,
# saturation_range=self.preproc_param.SATURATION_RANGE),
preprocess.Expand(mean=self.preproc_param.MEAN,
max_expand_ratio=self.preproc_param.MAX_EXPAND_RATIO),
preprocess.ToPercentCoords(),
preprocess.Resize(tuple(self.image_size)),
preprocess.ToAbsoluteCoords(),
preprocess.ToTensor(),
# preprocess.ToGPU(),
preprocess.Normalize(mean=self.preproc_param.MEAN, std=self.preproc_param.STD),
preprocess.ToXYWH(),
])
else:
self.transform = preprocess.Compose([
preprocess.ConvertFromInts(),
preprocess.Resize(tuple(self.image_size)),
preprocess.ToAbsoluteCoords(),
preprocess.ToTensor(),
# preprocess.ToGPU(),
preprocess.Normalize(mean=self.preproc_param.MEAN, std=self.preproc_param.STD),
preprocess.ToXYWH(),
])
def _get_db(self):
raise NotImplementedError
def __len__(self):
return len(self.db)
[docs] def __getitem__(self, index):
r''' fetch the image and annotation from self.db[index]
The data in the self.db can be discribed as:
db[index] = {
'image': 'Absolute Path',
'boxes': np.ndarray,
'labels': np.adarray}
Args:
index: index for db,
Returns:
'image', torch(c,h,w),
'target', np.ndarray(n,5)
0~4 is the bounding box in AbsoluteCoords with format x,y,w,h
5 is the bounding box label
'''
db_rec = copy.deepcopy(self.db[index])
# read the images
if self.using_pickle:
# decode image
encoded_image = copy.deepcopy(self.img_db[index])
image = Image.open(io.BytesIO(encoded_image))
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
image_file = db_rec['image']
image = cv2.imread(image_file)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if image is None:
raise ValueError('Fail to read {}'.format(image_file))
boxes = db_rec['boxes']
labels = db_rec['labels']
# preprocess
image, boxes, labels = self.transform(image, boxes, labels)
return image, np.concatenate((boxes, labels[:,None]),axis=1)
def reorder_data(self, db, cfg_joints_name, ds_joints_name):
''' reorder the db based on the cfg_joints_name
:meta private:
'''
order = []
for cfg_name in cfg_joints_name:
if cfg_name in ds_joints_name:
order.append(ds_joints_name.index(cfg_name))
else:
order.append(-1)
order = np.array(order)
raise NotImplementedError
return db
def saving_pickle(self, pickle_path):
'''
:meta private:
'''
img_db = []
for idx, db_rec in enumerate(self.db):
sys.stdout.write('\rLoading Image: {}/{}'.format(idx, len(self.db)))
sys.stdout.flush()
# load bytes from file
with open(db_rec['image'], 'rb') as f:
img_db.append(f.read())
# serialize
sys.stdout.write('\rSaving img_db ({}) to {}\n'.format(len(self.db), pickle_path))
with open(pickle_path, 'wb') as handle:
return pickle.dump(img_db, handle, protocol=pickle.HIGHEST_PROTOCOL)
def loading_pickle(self, pickle_path):
'''
:meta private:
'''
sys.stdout.write('\rLoading Pickle from {}\n'.format(pickle_path))
with open(pickle_path, 'rb') as handle:
return pickle.load(handle)