/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *  Modifications Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#define THRUST_ENABLE_FUTURE_RAW_DATA_MEMBER

#include <thrust/detail/config.h>

#if THRUST_CPP_DIALECT >= 2017

#  include <thrust/async/copy.h>
#  include <thrust/async/reduce.h>
#  include <thrust/device_make_unique.h>
#  include <thrust/device_vector.h>
#  include <thrust/host_vector.h>

#  include "test_param_fixtures.hpp"
#  include "test_real_assertions.hpp"
#  include "test_utils.hpp"

TESTS_DEFINE(AsyncReduceIntoTests, NumericalTestsParams);

THRUST_SUPPRESS_DEPRECATED_PUSH

template <typename T>
struct custom_plus
{
  THRUST_HOST_DEVICE T operator()(T lhs, T rhs) const
  {
    return lhs + rhs;
  }
};

#  define DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(NAME, MEMBERS, CTOR, DTOR, VALIDATE, ...) \
    template <typename T>                                                                     \
    struct NAME                                                                               \
    {                                                                                         \
      MEMBERS                                                                                 \
                                                                                              \
      NAME()                                                                                  \
      {                                                                                       \
        CTOR                                                                                  \
      }                                                                                       \
                                                                                              \
      ~NAME()                                                                                 \
      {                                                                                       \
        DTOR                                                                                  \
      }                                                                                       \
                                                                                              \
      template <typename Event>                                                               \
      void validate_event(Event& e)                                                           \
      {                                                                                       \
        THRUST_UNUSED_VAR(e);                                                                 \
        VALIDATE                                                                              \
      }                                                                                       \
                                                                                              \
      template <typename ForwardIt, typename Sentinel, typename OutputIt>                     \
      THRUST_HOST auto operator()(ForwardIt&& first, Sentinel&& last, OutputIt&& output)      \
        THRUST_DECLTYPE_RETURNS(::thrust::async::reduce_into(__VA_ARGS__))                    \
    };                                                                                        \
    /**/

#  define DEFINE_ASYNC_REDUCE_INTO_INVOKER(NAME, ...)                                                \
    DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(                                                       \
      NAME, THRUST_PP_EMPTY(), THRUST_PP_EMPTY(), THRUST_PP_EMPTY(), THRUST_PP_EMPTY(), __VA_ARGS__) \
    /**/

#  define DEFINE_SYNC_REDUCE_INVOKER(NAME, ...)                                                                     \
    template <typename T>                                                                                           \
    struct NAME                                                                                                     \
    {                                                                                                               \
      template <typename ForwardIt, typename Sentinel>                                                              \
      THRUST_HOST auto operator()(ForwardIt&& first, Sentinel&& last) THRUST_RETURNS(::thrust::reduce(__VA_ARGS__)) \
    };                                                                                                              \
    /**/

DEFINE_ASYNC_REDUCE_INTO_INVOKER(reduce_into_async_invoker, THRUST_FWD(first), THRUST_FWD(last), THRUST_FWD(output));
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device, thrust::device, THRUST_FWD(first), THRUST_FWD(last), THRUST_FWD(output));
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator,
  thrust::device(thrust::device_allocator<void>{}),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output));
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_on
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device.on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output));
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_on
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device(thrust::device_allocator<void>{}).on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output));

DEFINE_SYNC_REDUCE_INVOKER(reduce_sync_invoker, THRUST_FWD(first), THRUST_FWD(last));

DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_init, THRUST_FWD(first), THRUST_FWD(last), THRUST_FWD(output), random_integer<T>());
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_init,
  thrust::device,
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>());
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_init,
  thrust::device(thrust::device_allocator<void>{}),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>());
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_on_init
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device.on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>());
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_on_init
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device(thrust::device_allocator<void>{}).on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>());

DEFINE_SYNC_REDUCE_INVOKER(reduce_sync_invoker_init, THRUST_FWD(first), THRUST_FWD(last), random_integer<T>());

DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_init_plus,
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  thrust::plus<T>());
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_init_plus,
  thrust::device,
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  thrust::plus<T>());
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_init_plus,
  thrust::device(thrust::device_allocator<void>{}),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  thrust::plus<T>());
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_on_init_plus
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device.on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  thrust::plus<T>());
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_on_init_plus
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device(thrust::device_allocator<void>{}).on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  thrust::plus<T>());

DEFINE_SYNC_REDUCE_INVOKER(
  reduce_sync_invoker_init_plus, THRUST_FWD(first), THRUST_FWD(last), random_integer<T>(), thrust::plus<T>());

DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_init_custom_plus,
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  custom_plus<T>());
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_init_custom_plus,
  thrust::device,
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  custom_plus<T>());
DEFINE_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_init_custom_plus,
  thrust::device(thrust::device_allocator<void>{}),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  custom_plus<T>());
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_on_init_custom_plus
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device.on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  custom_plus<T>());
DEFINE_STATEFUL_ASYNC_REDUCE_INTO_INVOKER(
  reduce_into_async_invoker_device_allocator_on_init_custom_plus
  // Members.
  ,
  SPECIALIZE_DEVICE_RESOURCE_NAME(Stream_t) stream_;
  // Constructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamCreateWithFlags)(
    &stream_, SPECIALIZE_DEVICE_RESOURCE_NAME(StreamNonBlocking)));
  // Destructor.
  ,
  thrust::THRUST_DEVICE_BACKEND_DETAIL::throw_on_error(SPECIALIZE_DEVICE_RESOURCE_NAME(StreamDestroy)(stream_));
  // `validate_event` member.
  ,
  ASSERT_EQ_QUIET(stream_, e.stream().native_handle());
  // Arguments to `thrust::async::reduce_into`.
  ,
  thrust::device(thrust::device_allocator<void>{}).on(stream_),
  THRUST_FWD(first),
  THRUST_FWD(last),
  THRUST_FWD(output),
  random_integer<T>(),
  custom_plus<T>());

DEFINE_SYNC_REDUCE_INVOKER(
  reduce_sync_invoker_init_custom_plus, THRUST_FWD(first), THRUST_FWD(last), random_integer<T>(), custom_plus<T>());

///////////////////////////////////////////////////////////////////////////////

template <typename T, template <typename> class AsyncReduceIntoInvoker, template <typename> class SyncReduceIntoInvoker>
THRUST_HOST void test_async_reduce_into()
{
  for (auto size : get_sizes())
  {
    SCOPED_TRACE(testing::Message() << "with size = " << size);

    thrust::host_vector<T> h0(random_integers<T>(size));

    thrust::device_vector<T> d0a(h0);
    thrust::device_vector<T> d0b(h0);
    thrust::device_vector<T> d0c(h0);
    thrust::device_vector<T> d0d(h0);

    auto s0a = thrust::device_make_unique<T>();
    auto s0b = thrust::device_make_unique<T>();
    auto s0c = thrust::device_make_unique<T>();
    auto s0d = thrust::device_make_unique<T>();

    auto const s0a_ptr = s0a.get();
    auto const s0b_ptr = s0b.get();
    auto const s0c_ptr = s0c.get();
    auto const s0d_ptr = s0d.get();

    AsyncReduceIntoInvoker<T> invoke_async;
    SyncReduceIntoInvoker<T> invoke_sync;

    ASSERT_EQ(h0, d0a);
    ASSERT_EQ(h0, d0b);
    ASSERT_EQ(h0, d0c);
    ASSERT_EQ(h0, d0d);

    auto f0a = invoke_async(d0a.begin(), d0a.end(), s0a_ptr);
    auto f0b = invoke_async(d0b.begin(), d0b.end(), s0b_ptr);
    auto f0c = invoke_async(d0c.begin(), d0c.end(), s0c_ptr);
    auto f0d = invoke_async(d0d.begin(), d0d.end(), s0d_ptr);

    invoke_async.validate_event(f0a);
    invoke_async.validate_event(f0b);
    invoke_async.validate_event(f0c);
    invoke_async.validate_event(f0d);

    // This potentially runs concurrently with the copies.
    auto const r0 = invoke_sync(h0.begin(), h0.end());

    TEST_EVENT_WAIT(thrust::when_all(f0a, f0b, f0c, f0d));

    ASSERT_EQ(r0, *s0a_ptr);
    ASSERT_EQ(r0, *s0b_ptr);
    ASSERT_EQ(r0, *s0c_ptr);
    ASSERT_EQ(r0, *s0d_ptr);
  }
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker, reduce_sync_invoker>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device, reduce_sync_invoker>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_allocator, reduce_sync_invoker>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_on)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_on, reduce_sync_invoker>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_on)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_allocator_on, reduce_sync_invoker>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_init)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_init, reduce_sync_invoker_init>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_init)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_init, reduce_sync_invoker_init>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_init)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_allocator_init, reduce_sync_invoker_init>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_on_init)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_on_init, reduce_sync_invoker_init>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_on_init)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_allocator_on_init, reduce_sync_invoker_init>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_init_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_init_plus, reduce_sync_invoker_init_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_init_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_init_plus, reduce_sync_invoker_init_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_init_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_allocator_init_plus, reduce_sync_invoker_init_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_on_init_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_on_init_plus, reduce_sync_invoker_init_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_on_init_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_allocator_on_init_plus, reduce_sync_invoker_init_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_init_custom_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_init_custom_plus, reduce_sync_invoker_init_custom_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_init_custom_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_init_custom_plus, reduce_sync_invoker_init_custom_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_init_custom_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T,
                         reduce_into_async_invoker_device_allocator_init_custom_plus,
                         reduce_sync_invoker_init_custom_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_on_init_custom_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T, reduce_into_async_invoker_device_on_init_custom_plus, reduce_sync_invoker_init_custom_plus>();
}

TYPED_TEST(AsyncReduceIntoTests, test_async_reduce_into_policy_allocator_on_init_custom_plus)
{
  SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
  using T = typename TestFixture::input_type;
  test_async_reduce_into<T,
                         reduce_into_async_invoker_device_allocator_on_init_custom_plus,
                         reduce_sync_invoker_init_custom_plus>();
}

#endif
