// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP #include "cudnn.hpp" #include "../pointer.hpp" #include namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { /** @brief computes softmax (or log softmax) * * @tparam T element type (must be `half` or `float`) * * @param handle valid cuDNN handle * @param outputDesc tensor descriptor for A * @param[out] output pointer to tensor in device memory * @param inputDesc tensor descriptor for C * @param[in] input pointer to tensor in device memory * @param log apply log on probabilities * * Exception Guarantee: Basic */ template void softmax(const cudnn::Handle& handle, const TensorDescriptor& outputDesc, DevicePtr output, const TensorDescriptor& inputDesc, DevicePtr input, bool log) { T alpha = 1.0, beta = 0.0; cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE; CUDA4DNN_CHECK_CUDNN( cudnnSoftmaxForward( handle.get(), algo, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, inputDesc.get(), input.get(), &beta, outputDesc.get(), output.get() ) ); } template <> inline void softmax(const cudnn::Handle& handle, const TensorDescriptor& outputDesc, DevicePtr output, const TensorDescriptor& inputDesc, DevicePtr input, bool log) { /* we specalize for fp16 as the scaling factors must be provided as `float` */ float alpha = 1.0, beta = 0.0; cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE; CUDA4DNN_CHECK_CUDNN( cudnnSoftmaxForward( handle.get(), algo, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, inputDesc.get(), input.get(), &beta, outputDesc.get(), output.get() ) ); } }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP */