Source code for ssds.dataset.dali_tfrecord

import torch
import os
import math
import ctypes
from subprocess import call
from glob import glob

import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.tfrecord as tfrec
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.plugin.pytorch as dalitorch

from .dali_dataiterator import DaliDataset, DaliPipeline

[docs]class TFRecordPipeline(Pipeline, DaliPipeline): """ Currently the easiest way for using dali to process the dataset is using TFRecord Files """ def __init__(self, tfrecords, batch_size, target_size, preproc_param, num_threads, num_shards, device_ids, training=False): Pipeline.__init__(self, batch_size=batch_size, num_threads=num_threads, device_id=device_ids, prefetch_queue_depth=num_threads, seed=42, exec_async=False, exec_pipelined=False) DaliPipeline.__init__(self, target_size=target_size, preproc_param=preproc_param, training=training) tfrecords_idx = [tfrecord+"_idx" for tfrecord in tfrecords] for tfrecord, tfrecord_idx in zip(tfrecords, tfrecords_idx): if os.path.exists(tfrecord_idx): continue call(["tfrecord2idx", tfrecord, tfrecord+"_idx"]) self.length = sum([len(open(f).readlines()) for f in tfrecords_idx]) self.input = ops.TFRecordReader(path = tfrecords, index_path = tfrecords_idx, features = { 'image/height' : tfrec.FixedLenFeature([1], tfrec.int64, -1), 'image/width' : tfrec.FixedLenFeature([1], tfrec.int64, -1), 'image/encoded' : tfrec.FixedLenFeature((), tfrec.string, ""), 'image/format' : tfrec.FixedLenFeature((), tfrec.string, ""), 'image/object/bbox/xmin': tfrec.VarLenFeature(tfrec.float32, 0.0), 'image/object/bbox/ymin': tfrec.VarLenFeature(tfrec.float32, 0.0), 'image/object/bbox/xmax': tfrec.VarLenFeature(tfrec.float32, 0.0), 'image/object/bbox/ymax': tfrec.VarLenFeature(tfrec.float32, 0.0), 'image/object/class/text': tfrec.FixedLenFeature([ ], tfrec.string, ''), 'image/object/class/label': tfrec.VarLenFeature(tfrec.int64, -1) }, num_shards = num_shards, random_shuffle = training) self.training = training self.cat = dalitorch.TorchPythonFunction(function=lambda l,t,r,b: torch.cat([l,t,r,b]).view(4,-1).permute(1,0)) #[l*w,t*h,r*w,b*h], [l,t,r,b] self.cast = ops.Cast(dtype=types.DALIDataType.INT32)
[docs] def define_graph(self): inputs = self.input() images = inputs["image/encoded"] bboxes = self.cat(inputs["image/object/bbox/xmin"], inputs["image/object/bbox/ymin"], inputs["image/object/bbox/xmax"], inputs["image/object/bbox/ymax"]) labels = self.cast(inputs["image/object/class/label"]) return self.predefined_graph(images, bboxes, labels)
def __len__(self): return self.length
[docs]class DaliTFRecord(DaliDataset): 'Data loader for data parallel using Dali for TFRecord files' def __init__(self, cfg, dataset_dir, image_sets, batch_size, training=False): super(DaliTFRecord, self).__init__(cfg, dataset_dir, image_sets, batch_size, training) self.tfrecords = [path for sets in image_sets for path in glob(os.path.join(dataset_dir, sets))] self.pipe = TFRecordPipeline(tfrecords=self.tfrecords, **self.pipeline_args) self.pipe.build() def reset_size(self, batch_size, target_size): self.batch_size = batch_size self.target_size = target_size self.pipeline_args["batch_size"] = batch_size self.pipeline_args["target_size"] = target_size del self.pipe self.pipe = TFRecordPipeline(tfrecords=self.tfrecords, **self.pipeline_args) self.pipe.build()