Preprocess the Image Data by NPP in TensorRT Model Inference

NVIDIA TensorRT(TRT) library is a high-performance deep learning inference engine and delivers low latency and high-throughput for deep learning inference applications. It allows users to convert the model from other popular frameworks like pytorch or tensorflow. However, TensorRT only supports float32 rather than uint8 data type as input data type, which is the most common format for image data. In this case, when we try to deploy the image-based tasks with TensorRT, it always needs to convert the images from uint8 to float32, and then transfer the float32 image date to gpu to allow the TRT engine inference the model. When the image size is large, this preprocessing stage is slightly slow. In this blog, we are trying to introduce the NVIDIA NPP library to speed up this preprocessing progress.

Experiments

This blog uses the environment shown below and used a detection model trained by ssds.pytorch to do the experiments and evaluations. The model uses ResNet18 as feature extractor and YoloV3 as detection head. The model has already been converted to TRT with 1x3x736x1280 input and int8 computation precision.

SYS: Ubuntu 18.04
GPU: T4
GCC: 7.5
CMake: 3.16.6
CUDA: 10.2
CUDNN: 7.6.5
TensorRT: 7.0
OpenCV: 4.3.0/3.4.10

It should be noted that when the model is converted to a TRT model, TRT will select different kernel functions and their parameters according to the GPU framework, and thus optimize the inference speed. Therefore, it has to use the same GPU framework for TRT model generation and execution. And even the TRT model generated by different types of GPUs with the same framework, its inference speed will be slightly weakened based on its execution gpu machine. For example, although 2080ti and t4 belong to the same 7.5 computing framework, when we infer the model on T4, the model generated by 2080ti is 3 to 10% slower than the model generated by T4.

CPU image preprocessing

In some deep learning frameworks, it can specify the input data type & format and the data preprocessing ops in the inference graph. For example, when we freeze the weights into the frozen graph in tensorflow, we can specify the data type accepted by the model during inference through tf.placeholder(dtype=tf.uint8, shape=input_shape, name='image_tensor'). This preprocessing way is not supportted on TensorRT. TensorRT does support multiple data types, and the data type of the input and output ops can be determined by the converted onnx/uff file. However, when the input and output data type of onnx/uff model is changed to other types than float32, it can not be successfully converted into a TRT inference model in the most cases.

This data type limitation in TensorRT is even more unfriendly to the computer vision models. Images or videos in computer vision tasks are often stored in the computer as uint8 data ([0, 255]) which is not supported by TensorRT. In this case, the images must be converted to float and then do the TensorRT model inference. In some tasks, the resolution of the images or video clips of the input model is large, such as 4k or 8k, and it is slow to tranfer the data from uint8 to float in cpu and from cpu memory to gpu memory. In some cases, the time cost of pre-processing and transmission is the bottleneck in model deployment.

Most of the TRT projects on github often use the official TensorRT example to preprocess the image data in cpu, while I prefer to use the OpenCV functions to preprocess the images. The code for these two methods are shown as below.

Code: Preprocess data in cpu

TensorRT official preprocessing code

bool SampleUffSSD::processInput(const samplesCommon::BufferManager& buffers)
{
const int inputC = mInputDims.d[0];
const int inputH = mInputDims.d[1];
const int inputW = mInputDims.d[2];
const int batchSize = mParams.batchSize;

// Available images
std::vector<std::string> imageList = {"dog.ppm", "bus.ppm"};
mPPMs.resize(batchSize);
assert(mPPMs.size() <= imageList.size());
for (int i = 0; i < batchSize; ++i)
{
readPPMFile(locateFile(imageList[i], mParams.dataDirs), mPPMs[i]);
}

float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0]));
// Host memory for input buffer
for (int i = 0, volImg = inputC * inputH * inputW; i < mParams.batchSize; ++i)
{
for (int c = 0; c < inputC; ++c)
{
// The color image to input should be in BGR order
for (unsigned j = 0, volChl = inputH * inputW; j < volChl; ++j)
{
hostDataBuffer[i * volImg + c * volChl + j]
= (2.0 / 255.0) * float(mPPMs[i].buffer[j * inputC + c]) - 1.0;
}
}
}

return true;
}

OpenCV preprocessing code

int imageToTensor(const std::vector<cv::Mat> & images, float * tensor, const int batch_size, const float alpha, const float beta) {
const size_t height = images[0].rows;
const size_t width = images[0].cols;
const size_t channels = images[0].channels();
const size_t stridesCv[3] = { width * channels, channels, 1 };
const size_t strides[4] = { height * width * channels, height * width, width, 1 };

#pragma omp parallel for num_threads(c_numOmpThread) schedule(static, 1)
for (int b = 0; b < batch_size; b++)
{
cv::Mat image_f;
images[b].convertTo(image_f, CV_32F, alpha, beta);
std::vector<cv::Mat> split_channels = {
cv::Mat(images[b].size(),CV_32FC1,tensor + b * strides[0]),
cv::Mat(images[b].size(),CV_32FC1,tensor + b * strides[0] + strides[1]),
cv::Mat(images[b].size(),CV_32FC1,tensor + b * strides[0] + 2*strides[1]),
};
cv::split(image_f, split_channels);
}
return batch_size * height * width * channels;
}

