convolution.hpp 24.9 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.

#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP

#include "cudnn.hpp"
#include "activation.hpp"

#include "../pointer.hpp"
#include "../workspace.hpp"

#include <cudnn.h>

#include <cstddef>
#include <array>
#include <algorithm>
#include <vector>
#include <type_traits>
#include <iterator>

namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {

    /** describe convolution filters
     *
     * @tparam  T   type of elements in the kernels
     */
    template <class T>
    class FilterDescriptor {
    public:
        FilterDescriptor() noexcept : descriptor{ nullptr } { }
        FilterDescriptor(const FilterDescriptor&) = delete;
        FilterDescriptor(FilterDescriptor&& other) noexcept
            : descriptor{ other.descriptor } {
            other.descriptor = nullptr;
        }

        /** constructs a filter descriptor from the filter dimensions provided in \p shape
         *
         * Shape dimensions:
         * 0: number of filters
         * 1: number of input feature maps
         * 2..n: kernel dimensions
         *
         * Exception Guarantee: Strong
         */
        template <class SequenceContainer, typename = decltype(std::begin(std::declval<SequenceContainer>()))>
        FilterDescriptor(const SequenceContainer& shape) {
            constructor(shape.begin(), shape.end());
        }

        /** constructs a filter descriptor from the filter dimensions provided in [begin, end)
         *
         * Shape dimensions:
         * 0: number of filters
         * 1: number of input feature maps
         * 2..n: kernel dimensions
         *
         * Exception Guarantee: Strong
         */
        template <class ForwardItr, typename = typename std::enable_if<!std::is_integral<ForwardItr>::value, void>::type> // TODO is_iterator
        FilterDescriptor(ForwardItr begin, ForwardItr end) {
            constructor(begin, end);
        }

        /** constructs a filter descriptor from the filter dimensions provided as arguments
         *
         * Shape dimensions:
         * 0: number of filters
         * 1: number of input feature maps
         * 2..n: kernel dimensions
         *
         * Exception Guarantee: Strong
         */
        template <class ...Sizes>
        FilterDescriptor(Sizes ...sizes) {
            static_assert(sizeof...(Sizes) >= 3, "filter descriptors must have at least three dimensions");
            static_assert(sizeof...(Sizes) <= CUDNN_DIM_MAX, "required rank exceeds maximum supported rank");
            std::array<int, sizeof...(Sizes)> dims = { static_cast<int>(sizes)... };
            constructor(std::begin(dims), std::end(dims));
        }

        ~FilterDescriptor() noexcept {
            if (descriptor != nullptr) {
                /* cudnnDestroyFilterDescriptor will not fail for a valid descriptor object */
                CUDA4DNN_CHECK_CUDNN(cudnnDestroyFilterDescriptor(descriptor));
            }
        }

        FilterDescriptor& operator=(const FilterDescriptor&) = delete;
        FilterDescriptor& operator=(FilterDescriptor&& other) noexcept {
            descriptor = other.descriptor;
            other.descriptor = nullptr;
            return *this;
        };

        cudnnFilterDescriptor_t get() const noexcept { return descriptor; }

    private:
        template <class ForwardItr>
        void constructor(ForwardItr start, ForwardItr end) {
            CV_Assert(start != end);
            CV_Assert(std::distance(start, end) >= 3);
            CV_Assert(std::distance(start, end) <= CUDNN_DIM_MAX);

            CUDA4DNN_CHECK_CUDNN(cudnnCreateFilterDescriptor(&descriptor));
            try {
                const auto rank = std::distance(start, end);
                if (rank == 4) {
                    std::array<int, 4> dims;
                    std::copy(start, end, std::begin(dims));
                    CUDA4DNN_CHECK_CUDNN(
                        cudnnSetFilter4dDescriptor(
                            descriptor,
                            detail::get_data_type<T>(), CUDNN_TENSOR_NCHW,
                            dims[0], dims[1], dims[2], dims[3]
                        )
                    );
                } else {
                    std::vector<int> dims(start, end);
                    CUDA4DNN_CHECK_CUDNN(
                        cudnnSetFilterNdDescriptor(
                            descriptor,
                            detail::get_data_type<T>(), CUDNN_TENSOR_NCHW,
                            dims.size(), dims.data()
                        )
                    );
                }
            } catch (...) {
                /* cudnnDestroyFilterDescriptor will not fail for a valid descriptor object */
                CUDA4DNN_CHECK_CUDNN(cudnnDestroyFilterDescriptor(descriptor));
                throw;
            }
        }

        cudnnFilterDescriptor_t descriptor;
    };

    /** describes a convolution operation
     *
     * @tparam  T   type of element participating in convolution
     */
    template <class T>
    class ConvolutionDescriptor {
    public:
        ConvolutionDescriptor() noexcept : descriptor{ nullptr } { }
        ConvolutionDescriptor(const ConvolutionDescriptor&) = delete;
        ConvolutionDescriptor(ConvolutionDescriptor&& other) noexcept
            : descriptor{ other.descriptor } {
            other.descriptor = nullptr;
        }

        /** constructs a convolution descriptor
         *
         * Pre-conditions:
         * - \p zero_padding, \p stride and \p dilation must have the same size
         *
         * The length of the containers is interpreted as the order of the convolution.
         *
         * Exception Guarantee: Strong
         */
        template <class SequenceContainer, typename = decltype(std::begin(std::declval<SequenceContainer>()))>
        ConvolutionDescriptor(
            const SequenceContainer& zero_padding,
            const SequenceContainer& stride,
            const SequenceContainer& dilation,
            std::size_t group_count)
        {
            constructor(zero_padding, stride, dilation, group_count);
        }

        ~ConvolutionDescriptor() noexcept {
            if (descriptor != nullptr) {
                /* cudnnDestroyConvolutionDescriptor will not fail for a valid descriptor object */
                CUDA4DNN_CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(descriptor));
            }
        }

        ConvolutionDescriptor& operator=(const ConvolutionDescriptor&) = delete;
        ConvolutionDescriptor& operator=(ConvolutionDescriptor&& other) noexcept {
            descriptor = other.descriptor;
            other.descriptor = nullptr;
            return *this;
        };

        cudnnConvolutionDescriptor_t get() const noexcept { return descriptor; }

    private:
        template <class SequenceContainer>
        void constructor(
            const SequenceContainer& zero_padding,
            const SequenceContainer& stride,
            const SequenceContainer& dilation,
            std::size_t group_count)
        {
            CV_Assert(zero_padding.size() == stride.size());
            CV_Assert(zero_padding.size() == dilation.size());

            CUDA4DNN_CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&descriptor));
            try {
                const auto rank = zero_padding.size();
                if (rank == 2) {
                    CUDA4DNN_CHECK_CUDNN(
                        cudnnSetConvolution2dDescriptor(
                            descriptor,
                            zero_padding[0], zero_padding[1],
                            stride[0], stride[1],
                            dilation[0], dilation[1],
                            CUDNN_CROSS_CORRELATION,
                            detail::get_data_type<T>()
                        )
                    );
                } else {
                    std::vector<int> ipadding(std::begin(zero_padding), std::end(zero_padding));
                    std::vector<int> istride(std::begin(stride), std::end(stride));
                    std::vector<int> idilation(std::begin(dilation), std::end(dilation));
                    CUDA4DNN_CHECK_CUDNN(
                        cudnnSetConvolutionNdDescriptor(
                            descriptor,
                            rank, ipadding.data(), istride.data(), idilation.data(),
                            CUDNN_CROSS_CORRELATION,
                            detail::get_data_type<T>()
                        )
                    );
                }
                CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionGroupCount(descriptor, group_count));

