Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libcudacxx/include/cuda/__tma/make_tma_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ __get_tensor_sizes(const ::DLTensor& __tensor, int __rank, ::CUtensorMapDataType
[[maybe_unused]] int64_t __cumulative_size = 1;
if (__input_strides == nullptr)
{
# if DLPACK_MAJOR_VERSION > 1 || (DLPACK_MAJOR_VERSION == 1 && DLPACK_MINOR_VERSION >= 2)
_CCCL_THROW(::std::invalid_argument{"__tensor.strides=nullptr is not supported for DLPack v1.2 and later"});
# else
for (int __i = 0; __i < __rank - 1; ++__i)
{
// TODO(fbusato): check mul overflow
Expand All @@ -428,6 +431,7 @@ __get_tensor_sizes(const ::DLTensor& __tensor, int __rank, ::CUtensorMapDataType
__output_strides[__i] = __stride_bytes;
}
return __output_strides;
# endif // DLPACK_MAJOR_VERSION > 1 || (DLPACK_MAJOR_VERSION == 1 && DLPACK_MINOR_VERSION >= 2)
}
// TMA ignores the innermost stride (always 1).
for (int __i = __rank - 2; __i >= 0; --__i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ bool test_strides()
// stride is 0
strides_storage[0] = 0;
unused(cuda::make_tma_descriptor(tensor, box_sizes));
// stride is nullptr
tensor.strides = nullptr;
unused(cuda::make_tma_descriptor(tensor, box_sizes));
return true;
}

Expand Down
Loading