The time cost of OpenCV preprocessing method(ms)

GPU(Precision) Image2Float Copy2GPU Inference GPU2CPU
t4(int8) 2.53026 0.935451 2.56143 0.0210528

As shown in the example preprocessing code above, in the cpu, the data is first converted to float type and normalized to [0,1]. The arrangement is also permuted from NHWC to NCHW. Then the float data is transferred to gpu memory to do TRT model inference. In the time cost table, it shows that the speed of image preprocessing and transmission for this model is actually greater than the speed of model inference. In this case, the model deployed in gpu is not efficient and still has room to speed up.

GPU Image Preprocessing by NPP

As mentioned, there are two reasons make the CPU image data preprocessing slow: the efficiency of CPU to convert the image from uint8 to float32 is low; since the float32 data is 4 times larger than uint8 data, the transmission efficiency between cpu memory and gpu memory is slower for float data. In this case, a simple speed-up way is to transfer uint8 data to gpu and allows gpu to complete the conversion from uint8 to float32. These processes can be done by NPP easily and efficiently.

Nvidia NPP is a cuda library for GPU accelerated 2D image and signal processing. It contains multiple submodules, which allows users to efficiently do the image computation on the gpu like the data type conversion, the color or geometric transformation and etc.. In this example, the NPPC, NPPIDEI and NPPIAL in NPP are used to perform the data type conversion from uint8 to float32 in the image data preprocessing, the channel change from NHWC to NCHW, and the normalization. The code is shown as follows.

Code: Preprocess data in gpu

NPP preprocessing code

int imageToTensorGPUFloat(const std::vector<cv::Mat> & images, void * gpu_images, void * tensor, const int batch_size, const float alpha) {
const int height = images[0].rows;
const int width = images[0].cols;
const size_t channels = images[0].channels();
const size_t stride = height * width * channels;
const size_t stride_s = width * channels;
const int dstOrder[3] = {2, 1, 0};
Npp32f scale[3] = {alpha, alpha, alpha};
NppiSize dstSize = {width, height};

#pragma omp parallel for num_threads(c_numOmpThread) schedule(static, 1)
for (int b = 0; b < batch_size; b++)
{
cudaMemcpy((Npp8u*)gpu_images + b * stride, images[b].data, stride, cudaMemcpyHostToDevice);
nppiSwapChannels_8u_C3IR((Npp8u*)gpu_images + b * stride, stride_s, dstSize, dstOrder);
nppiConvert_8u32f_C3R((Npp8u*)gpu_images + b * stride, stride_s, (Npp32f*)tensor, stride_s*sizeof(float), dstSize);
nppiMulC_32f_C3IR(scale, (Npp32f*)tensor, stride_s*sizeof(float), dstSize);
}
return batch_size * stride;
}

NPP preprocessing code (without normalization and channel permutation)

int imageToTensorGPUFloat(const std::vector<cv::Mat> & images, void * gpu_images, void * tensor, const int batch_size) {
const int height = images[0].rows;
const int width = images[0].cols;
const size_t channels = images[0].channels();
const size_t stride = height * width * channels;
NppiSize dstSize = {width, height};

#pragma omp parallel for num_threads(c_numOmpThread) schedule(static, 1)
for (int b = 0; b < batch_size; b++)
{
cudaMemcpy((Npp8u*)gpu_images + b * stride, images[b].data, stride, cudaMemcpyHostToDevice);
nppiConvert_8u32f_C3R((Npp8u*)gpu_images + b * stride, width * channels, (Npp32f*)tensor, width * channels*sizeof(float), dstSize);
}
return batch_size * stride;
}

The time cost of NPP preprocessing (without normalization and channel permutation) (ms)

GPU(Precision) Image2GPU2Float Inference GPU2CPU
t4(int8) 0.532469 3.07869 0.0208867

As shown in the example code above, for preprocessing in the gpu, the uint8 data is first transferred to the gpu memory. Then the data arrangement is permuted from NHWC to NCHW and finally the uint8 data converted to the float type and normalized to [0., 1.]. The normalized data is directly stored in the gpu memory reserved by the TRT model. Since elementwise operation and channel permute are performed efficiently in the TRT model, the normalization and channel conversion in the preprocessing can be moved to the model as operations. Compared with CPU image preprocessing, the GPU image preprocessing time is reduced from 3.5ms to 0.5ms, the total running time of the entire model is reduced from 6ms to 3.5ms, and the frame processing per second (fps) is from 166 frames raised to 285 frames, the overall speed has reached 1.7 times faster.

It should be noted that due to the long conversion time of the TRT model, the example in this blog only tests the execution speed when batch is 1. If large batches are encountered during deployment and the gpu preprocessing speed is slow, it may be due to the cuda code execution and transmission. In this case, it would be better to copy the entire batch of images to the GPU memory and the data type conversion to improve the preprocessing speed in the batch. Another way to speed up the progress is process each image sample in a stream.

Reference