#if CUDNN_MAJOR >= 8
                /* cuDNN 7 and below use FMA math by default. cuDNN 8 includes TF32 Tensor Ops
                 * in the default setting. TF32 convolutions have lower precision than FP32.
                 * Hence, we set the math type to CUDNN_FMA_MATH to reproduce old behavior.
                 */
                CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionMathType(descriptor, CUDNN_FMA_MATH));
#endif

                if (std::is_same<T, half>::value)
                    CUDA4DNN_CHECK_CUDNN(cudnnSetConvolutionMathType(descriptor, CUDNN_TENSOR_OP_MATH));
            } catch (...) {
                /* cudnnDestroyConvolutionDescriptor will not fail for a valid descriptor object */
                CUDA4DNN_CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(descriptor));
                throw;
            }
        }

        cudnnConvolutionDescriptor_t descriptor;
    };

    /** wrapper around a convolution algorithm
     *
     * @tparam  T   type of elements being convolved
     */
    template <class T>
    class ConvolutionAlgorithm {
    public:
        ConvolutionAlgorithm() noexcept : workspace_size{ 0 } { }
        ConvolutionAlgorithm(ConvolutionAlgorithm&) = default;
        ConvolutionAlgorithm(ConvolutionAlgorithm&&) = default;

        /** selects a good algorithm for convolution for given configuration
         *
         * Exception Guarantee: Strong
         */
        ConvolutionAlgorithm(
            const Handle& handle,
            const ConvolutionDescriptor<T>& convDesc,
            const FilterDescriptor<T>& filterDesc,
            const TensorDescriptor<T>& inputDesc,
            const TensorDescriptor<T>& outputDesc)
        {
#if CUDNN_MAJOR >= 8
            int requestedAlgoCount = 0, returnedAlgoCount = 0;
            CUDA4DNN_CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(handle.get(), &requestedAlgoCount));
            std::vector<cudnnConvolutionFwdAlgoPerf_t> results(requestedAlgoCount);
            CUDA4DNN_CHECK_CUDNN(
                cudnnGetConvolutionForwardAlgorithm_v7(
                    handle.get(),
                    inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(),
                    requestedAlgoCount,
                    &returnedAlgoCount,
                    &results[0]
                )
            );

            size_t free_memory, total_memory;
            CUDA4DNN_CHECK_CUDA(cudaMemGetInfo(&free_memory, &total_memory));

            bool found_conv_algorithm = false;
            for (int i = 0; i < returnedAlgoCount; i++)
            {
                if (results[i].status == CUDNN_STATUS_SUCCESS &&
                    results[i].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
                    results[i].memory < free_memory)
                {
                    found_conv_algorithm = true;
                    algo = results[i].algo;
                    workspace_size = results[i].memory;
                    break;
                }
            }

            if (!found_conv_algorithm)
                CV_Error (cv::Error::GpuApiCallError, "cuDNN did not return a suitable algorithm for convolution.");
#else
            CUDA4DNN_CHECK_CUDNN(
                cudnnGetConvolutionForwardAlgorithm(
                    handle.get(),
                    inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(),
                    CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
                    0, /* no memory limit */
                    &algo
                )
            );

            CUDA4DNN_CHECK_CUDNN(
                cudnnGetConvolutionForwardWorkspaceSize(
                    handle.get(),
                    inputDesc.get(), filterDesc.get(), convDesc.get(), outputDesc.get(),
                    algo, &workspace_size
                )
            );
#endif
        }

        ConvolutionAlgorithm& operator=(const ConvolutionAlgorithm&) = default;
        ConvolutionAlgorithm& operator=(ConvolutionAlgorithm&& other) = default;

        cudnnConvolutionFwdAlgo_t get() const noexcept { return algo; }

        /** number of bytes of workspace memory required by the algorithm */
        std::size_t get_workspace_size() const noexcept { return workspace_size; }

    private:
        cudnnConvolutionFwdAlgo_t algo;
        std::size_t workspace_size;
    };

    /** gives the shape of the output tensor of convolution
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void getConvolutionForwardOutputDim(
        const ConvolutionDescriptor<T>& convDesc,
        const FilterDescriptor<T>& filterDesc,
        const TensorDescriptor<T>& inputDesc,
        std::vector<int>& output)
    {
        output.clear();
        output.resize(CUDNN_DIM_MAX); /* we use `output` to hold temporaries */

        std::vector<int> temp(CUDNN_DIM_MAX);
        cudnnDataType_t tempDataType;
        CUDA4DNN_CHECK_CUDNN(
            cudnnGetTensorNdDescriptor(
                inputDesc.get(),
                CUDNN_DIM_MAX + 1, /* according to docs, this is what we do to get the rank */
                &tempDataType,
                output.data(),
                temp.data(),
                temp.data()
            )
        );

        const auto rank = output[0];
        output.resize(rank);
        CUDA4DNN_CHECK_CUDNN(
            cudnnGetConvolutionNdForwardOutputDim(
                convDesc.get(), inputDesc.get(), filterDesc.get(), rank, output.data()
            )
        );
    }

    /** @brief performs convolution
     *
     * dstValue = alpha * result + beta * priorDstValue
     *
     * @tparam          T           convolution element type (must be `half` or `float`)
     *
     * @param           handle      valid cuDNN Handle
     * @param           convDesc    convolution description
     * @param           convAlgo    algorithm to use for convolution
     * @param           workspace   workspace memory which meets the requirements of \p convAlgo
     * @param           filterDesc  filter descriptor
     * @param[in]       filterPtr   pointer to device memory containing the filters
     * @param           inputDesc   tensor descriptor describing the input
     * @param[in]       inputPtr    pointer to input tensor in device memory
     * @param           alpha       result scale factor
     * @param           beta        previous value scale factor
     * @param           outputDesc  tensor descriptor describing the output
     * @param[out]      outputPtr   pointer to output tensor in device memory
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void convolve(
        const Handle& handle,
        const ConvolutionDescriptor<T>& convDesc,
        const ConvolutionAlgorithm<T>& convAlgo,
        WorkspaceInstance workspace,
        const FilterDescriptor<T>& filterDesc,
        DevicePtr<const T> filterPtr,
        const TensorDescriptor<T>& inputDesc,
        DevicePtr<const T> inputPtr,
        T alpha, T beta,
        const TensorDescriptor<T>& outputDesc,
        DevicePtr<T> outputPtr)
    {
        CV_Assert(handle);

        CUDA4DNN_CHECK_CUDNN(
            cudnnConvolutionForward(
                handle.get(),
                &alpha, inputDesc.get(), inputPtr.get(),
                filterDesc.get(), filterPtr.get(),
                convDesc.get(), convAlgo.get(),
                static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
                &beta, outputDesc.get(), outputPtr.get()
            )
        );
    }

    template <> inline
    void convolve(
        const Handle& handle,
        const ConvolutionDescriptor<half>& convDesc,
        const ConvolutionAlgorithm<half>& convAlgo,
        WorkspaceInstance workspace,
        const FilterDescriptor<half>& filterDesc,
        DevicePtr<const half> filterPtr,
        const TensorDescriptor<half>& inputDesc,
        DevicePtr<const half> inputPtr,
        half alpha, half beta,
        const TensorDescriptor<half>& outputDesc,
        DevicePtr<half> outputPtr)
    {
        CV_Assert(handle);

        /* we specalize for fp16 as the scaling factors must be provided as `float` */
        float alpha_ = alpha, beta_ = beta;
        CUDA4DNN_CHECK_CUDNN(
            cudnnConvolutionForward(
                handle.get(),
                &alpha_, inputDesc.get(), inputPtr.get(),
                filterDesc.get(), filterPtr.get(),
                convDesc.get(), convAlgo.get(),
                static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
                &beta_, outputDesc.get(), outputPtr.get()
            )
        );
    }

    /** @brief performs convolution, bias addition and activation simultaneously
     *
     * dstValue = act(alpha * conv(input) + bias)
     *
     * @tparam          T           convolution element type (must be `half` or `float`)
     *
     * @param           handle      valid cuDNN Handle
     * @param           convDesc    convolution description
     * @param           convAlgo    algorithm to use for convolution
     * @param           workspace   workspace memory which meets the requirements of \p convAlgo
     * @param           filterDesc  filter descriptor
     * @param[in]       filterPtr   pointer to device memory containing the filters
     * @param           alpha       convolution scale factor
     * @param           inputDesc   tensor descriptor describing the input
     * @param[in]       inputPtr    pointer to input tensor in device memory
     * @param           biasDesc    tensor descriptor describing the bias
     * @param[in]       biasPtr     pointer to bias tensor in device memory
     * @param           actDesc     activation descriptor
     * @param           outputDesc  tensor descriptor describing the output
     * @param[out]      outputPtr   pointer to output tensor in device memory
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void convolve_with_bias_activation(
        const Handle& handle,
        T alpha,
        const ConvolutionDescriptor<T>& convDesc,
        const ConvolutionAlgorithm<T>& convAlgo,
        WorkspaceInstance workspace,
        const FilterDescriptor<T>& filterDesc,
        DevicePtr<const T> filterPtr,
        const TensorDescriptor<T>& inputDesc,
        DevicePtr<const T> inputPtr,
        const TensorDescriptor<T>& biasDesc,
        DevicePtr<const T> biasPtr,
        const ActivationDescriptor& actDesc,
        const TensorDescriptor<T>& outputDesc,
        DevicePtr<T> outputPtr)
    {
        CV_Assert(handle);

        T alpha2 = 0.0;
        CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
            handle.get(),
            &alpha, inputDesc.get(), inputPtr.get(),
            filterDesc.get(), filterPtr.get(),
            convDesc.get(), convAlgo.get(),
            static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
            &alpha2, outputDesc.get(), outputPtr.get(),
            biasDesc.get(), biasPtr.get(),
            actDesc.get(),
            outputDesc.get(), outputPtr.get()));
    }

    template <> inline
    void convolve_with_bias_activation(
        const Handle& handle,
        half alpha,
        const ConvolutionDescriptor<half>& convDesc,
        const ConvolutionAlgorithm<half>& convAlgo,
        WorkspaceInstance workspace,
        const FilterDescriptor<half>& filterDesc,
        DevicePtr<const half> filterPtr,
        const TensorDescriptor<half>& inputDesc,
        DevicePtr<const half> inputPtr,
        const TensorDescriptor<half>& biasDesc,
        DevicePtr<const half> biasPtr,
        const ActivationDescriptor& actDesc,
        const TensorDescriptor<half>& outputDesc,
        DevicePtr<half> outputPtr)
    {
        CV_Assert(handle);

        float alpha_ = alpha, alpha2 = 0.0;
        CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
            handle.get(),
            &alpha_, inputDesc.get(), inputPtr.get(),
            filterDesc.get(), filterPtr.get(),
            convDesc.get(), convAlgo.get(),
            static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
            &alpha2, outputDesc.get(), outputPtr.get(),
            biasDesc.get(), biasPtr.get(),
            actDesc.get(),
            outputDesc.get(), outputPtr.get()));
    }

    /** @brief performs convolution, bias addition, eltwise addition and activation simultaneously
     *
     * dstValue = act(alpha1 * conv(input) + bias + alpha2 * eltwise)
     *
     * @tparam          T           convolution element type (must be `half` or `float`)
     *
     * @param           handle      valid cuDNN Handle
     * @param           convDesc    convolution description
     * @param           convAlgo    algorithm to use for convolution
     * @param           workspace   workspace memory which meets the requirements of \p convAlgo
     * @param           filterDesc  filter descriptor
     * @param[in]       filterPtr   pointer to device memory containing the filters
     * @param           alpha1      convolution scale factor
     * @param           inputDesc   tensor descriptor describing the input
     * @param[in]       inputPtr    pointer to input tensor in device memory
     * @param           biasDesc    tensor descriptor describing the bias
     * @param[in]       biasPtr     pointer to bias tensor in device memory
     * @param           alpha2      eltwise scale factor
     * @param           eltwiseDesc tensor descriptor describing the eltwise tensor
     * @param[in]       eltwisePtr  pointer to the eltwise tensor in device memory
     * @param           actDesc     activation descriptor
     * @param           outputDesc  tensor descriptor describing the output
     * @param[out]      outputPtr   pointer to output tensor in device memory
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void convolve_with_bias_eltwise_activation(
        const Handle& handle,
        T alpha1,
        const ConvolutionDescriptor<T>& convDesc,
        const ConvolutionAlgorithm<T>& convAlgo,
        WorkspaceInstance workspace,
        const FilterDescriptor<T>& filterDesc,
        DevicePtr<const T> filterPtr,
        const TensorDescriptor<T>& inputDesc,
        DevicePtr<const T> inputPtr,
        const TensorDescriptor<T>& biasDesc,
        DevicePtr<const T> biasPtr,
        T alpha2,
        const TensorDescriptor<T>& eltwiseDesc,
        DevicePtr<const T> eltwisePtr,
        const ActivationDescriptor& actDesc,
        const TensorDescriptor<T>& outputDesc,
        DevicePtr<T> outputPtr)
    {
        CV_Assert(handle);

        CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
            handle.get(),
            &alpha1, inputDesc.get(), inputPtr.get(),
            filterDesc.get(), filterPtr.get(),
            convDesc.get(), convAlgo.get(),
            static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
            &alpha2, eltwiseDesc.get(), eltwisePtr.get(),
            biasDesc.get(), biasPtr.get(),
            actDesc.get(),
            outputDesc.get(), outputPtr.get()));
    }

    template <> inline
    void convolve_with_bias_eltwise_activation(
        const Handle& handle,
        half alpha1,
        const ConvolutionDescriptor<half>& convDesc,
        const ConvolutionAlgorithm<half>& convAlgo,
        WorkspaceInstance workspace,
        const FilterDescriptor<half>& filterDesc,
        DevicePtr<const half> filterPtr,
        const TensorDescriptor<half>& inputDesc,
        DevicePtr<const half> inputPtr,
        const TensorDescriptor<half>& biasDesc,
        DevicePtr<const half> biasPtr,
        half alpha2,
        const TensorDescriptor<half>& eltwiseDesc,
        DevicePtr<const half> eltwisePtr,
        const ActivationDescriptor& actDesc,
        const TensorDescriptor<half>& outputDesc,
        DevicePtr<half> outputPtr)
    {
        CV_Assert(handle);

        float alpha1_ = alpha1, alpha2_ = alpha2;
        CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward(
            handle.get(),
            &alpha1_, inputDesc.get(), inputPtr.get(),
            filterDesc.get(), filterPtr.get(),
            convDesc.get(), convAlgo.get(),
            static_cast<void*>(workspace.get()), workspace.size_in_bytes(),
            &alpha2_, eltwiseDesc.get(), eltwisePtr.get(),
            biasDesc.get(), biasPtr.get(),
            actDesc.get(),
            outputDesc.get(), outputPtr.get()));
    }

}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */

#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP */