CUTLASS 3.5.0 (#1411)

This commit is contained in:
Vijay Thakkar 2024-03-19 17:51:04 -04:00 committed by GitHub
parent ffa34e7075
commit 629f4653c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
468 changed files with 48730 additions and 7253 deletions

View File

@ -1,28 +1,45 @@
# NVIDIA CUTLASS Changelog
## 3.5 (2024-03-18)
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
- Updates and bugfixes from the community (thanks!)
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm).
- Statically available [CUTLASS Version macros](./include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](./examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
* Expanded [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
* [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
* Profiler support for lower-aligned Hopper GEMMs.
* Performance Improvements to [Scatter-Gather Hopper Example](/examples/52_hopper_gather_scatter_fusion).
* Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion).
* Sub-Byte type fixes and improvements.
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](./include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
@ -34,7 +51,7 @@
* SM80 EVT support in C++ and Python.
* Other SM90 epilogue improvements.
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](./python/README.md) for details.
* SM90 TF32 kernel improvements for all layouts.
* SM90 rasterization direction support in the CUTLASS profiler.
* Improvement for CUTLASS profiler build times.
@ -42,65 +59,65 @@
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
* New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
* [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](./examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
* New [Epilogue Visitor Tree (EVT)](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
* [Stream-K](./include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](./include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
* New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
* [Hopper GEMM+Permute](./examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
* New CUTLASS 2D Convolution Python interface. New [example](./examples/python/03_basic_conv2d.ipynb) here.
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](./python/README.md) and new [examples](./examples/python).
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
* An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
* Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
* New [*warp-specialized persistent cooperative*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
* An [example](./examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* [Row Broadcast](include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
* [Permute + GEMM fusion](./examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* [Row Broadcast](./include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
* The GitHub branch is renamed from `master` to `main` in this release.
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
* [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md).
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](media/docs/cutlass_3x_backwards_compatibility.md).
* Updates to [Functionality](media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
* Updates to [Compatibility](/README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](/README.md#Target-Architecture).
* New warp-specialized GEMM [kernel schedules](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md).
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md).
* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
* [CUTLASS library integration](/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
* Support for [Hopper GEMMs](examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](examples/48_hopper_warp_specialized_gemm), [49](examples/49_hopper_gemm_schedules_with_collective_builder), and [50](examples/50_hopper_gemm_with_epilogue_swizzle).
* [CUTLASS library integration](./tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
* Support for [Hopper GEMMs](./examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](./examples/48_hopper_warp_specialized_gemm), [49](./examples/49_hopper_gemm_schedules_with_collective_builder), and [50](./examples/50_hopper_gemm_with_epilogue_swizzle).
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
* [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
* [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
* Optimized [Group Conv](/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
* [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
* [kOptimized](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
* [Stream-K](./examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
* [Fused multi-head attention Kernel](./examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
* [Dual GEMM](./examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
* Hopper improves [double precision matrix multiplication](./test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
* [BLAS3](./test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
* [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
* Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
* [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
* [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
* [kFixedStrideDilation](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
* [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
* [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
* [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
* [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
* [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115).
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
@ -109,54 +126,54 @@
* CUDA 10.2
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
* Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
* [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
* Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
* Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
* [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
* [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
* [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
* Standalone [Layernorm](./tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](./tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
* [Python-based instance emitter](/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
* [First layer Convolution kernels](./test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
* [Few channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
* [Fixed channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
* [Unit tests](./test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
* [Python-based instance emitter](./python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/python/cutlass_library/rank_k_operation.py)
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/python/cutlass_library/rank_k_operation.py)
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/python/cutlass_library/symm_operation.py)
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/python/cutlass_library/trmm_operation.py)
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
* [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
* [GEMM + Softmax example](/examples/35_gemm_softmax)
* [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
* [HERK](./test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
* [SYRK](./test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
* [SYMM](./test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/symm_operation.py)
* [TRMM](./test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/trmm_operation.py)
* [Unit tests](./test/unit/gemm/device/testbed_rank_k_universal.h)
* [CUTLASS Python](./examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
* [Python-based runtime](./tools/library/scripts/rt.py) interoperable with existing emitters
* [GEMM + Softmax example](./examples/35_gemm_softmax)
* [Gather and Scatter Fusion with GEMM](./examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
* It can select random rows in a row major matrix.
* It can select random columns in a column major matrix.
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
* Supported kernels: GEMM and CONV.
* Supported types: fp16 and int8.
* Supported architectures: Turing and Ampere.
* [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
* [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
* [Transposed Convolution](./examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
* [Utility functions](./tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
* Epilogue enhancement:
* Eliminate bank conflicts in int8 tensor core kernels.
* Half2 usage if epilogue compute type is fp16.
* More activation functions: Silu, Hardswish, Leaky Relu.
* New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
* New elementwise fusion pattern for [residual block](./include/cutlass/epilogue/thread/linear_combination_residual_block.h).
* [Group GEMM](./examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
@ -166,17 +183,17 @@
* **TF32x3:** emulated single-precision using Tensor Cores
* 45+ TFLOPs on NVIDIA A100
* [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
* [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
* [GEMM SDK example](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
* [COMPLEX GEMM SDK example](./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
* [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
* [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
* [Conv WGrad SDK example](/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
* [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
* [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
* [SDK example](/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
* [cutlass::gemm::device::GemmGrouped](/include/cutlass/gemm/device/gemm_grouped.h)
* [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
* [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
* [cutlass::gemm::device::GemmGrouped](./include/cutlass/gemm/device/gemm_grouped.h)
* [Implicit GEMM Convolution fusion](./examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
* Updates from the community (thanks!)
@ -186,11 +203,11 @@
* CUDA 10.2
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
* [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
* Mainloop fusion for GEMM: [summation over A or B](./examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
* [Strided DGRAD (optimized iterators)](./include/cutlass/conv/kernel/default_conv2d_dgrad.h)
* [Half-precision GELU_taylor activation functions](./include/cutlass/epilogue/thread/activation.h#L196)
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
* Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
* Tuning and bug fixes to [fused GEMM + GEMM example](./examples/13_two_tensor_op_fusion/)
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
@ -205,24 +222,24 @@
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
* Adopt the new L2 prefetch feature in [cp.async](./include/cutlass/arch/memory.h) and [global load](./include/cutlass/arch/memory_sm80.h)
* Fused operators with GEMM and Convolution
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
* 64b tensor strides and leading dimensions support for GEMMs
* Affine rank=2 matrix layouts
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h)
* Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
* [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h)
* Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
* [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation
* [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
* Accelerates over previous implementation by cutting down redundant math by 4x
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
* Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h)
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_conv/quaternion_conv.cu)
* [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
* Updates to [quaternion.h](./include/cutlass/quaternion.h) and [functional.h](./include/cutlass/functional.h)
* SDK Example for [GEMM](./examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](./examples/22_quaternion_conv/quaternion_conv.cu)
* [Unit tests for GEMM](./test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](./test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
* Many improvements to the epilogue.
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
* Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
* Performance improvement for FP16 tensor core kernels
* Bug fixes
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
@ -234,14 +251,14 @@
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
* Tensor reductions
* _m_-to-_n_ reductions of tensors with affine layout
* [Specializations](/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
* [Specializations](/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
* [Specializations](./test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
* [Specializations](./test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
* Custom reduction functors such as `cutlass::logical_and`
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
* Optimizations for 3-D convolution
* [Optimized tile iterators](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
* [Optimized tile iterators](./include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
* Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
* [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md)
* [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md)
* Corrections and bug fixes reported by the CUTLASS community
* Thank you for filing these issues!
@ -256,7 +273,7 @@
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
* [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
* [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
@ -264,13 +281,13 @@
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
* Minor Features:
* [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
* Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code
* [Floating-point constants](/include/cutlass/constants.h)
* [Activation functions](./include/cutlass/epilogue/thread/activation.h) such as [GeLU](./include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](./include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
* Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code
* [Floating-point constants](./include/cutlass/constants.h)
* NVIDIA Ampere GPU Architecture examples and documentation:
* [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
* [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
* Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue)
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue)
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
@ -290,11 +307,11 @@
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
* [SDK Examples of Planar Complex GEMMs](./examples/10_planar_complex/planar_complex.cu)
* Minor enhancements and bug fixes
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
@ -304,10 +321,10 @@
* Encapsulated functionality embodying modern C++11 programming techniques
* Optimized containers and data types for efficient, generic, portable device code
* Updates to:
* [Quick start guide](/media/docs/quickstart.md)
* [Documentation](/README.md#documentation)
* [Utilities](/media/docs/utilities.md)
* [CUTLASS Profiler](/media/docs/profiler.md)
* [Quick start guide](./media/docs/quickstart.md)
* [Documentation](./README.md#documentation)
* [Utilities](./media/docs/utilities.md)
* [CUTLASS Profiler](./media/docs/profiler.md)
* Native Turing Tensor Cores
* Efficient GEMM kernels targeting Turing Tensor Cores
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands

View File

@ -67,14 +67,13 @@ elseif (CUDA_VERSION VERSION_LESS 11.4)
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.")
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5)
message(FATAL_ERROR "GCC version must be at least 7.5!")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3)
message(FATAL_ERROR "GCC version must be at least 7.3!")
endif()
if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
endif()
find_package(Doxygen QUIET)
################################################################################
@ -168,6 +167,7 @@ endif()
include(GNUInstallDirs)
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)
###################################################################################################
#

View File

@ -1,6 +1,6 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS")
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS")
[README](/README.md#documentation) > **Contributors**
[README](./README.md#documentation) > **Contributors**
# CUTLASS Developers and Contributors

View File

@ -326,6 +326,14 @@ function(cutlass_add_library NAME)
cxx_std_11
)
get_target_property(TARGET_TYPE ${NAME} TYPE)
if (TARGET_TYPE MATCHES "SHARED")
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Shared)
elseif(TARGET_TYPE MATCHES "STATIC")
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Static)
endif()
if(__EXPORT_NAME)
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
@ -336,10 +344,19 @@ endfunction()
function(cutlass_add_executable NAME)
set(options)
set(oneValueArgs)
set(oneValueArgs CUDA_RUNTIME_LIBRARY)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED __CUDA_RUNTIME_LIBRARY)
set(__CUDA_RUNTIME_LIBRARY Shared)
endif()
set(__CUDA_RUNTIME_LIBRARY_ALLOWED None Shared Static)
if (NOT __CUDA_RUNTIME_LIBRARY IN_LIST __CUDA_RUNTIME_LIBRARY_ALLOWED)
message(FATAL_ERROR "CUDA_RUNTIME_LIBRARY value '${__CUDA_RUNTIME_LIBRARY}' is not in allowed list of '${__CUDA_RUNTIME_LIBRARY_ALLOWED}'")
endif()
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
@ -359,6 +376,8 @@ function(cutlass_add_executable NAME)
cxx_std_11
)
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY ${__CUDA_RUNTIME_LIBRARY})
endfunction()
function(cutlass_target_sources NAME)

View File

@ -4,6 +4,7 @@
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
- ["Benchmarking GPU Tensor Cores on General Matrix Multiplication Kernels through CUTLASS"](https://www.mdpi.com/2076-3417/13/24/13022). Xuanteng Huang, Xianwei Zhang, Panfei Yang, Nong Xiao. _Journal of Applied Sciences_, December 2023.
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 3.4
# CUTLASS 3.5
_CUTLASS 3.4 - February 2024_
_CUTLASS 3.5 - March 2024_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
@ -19,16 +19,16 @@ mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for half-precision floating
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32),
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
See the [functionality listing](./media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
@ -37,25 +37,27 @@ CuTe is a collection of C++ CUDA template abstractions for defining and operatin
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cute/00_quickstart.md).
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
# What's New in CUTLASS 3.4
# What's New in CUTLASS 3.5
CUTLASS 3.4.1 is an update to CUTLASS adding:
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
CUTLASS 3.5 is an update to CUTLASS adding:
CUTLASS 3.4.0 is an update to CUTLASS adding:
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
- Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
- Updates and bugfixes from the community (thanks!)
Minimum requirements:
@ -98,7 +100,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
# Compatibility
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads).
performs best when built with the [**CUDA 12.4 Toolkit**](https://developer.nvidia.com/cuda-downloads).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
## Operating Systems
@ -142,28 +144,28 @@ The target architecture information is passed on to CUTLASS via the cmake flag `
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
```
Please refer to the [functionality documentation](media/docs/functionality.md) for details on which kernels require which target architectures.
Please refer to the [functionality documentation](./media/docs/functionality.md) for details on which kernels require which target architectures.
# Documentation
CUTLASS is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](media/docs/terminology.md) - describes terms used in the code
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](media/docs/layout.md) - describes layouts of matrices and tensors in memory
- [Tile Iterators](media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
- [Quick Start Guide](./media/docs/quickstart.md) - build and run CUTLASS
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development
# Resources
We have also described the structure of an efficient GEMM in our talk at the
@ -182,7 +184,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
paths.
CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](media/docs/quickstart.md).
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.
@ -227,7 +229,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
and template concepts defined in the CUTLASS project.
A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](media/docs/code_organization.md), but several main components are summarized below.
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
## CUTLASS Template Library
@ -276,7 +278,7 @@ include/ # client applications should target this directory
### CUTLASS SDK Examples
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations.
### Tools
@ -301,7 +303,7 @@ tools/
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
Instructions for building and running the Unit tests are described in the [Quickstart guide](media/docs/quickstart.md).
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
# Performance Profiling
@ -517,9 +519,9 @@ reference_device: Passed
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
# About

View File

@ -37,7 +37,7 @@ endif()
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG v1.13.0
GIT_TAG v1.14.0
)
FetchContent_GetProperties(googletest)

View File

@ -260,7 +260,7 @@ private:
if (options.vectorize <= 2) return std::make_pair(false, -1);
// Boundary check.
if (i > elements.size() || (i + options.vectorize - 1) > elements.size())
if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size()))
return std::make_pair(false, -1);
// Check if either all elements are valid or invalid.

View File

@ -94,7 +94,7 @@ __global__ void copy(
typename Iterator::Fragment fragment;
for(int i = 0; i < fragment.size(); ++i) {
for(size_t i = 0; i < fragment.size(); ++i) {
fragment[i] = 0;
}

View File

@ -207,15 +207,15 @@ cudaError_t strided_batched_gemm_nn_reference(
cudaError_t result = cudaSuccess;
if (A.size() < lda * k * batch_count) {
if (A.size() < size_t(lda * k * batch_count)) {
std::cout << "the size of A is too small" << std::endl;
return cudaErrorInvalidValue;
}
if (B.size() < ldb * n) {
if (B.size() < size_t(ldb * n)) {
std::cout << "the size of B is too small" << std::endl;
return cudaErrorInvalidValue;
}
if (C.size() < ldc * n * batch_count) {
if (C.size() < size_t(ldc * n * batch_count)) {
std::cout << "the size of C is too small" << std::endl;
return cudaErrorInvalidValue;
}

View File

@ -102,7 +102,7 @@ struct B2bFusedGroupedGemmRun
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
view, seed, 1, -1, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {

View File

@ -157,35 +157,34 @@ struct B2bGemm {
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int batch_count;
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
GemmCoord problem_size_0{0,0,0};
GemmCoord problem_size_1{0,0,0};
typename B2bMma::IteratorA0::TensorRef ref_A0{};
typename B2bMma::IteratorB0::TensorRef ref_B0{};
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
typename B2bMma::IteratorB1::TensorRef ref_B1{};
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
int64_t batch_stride_A0{0};
int64_t batch_stride_B0{0};
int64_t batch_stride_B1{0};
int64_t batch_stride_C1{0};
int64_t batch_stride_D1{0};
int64_t batch_stride_Bias0{0};
int64_t batch_stride_Scale0{0};
typename OutputOp0::Params epilogue0 {};
typename OutputOp1::Params epilogue1 {};
int batch_count{1};
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
Arguments() = default;
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
@ -285,47 +284,45 @@ struct B2bGemm {
/// Parameters structure
struct Params {
cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename B2bMma::IteratorA0::Params params_A0;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::Params params_B0;
typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::Params params_C0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
typename B2bMma::IteratorB1::Params params_B1;
typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::Params params_C1;
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
typename Epilogue::OutputTileIterator::Params params_D1;
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
int gemm_k_iterations_1;
int gemm_k_size_1;
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size_0{};
cutlass::gemm::GemmCoord problem_size_1{};
cutlass::gemm::GemmCoord grid_tiled_shape{};
int swizzle_log_tile{0};
typename B2bMma::IteratorA0::Params params_A0{};
typename B2bMma::IteratorA0::TensorRef ref_A0{};
typename B2bMma::IteratorB0::Params params_B0{};
typename B2bMma::IteratorB0::TensorRef ref_B0{};
typename Epilogue::OutputTileIterator::Params params_C0{};
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
typename B2bMma::IteratorB1::Params params_B1{};
typename B2bMma::IteratorB1::TensorRef ref_B1{};
typename Epilogue::OutputTileIterator::Params params_C1{};
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
typename Epilogue::OutputTileIterator::Params params_D1{};
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
typename OutputOp0::Params output_op_0{};
typename OutputOp1::Params output_op_1{};
int64_t batch_stride_A0{0};
int64_t batch_stride_B0{0};
int64_t batch_stride_B1{0};
int64_t batch_stride_C1{0};
int64_t batch_stride_D1{0};
int64_t batch_stride_Bias0{0};
int64_t batch_stride_Scale0{0};
int *semaphore = nullptr;
int gemm_k_iterations_0{0};
int gemm_k_size_0{0};
int gemm_k_iterations_1{0};
int gemm_k_size_1{0};
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
Params() = default;
CUTLASS_HOST_DEVICE
Params(

View File

@ -27,10 +27,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(TEST_STANDARD --m=1024 --n=1024 --k=1024)
set(TEST_LARGE_PERFCHECK --m=4096 --n=3456 --k=4096 --perf-check)
cutlass_example_add_executable(
23_ampere_gemm_operand_reduction_fusion
ampere_gemm_operand_reduction_fusion.cu
TEST_COMMAND_OPTIONS
TEST_STANDARD
TEST_LARGE_PERFCHECK
)

View File

@ -377,22 +377,22 @@ Result profile(Options const &options) {
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1997,
ElementInputA(2),
ElementInputA(-2),
ElementInputA(1),
ElementInputA(-1),
0); // <- Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
2003,
ElementInputB(2),
ElementInputB(-2),
ElementInputB(1),
ElementInputB(-1),
0); // <- Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
2017,
ElementOutput(2),
ElementOutput(-2),
ElementOutput(1),
ElementOutput(-1),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros

View File

@ -789,7 +789,7 @@ public:
problem_count_check += bin.second.size();
}
if (problem_count_check != this->problem_count()) {
if (problem_count_check != size_t(this->problem_count())) {
std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl;
}

View File

@ -113,10 +113,10 @@ cudaError_t CutlassSsyrkNN(
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5, // Stages
1, // AligmentA
1, // AlignmentA
false, // SplitKSerail
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::BlasMode::kSymmetric
>;
@ -149,7 +149,7 @@ cudaError_t CutlassSsyrkNN(
//
// Launch the CUTLASS SYRK kernel.
//
cutlass::Status status = syrk_operator(args);
//

View File

@ -456,7 +456,7 @@ struct Testbed {
bool verify_tensor(std::vector<Element> vector_Input, \
std::vector<Element> vector_Input_Ref) {
int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size();
auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size());
float abs_tol = options.tolerance;
float rel_tol = options.tolerance;

View File

@ -454,48 +454,48 @@ struct Testbed {
cutlass::reference::host::TensorFillRandomUniform(
tensor_A0.host_view(),
options.seed,
ElementInputA0(5),
ElementInputA0(-5),
ElementInputA0(4),
ElementInputA0(-4),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_B0.host_view(),
options.seed + 1,
ElementInputB0(5),
ElementInputB0(-5),
ElementInputB0(4),
ElementInputB0(-4),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_A1.host_view(),
options.seed + 2,
ElementInputA1(5),
ElementInputA1(-5),
ElementInputA1(4),
ElementInputA1(-4),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_Beta.host_view(),
options.seed + 3,
ElementInputScaleBias(5),
ElementInputScaleBias(-5),
ElementInputScaleBias(4),
ElementInputScaleBias(-4),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_Gamma.host_view(),
options.seed + 4,
ElementInputScaleBias(5),
ElementInputScaleBias(-5),
ElementInputScaleBias(4),
ElementInputScaleBias(-4),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_Shifted_K.host_view(),
options.seed + 5,
ElementOutput(5),
ElementOutput(-6),
ElementOutput(4),
ElementOutput(-5),
0
);

View File

@ -803,7 +803,7 @@ public:
// Use 'D' for the in/out workspace
this->block_D.copy_from_device(this->block_C.get());
for (int i = 0; i < this->options.problem_sizes.size(); ++i) {
for (size_t i = 0; i < this->options.problem_sizes.size(); ++i) {
cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i];
int32_t batch_count = 1;
int64_t lda = this->lda_host.at(i);
@ -904,10 +904,10 @@ public:
// Run profiling loop
//
int last_stream_idx = 0;
size_t last_stream_idx = 0;
for (int iter = 0; iter < this->options.iterations; ++iter) {
for (int i = 0; i < this->options.problem_sizes.size(); ++i) {
for (size_t i = 0; i < this->options.problem_sizes.size(); ++i) {
cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i];
int32_t batch_count = 1;
int64_t lda = this->lda_host.at(i);
@ -1146,7 +1146,7 @@ public:
);
// Initialize the Rank2K object
Rank2K rank2k;
Rank2K rank2k{};
size_t workspace_size = rank2k.get_workspace_size(args);
cutlass::DeviceAllocation<uint8_t> workspace(workspace_size);

View File

@ -40,7 +40,7 @@
// Nans & inf detection
#define NANCHECK(frag) \
{ \
for (int _i = 0; _i < frag.size(); ++_i) { \
for (size_t _i = 0; _i < frag.size(); ++_i) { \
assert(std::isfinite(float(frag[_i]))); \
assert(!std::isnan(float(frag[_i]))); \
} \
@ -147,7 +147,7 @@ constexpr __string_view __get_type_name() {
{ \
auto typeStr = __get_type_name<decltype(frag)>(); \
PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \
for (int _start = 0; _start < frag.size(); _start += 8) { \
for (size_t _start = 0; _start < frag.size(); _start += 8) { \
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
} \
/*__syncthreads(); \

View File

@ -167,58 +167,39 @@ public:
// Data members
//
GemmCoord *problem_sizes0;
GemmCoord *problem_sizes1;
GemmCoord *problem_sizes0{nullptr};
GemmCoord *problem_sizes1{nullptr};
int problem_count;
int threadblock_count;
int problem_count{0};
int threadblock_count{0};
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementV ** ptr_V;
ElementO ** ptr_O;
ElementOAccum ** ptr_O_accum;
ElementQ ** ptr_Q{nullptr};
ElementK ** ptr_K{nullptr};
ElementP ** ptr_P{nullptr};
ElementV ** ptr_V{nullptr};
ElementO ** ptr_O{nullptr};
ElementOAccum ** ptr_O_accum{nullptr};
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutO::Stride::LongIndex *ldo;
// Scale
ElementAccumulator scale;
typename LayoutQ::Stride::LongIndex *ldq{nullptr};
typename LayoutK::Stride::LongIndex *ldk{nullptr};
typename LayoutP::Stride::LongIndex *ldv{nullptr};
typename LayoutO::Stride::LongIndex *ldo{nullptr};
// Whether causal masking is to be performed
bool causal;
bool causal{false};
// Scale
ElementAccumulator scale{0};
// Only used by device-level operator
GemmCoord *host_problem_sizes;
GemmCoord *host_problem_sizes{nullptr};
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_O_accum(nullptr),
ldq(nullptr),
ldk(nullptr),
ldv(nullptr),
ldo(nullptr),
scale(0),
causal(false),
host_problem_sizes(nullptr)
{
}
/// Default ctor
Arguments() = default;
/// Ctor
CUTLASS_HOST_DEVICE

View File

@ -286,7 +286,7 @@ struct Options {
// Number of real-valued multiply-adds
int64_t fops = int64_t();
for (int i = 0; i < problem_sizes0.size(); ++i) {
for (size_t i = 0; i < problem_sizes0.size(); ++i) {
auto const& problem0 = problem_sizes0[i];
auto const& problem1 = problem_sizes1[i];
for (int row = 0; row < problem0.m(); ++row) {

View File

@ -340,7 +340,7 @@ struct Options {
// Number of real-valued multiply-adds
int64_t fops = int64_t();
for (int i = 0; i < problem_sizes0.size(); ++i) {
for (size_t i = 0; i < problem_sizes0.size(); ++i) {
auto const& problem0 = problem_sizes0[i];
auto const& problem1 = problem_sizes1[i];

View File

@ -244,11 +244,13 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
CUTLASS_DEVICE
bool set_prologue_done(bool value) {
prologue_done_ = value;
return true;
}
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) {
zero_outside_bounds_ = value;
return true;
}
template <bool kLoadA = true, bool kLoadB = true>

View File

@ -1446,7 +1446,7 @@ struct AttentionBackwardKernel {
uint8_t lane_id) {
cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
dropout_keep_mask_doivj;
dropout_keep_mask_doivj.fill(1);
dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
const float dropout_scale =
kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;
@ -1744,7 +1744,7 @@ struct AttentionBackwardKernel {
[&](int accum_m) {},
[&](int accum_m /*q*/, int accum_n /*k*/, int idx) {
if (zij.at({accum_n, accum_m}) == scalar_t(0)) {
dropout_keep_mask_doivj[idx] = cutlass::uint1b_t(0);
dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0};
}
},
[&](int accum_m) {});

View File

@ -40,7 +40,6 @@
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"

View File

@ -452,7 +452,7 @@ public:
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage));
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
cudaDeviceProp properties;
int device_idx;
@ -509,7 +509,7 @@ public:
);
// Initialize the GEMM object
Gemm gemm;
Gemm gemm{};
result.status = gemm.initialize(args);

View File

@ -102,7 +102,8 @@ gett_kernel(
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
TileShape, Shape<_1,_2,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

View File

@ -289,7 +289,8 @@ struct ExampleRunner
ElementAccumulator,
Shape<_128,_128,_64>,
Shape<_2,_2,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOpt::SharedStorage)>,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename EpilogueOpt::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

View File

@ -39,6 +39,11 @@
#include "gather_tensor.hpp"
namespace cutlass {
///Forward declaration
struct CudaHostAdapter;
}
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
@ -143,10 +148,10 @@ public:
// Kernel entry point API
struct Params {
GemmUniversalMode mode;
ProblemShape problem_shape;
MainloopParams mainloop;
EpilogueParams epilogue;
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
GatherA gather_A{};
GatherB gather_B{};
};
@ -191,14 +196,15 @@ public:
}
static
int
size_t
get_workspace_size(Arguments const& args) {
return 0;
}
static
cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter* cuda_adapter = nullptr) {
return Status::kSuccess;
}

View File

@ -39,7 +39,7 @@
#include "cutlass/epilogue/collective/detail.hpp"
#include "cute/tensor.hpp"
#include "cute/numeric/int.hpp"
#include "cute/numeric/numeric_types.hpp"
#include "gather_tensor.hpp"

View File

@ -393,7 +393,8 @@ private:
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
@ -403,7 +404,8 @@ private:
ElementB, StrideBPermute, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
TileShapePermute, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpiloguePermute::SharedStorage)>,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpiloguePermute::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

View File

@ -37,7 +37,7 @@
#include "cutlass/layout/matrix.h"
#include "cutlass/tensor_view.h"
#include "cutlass/fast_math.h"
#include "cute/numeric/uint128.hpp"
#include "cute/numeric/numeric_types.hpp"
namespace example
{

View File

@ -0,0 +1,34 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
58_ada_fp8_gemm
ada_fp8_gemm.cu
)

View File

@ -0,0 +1,826 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of running an Ada FP8 GEMM.
In addition to using FP8 Tensor Core instructions, the Ada FP8 GEMM uses a distinct epilogue
that enables additional scaling of operands/outputs, storing a pre-activation-function output
tensor (called the "auxiliary" output), and computing the absolute maximum value of the
outputs.
Pseudocode for this epilogue is as follows:
Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
D = activation(Aux)
if Aux is fp8 type:
abs_max_output = max( abs(aux) | (for every aux in Aux))
Aux = scale_aux * Aux
endif
if D is fp8 type:
abs_max_output = max( abs(d) | (for every d in D))
D = scale_d * D
endif
Parameter Aux is optionally stored to global memory
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h"
#include "cutlass/gemm/device/gemm_universal_with_absmax.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementOutput = cutlass::float_e4m3_t;
using ElementAuxOutput = ElementOutput;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
static int const kStages = 3;
static int const kAlignmentA = 16;
static int const kAlignmentB = 16;
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax<
cutlass::epilogue::thread::ReLu,
ElementOutput,
ElementAuxOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>;
template <typename MathOperator>
using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax<
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89,
cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages,
kAlignmentA, kAlignmentB, MathOperator
>;
using ElementAbsmax = typename EpilogueOutputOp::ElementAbsmax;
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
cutlass::gemm::GemmCoord problem_size;
int iterations;
int warmup_iterations;
bool scale_A;
bool scale_B;
bool scale_C;
float alpha;
float beta;
Options():
help(false),
error(false),
reference_check(false),
iterations(20),
warmup_iterations(5),
scale_A(true),
scale_B(true),
scale_C(true),
alpha(1.f),
beta(0.f)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, 5);
cmd.get_cmd_line_argument("reference-check", reference_check, false);
cmd.get_cmd_line_argument("scale-A", scale_A, true);
cmd.get_cmd_line_argument("scale-B", scale_B, true);
cmd.get_cmd_line_argument("scale-C", scale_C, true);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
int m, n, k;
cmd.get_cmd_line_argument("m", m, 1024);
cmd.get_cmd_line_argument("n", n, 1024);
cmd.get_cmd_line_argument("k", k, 1024);
problem_size = cutlass::gemm::GemmCoord{m, n, k};
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "58_ada_fp8_gemm\n\n"
<< " This example executes a GEMM using Ada FP8 Tensor Core operations. In addition to performing\n"
<< " a normal GEMM, the kernel performs the following operations:\n"
<< " Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias\n"
<< " D = activation(Aux)\n\n"
<< " if Aux is fp8:\n"
<< " abs_max_output = max( abs(aux) | (for every aux in Aux) )\n"
<< " Aux = scale_aux * Aux\n\n"
<< " if D is fp8 type:\n"
<< " abs_max_output = max( abs(d) | (for every d in D) )\n"
<< " D = scale_d * D\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M dimension of the GEMM\n"
<< " --n=<int> Sets the N dimension of the GEMM\n"
<< " --k=<int> Sets the K dimension of the GEMM\n"
<< " --scale-A=<bool> Whether to apply a scaling factor to operand A (default: true)\n"
<< " --scale-B=<bool> Whether to apply a scaling factor to operand B (default: true)\n"
<< " --scale-C=<bool> Whether to apply a scaling factor to operand C (default: true)\n"
<< " --iterations=<int> Number of profiling iterations to perform\n"
<< " --warmup-iterations=<int> Number of warmup iterations to perform\n"
<< " --reference-check=<bool> If true, performs reference check\n";
return out;
}
/// Compute performance in GFLOP/s
float gflops(float runtime_s) const {
// Two flops per multiply-add
return 2.0f * float(problem_size.product()) / float(1.0e9) / runtime_s;
}
};
/// Helper class to run the kernel
template <typename Gemm>
struct TestbedRunner {
using ElementAccumulator = typename Gemm::ElementAccumulator;
using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute;
using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor;
static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded;
static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
cutlass::HostTensor<typename Gemm::ElementC, typename Gemm::LayoutC> tensor_C;
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementAuxOutput, typename Gemm::LayoutC> tensor_Aux;
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, typename Gemm::LayoutC> tensor_D;
cutlass::HostTensor<typename Gemm::ElementC, typename Gemm::LayoutC> tensor_Vector;
cutlass::HostTensor<ElementAccumulator, typename Gemm::LayoutC> tmp_D;
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, typename Gemm::LayoutC> reference_D;
cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementAuxOutput, typename Gemm::LayoutC> reference_Aux;
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_A;
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_B;
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_C;
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_D;
cutlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_Aux;
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> abs_max_Aux;
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> abs_max_D;
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> reference_abs_max_Aux;
cutlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> reference_abs_max_D;
//
// Methods
//
TestbedRunner(
bool scaleA = true,
bool scaleB = true,
bool scaleC = true,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize scaling factors
template <typename Element, typename Layout>
bool initialize_scale_factor(cutlass::TensorView<Element, Layout> view, uint64_t seed, int bits=0) {
cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits);
return true;
}
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
std::cerr << "Not implemented";
return false;
}
return true;
}
/// Initializes data structures
void initialize(const Options& options) {
//
// Allocate the GEMM workspace
//
tensor_A.resize(options.problem_size.mk());
tensor_B.resize(options.problem_size.kn());
tensor_C.resize(options.problem_size.mn());
tensor_D.resize(options.problem_size.mn());
tensor_Vector.resize({1, options.problem_size.n()});
reference_D.resize(options.problem_size.mn(), false);
tmp_D.resize(options.problem_size.mn(), false);
initialize_tensor(tensor_A.host_view(), init_A, seed + 2019);
initialize_tensor(tensor_B.host_view(), init_B, seed + 2018);
initialize_tensor(tensor_C.host_view(), init_C, seed + 2017);
initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020);
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
cutlass::Coord<2> origin(0);
tensor_A.host_view().at(origin) = typename Gemm::ElementA(1);
tensor_B.host_view().at(origin) = typename Gemm::ElementB(1);
tensor_C.host_view().at(origin) = typename Gemm::ElementC(1);
tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1);
cutlass::reference::host::TensorFill(tensor_D.host_view());
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_Vector.sync_device();
int scale_bits = 2;
if (options.scale_A) {
scale_A.resize({1, 1});
initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits);
scale_A.sync_device();
}
if (options.scale_B) {
scale_B.resize({1, 1});
initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits);
scale_B.sync_device();
}
if (options.scale_C) {
scale_C.resize({1, 1});
initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits);
scale_C.sync_device();
}
if (kScaleOutput) {
scale_D.resize({1, 1});
initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits);
scale_D.sync_device();
abs_max_D.resize({1, 1});
cutlass::reference::host::TensorFill(abs_max_D.host_view());
abs_max_D.sync_device();
reference_abs_max_D.resize({1, 1});
}
if (kScaleAux) {
tensor_Aux.resize(options.problem_size.mn());
cutlass::reference::host::TensorFill(tensor_Aux.host_view());
tensor_Aux.sync_device();
scale_Aux.resize({1, 1});
initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits);
scale_Aux.sync_device();
abs_max_Aux.resize({1, 1});
cutlass::reference::host::TensorFill(abs_max_Aux.host_view());
abs_max_Aux.sync_device();
reference_Aux.resize(options.problem_size.mn(), false);
reference_abs_max_Aux.resize({1, 1});
}
}
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(const Options& options) {
tensor_D.sync_host();
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
if (kScaleAux) {
tensor_Aux.sync_host();
abs_max_Aux.sync_host();
passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view());
passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view());
}
if (kScaleOutput) {
abs_max_D.sync_host();
passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view());
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
std::string output_file = "testbed_with_amax_errors.txt";
std::ofstream file(output_file);
file
<< "problem: " << options.problem_size
<< ", alpha: " << options.alpha << ", beta: " << options.beta << "\n\n";
file
<< "A =\n" << tensor_A.host_view()
<< "\nB =\n" << tensor_B.host_view()
<< "\nC =\n" << tensor_C.host_view()
<< "\nVector =\n" << tensor_Vector.host_view()
<< "\nScaleA = " << scale_A.host_view()
<< "\nScaleB = " << scale_B.host_view()
<< "\nScaleC = " << scale_C.host_view()
<< "\nScaleD = " << scale_D.host_view()
<< "\nScaleAux = " << scale_Aux.host_view()
<< "\n\nReference D =\n" << reference_D.host_view()
<< "\nComputed D =\n" << tensor_D.host_view();
if (kScaleAux) {
file
<< "\n\nReference Aux =\n" << reference_Aux.host_view()
<< "\nComputed Aux =\n" << tensor_Aux.host_view()
<< "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view()
<< "\nComputed Absmax Aux = " << abs_max_Aux.host_view();
}
if (kScaleOutput) {
file
<< "\n\nReference Absmax D = " << reference_abs_max_D.host_view()
<< "\nComputed Absmax D = " << abs_max_D.host_view();
}
std::cerr << "Dumped results to " << output_file << std::endl;
}
return passed;
}
/// Verifies the result is a GEMM
bool verify(const Options& options) {
cutlass::Coord<2> origin(0);
ElementCompute scaled_alpha = options.alpha;
if (options.scale_A) {
scaled_alpha *= scale_A.host_view().at(origin);
}
if (options.scale_B) {
scaled_alpha *= scale_B.host_view().at(origin);
}
ElementCompute scaled_beta = options.beta;
if (options.scale_C) {
scaled_beta *= scale_C.host_view().at(origin);
}
//
// Verify
//
cutlass::reference::host::GemmComplex<
typename Gemm::ElementA, typename Gemm::LayoutA,
typename Gemm::ElementB, typename Gemm::LayoutB,
typename Gemm::ElementC, typename Gemm::LayoutC,
ElementCompute, ElementAccumulator, ElementAccumulator
>(
options.problem_size,
scaled_alpha,
tensor_A.host_ref(),
Gemm::kTransformA,
tensor_B.host_ref(),
Gemm::kTransformB,
scaled_beta,
tensor_C.host_ref(),
tmp_D.host_ref(),
ElementAccumulator(0)
);
ElementCompute tmp_abs_max_Aux(0.);
ElementCompute tmp_abs_max_D(0.);
cutlass::NumericConverter<ElementCompute, typename Gemm::ElementC> cvt_c_to_compute;
cutlass::NumericConverter<ElementCompute, ElementAccumulator> cvt_accum_to_compute;
cutlass::NumericConverter<ElementAccumulator, ElementCompute> cvt_compute_to_accum;
cutlass::NumericConverter<typename Gemm::EpilogueOutputOp::ElementOutput, ElementCompute> cvt_compute_to_d;
cutlass::NumericConverter<typename Gemm::EpilogueOutputOp::ElementAuxOutput, ElementCompute> cvt_compute_to_aux;
cutlass::absolute_value_op<ElementCompute> abs;
cutlass::maximum_with_nan_propogation<ElementCompute> max;
cutlass::epilogue::thread::ReLu<ElementCompute> act;
ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.);
for (int m = 0; m < options.problem_size.m(); ++m) {
for (int n = 0; n < options.problem_size.n(); ++n) {
ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n}));
ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n}));
ElementCompute aux = intermediate + bias;
ElementCompute d = act(aux);
tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux);
tmp_abs_max_D = max(abs(d), tmp_abs_max_D);
reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale);
if (kScaleAux) {
reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin));
}
}
}
if (kScaleAux) {
reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_Aux);
}
if (kScaleOutput) {
reference_abs_max_D.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_D);
}
return compare_reference(options);
}
/// Returns true if the CUDA device is sufficient to execute the kernel.
bool sufficient() const {
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
std::cerr << "This example requires CUDA 12.4 or greater." << std::endl;
return false;
}
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
std::cerr << "cudaGetDevice() failed with error: " << cudaGetErrorString(result) << std::endl;
return false;
}
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() failed with error: " << cudaGetErrorString(result) << std::endl;
return false;
}
if (properties.major < 8 || (properties.major == 8 && properties.minor < 9)) {
std::cerr << "CUTLASS's Ada FP8 GEMM example requires a device of compute capability 89 or higher.\n" << std::endl;
return false;
}
if (properties.sharedMemPerBlockOptin < smem_size) {
std::cerr << "Insufficient shared memory. Need " << smem_size
<< ", but device only has " << properties.sharedMemPerBlockOptin << std::endl;
return false;
}
return true;
}
/// Executes one test
bool run(Options& options)
{
// Waive test if insufficient CUDA device
if (!sufficient()) {
std::cerr << "Insufficient resources to run the kernel." << std::endl;
return false;
}
this->initialize(options);
//
// Initialize the GEMM operator
//
typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{
ElementCompute(options.alpha),
ElementCompute(options.beta)
};
typename Gemm::EpilogueOutputOp::Params epilogue_params{
activation_params,
scale_A.device_data(),
scale_B.device_data(),
scale_C.device_data(),
scale_D.device_data(),
scale_Aux.device_data(),
abs_max_Aux.device_data(),
abs_max_D.device_data()
};
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
options.problem_size,
/* batch_count = */ 1,
epilogue_params,
tensor_A.device_data(),
tensor_B.device_data(),
tensor_C.device_data(),
tensor_D.device_data(),
tensor_Aux.device_data(),
tensor_Vector.device_data(),
options.problem_size.m() * options.problem_size.k(),
options.problem_size.n() * options.problem_size.k(),
options.problem_size.m() * options.problem_size.n(),
options.problem_size.m() * options.problem_size.n(),
(int)options.problem_size.m(), // Batch stride vector
tensor_A.layout().stride(0),
tensor_B.layout().stride(0),
tensor_C.layout().stride(0),
tensor_D.layout().stride(0),
(int64_t)0 // Leading dimension of vector. This must be 0
};
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::can_implement() failed" << std::endl;
return false;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::initialize() failed" << std::endl;
return false;
}
//
// Run the GEMM
//
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Gemm::run() failed" << std::endl;
return false;
}
cudaError_t cuda_error = cudaDeviceSynchronize();
if (cuda_error != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(cuda_error) << std::endl;
return false;
}
//
// Verify
//
bool passed = true;
if (options.reference_check) {
passed &= this->verify(options);
} else {
std::cout << "Skipped reference check" << std::endl;
}
//
// Warm up
//
for (int i = 0; i < options.warmup_iterations; ++i) {
gemm_op();
}
//
// Profile
//
cudaEvent_t events[2];
cudaError_t error;
for (auto & event : events) {
error = cudaEventCreate(&event);
if (error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(error) << std::endl;
return false;
}
}
// Record an event at the start of a series of GEMM operations
error = cudaEventRecord(events[0]);
if (error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(error) << std::endl;
return false;
}
// Run profiling loop
for (int iter = 0; iter < options.iterations; ++iter) {
gemm_op();
}
// Record an event when the GEMM operations have been launched.
error = cudaEventRecord(events[1]);
if (error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(error) << std::endl;
return false;
}
// Wait for work on the device to complete.
error = cudaEventSynchronize(events[1]);
if (error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(error) << std::endl;
return false;
}
// Measure elapsed runtime
float runtime_ms = 0;
error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(error) << std::endl;
return false;
}
// Compute average runtime and GFLOPs.
runtime_ms = runtime_ms / float(options.iterations);
float gflops = options.gflops(runtime_ms / 1000.0f);
std::cout << "Problem size: " << options.problem_size.m() << 'x' << options.problem_size.n() << 'x' << options.problem_size.k() << std::endl;
std::cout << "Runtime (ms): " << runtime_ms << std::endl;
std::cout << "GFLOPs/sec: " << gflops << std::endl;
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
return passed;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const** argv) {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) ||
(props.major != 8 && props.minor != 9)) {
//
// This example requires an NVIDIA Ada-architecture GPU.
//
std::cout
<< "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture "
<< "and CUDA toolkit version 12.4 or later.\n";
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
std::cout << "Running GEMM with staged accumulation (OpMultiplyAdd)" << std::endl;
std::cout << "=====================================================" << std::endl;
TestbedRunner<Gemm_<cutlass::arch::OpMultiplyAdd>> testbed_staged_accum;
bool passed = testbed_staged_accum.run(options);
if (passed) {
std::cout << "Passed" << std::endl;
} else {
std::cout << "Failed" << std::endl;
}
std::cout << "\nRunning GEMM with fast accumulation (OpMultiplyAddFastAccum)" << std::endl;
std::cout << "============================================================" << std::endl;
TestbedRunner<Gemm_<cutlass::arch::OpMultiplyAddFastAccum>> testbed_fast_accum;
passed = testbed_fast_accum.run(options);
if (passed) {
std::cout << "Passed" << std::endl;
} else {
std::cout << "Failed" << std::endl;
}
return 0;
}

View File

@ -0,0 +1,36 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
59_ampere_gather_scatter_conv
ampere_gather_scatter_conv.cu
)
if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND)
target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX)
endif()

View File

@ -0,0 +1,209 @@
# Example 59: Ampere gather/scatter convolution
CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel capable of operating on both affine and gather/scatter tensors.
Example executions:
```sh
./59_ampere_gather_scatter_conv
./59_ampere_gather_scatter_conv --n=108
./59_ampere_gather_scatter_conv --n=4096 --i=1
./59_ampere_gather_scatter_conv --n=1080 --i=1000
./59_ampere_gather_scatter_conv --n=131072 --i=1000 --no-check
```
This example demonstrates a few super cool features of CUTLASS and CuTe. It shows off
1. A dense conv 3D fprop kernel written as a single file ...
2. ... that leverages off-the-shelf CUTLASS collectives to show how custom kernels can use collectives ...
3. ... and uses the exact same templated kernel to also stamp out a gather/scatter 3D fprop conv ...
4. ... while getting near peak performance of the Ampere class tensor core on Ampere and Ada GPUs ...
5. ... by using static cute shapes and strides in case problem shapes are known at compile time.
## A dense conv 3D fprop kernel written in CUTLASS 3.x and CuTe
The most common strategy for implementing high performance convolution kernels on the GPU is to transform
the activation tensor in such a way that we can perform the computation as a GEMM. This is called the
image to column (im2col) transformation. [CUTLASS 2.x implementation of im2col based convolutions is
documented separately](../../media/docs/implicit_gemm_convolution.md), and here we consider a fresh approach for CuTe.
A 3D convolution has the following input tensors:
- Activation tensor (Act): `((N,(D,H,W)), (C,(1,1,1)))`
- Filter tensor (Flt): `( K, (C,(T,R,S)))`
- Output tensor (Out): `((N,(Z,P,Q)), K )`
Where
- N := number of images
- DHW := spatial dimensions of the activation tensor
- C := channel dimension of the activation tensor
- K := channel dimension of the filter and output tensor
- TRS := spoke dimensions of the filter tensor
- ZPQ := spatial dimensions of the output tensor
As is evident in the tensor shapes, these cannot be issued to a GEMM just yet, since there is no
logical M, N, and K modes we can group the tensor modes into.
Notice that every spoke of the filter tensor (TRS) will be applied to some (offset) view of the
activation tensor, thus expanding the logical size of the activation tensor.
Additionally, a similar logical transform of the spatial dimensions can be encoded as a function of the
padding, dilations, traversal strides, and filter spokes. This gets us to our im2col transform:
im2col transform affects the component shapes/strides of the activation tensor in the following way:
- ZPQ Shape : changes DHW domain with formula `(1 + (DHW + pad - (((TRS-1) * dilation) + 1)) / traversal_stride)`
- TRS Shape : TRS domain instead of `(1,1,1)`
- ZPQ Strides : Original DHW strides get `elem_scale()`-ed by traversal strides DHW
- TRS Strides : Original DHW strides get `elem_scale()`-ed by dilation DHW
With this transform applied, we end up with a set of input and output tensors that
are logically consistent in their MNK dimensions, thus allowing us to dispatch to a GEMM.
im2col activation layout: ((N,(Z,P,Q)), (C,(T,R,S))) // logical (M,K)
filter layout : ( K, (C,(T,R,S))) // logical (N,K)
output layout : ((N,(Z,P,Q)), K ) // logical (M,N)
CuTe's layout representation and algebra make these folded tensors easy to represent and manipulate.
This is most evident in the reference check code used in this example:
```cpp
for (size_t logical_m = 0; logical_m < size<0>(mOutputRef); ++logical_m) {
for (size_t logical_n = 0; logical_n < size<1>(mOutputRef); ++logical_n) {
auto accumulator = float(0);
for (size_t logical_k = 0; logical_k < size<1>(mStencil); ++logical_k) {
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
}
mOutputRef(logical_m, logical_n) = accumulator;
}
}
```
Which succinctly demonstrates how im2col transform allows us to implement convolutions
as GEMMs with special layout transformations on the input tensor.
Note: in the example kernel's implementation we treat activations as the B tensor
and filter as the A tensor, thus making their logical dimensions NK and MK respectively.
## Leveraging CUTLASS collectives off the shelf in a custom kernel
Now that we have transformed our problem in such a way that allows us to dispatch to a GEMM,
we can reuse much of the machinery CUTLASS offers to implement this forward pass convolution
operator. CUTLASS decomposes these "moving parts" of GPU linear algebra into reusable,
modular software components abstracted by C++ template classes. This example
demonstrates how some of the lower layers of the hierarchy can be re-used for custom kernels
by writing a custom kernel for convolution that re-uses the Ampere/Ada GEMM collectives
from CUTLASS 3.
A kernel author is free to compose their custom components with any of the existing templates
in the CUTLASS hierarchy to leverage existing high performance implementations from the CUTLASS
team. In this example, we write a custom kernel layer and compose with an existing collective.
However, any of the CUTLASS kernels can be composed with bespoke collectives if the desired
customization is a mainloop or epilogue fusion without changes to the grid planning,
tile scheduling, load balancing, or thread marshalling.
## Implementing gather/scatter and dense convolution with the same kernel
Functionality and correctness of the implemented kernel, as a virtue of using
CuTe and off the shelf CUTLASS collectives, only relies on the logical consistency of
the layouts of input and output tensors. This means that we can freely change how
the logical coordinates of the tensors map into the index space, and even how these dereferences
happen. [CUTLASS example 52](../52_hopper_gather_scatter_fusion/) demonstrates this by implementing a custom stride that
supports indexed indirection for tensor data accesses. This allows for example 52
to implement a GEMM where inputs are gathered and output is scattered based on an index buffer.
We re-use the same custom stride utilities in this example to implement a convolution kernel
that gathers along the NDHW dimensions of the activation tensor and scatters the output along the
NZPQ dimensions of the output tensor, treating the channel dimensions as the dense vectors.
Our dense affine im2col transformed activation tensor:
```cpp
// im2col transformed activation layout: ((nzpq), (ctrs)) => idx
auto xformed_act_layout = make_layout(
make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)),
make_stride(make_stride(D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, H*W*C, W*C, C)));
```
now becomes a composed layout that uses `IndexedGather`:
```cpp
// Inner layout of the composition:
// ((nzpq), (csrt)) => (idx_buffer_idx, dense_offset)
auto EG = E<0>{}; // Gather basis (1,0) (idx_buffer_idx)
auto EC = E<1>{}; // Contiguous basis (0,1) (dense_offset)
auto xformed_act_logical_inner = make_layout(
make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)),
make_stride(make_stride(D*H*W*EG, H*W*EG, W*EG, EG), make_stride(EC, H*W*EG, W*EG, EG)));
// Outer layout of the composition:
// (idx_buffer_idx, dense_offset) => idx
// IndexedGather obtains idx by applying (gmem_base_ptr + gather_idx_buf[idx_buffer_idx] + dense_offset)
auto xformed_act_gather_outer = make_layout(
make_shape(_1{},_1{}),
make_stride(CustomStride{IndexedGather{gather_idx_buf}, C}, _1{}));
// Compose the inner and outer layouts
// ((nzpq), (ctrs)) => idx
auto xformed_act_composed_layout = composition(
xformed_act_gather_outer,
make_arithmetic_tuple(_0{}, _0{}),
xformed_act_logical_inner);
```
Here, we create a composed layout whose inner layout has the same logical MK shape as earlier,
but with an outer layout that uses the custom strides with an index buffer to access memory with
indirections. A custom stride requires two inputs to compute the index that a certain coordinate maps to:
the index buffer offset and the dense offset into the vector. This entails that our inner layout
(the one with the logical MK shape) has a rank-2 codomain `(idx_buffer_idx, dense_offset)`.
We can set up such a layout with scaled basis strides, which allow us to map a domain onto a
codomain with multiple orthogonal bases. The two codomain basis are the
index buffer offsets (rank 0 basis), and the dense vector offsets (rank 1 basis).
A similar composed layout is set up for the output scatter tensor.
This tensor still has a logical MK shape and is backed by a CuTe layout, which means we can still
tile, partition, and otherwise manipulate it with CuTe's layout algebra in the same way we would any
other tensor. Substituting the activation tensor's affine layout for this gather layout requires
no changes to the implementation of the kernel whatsoever. Everything composes. This example
stamps out a dense 3D convolution as well as gather/scatter 3D convolution using the same kernel template,
with the only difference between them being the layouts of the input and output tensors.
Convolutions are just a special case of tensor contractions, and as [example 51](../51_hopper_gett)
demonstrates, the exact same collective used in this example can also be used to implement arbitrary GETTs.
Of course, this also means that the same kernel can implement gather/scatter GETTs as well!
This demonstrates the composition power of not just CuTe, but also CUTLASS 3's two level
micro kernel abstraction. A single highly tuned temporal micro-kernel (collective) can be implemented once
and applied to compute dense GETTs, gather/scatter GETTs, dense convolutions, and gather/scatter convolutions.
## Peak performance on Ampere and Ada GPUs by leveraging domain specific knowledge
Often, when implementing custom kernels, a user has more knowledge of the problem domain that can be
exploited to deliver higher performance than otherwise could be through general kernels. In this example
we presume that the shape of each of the images (DHWC dimensions) as well as the filter (TRS) are available
a-priori and that the tile shape evenly divides the problem. Number of images (N) is still left as a runtime
parameter.
Knowing the extents of our tensors at compile time allows us to encode them as static cute shapes rather than
a dynamic problem shape, resulting in the elimination of most of the index computation instructions such as
expensive div/mods. Knowing that the problem shape is divisible by the tile shape allows us to use the
Ampere collective that does not perform predication on global memory loads, further reducing overheads
and allowing us to achieve near peak performance on RTX Ampere and Ada GPUs.
Running this example on an RTX 3080Ti prints the following performance numbers (some output culled for brevity):
```
$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.
Allocating tensors ... done.
Initializing data ... done.
Initializing gather/scatter index buffers ... done.
Running dense fprop kernel
Conv TFLOP count = 0.927713
Conv dense perf: 31.027376ms | TFLOP/s = 29.899819
Running gather/scatter fprop kernel
Conv TFLOP count = 0.927713
Conv gather/scatter perf: 28.973721ms | TFLOP/s = 32.019117
```
With this in mind, this example kernel has the following limitations:
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already

View File

@ -0,0 +1,320 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_atom.hpp"
#include <random>
#include "cutlass/util/print_error.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_mma.hpp"
using namespace cute;
struct AmpereUnpredicatedFprop {
//
// Static config for conv problem shape
//
using D = _6;
using H = _4;
using W = _4;
using T = _3;
using R = _3;
using S = _3;
using Z = _4;
using P = _2;
using Q = _2;
using C = _64;
using K = _128;
// Tiler config
using Tiler_K = decltype(cute::min(K{}, _128{}));
using Tiler_C = decltype(cute::min(C{}, _32{}));
using Tiler_N = _4;
using TileM = Tiler_K;
using TileN = Shape<Tiler_N, Z, P, Q>;
using TileK = Shape<Tiler_C,_1,_1,_1>;
using PIPE = _3;
using TilerFlt = Shape<TileM, TileK>;
using TilerAct = Shape<TileN, TileK>;
using TilerOut = Shape<TileM, TileN>;
using TileSizeM = Int<size(TileM{})>;
using TileSizeN = Int<size(TileN{})>;
using TileSizeK = Int<size(TileK{})>;
static constexpr int Stages = PIPE::value;
using ElementFlt = tfloat32_t;
using ElementAct = tfloat32_t;
using ElementOut = float;
using TiledMma = TiledMMA<
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
Layout<Shape<_2,_2,_1>>,
Tile<_32,_32,Underscore>>;
static constexpr int MaxThreadsPerBlock = size(TiledMma{});
static constexpr int MinBlocksPerMultiprocessor = 1;
union SharedStorage {
struct {
ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})];
ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})];
} mainloop;
struct {
ElementOut sCMatrix[size(TileM{}) * size(TileN{})];
} epilogue;
};
//
// Stencil tensor
//
using GmemLayoutFlt = decltype(make_ordered_layout(
Shape< K, Shape< C, T, R, S>>{},
tuple<_4, tuple<_0,_3,_2,_1>>{}));
// We have 64 elements * 32b each in the major mode that we can vectorize
// Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4
// Rest along the minor mode
using GmemTiledCopyFlt = decltype(make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementFlt>{},
Layout<Shape <_16, _8>,
Stride< _8, _1>>{},
Layout<Shape < _1, _4>>{}));
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
// using SmemLayoutFlt = decltype(
// composition(Swizzle<3,2,3>{},
// make_ordered_layout(
// Shape<TileSizeM,TileSizeK,PIPE>{},
// tuple< _1, _0, _2>{})));
using SmemLayoutAtomFlt = decltype(
composition(Swizzle<1,2,3>{},
Layout<Shape <_8,Shape <_4, _2>>,
Stride<_4,Stride<_1,_32>>>{}));
using SmemCopyAtomFlt = Copy_Atom<SM75_U32x4_LDSM_N, ElementFlt>;
//
// Activation tensor
//
// Activation tensor is major in the contraction mode, so vectorize that mode first
// Then lay out the rest of the threads along the other mode
using GmemTiledCopyAct = decltype(make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementAct>{},
Layout<Shape <_16, _8>,
Stride< _8, _1>>{},
Layout<Shape < _1, _4>>{}));
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
// using SmemLayoutAct = decltype(
// composition(Swizzle<3,2,3>{},
// make_ordered_layout(
// Shape<TileSizeN,TileSizeK,PIPE>{},
// tuple< _1, _0, _2>{})));
using SmemLayoutAtomAct = decltype(
composition(Swizzle<1,2,3>{},
Layout<Shape <_8,Shape <_4, _2>>,
Stride<_4,Stride<_1,_32>>>{}));
using SmemCopyAtomAct = Copy_Atom<SM75_U32x4_LDSM_N, ElementAct>;
//
// Output tensor
//
using GmemTiledCopyOut = decltype(make_tiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, ElementAct>{},
Layout<Shape <_8, _16>,
Stride<_1, _8>>{},
Layout<Shape <_4, _1>>{}));
using SmemCopyAtomOut = Copy_Atom<UniversalCopy<uint32_t>, ElementOut>;
// This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability
using SmemLayoutOut = Layout<Shape<TileSizeM, TileSizeN>>;
//
// Conv functor
//
template <class EngineFlt, class TensorActivation, class TensorOutput>
void __device__
operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K, (C,T,R,S))
TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S))
TensorOutput mOut, // ( K, (N,Z,P,Q))
char* smem_buf) const {
using namespace cute;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
Shape<TileM,TileN,TileK>,
ElementFlt,
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
ElementAct,
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
TiledMma,
GmemTiledCopyFlt,
SmemLayoutAtomFlt,
SmemCopyAtomFlt,
cute::identity,
GmemTiledCopyAct,
SmemLayoutAtomAct,
SmemCopyAtomAct,
cute::identity>;
TiledMma tiled_mma;
Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
clear(accum);
// Set up tensors
// NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k')
Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1)
Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n')
// Compute m_coord and n_coord with their post-tiled shapes
auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k')
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1)
Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N)
auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
int k_tile_count = size<2>(gA);
CollectiveMainloop collective_mma;
collective_mma(
accum,
gA,
gB,
accum,
k_tile_iter, k_tile_count,
Underscore{}, // no residue since we do not support predication
threadIdx.x,
smem_buf);
//
// Epilogue
//
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{});
auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma);
auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x);
auto tCrC = smem_thr_copy_C.retile_S(accum);
auto tCsC = smem_thr_copy_C.partition_D(sC);
copy(smem_tiled_copy_C, tCrC, tCsC);
__syncthreads();
GmemTiledCopyOut gmem_tiled_copy_C;
auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x);
auto tDsC = gmem_thr_copy_C.partition_S(sC);
auto tDgC = gmem_thr_copy_C.partition_D(gC);
copy(gmem_tiled_copy_C, tDsC, tDgC);
#if 0
if (thread0()) {
print("mAct = "); print(mAct); print('\n');
print("mFlt = "); print(mFlt); print('\n');
print("mOut = "); print(mOut); print('\n');
print("gA = "); print(gA); print('\n');
print("gB = "); print(gB); print('\n');
print("gC = "); print(gC); print('\n');
print("sA = "); print(sA.layout()); print('\n');
print("sB = "); print(sB.layout()); print('\n');
print("sC = "); print(sC.layout()); print('\n');
print("tAgA = "); print(tAgA.layout()); print('\n');
print("tBgB = "); print(tBgB.layout()); print('\n');
print("tAsA = "); print(tAsA.layout()); print('\n');
print("tBsB = "); print(tBsB.layout()); print('\n');
print("tCsA = "); print(tCsA.layout()); print('\n');
print("tCsB = "); print(tCsB.layout()); print('\n');
print("tCrC = "); print(tCrC.layout()); print('\n');
print("tCsC = "); print(tCsC.layout()); print('\n');
print("tDsC = "); print(tDsC.layout()); print('\n');
print("tDgC = "); print(tDgC.layout()); print('\n');
print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n');
print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n');
print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n');
print("k_tile_count = "); print(size<2>(gA)); print('\n');
print("k_tile_iter = "); print(*k_tile_iter); print('\n');
print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n');
}
#endif
}
};
template <class TensorFlt, class TensorAct, class TensorOut>
inline int
fprop_reference(
TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S))
TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S))
TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q))
TensorOut mOutputRef) {
int32_t N = size<1,0>(mOutputRef);
int32_t Z = size<1,1>(mOutputRef);
int32_t P = size<1,2>(mOutputRef);
int32_t Q = size<1,3>(mOutputRef);
int32_t T = size<1,3>(mStencil);
int32_t R = size<1,2>(mStencil);
int32_t S = size<1,1>(mStencil);
int32_t C = size<1,0>(mStencil);
size_t K = static_cast<size_t>(size<0>(mOutputRef));
size_t NZPQ = static_cast<size_t>(size<1>(mOutputRef));
size_t CTRS = static_cast<size_t>(size<1>(mStencil));
#if defined(_OPENMP)
#pragma omp parallel for
#endif
for (size_t logical_m = 0; logical_m < K; ++logical_m) {
for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) {
auto accumulator = float(0);
for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) {
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
}
mOutputRef(logical_m, logical_n) = accumulator;
}
}
return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01);
}

View File

@ -0,0 +1,392 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel
capable of operating on both affine and gather/scatter tensors.
This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off
1. A dense conv 3D fprop kernel written as a single file ...
2. ... that leverages off the shelf CUTLASS collectives to show how custom kernels can use collectives ...
3. ... and uses the exact same templated kernel to also stamp out a gather/scatter 3D fprop conv ...
4. ... while getting near peak performance of the Ampere class tensor core on Ampere and Ada GPUs ...
5. ... by using static cute shapes and strides in case problem shapes are known at compile time.
Full documentation for this example can be found within the README.md file in this directory.
Example executions:
./59_ampere_gather_scatter_conv
./59_ampere_gather_scatter_conv --n=108
./59_ampere_gather_scatter_conv --n=4096 --i=1
./59_ampere_gather_scatter_conv --n=1080 --i=1000
./59_ampere_gather_scatter_conv --n=131072 --i=1000 --no-check
*/
#include <thrust/sequence.h>
#include <thrust/universal_vector.h>
#include "ampere_conv_kernel.h"
#include "gather_tensor.hpp"
#include "cutlass/util/command_line.h"
bool check_cuda_result(cudaError_t code, const char* file, int line) {
if (code == cudaSuccess) {
return true;
}
std::cerr << "CUDA error at (" << file << "," << line << ")\n\t" << unsigned(code) << " -- " << cudaGetErrorString(code) << "\n";
return false;
}
#define CHECK_CUDA(code) (check_cuda_result(code, __FILE__, __LINE__))
using namespace cute;
using example::IndexedGather;
using example::CustomStride;
template<class Operator, class FilterTensor, class ActivationTensor, class OutputTensor>
__global__
__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
void kernel_entrypoint(FilterTensor mFlt, ActivationTensor mAct, OutputTensor mOut) {
extern __shared__ char smem_buf[];
Operator op;
op(mFlt, mAct, mOut, smem_buf);
}
int ampere_dense_conv_fprop(
int num_images,
float* activations,
float* filter,
float* output,
float* output_ref,
int num_iterations = 1,
bool do_ref_check = true) {
auto D = typename AmpereUnpredicatedFprop::D{};
auto H = typename AmpereUnpredicatedFprop::H{};
auto W = typename AmpereUnpredicatedFprop::W{};
auto Z = typename AmpereUnpredicatedFprop::Z{};
auto P = typename AmpereUnpredicatedFprop::P{};
auto Q = typename AmpereUnpredicatedFprop::Q{};
auto C = typename AmpereUnpredicatedFprop::C{};
auto K = typename AmpereUnpredicatedFprop::K{};
auto S = typename AmpereUnpredicatedFprop::S{};
auto R = typename AmpereUnpredicatedFprop::R{};
auto T = typename AmpereUnpredicatedFprop::T{};
int N = num_images; // dynamic
if (num_images % int(typename AmpereUnpredicatedFprop::Tiler_N{}) != 0) {
printf("ERROR: Input image count must be evenly divisible by CTA tiler N.\n");
return 1;
}
// Tensor Activation: (n,d,h,w,c)::(?,6,4,4,64):(6144,1536,384,64,1)
auto activation_layout = make_layout(
make_shape (make_shape ( N, D, H, W), make_shape ( C, _1{},_1{},_1{})),
make_stride(make_stride(D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, _0{},_0{},_0{})));
auto xformed_act_layout = make_layout(
make_shape (make_shape(N, Z, P, Q), make_shape ( C, T, R, S)),
make_stride(stride<0>(activation_layout), make_stride(_1{}, H*W*C, W*C, C)));
// Tensor Filter : (k,c,s,r,t)::(128,3,3,3,64):(1728,576,192,64,1)
auto filter_layout = AmpereUnpredicatedFprop::GmemLayoutFlt{};
// Tensor Output : (n,z,p,q,k)::(?,4,2,2,128):(2048,1024,512,128,1)
auto output_layout = make_ordered_layout(
make_shape( K, make_shape( N, Z, P, Q)),
make_tuple(_0{}, make_tuple(_4{},_3{},_2{},_1{})));
Tensor mActivation = make_tensor(make_gmem_ptr(activations), activation_layout);
Tensor mXformedAct = make_tensor(make_gmem_ptr(activations), xformed_act_layout);
Tensor mFilter = make_tensor(make_gmem_ptr(filter), filter_layout);
Tensor mOutput = make_tensor(make_gmem_ptr(output), output_layout); // (K, (N,Z,P,Q))
Tensor mOutputRef = make_tensor(make_gmem_ptr(output_ref), output_layout);
print("xformed act layout ((N,Z,P,Q), (C,T,R,S)) = "); print(xformed_act_layout); print("\n");
cudaEvent_t start, stop;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
constexpr size_t smem_size = sizeof(typename AmpereUnpredicatedFprop::SharedStorage);
Tensor gOutput_mn = zipped_divide(mOutput, typename AmpereUnpredicatedFprop::TilerOut{}); // ((BLK_M, BLK_N), (m', n'))
dim3 lauch_grid {static_cast<uint32_t>(size<1,1>(gOutput_mn)), static_cast<uint32_t>(size<1,0>(gOutput_mn)), 1};
CHECK_CUDA(cudaFuncSetAttribute(
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedAct), decltype(mOutput)>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
CHECK_CUDA(cudaEventRecord(start));
for (int i = 0; i < num_iterations; ++i) {
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedAct), decltype(mOutput)>
<<<lauch_grid, AmpereUnpredicatedFprop::MaxThreadsPerBlock, smem_size>>>(
mFilter, mXformedAct, mOutput);
}
CHECK_CUDA(cudaEventRecord(stop));
CHECK_CUDA(cudaEventSynchronize(stop));
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
milliseconds /= float(num_iterations);
double tflop_count = (2 * double(size<0>(xformed_act_layout)) * double(size(filter_layout))) / double(1e12);
double tflops = tflop_count / (double(milliseconds) / double(1e3));
printf("Conv TFLOP count = %f\n", tflop_count);
printf("Conv dense perf: %fms | TFLOP/s = %f\n", milliseconds, tflops);
if (do_ref_check) {
printf("Running host reference check ...\n");
return fprop_reference(mFilter, mXformedAct, mOutput, mOutputRef);
}
else {
return 0;
}
}
int ampere_gather_scatter_conv_fprop(
int num_images,
float* activations,
uint32_t *gather_idx_buf,
float* filter,
float* output,
uint32_t *scatter_idx_buf,
int num_iterations = 1) {
auto D = typename AmpereUnpredicatedFprop::D{};
auto H = typename AmpereUnpredicatedFprop::H{};
auto W = typename AmpereUnpredicatedFprop::W{};
auto Z = typename AmpereUnpredicatedFprop::Z{};
auto P = typename AmpereUnpredicatedFprop::P{};
auto Q = typename AmpereUnpredicatedFprop::Q{};
auto C = typename AmpereUnpredicatedFprop::C{};
auto K = typename AmpereUnpredicatedFprop::K{};
auto S = typename AmpereUnpredicatedFprop::S{};
auto R = typename AmpereUnpredicatedFprop::R{};
auto T = typename AmpereUnpredicatedFprop::T{};
int N = num_images; // dynamic
if (N % int(typename AmpereUnpredicatedFprop::Tiler_N{}) != 0) {
printf("ERROR: Input image count must be evenly divisible by CTA tiler N. Got num_images = %d\n", N);
return 1;
}
// Tensor Filter : (k,c,s,r,t)::(128,3,3,3,64):(1728,576,192,64,1)
auto filter_layout = AmpereUnpredicatedFprop::GmemLayoutFlt{};
// Tensor Output : (n,z,p,q,k)::(?,4,2,2,128):(2048,1024,512,128,1)
auto output_layout = make_ordered_layout(
make_shape( K, make_shape( N, Z, P, Q)),
make_tuple(_0{}, make_tuple(_4{},_3{},_2{},_1{})));
// Input gather layout
// inner_layout(make_coord((nzpq), (csrt))) => (idx_buffer_idx, dense_c_idx)
auto EG = E<0>{}; // Gather basis (1,0) (idx_buffer_idx)
auto EC = E<1>{}; // Contiguous basis (0,1) (dense_offset)
auto xformed_act_logical_inner = make_layout(
make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)),
make_stride(make_stride(D*H*W*EG, H*W*EG, W*EG, EG), make_stride(EC, H*W*EG, W*EG, EG)));
// outer_layout(make_coord(idx_buffer_idx, dense_c_idx)) => idx
// IndexedGather obtains idx by applying (gmem_base_ptr + gather_idx_buf[idx_buffer_idx] + dense_offset)
auto xformed_act_gather_outer = make_layout(
make_shape(_1{},_1{}),
make_stride(CustomStride{IndexedGather{gather_idx_buf}, C}, _1{}));
// Compose the inner and outer layouts
// gather_composed(make_coord((nzpq), (csrt))) => idx
auto xformed_act_composed_layout = composition(
xformed_act_gather_outer,
make_arithmetic_tuple(_0{}, _0{}),
xformed_act_logical_inner);
// Output scatter layout
auto out_basis_stride = make_stride(
E<1>{},
make_stride(Z*P*Q*E<0>{}, P*Q*E<0>{}, Q*E<0>{}, _1{}*E<0>{})); // -> (crd0, crd1)
auto out_basis_layout = make_layout(shape(output_layout), out_basis_stride);
auto out_scatter_layout = make_layout(
make_shape(_1{},_1{}),
make_stride(CustomStride{IndexedGather{scatter_idx_buf}, K}, _1{}));
auto out_composed_layout = composition(
out_scatter_layout,
make_arithmetic_tuple(_0{},_0{}),
out_basis_layout);
Tensor mXformedActGather = make_tensor(make_gmem_ptr(activations), xformed_act_composed_layout);
Tensor mFilter = make_tensor(make_gmem_ptr(filter), filter_layout);
Tensor mOutputScatter = make_tensor(make_gmem_ptr(output), out_composed_layout); // (K, (N,Z,P,Q))
Tensor gOutput_mn = zipped_divide(mOutputScatter, typename AmpereUnpredicatedFprop::TilerOut{}); // ((BLK_M, BLK_N), (m', n'))
dim3 lauch_grid {static_cast<uint32_t>(size<1,1>(gOutput_mn)), static_cast<uint32_t>(size<1,0>(gOutput_mn)), 1};
constexpr size_t smem_size = sizeof(typename AmpereUnpredicatedFprop::SharedStorage);
print("xforemed gather layout ((N,Z,P,Q), (C,T,R,S)) = "); print(xformed_act_composed_layout); print("\n");
print("Output scatter layout ( K, (N,Z,P,Q)) = "); print(out_composed_layout); print("\n");
print("Filter layout ( K, (C,T,R,S)) = "); print(filter_layout); print("\n");
CHECK_CUDA(cudaFuncSetAttribute(
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedActGather), decltype(mOutputScatter)>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
cudaEvent_t start, stop;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start));
for (int i = 0; i < num_iterations; ++i) {
kernel_entrypoint<AmpereUnpredicatedFprop, decltype(mFilter), decltype(mXformedActGather), decltype(mOutputScatter)>
<<<lauch_grid, AmpereUnpredicatedFprop::MaxThreadsPerBlock, smem_size>>>(
mFilter, mXformedActGather, mOutputScatter);
}
CHECK_CUDA(cudaEventRecord(stop));
CHECK_CUDA(cudaEventSynchronize(stop));
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
milliseconds /= float(num_iterations);
double tflop_count = (2 * double(size<0>(xformed_act_logical_inner)) * double(size(filter_layout))) / double(1e12);
double tflops = tflop_count / (double(milliseconds) / double(1e3));
printf("Conv TFLOP count = %f\n", tflop_count);
printf("Conv gather/scatter perf: %fms | TFLOP/s = %f\n", milliseconds, tflops);
return 0;
}
int
main(int argc, char const** argv) {
cutlass::CommandLine cmd(argc, argv);
std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n";
if (cmd.check_cmd_line_flag("help")) {
std::cout
<< "Options:\n"
"\t--n=<int> Sets the number of images for the input activation tensor (dataset size). Default = 131072.\n"
"\t--i=<int> Sets the benchmarking repetitions. Default = 128.\n"
"\t--nocheck If specified, skips the reference check for dense kernel.\n"
"\t--help Displays this help message and exits.\n";
return 0;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major < 8) {
std::cerr << "This example requires an Ampere GPU or newer.\n";
return 0;
}
int num_images = 4320;
cmd.get_cmd_line_argument("n", num_images, 4320);
int num_iterations = 128;
cmd.get_cmd_line_argument("i", num_iterations, 128);
bool do_host_ref_check = not cmd.check_cmd_line_flag("no-check");
auto D = typename AmpereUnpredicatedFprop::D{};
auto H = typename AmpereUnpredicatedFprop::H{};
auto W = typename AmpereUnpredicatedFprop::W{};
auto Z = typename AmpereUnpredicatedFprop::Z{};
auto P = typename AmpereUnpredicatedFprop::P{};
auto Q = typename AmpereUnpredicatedFprop::Q{};
auto C = typename AmpereUnpredicatedFprop::C{};
auto K = typename AmpereUnpredicatedFprop::K{};
auto activation_layout = make_layout(
make_shape (make_shape (num_images, D, H, W), make_shape ( C, _1{},_1{},_1{})),
make_stride(make_stride( D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, _0{},_0{},_0{})));
auto filter_layout = typename AmpereUnpredicatedFprop::GmemLayoutFlt{};
auto output_layout = make_ordered_layout(
make_shape( K, make_shape(num_images, Z, P, Q)),
make_step (_0{}, make_step ( _4{},_3{},_2{},_1{})));
print("Filter layout ( K, (C,T,R,S)) = "); print(filter_layout); print("\n");
print("Activation layout ((N,D,H,W), (C,1,1,1)) = "); print(activation_layout); print("\n");
print("Output layout ( K, (N,Z,P,Q)) = "); print(output_layout); print("\n");
// allocate tensors
std::cout << "Allocating tensors ... ";
thrust::universal_vector<float> activation_data(size_t(cute::size(activation_layout)), float(0));
thrust::universal_vector<float> filter_data(size_t(cute::size(filter_layout)), float(0));
thrust::universal_vector<float> output_data(size_t(cute::size(output_layout)), float(0));
thrust::universal_vector<float> output_data_ref(size_t(cute::size(output_layout)), float(0));
std::cout << "done.\n";
// init tensors
std::cout << "Initializing data ... " << std::flush;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> uniform_dist(-1.0, 1.0);
for (std::size_t i = 0; i < size_t(cute::size(activation_layout)); ++i) {
activation_data[i] = uniform_dist(gen);
}
for (std::size_t i = 0; i < size_t(cute::size(filter_layout)); ++i) {
filter_data[i] = uniform_dist(gen);
}
std::cout << "done.\n";
// set up index buffers for gather/scatter, fill with indireciton indices in reversed order
std::cout << "Initializing gather/scatter index buffers ... ";
thrust::universal_vector<uint32_t> gather_idx_buf(size_t(size<0>(activation_layout)));
thrust::universal_vector<uint32_t> scatter_idx_buf(size_t(size<1>(output_layout)));
thrust::sequence(gather_idx_buf.rbegin(), gather_idx_buf.rend());
thrust::sequence(scatter_idx_buf.rbegin(), scatter_idx_buf.rend());
std::cout << "done.\n";
// launch dense
std::cout << "\nRunning dense fprop kernel\n";
int passed = ampere_dense_conv_fprop(
num_images,
activation_data.data().get(),
filter_data.data().get(),
output_data.data().get(),
output_data_ref.data().get(),
num_iterations,
do_host_ref_check);
// launch gather/scatter
std::cout << "\nRunning gather/scatter fprop kernel\n";
ampere_gather_scatter_conv_fprop(
num_images,
activation_data.data().get(),
gather_idx_buf.data().get(),
filter_data.data().get(),
output_data.data().get(),
scatter_idx_buf.data().get(),
num_iterations);
return passed;
}

View File

@ -138,6 +138,8 @@ foreach(EXAMPLE
55_hopper_mixed_dtype_gemm
56_hopper_ptr_array_batched_gemm
57_hopper_grouped_gemm
58_ada_fp8_gemm
59_ampere_gather_scatter_conv
)
add_subdirectory(${EXAMPLE})

View File

@ -32,6 +32,7 @@
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
namespace example {
@ -59,7 +60,7 @@ struct IndexedGather
CUTE_HOST_DEVICE friend
void
print(IndexedGather const &s) {
print("Indexed");
cute::print("Indexed");
}
Index const *indices_;
@ -81,9 +82,9 @@ struct StridedGather
CUTE_HOST_DEVICE friend
void
print(StridedGather const &s) {
print("Strided{");
cute::print("Strided{");
print(s.stride_);
print("}");
cute::print("}");
}
Stride stride_;
@ -109,11 +110,11 @@ struct CustomStride
CUTE_HOST_DEVICE friend
void
print(CustomStride const & s) {
print("Custom{");
cute::print("Custom{");
print(s.func_);
print(",");
cute::print(",");
print(s.stride_);
print("}");
cute::print("}");
}
template<class Div>

View File

@ -29,8 +29,23 @@
cutlass_example_add_executable(
sgemm_nt_1
sgemm_nt_1.cu
sgemm_1
sgemm_1.cu
)
cutlass_example_add_executable(
sgemm_2
sgemm_2.cu
)
cutlass_example_add_executable(
sgemm_sm70
sgemm_sm70.cu
)
cutlass_example_add_executable(
sgemm_sm80
sgemm_sm80.cu
)
cutlass_example_add_executable(

View File

@ -0,0 +1,469 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class AThreadLayout,
class TB, class BStride, class BSmemLayout, class BThreadLayout,
class TC, class CStride, class CSmemLayout, class CThreadLayout,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
TC * C, CStride dC, CSmemLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
{
using namespace cute;
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
static_assert(is_static<AThreadLayout>::value);
static_assert(is_static<BThreadLayout>::value);
static_assert(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads
CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads
CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M
CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K
CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N
CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K
CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M
CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles
Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K
CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via projections of a ThreadLayout tC
// Partition sA (M,K) by the rows of tC
Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
// Partition sB (N,K) by the cols of tC
Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
// Partition gC (M,N) by the tile of tC
Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
// Allocate the accumulators -- same shape/layout as the partitioned data
Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N)
CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M
CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M
CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N
CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,
// and then computes on those tiles.
// copy(.) operates on the global and shared memory via the tA|tB partitioning
// gemm(.) operates on the shared and register memory via the tC partitioning
auto K_TILE_MAX = size<2>(tAgA);
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Copy gmem to smem with tA|tB thread-partitioned tensors
copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K)
copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K)
// TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to
// Tensor tAgAk = tAgA(_,_,k_tile);
// CUTE_UNROLL
// for (int i = 0; i < size(tAsA); ++i) {
// tAsA(i) = tAgAk(i);
// }
cp_async_fence(); // Label the end of (potential) cp.async instructions
cp_async_wait<0>(); // Sync on all (potential) cp.async instructions
__syncthreads(); // Wait for all threads to write to smem
// Compute gemm on tC thread-partitioned smem
gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
// CUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// CUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// CUTE_UNROLL
// for (int n = 0; n < size<1>(tCrC); ++n) {
// tCrC(m,n) += tCsA(m,k) * tCsB(n,k);
// }
// }
// }
__syncthreads(); // Wait for all threads to read from smem
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
// TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to
// CUTE_UNROLL
// for (int i = 0; i < size(tCsA); ++i) {
// tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);
// }
}
// Setup params for an NT GEMM
// Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx
dim3 dimBlock(size(tC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
// Setup params for a TN GEMM
// Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major
auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major
dim3 dimBlock(size(tC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
cute::device_init(0);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
return 0;
}

View File

@ -0,0 +1,523 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
using namespace cute;
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of partitioning via a TiledCopy
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
// Copy gmem to rmem for k_tile=0
copy(copy_a, tAgA(_,_,_,0), tArA);
copy(copy_b, tBgB(_,_,_,0), tBrB);
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via a TiledMMA
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// TUTORIAL: Example of an inner loop that pipelines compute with reads
// from global memory by staging through register and shared memory.
// Data is read from global to registers, then to shared via the TiledCopy partitions
// gemm(.) operates on the shared memory directly via the TiledMMA partitions
auto K_TILE_MAX = size<3>(tAgA);
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Copy rmem to smem with tA|tB thread-partitioned tensors
__syncthreads(); // Wait for all threads to consume smem
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads(); // Wait for all threads to consume smem
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
// TUTORIAL: The above call to copy(copy_a, tAgA(_,_,_,k_tile_next), tArA) is equivalent to
// CUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// CUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// copy_a.call(tAgA(_,m,k), tArA(_,m,k);
// }
// }
// Compute gemm on mma-partitioned smem
gemm(mma, tCsA, tCsB, tCrC);
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
// CUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// CUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// CUTE_UNROLL
// for (int n = 0; n < size<1>(tCrC); ++n) {
// mma.call(tCsA(_,m,k), tCsB(_,n,k), tCrC(_,m,n);
// }
// }
// }
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
// TUTORIAL: Construct TiledCopy with a particular Copy_Atom to use and
// define the partitioning pattern to apply.
// Each thread will (try to) copy 4x1 elements of type TA using 128-bit copy.
// Use 32x8 of these threads.
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TB>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
// TUTORIAL: Construct TiledMMA with a particular MMA_Atom to use and
// define the partitioning pattern to apply.
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 UniversalFMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
// TUTORIAL: Construct TiledCopy to define the Copy_Atom to use and the
// partitioning pattern to apply.
// Each thread will copy 1x1 elements of type TA.
// Use 32x8 of these threads arranged in k-major.
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
// TUTORIAL: Construct TiledMMA to define the MMA_Atom to use and the
// partitioning pattern to apply.
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
cute::device_init(0);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
return 0;
}

View File

@ -1,426 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#include "cutlass/util/helper_cuda.hpp"
template <class MShape, class NShape, class KShape,
class TA, class AStride, class ABlockLayout, class AThreadLayout,
class TB, class BStride, class BBlockLayout, class BThreadLayout,
class TC, class CStride, class CBlockLayout, class CThreadLayout,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(MShape M, NShape N, KShape K,
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
TC * C, CStride dC, CBlockLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
{
using namespace cute;
using X = Underscore;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
// Shared memory buffers
__shared__ TA smemA[cosize_v<ABlockLayout>];
__shared__ TB smemB[cosize_v<BBlockLayout>];
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K)
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K)
// Represent the full tensors
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K)
auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
// Get the appropriate blocks for this thread block --
// potential for thread block locality
auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K)
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB
// Default is a raked partition, but can be changed with Step<X,Y> parameter
auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
//
// Define C accumulators and A/B partitioning
//
// TUTORIAL: Example of partitioning via projections of tC
// Partition sA (M,K) by the rows of tC
auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
// Partition sB (N,K) by the cols of tC
auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
// Partition gC (M,N) by the tile of tC
auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
// Allocate the accumulators -- same size as the projected data
auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N)
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print("mA\n");
print(mA.shape()); print("\n"); print(mA.stride());
print("\n\ngA\n");
print(gA.shape()); print("\n"); print(gA.stride());
print("\n\ntAgA\n");
print(tAgA.shape()); print("\n"); print(tAgA.stride());
print("\n\nsA\n");
print(sA.shape()); print("\n"); print(sA.stride());
print("\n\ntAsA\n");
print(tAsA.shape()); print("\n"); print(tAsA.stride());
print("\n\n");
}
#endif
#if 0
if(thread0()) {
print("mB\n");
print(mB.shape()); print("\n"); print(mB.stride());
print("\n\ngB\n");
print(gB.shape()); print("\n"); print(gB.stride());
print("\n\ntBgB\n");
print(tBgB.shape()); print("\n"); print(tBgB.stride());
print("\n\nsB\n");
print(sB.shape()); print("\n"); print(sB.stride());
print("\n\ntBsB\n");
print(tBsB.shape()); print("\n"); print(tBsB.stride());
print("\n\n");
}
#endif
#if 0
if(thread0()) {
print("mC\n");
print(mC.shape()); print("\n"); print(mC.stride());
print("\n\ngC\n");
print(gC.shape()); print("\n"); print(gC.stride());
print("\n\ntCsA\n");
print(tCsA.shape()); print("\n"); print(tCsA.stride());
print("\n\ntCsB\n");
print(tCsB.shape()); print("\n"); print(tCsB.stride());
print("\n\ntCgC\n");
print(tCgC.shape()); print("\n"); print(tCgC.stride());
print("\n\ntCrC\n");
print(tCrC.shape()); print("\n"); print(tCrC.stride());
print("\n\n");
}
#endif
#if 1
// TUTORIAL: Example of a very simple compute loop
// Data is read from global to shared memory via the tA|tB partitioning
// gemm(.) operates on the shared memory directly via the tC partitioning
auto k_max = size<2>(tAgA);
for (int k = 0; k < k_max; ++k)
{
// Copy gmem to smem
copy(tAgA(_,_,k), tAsA);
copy(tBgB(_,_,k), tBsB);
// In case copy uses cp.async, make sure that the cp.async
// instructions are ordered with respect to other cp.async
// instructions (fence), then wait on all the outstanding copy
// operations (wait<0>()). __syncthreads() alone does not do
// this.
//
// NOTE: cp_async_wait<0>() currently issues cp.async.wait_all.
// This is equivalent to cp.async.commit_group followed by
// cp.async_wait_group 0. This should make the first
// cp_async_fence() (which also issues cp.async.commit_group)
// redundant. The tutorial works as-is, so we'll leave the
// redundant fence in for now and study its removal later.
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
// Compute gemm on smem
gemm(tCsA, tCsB, tCrC);
__syncthreads();
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
template <typename TA, typename TB, typename TC,
typename Alpha, typename Beta>
void
gemm(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, ldA);
auto dB = make_stride(Int<1>{}, ldB);
auto dC = make_stride(Int<1>{}, ldC);
// Define block sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
// Define the block layouts (static)
auto sA = make_layout(make_shape(bM,bK));
auto sB = make_layout(make_shape(bN,bK));
auto sC = make_layout(make_shape(bM,bN));
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
dim3 dimBlock(size(tC));
dim3 dimGrid(ceil_div(size(M), size(bM)),
ceil_div(size(N), size(bN)));
gemm_device
<<< dimGrid, dimBlock, 0, stream >>>
(M, N, K,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
#include <cstdlib>
#include <cstdio>
#include <cassert>
void test_gemm(int m, int n, int k)
{
cute::device_init(0);
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
using TA = float;
using TB = float;
using TC = float;
using TI = float;
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
TI alpha = 1.0;
TI beta = 0.0;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
//
// cuBLas
//
cublasHandle_t handle;
cublasCreate(&handle);
// Run once
d_C = h_C;
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
m, n, k,
&alpha,
d_A.data().get(), m,
d_B.data().get(), n,
&beta,
d_C.data().get(), m);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cublas_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
m, n, k,
&alpha,
d_A.data().get(), m,
d_B.data().get(), n,
&beta,
d_C.data().get(), m);
}
double cublas_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000);
#else
std::cout << "Verification by comparison with cuBLAS is disabled, "
"either because the CMake option CUTLASS_ENABLE_CUBLAS "
"was explicitly set to OFF, or because CMake could not find cuBLAS. "
"If you would like to enable verification with cuBLAS, "
"please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, "
"rerun CMake, and recompile this example.\n";
#endif // CUTLASS_ENABLE_CUBLAS
//
// CuTe
//
// Run once (and check)
d_C = h_C;
gemm(m, n, k,
alpha,
d_A.data().get(), m,
d_B.data().get(), n,
beta,
d_C.data().get(), m);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(m, n, k,
alpha,
d_A.data().get(), m,
d_B.data().get(), n,
beta,
d_C.data().get(), m);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100);
auto host_matrix_to_const_column_major_cute_tensor =
[](const auto& X, int num_rows, int num_cols, int LDX) {
const auto shape = cute::Shape<int, int>{num_rows, num_cols};
const auto strides = cute::Stride<int, int>{1, LDX};
return cute::make_tensor(X.data(), cute::make_layout(shape, strides));
};
const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m);
// B^T is k x n, so B is n x k.
const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n);
const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m);
const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m);
print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view);
#endif // CUTLASS_ENABLE_CUBLAS
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
test_gemm(m, n, k);
return 0;
}

View File

@ -0,0 +1,526 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
using namespace cute;
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of partitioning via a TiledCopy
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
// Copy gmem to rmem for k_tile=0
copy(copy_a, tAgA(_,_,_,0), tArA);
copy(copy_b, tBgB(_,_,_,0), tBrB);
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via a TiledMMA
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// Copy rmem to smem
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads();
//
// PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory AND register memory
// Data is read from global to registers, then to shared via the tA|tB partitions
// Data is then copied from shared to registers in multiple waves via the tC partitions
// and gemm(.) operates on the current register wave
//
// Load A, B shmem->regs for k_block=0
copy(tCsA(_,_,0), tCrA(_,_,0));
copy(tCsB(_,_,0), tCrB(_,_,0));
auto K_TILE_MAX = size<3>(tAgA);
auto K_BLOCK_MAX = size<2>(tCrA);
CUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Pipeline the k-mode of the block registers
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
// Copy rmem to smem
__syncthreads();
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads();
}
// Copy smem to rmem for k_block+1
int k_block_next = (k_block + 1) % K_BLOCK_MAX;
copy(tCsA(_,_,k_block_next), tCrA(_,_,k_block_next));
copy(tCsB(_,_,k_block_next), tCrB(_,_,k_block_next));
if (k_block == 0)
{
// Copy gmem to rmem for k_tile+1
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
} // k_block
} // k_tile
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TB>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major < 7) {
std::cout << "This example requires an Volta GPU or newer (CC >= 70)" << std::endl;
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
return 0;
}
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
return 0;
}

View File

@ -0,0 +1,567 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
using namespace cute;
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles across the threads
//
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
//
// PREFETCH
//
auto K_PIPE_MAX = size<3>(tAsA);
// Total count of tiles
int k_tile_count = size<3>(tAgA);
// Current tile index in gmem to read from
int k_tile_next = 0;
// Start async loads for all pipes but the last
CUTE_UNROLL
for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) {
copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe));
copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe));
cp_async_fence();
--k_tile_count;
if (k_tile_count > 0) { ++k_tile_next; }
}
//
// Define A/B partitioning and C accumulators
//
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrA : "); print(tCrA); print("\n");
print("tCrB : "); print(tCrB); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = K_PIPE_MAX-1;
// Pipe slice
Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read);
Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read);
// Size of the register pipeline
auto K_BLOCK_MAX = size<2>(tCrA);
// PREFETCH register pipeline
if (K_BLOCK_MAX > 1) {
// Wait until our first prefetched tile is loaded in
cp_async_wait<K_PIPE_MAX-2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{}));
copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{}));
}
//
// PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's cp.async instructions
// and explicit pipelines in shared memory.
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
// Data is computed on registers(b_block).
//
// This allows all copies and compute to overlap:
// Copy from gmem->smem can overlap with copies from smem->rmem and compute on rmem.
// Copy from smem->rmem can overlap with compute on rmem.
//
CUTE_NO_UNROLL
while (k_tile_count > -(K_PIPE_MAX-1))
{
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
// Slice the smem_pipe_read smem
tCsA_p = tCsA(_,_,_,smem_pipe_read);
tCsB_p = tCsB(_,_,_,smem_pipe_read);
// Commit the smem for smem_pipe_read
cp_async_wait<K_PIPE_MAX-2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next));
copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write));
copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write));
cp_async_fence();
// Advance the gmem tile
--k_tile_count;
if (k_tile_count > 0) { ++k_tile_next; }
// Advance the smem pipe
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK, bP)); // (m,k,p) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK, bP)); // (n,k,p) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{});// Val layout 4x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
Layout<Shape< _4,_1>>{});// Val layout 4x1 n-major
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
auto bP = Int<3>{}; // Pipeline
// Define the smem layouts (static)
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TA>, TA>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TB>, TB>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (props.major < 8) {
std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl;
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
return 0;
}
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
return 0;
}

View File

@ -67,7 +67,7 @@
//
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
template <class TensorS, class TensorD, class ThreadLayout>
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
{
using namespace cute;
@ -77,12 +77,13 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
// Construct a partitioning of the tile among threads with the given thread arrangement.
// Concept: Tensor Layout Index
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
// Concept: Tensor ThrLayout ThrIndex
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
// Construct a register-backed Tensor with the same shape as each thread's partition
auto fragment = make_fragment_like(thr_tile_S);
// Use make_tensor to try to match the layout of thr_tile_S
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(thr_tile_S, fragment);
@ -95,17 +96,17 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
/// has the precondition that pointers are aligned to the vector size.
///
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
{
using namespace cute;
using Element = typename TensorS::value_type;
// Slice the tensors to obtain a view into each tile.
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
// Define `AccessType` which controls the size of the actual memory access.
using AccessType = cutlass::AlignedArray<Element, size(shape(VecLayout{}))>;
using AccessType = cutlass::AlignedArray<Element, size(VecLayout{})>;
// A copy atom corresponds to one hardware memory access.
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
@ -125,29 +126,18 @@ __global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLa
// Construct a Tensor corresponding to each thread's slice.
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tile_S);
Tensor thr_tile_D = thr_copy.partition_D(tile_D);
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN)
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN)
// Construct a register-backed Tensor with the same shape as each thread's partition
auto fragment = make_fragment_like(thr_tile_D);
// Use make_fragment because the first mode is the instruction-local mode
Tensor fragment = make_fragment_like(thr_tile_D); // (CopyOp, CopyM, CopyN)
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(tiled_copy, fragment, thr_tile_D);
}
/// Helper to convert a shape to a dim3
template <class Shape>
dim3 shape_to_dim3(Shape shape)
{
using namespace cute;
CUTE_STATIC_ASSERT_V(rank(shape) <= Int<3>{});
auto result = append<3>(product_each(shape), 1u);
return dim3(get<0>(result), get<1>(result), get<2>(result));
}
/// Main function
int main(int argc, char** argv)
{
@ -161,13 +151,13 @@ int main(int argc, char** argv)
// Define a tensor shape with dynamic extents (m, n)
auto tensor_shape = make_shape(256, 512);
//
// Allocate and initialize
//
thrust::host_vector<Element> h_S(size(tensor_shape));
thrust::host_vector<Element> h_D(size(tensor_shape));
//
// Initialize
//
for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
@ -180,33 +170,36 @@ int main(int argc, char** argv)
// Make tensors
//
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
Tensor tensor_S = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), make_layout(tensor_shape));
Tensor tensor_D = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), make_layout(tensor_shape));
//
// Partition
// Tile tensors
//
// Define a statically sized block (M, N).
//
// Note, by convention, capital letters are used to represent static modes.
auto block_shape = make_shape(Int<128>{}, Int<64>{});
if ((get<0>(tensor_shape) % get<0>(block_shape)) || (get<1>(tensor_shape) % get<1>(block_shape))) {
if ((size<0>(tensor_shape) % size<0>(block_shape)) || (size<1>(tensor_shape) % size<1>(block_shape))) {
std::cerr << "The tensor shape must be divisible by the block shape." << std::endl;
return -1;
}
// Equivalent check to the above
if (not weakly_compatible(block_shape, tensor_shape)) {
std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl;
return -1;
}
// Tile the tensor (m, m) ==> ((M, N), m', n') where (M, N) is the static tile
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
// shape, and modes (m', n') correspond to the number of tiles.
//
// These will be used to determine the CUDA kernel grid dimensinos.
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape);
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape);
//
// These will be used to determine the CUDA kernel grid dimensions.
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
// Thread arrangement
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{}));
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{}));
// Vector dimensions
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
@ -215,16 +208,16 @@ int main(int argc, char** argv)
// Determine grid and block dimensions
//
dim3 gridDim = shape_to_dim3(select<1,2>(shape(tiled_tensor_D))); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(shape(thr_layout)));
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(thr_layout));
//
// Launch the kernel
//
copy_kernel_vectorized<<< gridDim, blockDim >>>(
tiled_tensor_S,
tiled_tensor_D,
thr_layout,
tiled_tensor_S,
tiled_tensor_D,
thr_layout,
vec_layout);
cudaError result = cudaDeviceSynchronize();

View File

@ -33,6 +33,7 @@
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/tensor_predicate.hpp>
namespace cute
{
@ -43,15 +44,17 @@ namespace cute
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout>
class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> && y)
Tensor<YEngine, YLayout> && y,
PrdTensor const& p = {})
{
return axpby(alpha, x, beta, y);
return axpby(alpha, x, beta, y, p);
}
//
@ -60,13 +63,15 @@ axpby(Alpha const& alpha,
template <class Alpha,
class XEngine, class XLayout,
class Beta,
class YEngine, class YLayout>
class YEngine, class YLayout,
class PrdTensor = TrivialPredTensor>
CUTE_HOST_DEVICE
void
axpby(Alpha const& alpha,
Tensor<XEngine, XLayout> const& x,
Beta const& beta,
Tensor<YEngine, YLayout> & y)
Tensor<YEngine, YLayout> & y,
PrdTensor const& p = {})
{
auto isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
@ -81,7 +86,9 @@ axpby(Alpha const& alpha,
CUTE_UNROLL
for (int i = 0; i < size(x); ++i) {
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
if (p(i)) {
y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i));
}
}
}

View File

@ -0,0 +1,196 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>
#include <cute/tensor_predicate.hpp>
namespace cute
{
// cooperative_copy<NumThreads, MaxVecBits>(thr_idx, src, dst)
// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits.
// @pre 0 <= @a tid < NumThreads
// @pre Tensors @a src and @a dst are aligned up to MaxVecBits.
//
template <uint32_t NumThreads, uint32_t MaxVecBits,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
// Assumes the shapes are static, can generalize
CUTE_STATIC_ASSERT_V(size(src) == size(dst));
// Assumes the types are the same, can generalize
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == sizeof_bits_v<typename DstEngine::value_type>);
static_assert(MaxVecBits == sizeof_bits_v<typename SrcEngine::value_type> ||
MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128,
"Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance.");
// Check that the tensors are likely shared across threads: either gmem or smem
static_assert((is_gmem<SrcEngine>::value || is_smem<SrcEngine>::value),
"cooperative_copy expects shared gmem or smem source tensor.");
static_assert((is_gmem<DstEngine>::value || is_smem<DstEngine>::value),
"cooperative_copy expects shared gmem or smem destination tensor.");
// Precondition on tid in DEBUG
assert(tid < NumThreads);
// Precondition on pointer alignment in DEBUG
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
//
// Determine val+thr vectorization based on src/dst size and number of threads
// NOTE: This heuristic promotes parallelization over vectorization
//
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
// The number of elements that can be vectorized in values
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
constexpr int common_bits = common_elem * elem_bits;
constexpr int total_elem = decltype(size(src))::value;
constexpr int total_bits = total_elem * elem_bits;
static_assert(total_bits % NumThreads == 0);
constexpr int total_bits_per_thr = total_bits / NumThreads;
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
// Convert back to number of elements, safe_div
static_assert((vec_bits % elem_bits) == 0);
constexpr int vec_elem = vec_bits / elem_bits;
// Use only part of threads if there's not enough work for all threads
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
? NumThreads
: (total_elem / vec_elem);
// The common layout of the two tensors that can be vectorized over threads
// vidx -> coord
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
get_nonswizzle_portion(dst.layout()));
// Scale up the common_layout to cover the entire tensors
// vidx -> coord
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
// Create the Tiler
// ((vid,tid),iter)
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
// Apply and slice
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
// Should account for vec_bits < 8 and/or vec_elem <= 1
// And also account for subbyte types, which could cause race conditions
// Want to ENFORCE sufficient vectorization in those cases
static_assert((vec_bits >= 8), "No support for subbyte copying");
using VecType = uint_bit_t<vec_bits>;
#if 0
if (thread0()) {
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
print(" "); print("src: "); print(src); print("\n");
print(" "); print("dst: "); print(dst); print("\n");
print(" "); print("common_layout: "); print(common_layout); print("\n");
print(" "); print("full_perm: "); print(full_perm); print("\n");
print(" "); print("Used vector: "); print(vec_elem); print("\n");
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
print(" "); print("src_v: "); print(src_v); print("\n");
print(" "); print("dst_v: "); print(dst_v); print("\n");
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
}
#ifdef __CUDA_ARCH__
__syncthreads();
#endif
#endif
// If we're using all threads (static) or the tid is in in-range (dynamic)
if (vec_thrs >= NumThreads or tid < vec_thrs) {
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
}
}
template <uint32_t NumThreads,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
constexpr uint32_t MaxVecBits = sizeof_bits_v<typename SrcEngine::value_type>;
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
}
// Accept mutable temporaries
template <uint32_t NumThreads,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return cooperative_copy<NumThreads>(tid, src, dst);
}
// Accept mutable temporaries
template <uint32_t NumThreads,
uint32_t MaxVecBits,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
}
} // end namespace cute

View File

@ -0,0 +1,326 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>
namespace cute
{
//
// Collective Shared-Memory GEMMs
//
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept and return value of type TA::value_type");
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept and return value of type TB::value_type");
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
// Compute the "residues"
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
// Shift the origin so k_residue is zeroth tile
sA.data() = &sA(0,k_residue);
sB.data() = &sB(0,k_residue);
#if 0
if (thread0()) {
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
}
#endif
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
#if 0
if (thread0()) {
print("rounded_sA: "); print(rounded_sA); print("\n");
print("rounded_sB: "); print(rounded_sB); print("\n");
print("rounded_sC: "); print(rounded_sC); print("\n");
}
#endif
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print("tCsA: "); print(tCsA); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCsC: "); print(tCsC); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
}
#endif
//
// PREDICATION
//
// Allocate the preds for only the MMA-mode of tCsA and tCsB
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
// Create coordinate tensors on a single compute block for predication
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
// Populate the m and n predicates
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
}
CUTE_UNROLL
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
}
#if 0
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
threadIdx.x,
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
#endif
//
// PREFETCH k_block = 0 (with k-predication)
//
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
}
}
}
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
}
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// static-if load the next k_block. No k-predication required on these loads.
if (k_block < K_BLOCK_MAX-1)
{
// Load the next k_block
int k_next = k_block + 1;
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsC); ++m)
{
CUTE_UNROLL
for (int n = 0; n < size<2>(tCsC); ++n)
{
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsC); ++i)
{
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
{
tCsC(i,m,n) = isBetaZero ? alpha * static_cast<TypeC>(tCrC(i,m,n)) : alpha * static_cast<TypeC>(tCrC(i,m,n)) + beta * static_cast<TypeC>(tCsC(i,m,n));
}
}
}
}
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op);
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
} // end namespace cute

View File

@ -145,10 +145,10 @@ copy_if(PrdTensor const& pred,
namespace detail {
// Trait that detects if atom's traits has a member function with(bool)
template<typename, typename Enable = void>
template <class, class Enable = void>
constexpr bool has_with_bool = false;
template<typename T>
template <class T>
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
} // end namespace detail

View File

@ -33,6 +33,7 @@
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/complex.hpp>
/** C++14 <functional> extensions */
@ -46,7 +47,7 @@ struct identity {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg);
return static_cast<T&&>(arg);
}
};
@ -69,7 +70,7 @@ struct constant_fn {
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return OP std::forward<T>(arg); \
return OP static_cast<T&&>(arg); \
} \
}
#define CUTE_RIGHT_UNARY_OP(NAME,OP) \
@ -77,7 +78,7 @@ struct constant_fn {
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return std::forward<T>(arg) OP ; \
return static_cast<T&&>(arg) OP ; \
} \
}
#define CUTE_NAMED_UNARY_OP(NAME,OP) \
@ -85,7 +86,7 @@ struct constant_fn {
template <class T> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& arg) const { \
return OP (std::forward<T>(arg)); \
return OP (static_cast<T&&>(arg)); \
} \
}
@ -115,7 +116,7 @@ struct shift_right_const {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg) >> Shift;
return static_cast<T&&>(arg) >> Shift;
}
};
@ -126,7 +127,7 @@ struct shift_left_const {
template <class T>
CUTE_HOST_DEVICE constexpr
decltype(auto) operator()(T&& arg) const {
return std::forward<T>(arg) << Shift;
return static_cast<T&&>(arg) << Shift;
}
};
@ -139,7 +140,7 @@ struct shift_left_const {
template <class T, class U> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
return std::forward<T>(lhs) OP std::forward<U>(rhs); \
return static_cast<T&&>(lhs) OP static_cast<U&&>(rhs); \
} \
}
#define CUTE_NAMED_BINARY_OP(NAME,OP) \
@ -147,7 +148,7 @@ struct shift_left_const {
template <class T, class U> \
CUTE_HOST_DEVICE constexpr \
decltype(auto) operator()(T&& lhs, U&& rhs) const { \
return OP (std::forward<T>(lhs), std::forward<U>(rhs)); \
return OP (static_cast<T&&>(lhs), static_cast<U&&>(rhs)); \
} \
}
@ -273,7 +274,7 @@ struct bound_fn {
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(T&& arg) {
return fn_(arg_, std::forward<T>(arg));
return fn_(arg_, static_cast<T&&>(arg));
}
Fn fn_;

View File

@ -252,7 +252,7 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
@ -451,6 +451,7 @@ gemm(MMA_Atom<MMA> const& mma,
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N)
make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K)
@ -496,245 +497,4 @@ gemm(MMA_Atom<MMA> const& mma,
}
}
//
// Collective Shared-Memory GEMMs
//
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */,
BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */)
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(is_same_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept and return value of type TA::value_type");
static_assert(is_same_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept and return value of type TB::value_type");
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
// Compute the "residues"
auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M]
auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N]
auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0]
// Shift the origin so k_residue is zeroth tile
sA.data() = &sA(0,k_residue);
sB.data() = &sB(0,k_residue);
#if 0
if (thread0()) {
printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M));
printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N));
printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K));
}
#endif
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X
Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K));
Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N));
#if 0
if (thread0()) {
print(rounded_sA.layout()); print("\n");
print(rounded_sB.layout()); print("\n");
print(rounded_sC.layout()); print("\n");
}
#endif
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print(tCsA.layout()); print("\n");
print(tCsB.layout()); print("\n");
print(tCsC.layout()); print("\n");
print(tCrA.layout()); print("\n");
print(tCrB.layout()); print("\n");
print(tCrC.layout()); print("\n");
}
#endif
//
// PREDICATION
//
// Allocate the preds for only the MMA-mode of tCsA and tCsB
Tensor tCpA = make_tensor<bool>(size<0>(tCsA));
Tensor tCpB = make_tensor<bool>(size<0>(tCsB));
// Create coordinate tensors on a single compute block for predication
Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k)
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k)
// Populate the m and n predicates
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue);
}
CUTE_UNROLL
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue);
}
#if 0
printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n",
threadIdx.x,
int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)),
int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0)));
#endif
//
// PREFETCH k_block = 0 (with k-predication)
//
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m
tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
}
}
}
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I
if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n
tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
}
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// static-if load the next k_block. No k-predication required on these loads.
if (k_block < K_BLOCK_MAX-1)
{
// Load the next k_block
int k_next = k_block + 1;
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m
tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n
tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int m = 0; m < size<1>(tCsC); ++m)
{
CUTE_UNROLL
for (int n = 0; n < size<2>(tCsC); ++n)
{
CUTE_UNROLL
for (int i = 0; i < size<0>(tCsC); ++i)
{
if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) &&
(n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue))
{
tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n);
}
}
}
}
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC)
{
gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */);
}
} // end namespace cute

View File

@ -0,0 +1,153 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include <cute/atom/copy_atom.hpp>
namespace cute
{
//
// Prefetch global tensors into L2
//
template <uint32_t NumThreads, uint32_t FetchBytes = 64,
class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
cooperative_prefetch(uint32_t const& tid,
Tensor<GEngine, GLayout> const& src)
{
static_assert(is_gmem<GEngine>::value, "Expected global tensor for prefetch");
constexpr int V = decltype(max_common_vector(src, src))::value;
if constexpr (V > 1) {
// L2 sector is 32B, default fetch granularity is 64B
using VecType = conditional_t<(V * sizeof_bits_v<typename GEngine::value_type>) < (FetchBytes * 8),
ArrayEngine<typename GEngine::value_type, V>,
uint8_t[FetchBytes] >;
Tensor src_v = recast<VecType const>(src);
CUTE_UNROLL
for (int i = tid; i < size(src_v); i += NumThreads) {
prefetch(raw_pointer_cast(&src_v(i)));
}
} else {
CUTE_UNROLL
for (int i = tid; i < size(src); i += NumThreads) {
prefetch(raw_pointer_cast(&src(i)));
}
}
}
template <class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
prefetch(Tensor<GEngine, GLayout> const& src)
{
return cooperative_prefetch<1>(0, src);
}
// Prefetch with copy atom
namespace detail {
template <class CopyOp, class = void>
constexpr bool has_prefetch = false;
template <class CopyOp>
constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
template <class CopyOp, class = void>
constexpr bool is_prefetch = false;
template <class CopyOp>
constexpr bool is_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = is_same_v<CopyOp, typename CopyOp::PREFETCH>;
} // end namespace detail
template <class CopyOp, class... CT_Args, class... CA_Args,
class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CA_Args...> const& atom,
Tensor<GEngine, GLayout> const& src)
{
if constexpr (detail::has_prefetch<CopyOp>) {
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CA_Args...>;
Prefetch_Atom prefetch_atom{atom};
auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms
return copy(prefetch_atom, src, dst);
} else {
return prefetch(src);
}
}
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
template <class... CT_Args,
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src)
{
using SrcType = typename SrcEngine::value_type;
static_assert(is_gmem<SrcEngine>::value, "Expected global tensor for L2 prefetch");
auto tiler = max_common_layout(src, src);
constexpr int vec_elem = decltype(size(tiler))::value;
constexpr int vec_bits = vec_elem * sizeof_bits_v<SrcType>;
static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP");
// Construct a new concrete Atom of the vector size
auto bulk_atom = Copy_Atom<Copy_Traits<SM90_BULK_COPY_G2S, Int<vec_bits>>, SrcType>{};
return prefetch(bulk_atom, logical_divide(src, tiler));
}
// Backwards-compat. Throw out any extra Copy_Atom args.
template <class... CT_Args, class... CA_Args,
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src)
{
return prefetch(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src);
}
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
} // end namespace cute

View File

@ -50,7 +50,7 @@ for_each(Tensor<Engine,Layout> const& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
static_cast<UnaryOp&&>(op)(tensor(i));
op(tensor(i));
}
}
@ -61,7 +61,7 @@ for_each(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
static_cast<UnaryOp&&>(op)(tensor(i));
op(tensor(i));
}
}
@ -71,7 +71,7 @@ CUTE_HOST_DEVICE constexpr
void
for_each(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
{
return for_each(tensor, static_cast<UnaryOp&&>(op));
return for_each(tensor, op);
}
//
@ -86,7 +86,7 @@ transform(Tensor<Engine,Layout>& tensor, UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = static_cast<UnaryOp&&>(op)(tensor(i));
tensor(i) = op(tensor(i));
}
}
@ -96,27 +96,34 @@ CUTE_HOST_DEVICE constexpr
void
transform(Tensor<Engine,Layout>&& tensor, UnaryOp&& op)
{
return transform(tensor, std::forward<UnaryOp>(op));
return transform(tensor, op);
}
// Similar to std::transform transforms one tensors and assigns it to another
template <class EngineIn, class LayoutIn, class EngineOut, class LayoutOut, class UnaryOp>
template <class EngineIn, class LayoutIn,
class EngineOut, class LayoutOut,
class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn,LayoutIn>& tensor_in, Tensor<EngineOut,LayoutOut>& tensor_out, UnaryOp&& op)
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
Tensor<EngineOut,LayoutOut> & tensor_out,
UnaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor_in); ++i) {
tensor_out(i) = static_cast<UnaryOp&&>(op)(tensor_in(i));
tensor_out(i) = op(tensor_in(i));
}
}
// Accept mutable temporaries
template <class EngineIn, class LayoutIn,
class EngineOut, class LayoutOut, class UnaryOp>
class EngineOut, class LayoutOut,
class UnaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& tensor_out, UnaryOp&& op)
transform(Tensor<EngineIn, LayoutIn > const& tensor_in,
Tensor<EngineOut,LayoutOut> && tensor_out,
UnaryOp&& op)
{
return transform(tensor_in, tensor_out, op);
}
@ -127,29 +134,31 @@ transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& t
// assigns it to tensor_out
template <class EngineIn1, class LayoutIn1,
class EngineIn2, class LayoutIn2,
class EngineOut, class LayoutOut, class BinaryOp>
class EngineOut, class LayoutOut,
class BinaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn1,LayoutIn1>& tensor_in1,
Tensor<EngineIn2,LayoutIn2>& tensor_in2,
Tensor<EngineOut,LayoutOut>& tensor_out,
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
Tensor<EngineOut,LayoutOut> & tensor_out,
BinaryOp&& op)
{
CUTE_UNROLL
for (int i = 0; i < size(tensor_in1); ++i) {
tensor_out(i) = static_cast<BinaryOp&&>(op)(tensor_in1(i), tensor_in2(i));
tensor_out(i) = op(tensor_in1(i), tensor_in2(i));
}
}
// Accept mutable temporaries
template <class EngineIn1, class LayoutIn1,
class EngineIn2, class LayoutIn2,
class EngineOut, class LayoutOut, class BinaryOp>
class EngineOut, class LayoutOut,
class BinaryOp>
CUTE_HOST_DEVICE constexpr
void
transform(Tensor<EngineIn1,LayoutIn1>&& tensor_in1,
Tensor<EngineIn2,LayoutIn2>&& tensor_in2,
Tensor<EngineOut,LayoutOut>&& tensor_out,
transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
Tensor<EngineIn2,LayoutIn2> const& tensor_in2,
Tensor<EngineOut,LayoutOut> && tensor_out,
BinaryOp&& op)
{
return transform(tensor_in1, tensor_in2, tensor_out, op);

View File

@ -204,36 +204,6 @@ for_each_leaf(T&& t, F&& f)
CUTE_GCC_UNREACHABLE;
}
//
// For Sequence
// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n]))
//
namespace detail {
template <int... I, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const&, F&& f) {
(f(Int<I>{}), ...);
}
}; // end namespace detail
template <int... I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(seq<I...> const& s, T&& t, F&& f) {
detail::for_sequence(s, [&](auto&& i){ f(get<remove_cvref_t<decltype(i)>::value>(static_cast<T&&>(t))); });
}
template <int I, class T, class F>
CUTE_HOST_DEVICE constexpr
void
for_sequence(T&& t, F&& f) {
for_sequence(make_seq<I>{}, static_cast<T&&>(t), static_cast<F&&>(f));
}
//
// Transform
// (t, f) => (f(t_0),f(t_1),...,f(t_n))
@ -551,15 +521,15 @@ take(T const& t)
template <int... I, class T>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t)
select(T const& t)
{
return cute::make_tuple(get<I>(t)...);
}
template <class T, typename Indices>
template <class T, class Indices>
CUTE_HOST_DEVICE constexpr
auto
select(T const & t, Indices const & indices)
select(T const& t, Indices const& indices)
{
if constexpr (is_tuple<Indices>::value) {
return cute::transform(indices, [&t](auto i) { return select(t, i); });
@ -655,7 +625,7 @@ flatten(T const& t)
namespace detail {
template<class FlatTuple, class TargetProfile>
template <class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
@ -680,7 +650,7 @@ unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple)
// @post congruent(@a result, @a target_profile)
// @post flatten(@a result) == @a flat_tuple
template<class FlatTuple, class TargetProfile>
template <class FlatTuple, class TargetProfile>
CUTE_HOST_DEVICE constexpr
auto
unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile)
@ -865,6 +835,7 @@ append(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
@ -902,6 +873,7 @@ prepend(T const& a, X const& x)
CUTE_GCC_UNREACHABLE;
}
template <class T, class X>
CUTE_HOST_DEVICE constexpr
auto
@ -1105,14 +1077,13 @@ zip2_by(T const& t, TG const& guide)
/// @return A tuple of the elements of @c t in reverse order.
template <class T>
CUTE_HOST_DEVICE constexpr auto
reverse(T const& t) {
CUTE_HOST_DEVICE constexpr
auto
reverse(T const& t)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [] (auto const&... a) {
return cute::make_tuple(a...);
}, tuple_rseq<T>{});
}
else {
return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq<T>{});
} else {
return t;
}
}

View File

@ -49,7 +49,7 @@ CUTE_DEVICE void cluster_arrive_relaxed()
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : );
#else
CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
#endif
}
@ -58,7 +58,7 @@ CUTE_DEVICE void cluster_arrive()
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
asm volatile("barrier.cluster.arrive.aligned;\n" : : );
#else
CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
#endif
}
@ -67,7 +67,7 @@ CUTE_DEVICE void cluster_wait()
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
asm volatile("barrier.cluster.wait.aligned;\n" : : );
#else
CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
#endif
}
@ -77,7 +77,7 @@ CUTE_DEVICE void cluster_sync()
cluster_arrive();
cluster_wait();
#else
CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined");
#endif
}
@ -94,7 +94,7 @@ CUTE_DEVICE dim3 cluster_grid_dims()
// MSVC requires protecting use of gridDim with __CUDA_ARCH__.
return gridDim;
#elif defined(_MSC_VER)
CUTE_RUNTIME_ASSERT("cluster_grid_dims() can only be called on device");
CUTE_INVALID_CONTROL_PATH("cluster_grid_dims() can only be called on device");
return {0, 0, 0};
#else
return {0, 0, 0};
@ -114,7 +114,7 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
// MSVC requires protecting use of blockIdx with __CUDA_ARCH__.
return blockIdx;
#elif defined(_MSC_VER)
CUTE_RUNTIME_ASSERT("cluster_id_in_grid() can only be called on device");
CUTE_INVALID_CONTROL_PATH("cluster_id_in_grid() can only be called on device");
return {0, 0, 0};
#else
return {0, 0, 0};

View File

@ -33,7 +33,7 @@
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/numeric/int.hpp>
#include <cute/numeric/numeric_types.hpp>
namespace cute
{
@ -89,4 +89,17 @@ using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<8>;
// Alias
using DefaultCopy = AutoVectorizingCopy;
//
// Global memory prefetch into L2
//
CUTE_HOST_DEVICE static void
prefetch(void const* gmem_ptr)
{
#if defined(__CUDA_ARCH__)
asm volatile("prefetch.global.L2 [%0];\n" : : "l"(gmem_ptr) : "memory");
#endif
}
} // end namespace cute

View File

@ -78,7 +78,7 @@ struct SM75_U32x1_LDSM_N
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -98,7 +98,7 @@ struct SM75_U32x2_LDSM_N
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -118,7 +118,7 @@ struct SM75_U32x4_LDSM_N
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -138,7 +138,7 @@ struct SM75_U16x2_LDSM_T
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -158,7 +158,7 @@ struct SM75_U16x4_LDSM_T
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -178,7 +178,7 @@ struct SM75_U16x8_LDSM_T
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

View File

@ -64,7 +64,7 @@ struct SM80_CP_ASYNC_CACHEALWAYS
"l"(gmem_ptr),
"n"(sizeof(TS)));
#else
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled");
#endif
}
};
@ -91,7 +91,7 @@ struct SM80_CP_ASYNC_CACHEGLOBAL
"l"(gmem_ptr),
"n"(sizeof(TS)));
#else
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled");
#endif
}
};
@ -121,7 +121,7 @@ struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL
"n"(sizeof(TS)),
"r"(src_size));
#else
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled");
#endif
}
};
@ -151,7 +151,7 @@ struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL
"n"(sizeof(TS)),
"r"(src_size));
#else
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled");
#endif
}
};

View File

@ -63,7 +63,7 @@ struct SM90_U32x1_STSM_N
:: "r"(smem_int_ptr),
"r"(src));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
@ -83,7 +83,7 @@ struct SM90_U32x2_STSM_N
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
@ -103,7 +103,7 @@ struct SM90_U32x4_STSM_N
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
@ -123,7 +123,7 @@ struct SM90_U16x2_STSM_T
:: "r"(smem_int_ptr),
"r"(src));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
@ -143,7 +143,7 @@ struct SM90_U16x4_STSM_T
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};
@ -163,7 +163,7 @@ struct SM90_U16x8_STSM_T
:: "r"(smem_int_ptr),
"r"(src0), "r"(src1), "r"(src2), "r"(src3));
#else
CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED.");
#endif
}
};

View File

@ -43,8 +43,7 @@
#include <cute/container/alignment.hpp>
#include <cute/container/bit_field.hpp>
#include <cute/container/array.hpp>
#include <cute/numeric/int.hpp> // to_Format<[u]intX>
#include <cute/numeric/half.hpp> // to_Format<half_t>
#include <cute/numeric/numeric_types.hpp>
namespace cute
{
@ -177,8 +176,10 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t) {
#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)
using TmaDescriptor = CUtensorMap;
using Im2ColTmaDescriptor = CUtensorMap;
#else
using TmaDescriptor = struct alignas(64) { char bytes[128]; };
using Im2ColTmaDescriptor = struct alignas(64) { char bytes[128]; };
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Initiates a TensorMap Prefetch
@ -197,7 +198,7 @@ prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
: "l"(gmem_int_desc)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}
@ -208,7 +209,7 @@ prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
// Replace tensor pointer directly in GMEM
CUTE_HOST_DEVICE
void
tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr,
tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr,
void const* const new_tensor_ptr)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
@ -218,14 +219,14 @@ tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr,
"tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;"
:: "l"(gmem_int_desc), "l"(new_desc_addr));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory
CUTE_HOST_DEVICE
void
tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
void const* const new_tensor_ptr)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
@ -239,7 +240,7 @@ tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
:: "l"(smem_int64_desc), "l"(new_desc_addr));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
@ -273,7 +274,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
@ -292,7 +293,7 @@ tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescripto
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
:: "l"(gmem_int_desc), "r"(smem_int_desc));
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
@ -307,7 +308,7 @@ tma_descriptor_fence_release()
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
asm volatile ("fence.proxy.tensormap::generic.release.gpu;");
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}
@ -332,7 +333,7 @@ tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr)
: "l"(gmem_int_desc), "l"(gmem_int_desc)
: "memory");
#else
CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3");
#endif
}

File diff suppressed because it is too large Load Diff

View File

@ -58,7 +58,7 @@ struct SM61_DP4A
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED");
#endif
}
};
@ -79,7 +79,7 @@ struct SM61_DP2A
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED");
#endif
}
};

View File

@ -74,7 +74,7 @@ struct SM70_8x8x4_F16F16F16F16_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -106,7 +106,7 @@ struct SM70_8x8x4_F16F16F16F16_NT
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -138,7 +138,7 @@ struct SM70_8x8x4_F16F16F16F16_NN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -170,7 +170,7 @@ struct SM70_8x8x4_F16F16F16F16_TT
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -210,7 +210,7 @@ struct SM70_8x8x4_F32F16F16F32_TN
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -246,7 +246,7 @@ struct SM70_8x8x4_F32F16F16F32_NT
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -282,7 +282,7 @@ struct SM70_8x8x4_F32F16F16F32_NN
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}
};
@ -318,7 +318,7 @@ struct SM70_8x8x4_F32F16F16F32_TT
"f"(c0), "f"(c1), "f"(c2), "f"(c3),
"f"(c4), "f"(c5), "f"(c6), "f"(c7));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
}

View File

@ -74,7 +74,7 @@ struct SM75_16x8x8_F32F16F16F32_TN
"r"(b0),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
#endif
}
};
@ -110,7 +110,7 @@ struct SM75_8x8x16_S32S8S8S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED");
#endif
}
};

View File

@ -33,6 +33,7 @@
#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
#include <cute/numeric/complex.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
@ -80,7 +81,7 @@ struct SM80_16x8x8_F16F16F16F16_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -113,7 +114,7 @@ struct SM80_16x8x16_F16F16F16F16_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -146,7 +147,7 @@ struct SM80_16x8x8_F32F16F16F32_TN
"r"(b0),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -179,7 +180,7 @@ struct SM80_16x8x16_F32F16F16F32_TN
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -212,7 +213,7 @@ struct SM80_16x8x8_F32BF16BF16F32_TN
"r"(b0),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -245,7 +246,7 @@ struct SM80_16x8x16_F32BF16BF16F32_TN
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -278,7 +279,7 @@ struct SM80_16x8x4_F32TF32TF32F32_TN
"r"(b0),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -311,7 +312,7 @@ struct SM80_16x8x8_F32TF32TF32F32_TN
"r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -344,7 +345,7 @@ struct SM80_8x8x4_F64F64F64F64_TN
"d"(b0),
"d"(c0), "d"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -385,14 +386,14 @@ struct SM80_8x8x4_C64C64C64C64_TN
// d.real() = -a.imag() * b.imag() + d.real();
SM80_8x8x4_F64F64F64F64_TN::fma(
rd0, rd1,
rd0, rd1,
-a0.imag(),
b0.imag(),
d0.real(), d1.real());
// d.imag() = a.real() * b.imag() + d.imag();
SM80_8x8x4_F64F64F64F64_TN::fma(
id0, id1,
id0, id1,
a0.real(),
b0.imag(),
d0.imag(), d1.imag());
@ -412,15 +413,15 @@ struct SM80_8x8x4_GC64C64C64GC64_TN
{
struct GaussComplex {
double t0, t1, t2;
CUTE_HOST_DEVICE //constexpr
operator complex<double>() const { return complex<double>(t0 - t1, t2 - t0 - t1); }
CUTE_HOST_DEVICE friend //constexpr
complex<double> operator*(GaussComplex const& a, complex<double> const& b) { return static_cast<complex<double>>(a) * b; }
CUTE_HOST_DEVICE friend //constexpr
complex<double> operator*(complex<double> const& a, GaussComplex const& b) { return b * a; }
CUTE_HOST_DEVICE friend //constexpr
complex<double> operator+(GaussComplex const& a, complex<double> const& b) { return static_cast<complex<double>>(a) + b; }
CUTE_HOST_DEVICE friend //constexpr
@ -481,7 +482,7 @@ struct SM80_8x8x16_S32S8S8S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -514,7 +515,7 @@ struct SM80_8x8x16_S32S8S8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -547,7 +548,7 @@ struct SM80_16x8x16_S32S8S8S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -580,7 +581,7 @@ struct SM80_16x8x16_S32S8S8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -613,7 +614,7 @@ struct SM80_16x8x32_S32S8S8S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -646,7 +647,7 @@ struct SM80_16x8x32_S32S8S8S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -679,7 +680,7 @@ struct SM80_8x8x16_S32S8U8S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -712,7 +713,7 @@ struct SM80_8x8x16_S32S8U8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -745,7 +746,7 @@ struct SM80_16x8x16_S32S8U8S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -778,7 +779,7 @@ struct SM80_16x8x16_S32S8U8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -811,7 +812,7 @@ struct SM80_16x8x32_S32S8U8S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -844,7 +845,7 @@ struct SM80_16x8x32_S32S8U8S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -877,7 +878,7 @@ struct SM80_8x8x16_S32U8S8S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -910,7 +911,7 @@ struct SM80_8x8x16_S32U8S8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -943,7 +944,7 @@ struct SM80_16x8x16_S32U8S8S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -976,7 +977,7 @@ struct SM80_16x8x16_S32U8S8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1009,7 +1010,7 @@ struct SM80_16x8x32_S32U8S8S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1042,7 +1043,7 @@ struct SM80_16x8x32_S32U8S8S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1075,7 +1076,7 @@ struct SM80_8x8x16_S32U8U8S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1108,7 +1109,7 @@ struct SM80_8x8x16_S32U8U8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1141,7 +1142,7 @@ struct SM80_16x8x16_S32U8U8S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1174,7 +1175,7 @@ struct SM80_16x8x16_S32U8U8S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1207,7 +1208,7 @@ struct SM80_16x8x32_S32U8U8S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1240,7 +1241,7 @@ struct SM80_16x8x32_S32U8U8S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1273,7 +1274,7 @@ struct SM80_8x8x32_S32S4S4S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1306,7 +1307,7 @@ struct SM80_8x8x32_S32S4S4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1339,7 +1340,7 @@ struct SM80_16x8x32_S32S4S4S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1372,7 +1373,7 @@ struct SM80_16x8x32_S32S4S4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1405,7 +1406,7 @@ struct SM80_16x8x64_S32S4S4S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1438,7 +1439,7 @@ struct SM80_16x8x64_S32S4S4S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1471,7 +1472,7 @@ struct SM80_8x8x32_S32S4U4S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1504,7 +1505,7 @@ struct SM80_8x8x32_S32S4U4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1537,7 +1538,7 @@ struct SM80_16x8x32_S32S4U4S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1570,7 +1571,7 @@ struct SM80_16x8x32_S32S4U4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1603,7 +1604,7 @@ struct SM80_16x8x64_S32S4U4S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1636,7 +1637,7 @@ struct SM80_16x8x64_S32S4U4S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1669,7 +1670,7 @@ struct SM80_8x8x32_S32U4S4S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1702,7 +1703,7 @@ struct SM80_8x8x32_S32U4S4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1735,7 +1736,7 @@ struct SM80_16x8x32_S32U4S4S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1768,7 +1769,7 @@ struct SM80_16x8x32_S32U4S4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1801,7 +1802,7 @@ struct SM80_16x8x64_S32U4S4S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1834,7 +1835,7 @@ struct SM80_16x8x64_S32U4S4S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1867,7 +1868,7 @@ struct SM80_8x8x32_S32U4U4S32_TN
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1900,7 +1901,7 @@ struct SM80_8x8x32_S32U4U4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1933,7 +1934,7 @@ struct SM80_16x8x32_S32U4U4S32_TN
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1966,7 +1967,7 @@ struct SM80_16x8x32_S32U4U4S32_TN_SATURATE
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -1999,7 +2000,7 @@ struct SM80_16x8x64_S32U4U4S32_TN
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -2032,7 +2033,7 @@ struct SM80_16x8x64_S32U4U4S32_TN_SATURATE
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -2067,7 +2068,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -2100,7 +2101,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
@ -2133,7 +2134,7 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};

View File

@ -73,7 +73,7 @@ struct SM90_16x8x4_F64F64F64F64_TN
"d"(b0),
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
#endif
}
};
@ -106,7 +106,7 @@ struct SM90_16x8x8_F64F64F64F64_TN
"d"(b0), "d"(b1),
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
#endif
}
};
@ -141,7 +141,7 @@ struct SM90_16x8x16_F64F64F64F64_TN
"d"(b0), "d"(b1), "d"(b2), "d"(b3),
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
#else
CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
#endif
}
};
@ -364,37 +364,185 @@ ss_op_selector()
// FP16 accumulator
if constexpr (is_same_v<ElementC, half_t>) {
static_assert(is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
static_assert(is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
if constexpr (is_same_v<ElementA, half_t> && is_same_v<ElementB, half_t>) {
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
// Dispatch against the Tile N mode size
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
// Dispatch against the Tile N mode size
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
}
else {
// FP8
// Input A: float_e4m3_t ; Input B: float_e4m3_t
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e4m3_t>) {
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_F16E4M3E4M3_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// FP8
// Input A: float_e4m3_t ; Input B: float_e5m2_t
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e5m2_t>) {
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_F16E4M3E5M2_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// FP8
// Input A: float_e5m2_t ; Input B: float_e5m2_t
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e5m2_t>) {
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_F16E5M2E5M2_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
// FP8
// Input A: float_e5m2_t ; Input B: float_e4m3_t
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e4m3_t>) {
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
if constexpr (Tile_N % 256 == 0) {
return SM90_64x256x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 192 == 0) {
return SM90_64x192x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 128 == 0) {
return SM90_64x128x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 96 == 0) {
return SM90_64x96x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 64 == 0) {
return SM90_64x64x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 32 == 0) {
return SM90_64x32x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 16 == 0) {
return SM90_64x16x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else if constexpr (Tile_N % 8 == 0) {
return SM90_64x8x32_F16E5M2E4M3_SS_TN<Args...>{};
}
else {
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
}
}
else {
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
}
}

File diff suppressed because it is too large Load Diff

View File

@ -81,7 +81,7 @@ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
CUTE_HOST_DEVICE
auto
with(TraitsArgs&&... args) const {
auto traits = Traits::with(std::forward<TraitsArgs>(args)...);
auto traits = Traits::with(static_cast<TraitsArgs&&>(args)...);
return Copy_Atom<decltype(traits), CopyInternalType>{traits};
}
@ -351,7 +351,7 @@ struct ThrCopy
partition_S(STensor&& stensor) const {
//static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename TiledCopy::ValType),
// "Expected ValType for tiling SrcTensor.");
auto thr_tensor = make_tensor(std::forward<STensor>(stensor).data(), TiledCopy::tidfrg_S(stensor.layout()));
auto thr_tensor = make_tensor(static_cast<STensor&&>(stensor).data(), TiledCopy::tidfrg_S(stensor.layout()));
return thr_tensor(thr_idx_, _, repeat<rank_v<STensor>>(_));
}
@ -361,7 +361,7 @@ struct ThrCopy
partition_D(DTensor&& dtensor) const {
//static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename TiledCopy::ValType),
// "Expected ValType for tiling DstTensor.");
auto thr_tensor = make_tensor(std::forward<DTensor>(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout()));
auto thr_tensor = make_tensor(static_cast<DTensor&&>(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout()));
return thr_tensor(thr_idx_, _, repeat<rank_v<DTensor>>(_));
}
@ -371,7 +371,7 @@ struct ThrCopy
retile_S(STensor&& stensor) {
// static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename TiledCopy::ValType),
// "Expected ValType for tiling SrcTensor.");
return make_tensor(std::forward<STensor>(stensor).data(), TiledCopy::retile(stensor.layout()));
return make_tensor(static_cast<STensor&&>(stensor).data(), TiledCopy::retile(stensor.layout()));
}
template <class DTensor>
@ -380,7 +380,7 @@ struct ThrCopy
retile_D(DTensor&& dtensor) {
// static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename TiledCopy::ValType),
// "Expected ValType for tiling DstTensor.");
return make_tensor(std::forward<DTensor>(dtensor).data(), TiledCopy::retile(dtensor.layout()));
return make_tensor(static_cast<DTensor&&>(dtensor).data(), TiledCopy::retile(dtensor.layout()));
}
};

View File

@ -94,17 +94,55 @@ struct Copy_Traits<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>>
namespace detail {
// Utility for exploding pointers, arrays, or tensors into Operation::copy
template <class Operation,
class PtrS, int... Is,
class PtrD, int... Id>
class PtrSrc, int... Is,
class PtrDst, int... Id>
CUTE_HOST_DEVICE constexpr
void
copy_explode(PtrS&& s, int_sequence<Is...>,
PtrD&& d, int_sequence<Id...>)
copy_explode_index(PtrSrc&& s, int_sequence<Is...>,
PtrDst&& d, int_sequence<Id...>)
{
return Operation::copy(s[Is]..., d[Id]...);
}
// Utility for exploding tuples into ::copy
template <class Operation,
class TupleArg, int... I>
CUTE_HOST_DEVICE constexpr
void
copy_explode(TupleArg&& t, int_sequence<I...>)
{
return Operation::copy(get<I>(static_cast<TupleArg&&>(t))...);
}
template <class Operation,
class TupleSrc, int... Is,
class TupleDst, int... Id>
CUTE_HOST_DEVICE constexpr
void
copy_explode(TupleSrc&& s, int_sequence<Is...>,
TupleDst&& d, int_sequence<Id...>)
{
return Operation::copy(get<Is>(static_cast<TupleSrc&&>(s))...,
get<Id>(static_cast<TupleDst&&>(d))...);
}
template <class Operation,
class TupleAux, int... Ia,
class TupleSrc, int... Is,
class TupleDst, int... Id>
CUTE_HOST_DEVICE constexpr
void
copy_explode(TupleAux&& a, int_sequence<Ia...>,
TupleSrc&& s, int_sequence<Is...>,
TupleDst&& d, int_sequence<Id...>)
{
return Operation::copy(get<Ia>(static_cast<TupleAux&&>(a))...,
get<Is>(static_cast<TupleSrc&&>(s))...,
get<Id>(static_cast<TupleDst&&>(d))...);
}
} // end namespace detail
//
@ -139,8 +177,8 @@ copy_unpack(Copy_Traits<CopyOp,Args...> const&,
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
"Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp.");
detail::copy_explode<CopyOp>(rS, make_int_sequence<RegNumSrc>{},
rD, make_int_sequence<RegNumDst>{});
detail::copy_explode_index<CopyOp>(rS, make_int_sequence<RegNumSrc>{},
rD, make_int_sequence<RegNumDst>{});
}
//

View File

@ -0,0 +1,879 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/*! \file
\brief im2col make_tma_copy
*/
#include "cute/arch/copy_sm90.hpp"
#include "cute/arch/copy_sm90_desc.hpp"
#include "cute/tensor.hpp"
#include "cute/algorithm/prefetch.hpp"
namespace cute
{
// Utility for unpacking TMA_LOAD_IM2COL arguments into a CopyOp
template <class CopyOp>
struct TMA_LOAD_IM2COL_Unpack
{
/// Copy from src to dst.
///
/// @param traits Copy traits created with a TMA descriptor that
/// correctly matches the input tensor and other convolution
/// parameters.
///
/// @param src Tile of the im2col-transformed coordinate tensor
/// (result of get_tma_tensor), representing the global-memory
/// tensor from which to load.
///
/// @param dst Shared memory tile, into which to load.
template <class... Args,
class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits<CopyOp, Args...> const& traits,
Tensor<TS,SLayout> const& src, // tile of the transformed global activation (A) tensor
Tensor<TD,DLayout> & dst) // shared memory tile
{
auto src_coord_offset = src(Int<0>{});
auto src_coord_cwhdn_offset_srt = flatten(src_coord_offset);
// Interpret the TMA IM2COL coordinate as (c, ([w,h,d]), n, ([s,r,t]))
CUTE_STATIC_ASSERT_V(rank(src_coord_offset) == _4{});
CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset));
if constexpr (detail::is_prefetch<CopyOp>) {
return detail::copy_explode<CopyOp>(traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
src_coord_cwhdn_offset_srt, tuple_seq<decltype(src_coord_cwhdn_offset_srt)>{});
} else {
static_assert(is_smem<TD>::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory.");
void* dst_ptr = cute::raw_pointer_cast(dst.data());
return detail::copy_explode<CopyOp>(traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
make_tuple(dst_ptr), seq<0>{},
src_coord_cwhdn_offset_srt, tuple_seq<decltype(src_coord_cwhdn_offset_srt)>{});
}
}
};
// Copy_Traits for SM90 im2col TMA load comes in two layers.
//
// 1. Copy_Traits<SM90_TMA_LOAD_IM2COL>
// 2. Copy_Traits<SM90_TMA_LOAD_IM2COL_OP>
//
// Copy_Traits<SM90_TMA_LOAD_IM2COL>
// is the "outer" layer. It has a TMA descriptor,
// but no barrier ("tma_mbar"), so it's "nonexecutable."
// One calls its "with" member function with a barrier,
// to get an executable "inner"-layer
// Copy_Traits<SM90_TMA_LOAD_IM2COL_OP> object.
// That object's "copy_unpack" member function
// actually invokes im2col TMA load.
struct SM90_TMA_LOAD_IM2COL_OP : SM90_TMA_LOAD_IM2COL {};
/// @brief Non-executable specialization of Copy_Traits for SM90
/// im2col TMA load, with TMA descriptor but no barrier.
///
/// Use `.with(memory_barrier)` to construct an executable version.
template <class NumBitsPerTMA, class TMATensor>
struct Copy_Traits<SM90_TMA_LOAD_IM2COL, NumBitsPerTMA, TMATensor>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
Im2ColTmaDescriptor tma_desc_;
TMATensor tma_tensor_;
CUTE_HOST_DEVICE constexpr
Im2ColTmaDescriptor const*
get_tma_descriptor() const
{
return &tma_desc_;
}
template <class GShape>
CUTE_HOST_DEVICE constexpr
TMATensor const
get_tma_tensor(GShape const&) const
{
return tma_tensor_;
}
/// @brief Get an executable specialization.
///
/// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL are not
/// directly executable. Instead, call this "with" member function
/// to get an executable specialization. "Executable" means that
/// @c copy_unpack works.
///
/// @param tma_mbar Memory barrier for synchronization
///
/// @param multicast_mask Multicast mask (unused; only exists
/// for interface compatibility with the actual multicast Copy_Traits)
///
/// @return Executable specialization of @c Copy_Traits
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_IM2COL_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const
{
return {{}, {&tma_desc_, &tma_mbar}};
}
// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL
// are not directly executable. Instead, call .with
// to get an executable specialization.
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
/// @brief Executable specialization of Copy_Traits for SM90 im2col
/// TMA load, with TMA descriptor and barrier.
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_IM2COL_OP, NumBitsPerTMA>
: TMA_LOAD_IM2COL_Unpack<SM90_TMA_LOAD_IM2COL_OP>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_IM2COL arguments
tuple<
Im2ColTmaDescriptor const*,
uint64_t* // smem mbarrier
> const opargs_;
};
template <class NumBitsPerTMA, class... Args>
struct Copy_Traits<SM90_TMA_LOAD_IM2COL::PREFETCH, NumBitsPerTMA, Args...>
: TMA_LOAD_IM2COL_Unpack<SM90_TMA_LOAD_IM2COL::PREFETCH>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_IM2COL::PREFETCH arguments
tuple<Im2ColTmaDescriptor const*> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(Copy_Traits<SM90_TMA_LOAD_IM2COL, NumBitsPerTMA, Args...> const& traits)
: opargs_({&traits.tma_desc_}) {}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_IM2COL_MULTICAST_OP : SM90_TMA_LOAD_IM2COL_MULTICAST {};
/// @brief Non-executable specialization of Copy_Traits for SM90
/// im2col TMA load, with TMA descriptor but no barrier or multicast
/// mask.
///
/// Use `.with(memory_barrier)` to construct an executable version.
template <class NumBitsPerTMA, class TMATensor>
struct Copy_Traits<SM90_TMA_LOAD_IM2COL_MULTICAST, NumBitsPerTMA, TMATensor>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
Im2ColTmaDescriptor tma_desc_;
TMATensor tma_tensor_;
CUTE_HOST_DEVICE constexpr
Im2ColTmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
template <class GShape>
CUTE_HOST_DEVICE constexpr
TMATensor const
get_tma_tensor(GShape const&) const
{
return tma_tensor_;
}
/// @brief Get an executable specialization.
///
/// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST
/// are not directly executable. Instead, call this "with" member
/// function to get an executable specialization. "Executable"
/// means that @c copy_unpack works.
///
/// @param tma_mbar Memory barrier for synchronization
///
/// @param multicast_mask Multicast mask (defaults to a single CTA)
///
/// @return Executable specialization of @c Copy_Traits
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_IM2COL_MULTICAST_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const {
return {{}, {&tma_desc_, &tma_mbar, multicast_mask}};
}
// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST
// are not directly executable. Instead, call .with to get an
// executable specialization.
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
/// @brief Executable specialization of Copy_Traits for SM90 multicast
/// im2col TMA load, with TMA descriptor, barrier, and multicast mask.
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_IM2COL_MULTICAST_OP, NumBitsPerTMA>
: TMA_LOAD_IM2COL_Unpack<SM90_TMA_LOAD_IM2COL_MULTICAST_OP>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit.
using SrcLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_IM2COL_MULTICAST arguments
tuple<
Im2ColTmaDescriptor const*,
uint64_t*, // smem mbarrier
uint16_t // multicast mask
> const opargs_;
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_STORE IM2COL////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
// The executable SM90_TMA_STORE_IM2COL with tma_desc
template <class NumBitsPerTMA, class TMATensor>
struct Copy_Traits<SM90_TMA_STORE_IM2COL, NumBitsPerTMA, TMATensor>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_STORE_IM2COL arguments
Im2ColTmaDescriptor tma_desc_;
TMATensor tma_tensor_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
Im2ColTmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
template <class GShape>
CUTE_HOST_DEVICE constexpr
TMATensor const
get_tma_tensor(GShape const&) const
{
return tma_tensor_;
}
// This is the copy_unpack dispatch for this Copy_Traits
// Src needs to be a smem tensor
// Dst needs to be a gmem tensor with TmaCoordIterator .data()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE_IM2COL");
void const* const desc_ptr = &(traits.tma_desc_);
void const* const src_ptr = cute::raw_pointer_cast(src.data());
auto dst_coord = flatten(take<0,3>(dst(Int<0>{})));
return detail::copy_explode<SM90_TMA_STORE_IM2COL>(make_tuple(desc_ptr, src_ptr), seq<0,1>{},
dst_coord, tuple_seq<decltype(dst_coord)>{});
}
};
namespace detail {
/// @brief Creates a TMA descriptor for im2col TMA load.
///
/// @param tensor_cwhdn Global activation tensor (A matrix of Fprop).
/// This is the original (not im2col-transformed) tensor in global
/// memory.
///
/// @param slayout Rank 2 (M,K) shared memory layout of the activation
/// tensor. Here, K is "GEMM K," not the filter tensor's mode of
/// the same name.
//////
/// @param traversal_stride Traversal strides convolution parameter
//////
/// Each of padding_shape, traversal_stride, and dilation_shape is a
/// tuple whose size is the number of spatial modes (e.g., 3 for a 5-D
/// convolution).
///
/// @return TMA descriptor for im2col TMA load
template <class EngineA, class LayoutA,
class SmemSwizzle, class TMALayout,
class LowerCornerStride,
class UpperCornerStride,
class LowerPaddingStride,
class UpperPaddingStride,
class TraversalStride,
class LowerSRTStride,
class DilationStride>
CUTE_HOST
auto
make_im2col_tma_copy_desc(
Tensor<EngineA, LayoutA> const& tensor_cwhdn, // (C,W,H,D,N)
uint32_t range_c, // TILE_C
uint32_t range_whdn, // TILE_WHDN
SmemSwizzle const& smem_swizzle, // Swizzle
TMALayout const& tma_layout_vt, // TMA layout
LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer"
UpperCornerStride const& upper_corner_whd, // WHD upper corner
LowerPaddingStride const& lower_padding_whd, // WHD lower padding
UpperPaddingStride const& upper_padding_whd, // WHD upper padding
TraversalStride const& stride_whd, // WHD traversal stride
LowerSRTStride const& lower_srt, // SRT offset of the "base pointer"
DilationStride const& stride_srt) // SRT stride - dilation
{
static_assert(is_gmem<EngineA>::value, "Tensor must point to GPU global memory.");
using value_type = typename EngineA::value_type;
constexpr uint32_t num_total_modes = LayoutA::rank;
constexpr int num_spatial_modes = num_total_modes - 2;
// Gmem starting address
void* gmem_address = (void*) raw_pointer_cast(tensor_cwhdn.data());
// Gmem extents are just the tensor shape
cute::array<uint64_t, 5> gmem_prob_shape = {1,1,1,1,1};
for_each(make_seq<num_total_modes>{}, [&](auto i) {
gmem_prob_shape[i] = static_cast<uint64_t>(shape<i>(tensor_cwhdn));
});
// Gmem strides are byte strides of the activation tensor in CWHDN order
cute::array<uint64_t, 5> gmem_prob_stride = {0,0,0,0,0};
for_each(make_seq<num_total_modes>{}, [&](auto i) {
gmem_prob_stride[i] = sizeof(value_type) * stride<i>(tensor_cwhdn);
});
// Traversal strides are a function of the dilation shape
// corresponding to spatial (WHD) modes.
cute::array<uint32_t, 5> tma_traversal_strides = {1,1,1,1,1};
for_each(make_seq<num_spatial_modes>{}, [&](auto i) {
tma_traversal_strides[i+1] = static_cast<uint32_t>(get<i>(stride_whd));
});
cute::array<int32_t, num_spatial_modes> tma_lower_corner{};
for_each(make_seq<num_spatial_modes>{}, [&](auto i) {
tma_lower_corner[i] = static_cast<int32_t>(get<i>(lower_corner_whd));
});
cute::array<int32_t, num_spatial_modes> tma_upper_corner{};
for_each(make_seq<num_spatial_modes>{}, [&](auto i) {
tma_upper_corner[i] = static_cast<int32_t>(get<i>(upper_corner_whd));
});
Im2ColTmaDescriptor tma_desc;
#if (__CUDACC_VER_MAJOR__ >= 12)
CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType<value_type>();
CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE;
CUtensorMapFloatOOBfill tma_oob_fill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle));
CUresult encode_result = cuTensorMapEncodeIm2col(
&tma_desc,
tma_format,
num_total_modes,
gmem_address,
gmem_prob_shape.data(),
gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly sizeof(value_type)
tma_lower_corner.data(),
tma_upper_corner.data(),
range_c,
range_whdn,
tma_traversal_strides.data(),
tma_interleave,
tma_swizzle,
tma_l2Promotion,
tma_oob_fill);
// The extra asserts help indicate the error's cause.
assert(encode_result != CUDA_ERROR_DEINITIALIZED);
assert(encode_result != CUDA_ERROR_NOT_INITIALIZED);
assert(encode_result != CUDA_ERROR_INVALID_CONTEXT);
assert(encode_result != CUDA_ERROR_INVALID_VALUE);
assert(encode_result == CUDA_SUCCESS);
#endif // (__CUDACC_VER_MAJOR__ >= 12)
//
// Calculate gemm shapes and linearized shapes based on tma layout tiling.
//
// Compute [w, h, d, n]
// q/p/z = (w/h/d + (upper_corner_whd - lower_corner_whd - 1)) / stride_whd + 1
auto gemm_mn_ = cute::transform(cute::make_seq<num_spatial_modes>{}, [&](auto i) {
return (shape<i+1>(tensor_cwhdn) + get<i>(upper_corner_whd) - get<i>(lower_corner_whd) - Int<1>{}) / get<i>(stride_whd) + Int<1>{};
});
auto gemm_mn = append(gemm_mn_, shape<num_spatial_modes+1>(tensor_cwhdn));
// Compute [c, s, r, t]
// fprop/wgrad, s/r/t = 1 + (upper_padding_whd - upper_corner_whd) / stride_srt
// wgrad, s/r/t = 1 + (lower_padding_whd - lower_corner_whd) / stride_srt
auto gemm_k_ = cute::transform(cute::make_seq<num_spatial_modes>{}, [&](auto i) {
auto padding_size = conditional_return(get<i>(stride_srt) > Int<0>{},
get<i>(upper_padding_whd) - get<i>(upper_corner_whd),
get<i>(lower_corner_whd) - get<i>(lower_padding_whd));
return Int<1>{} + padding_size / get<i>(stride_srt);
});
auto gemm_k = prepend(gemm_k_, shape<0>(tensor_cwhdn));
// For fprop/dgrad kernel, gemm_shapes is ((q, p, z, n), (c, s, r, t))
// For wgrad kernel, gemm_shapes is ((c, s, r, t), (q, p, z, n))
auto gemm_shapes_common = make_shape(gemm_mn, gemm_k);
auto gemm_shapes = make_shape(
basis_get(stride<0,1>(tma_layout_vt), gemm_shapes_common),
basis_get(stride<0,0>(tma_layout_vt), gemm_shapes_common));
// For fprop/dgrad kernel, linearized shapes is (whdn, (c, s, r, t))
// For wgrad kernel linearized shapes is ((c, s, r, t), whdn)
auto linear_shapes_common = make_shape(size(gemm_mn), gemm_k);
auto linear_shapes = make_shape(
basis_get(stride<0,1>(tma_layout_vt), linear_shapes_common),
basis_get(stride<0,0>(tma_layout_vt), linear_shapes_common));
//
// Calculate gmem basis stride based on tma layout tiling.
//
auto tma_basis_scale = make_shape(Int<1>{}, stride_whd, Int<1>{}, stride_srt);
auto tma_basis = elem_scale(tma_basis_scale, make_basis_like(tma_basis_scale));
auto gbasis_strides_common = make_stride(
append(get<1>(tma_basis), get<2>(tma_basis)),
prepend(get<3>(tma_basis), get<0>(tma_basis))); // ((w,h,d,n),(c,s,r,t))
auto gbasis_strides = make_stride(
basis_get(stride<0,1>(tma_layout_vt), gbasis_strides_common),
basis_get(stride<0,0>(tma_layout_vt), gbasis_strides_common));
//
// Create tma tensor
//
auto lower_corner = make_arithmetic_tuple(Int<0>{}, lower_corner_whd, Int<0>{}, lower_srt);
auto tensor_multimode = make_tensor(ArithmeticTupleIterator(lower_corner), gemm_shapes, gbasis_strides);
auto tensor_linear = make_identity_tensor(linear_shapes);
auto tma_tensor = make_tensor(tensor_multimode.data(), composition(
tensor_multimode.layout(),
tensor_linear(Int<0>{}),
tensor_linear.layout()));
return cute::make_tuple(tma_desc, tma_tensor);
}
/// Make a TiledCopy for im2col TMA load.
///
/// @param copy_op The copy implementation: either
/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST.
///
/// @param tensor_cwhdn The global tensor to use for im2col TMA loads.
/// For Fprop convolutions, this is the activation tensor. This is
/// the "original tensor that points to global memory, not the
/// coordinate (im2col-transformed) tensor.
///
/// @param slayout Layout of shared memory tile.
///
/// @param stride_whd The traversal strides convolution
/// parameter.
///
/// @return TiledCopy specialization for im2col TMA loads.
template <class CopyOp,
class GEngine, class GLayout,
class SLayout,
class TShape, class TStride,
class VShape, class VStride,
class LowerCornerStride,
class UpperCornerStride,
class LowerPaddingStride,
class UpperPaddingStride,
class TraversalStride,
class LowerSRTStride,
class DilationStride>
CUTE_HOST_RTC
auto
make_tma_copy_im2col(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout,
Layout<TShape,TStride> const& cta_t_map, // CTA tid -> logical TMA tid
Layout<VShape,VStride> const& cta_v_map, // CTA vid -> gmem coord
LowerCornerStride const& lower_corner_whd,
UpperCornerStride const& upper_corner_whd,
LowerPaddingStride const& lower_padding_whd,
UpperPaddingStride const& upper_padding_whd,
TraversalStride const& stride_whd, // traversal stride
LowerSRTStride const& lower_srt,
DilationStride const& stride_srt) // dilation
{
//
// TMA parameter checking
//
CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)),
"TMA requires CTA_Tile and SLayout top-level shape equivalence.");
CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{},
"Number of active CTAs in TMA must divide domain size of slayout.");
//
// TMA slayout manipulation
//
// Invert the smem to get the largest contiguous vector in the smem layout
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
// trunc_smem_idx -> trunc_smem_coord
// Map from smem idx to a gmem mode
auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout));
#if 0
print("g_layout : "); print(gtensor.layout()); print("\n");
print("s_layout : "); print(slayout); print("\n");
print("cta_t_map : "); print(cta_t_map); print("\n");
print("cta_v_map : "); print(cta_v_map); print("\n");
print("inv_smem : "); print(inv_smem_layout); print("\n");
print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n");
#endif
//
// TMA gtensor manipulation
//
// Generate a TupleBasis for the gtensor
auto glayout_basis = make_identity_layout(product_each(shape(gtensor)));
// Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc
auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode));
// Truncate any incompatibilities -- no starting in the middle of gmodes
auto smem_rank = find_if(stride(tma_layout_full), [](auto e) {
[[maybe_unused]] auto v = basis_value(e);
return not is_constant<1,decltype(v)>{};
});
static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem.");
// IM2COL uses a maximum of 2 modes
constexpr int smem_tma_rank = cute::min(int(smem_rank), 2);
// Keep only the static-1 basis modes into gmem
auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full);
// Split according to the portion each multicast CTA will be responsible for
auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map)));
#if 0
print("glayout_basis : "); print(glayout_basis); print("\n");
print("tma_layout_full : "); print(tma_layout_full); print("\n");
print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n");
print("tma_layout_vt : "); print(tma_layout_vt); print("\n");
#endif
auto range_c = size<0,0>(tma_layout_vt);
auto range_whdn = size<0,1>(tma_layout_vt);
Tensor gtensor_cwhdn = make_tensor(gtensor.data(),
flatten(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.layout()),
basis_get(stride<0,1>(tma_layout_vt), gtensor.layout()))));
auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc(
gtensor_cwhdn,
range_c,
range_whdn,
detail::get_swizzle_portion(slayout),
tma_layout_vt,
lower_corner_whd,
upper_corner_whd,
lower_padding_whd,
upper_padding_whd,
stride_whd,
lower_srt,
stride_srt);
//
// Construct the Copy_Traits
//
using T = typename GEngine::value_type;
constexpr int num_bits_per_tma = decltype(size<0>(tma_layout_vt))::value * sizeof(T) * 8;
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(tma_tensor)>;
#if 0
print("num_bits : "); print(NumBitsPerTMA{}); print("\n");
#endif
Traits tma_traits{tma_desc, tma_tensor};
//
// Construct the TiledCopy
//
auto cta_tiler = product_each(shape(cta_v_map));
// (CTA V, CTA T) -> smem_coord
auto layout_vt = composition(inv_smem_layout, make_layout(shape(tma_layout_vt)));
// Scale that up to cover all of the smem_coords
//
// The smem vector might not cover all of the tile,
// so multiply it up to cover the entire tile.
// "T" here (the parallel index) is a CTA index.
auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt)));
// Flip it and change the domain of the T from logical thr to thr_idx
auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT));
#if 0
print("cta_tiler : "); print(cta_tiler); print("\n");
print("layout_VT : "); print(layout_VT); print("\n");
print("layout_TV : "); print(layout_TV); print("\n");
#endif
using T = typename GEngine::value_type;
return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_TV), decltype(cta_tiler)>{tma_traits};
}
/// Make a TiledCopy for im2col TMA with no offsets.
/// E.g. im2col TMA load for C and im2col TMA store for D.
template <class CopyOp,
class GEngine, class GLayout,
class SLayout,
class TShape, class TStride,
class VShape, class VStride>
CUTE_HOST_RTC
auto
make_tma_copy_im2col(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor,
SLayout const& slayout,
Layout<TShape,TStride> const& cta_t_map, // CTA tid -> logical TMA tid
Layout<VShape,VStride> const& cta_v_map) // CTA vid -> gmem coord
{
constexpr int num_spatial_modes = rank<0>(GLayout{}) - 1;
return make_tma_copy_im2col(copy_op, gtensor, slayout, cta_t_map, cta_v_map,
append<num_spatial_modes>(Stride<_0>{}, Int<0>{}), // lower_corner_whd
append<num_spatial_modes>(Stride<_0>{}, Int<0>{}), // upper_corner_whd
append<num_spatial_modes>(Stride<_0>{}, Int<0>{}), // lower_padding_whd
append<num_spatial_modes>(Stride<_0>{}, Int<0>{}), // upper_padding_whd
append<num_spatial_modes>(Stride<_1>{}, Int<1>{}), // stride_whd
append<num_spatial_modes>(Stride<_0>{}, Int<0>{}), // lower_srt
append<num_spatial_modes>(Stride<_1>{}, Int<1>{})); // stride_srt
}
} // namespace detail
template <class CopyOp,
class Engine0, class Layout0,
class SLayout,
class CTATiler,
class MulticastSize,
class LowerCornerStride,
class UpperCornerStride,
class LowerPaddingStride,
class UpperPaddingStride,
class TraversalStride,
class LowerSRTStride,
class DilationStride>
CUTE_HOST_RTC
auto
make_im2col_tma_copy(CopyOp const& copy_op,
Tensor<Engine0, Layout0> const& tensor_cwhdn,
SLayout const& slayout,
CTATiler const& cta_tiler,
MulticastSize const& multicast_size,
LowerCornerStride const& lower_corner_whd,
UpperCornerStride const& upper_corner_whd,
LowerPaddingStride const& lower_padding_whd,
UpperPaddingStride const& upper_padding_whd,
TraversalStride const& stride_whd,
LowerSRTStride const& lower_srt,
DilationStride const& stride_srt)
{
auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler);
auto cta_t_tile = make_layout(multicast_size);
return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn,
slayout, cta_t_tile, cta_v_tile,
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt);
}
// Explicit default for multicast_size
template <class CopyOp,
class Engine0, class Layout0,
class SLayout,
class CTATiler,
class LowerCornerStride,
class UpperCornerStride,
class LowerPaddingStride,
class UpperPaddingStride,
class TraversalStride,
class LowerSRTStride,
class DilationStride>
CUTE_HOST_RTC
auto
make_im2col_tma_copy(CopyOp const& copy_op,
Tensor<Engine0, Layout0> const& tensor_cwhdn,
SLayout const& slayout,
CTATiler const& cta_tiler,
LowerCornerStride const& lower_corner_whd,
UpperCornerStride const& upper_corner_whd,
LowerPaddingStride const& lower_padding_whd,
UpperPaddingStride const& upper_padding_whd,
TraversalStride const& stride_whd,
LowerSRTStride const& lower_srt,
DilationStride const& stride_srt)
{
return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{},
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt);
}
// Explicit default for cta_tiler and multicast_size
template <class CopyOp,
class Engine0, class Layout0,
class SLayout,
class LowerCornerStride,
class UpperCornerStride,
class LowerPaddingStride,
class UpperPaddingStride,
class TraversalStride,
class LowerSRTStride,
class DilationStride>
CUTE_HOST_RTC
auto
make_im2col_tma_copy(CopyOp const& copy_op,
Tensor<Engine0, Layout0> const& tensor_cwhdn,
SLayout const& slayout,
LowerCornerStride const& lower_corner_whd,
UpperCornerStride const& upper_corner_whd,
LowerPaddingStride const& lower_padding_whd,
UpperPaddingStride const& upper_padding_whd,
TraversalStride const& stride_whd,
LowerSRTStride const& lower_srt,
DilationStride const& stride_srt)
{
return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{},
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt);
}
// No offsets copy.
template <class CopyOp,
class Engine0, class Layout0,
class SLayout,
class CTATiler,
class MulticastSize>
CUTE_HOST_RTC
auto
make_im2col_tma_copy(CopyOp const& copy_op,
Tensor<Engine0, Layout0> const& tensor_cwhdn,
SLayout const& slayout,
CTATiler const& cta_tiler,
MulticastSize const& multicast_size)
{
auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler);
auto cta_t_tile = make_layout(multicast_size);
return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, slayout, cta_t_tile, cta_v_tile);
}
// Explicit default for multicast_size
template <class CopyOp,
class Engine0, class Layout0,
class SLayout,
class CTATiler>
CUTE_HOST_RTC
auto
make_im2col_tma_copy(CopyOp const& copy_op,
Tensor<Engine0, Layout0> const& tensor_cwhdn,
SLayout const& slayout,
CTATiler const& cta_tiler)
{
return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{});
}
// Explicit default for cta_tiler and multicast_size
template <class CopyOp,
class Engine0, class Layout0,
class SLayout>
CUTE_HOST_RTC
auto
make_im2col_tma_copy(CopyOp const& copy_op,
Tensor<Engine0, Layout0> const& tensor_cwhdn,
SLayout const& slayout)
{
return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{});
}
} // namespace cute

View File

@ -38,6 +38,8 @@
#include <cute/atom/copy_traits.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/algorithm/prefetch.hpp>
#include <cute/numeric/integral_ratio.hpp>
namespace cute
@ -53,77 +55,55 @@ struct AuxTmaParams {
static_assert(is_static<TmaSwizzle>::value);
};
// Utility for unpacking TMA_LOAD arguments into a CopyOp
template <class CopyOp>
struct TMA_LOAD_Unpack
{
template <class... Args,
class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits<CopyOp, Args...> const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
auto src_coord = src.data().coord_;
if constexpr (detail::is_prefetch<CopyOp>) {
return detail::copy_explode<CopyOp>(traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
src_coord, tuple_seq<decltype(src_coord)>{});
} else {
static_assert(is_smem<TD>::value, "SM90_TMA_LOAD requires the destination be shared memory.");
void* dst_ptr = cute::raw_pointer_cast(dst.data());
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr);
#endif
return detail::copy_explode<CopyOp>(traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
make_tuple(dst_ptr), seq<0>{},
src_coord, tuple_seq<decltype(src_coord)>{});
}
}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD ///////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {};
// The executable SM90_TMA_LOAD with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD arguments
TmaDescriptor const& tma_desc_;
uint64_t& tma_load_mbar_;
template <class Coord, int... Is>
CUTE_HOST_DEVICE constexpr
void
copy_unpack_(void const* const dst_ptr,
Coord const& src_coord, seq<Is...>) const
{
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr);
#endif
SM90_TMA_LOAD::copy(&tma_desc_, tma_load_mbar_,
dst_ptr, get<Is>(src_coord)...);
}
// This is the copy_unpack dispatch for this Copy_Traits
// Src needs to be a gmem tensor with TmaCoordIterator .data()
// Dst needs to be a smem tensor
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_TMA_LOAD");
traits.copy_unpack_(cute::raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq<decltype(src.data().coord_)>{});
}
};
// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
@ -144,7 +124,7 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {tma_desc_, tma_mbar};
return {{}, {&tma_desc_, &tma_mbar}};
}
// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
@ -152,7 +132,7 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {*new_tma_desc, tma_mbar};
return {{}, {new_tma_desc, &tma_mbar}};
}
// Generate the TMA coord tensor
@ -173,72 +153,65 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM90_TMA_LOAD with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_OP>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD arguments
tuple<
TmaDescriptor const*,
uint64_t* // smem mbarrier
> const opargs_;
};
// The prefetch for SM90_TMA_LOAD with tma_desc
template <class NumBitsPerTMA, class... Args>
struct Copy_Traits<SM90_TMA_LOAD::PREFETCH, NumBitsPerTMA, Args...>
: TMA_LOAD_Unpack<SM90_TMA_LOAD::PREFETCH>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD::PREFETCH arguments
tuple<TmaDescriptor const*> const opargs_;
// Construct with any other Traits' TMA Desc
template <class... CopyArgs>
CUTE_HOST_DEVICE
Copy_Traits(Copy_Traits<CopyArgs...> const& traits)
: opargs_({&traits.tma_desc_}) {}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {};
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_MULTICAST arguments
TmaDescriptor const& tma_desc_;
uint64_t& tma_load_mbar_;
uint16_t const& multicast_mask_;
template <class Coord, int... Is>
CUTE_HOST_DEVICE constexpr
void
copy_unpack_(void const* const dst_ptr,
Coord const& src_coord, seq<Is...>) const
{
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr);
#endif
SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, tma_load_mbar_, multicast_mask_,
dst_ptr, get<Is>(src_coord)...);
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST");
traits.copy_unpack_(cute::raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq<decltype(src.data().coord_)>{});
}
};
// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar
// Use .with(tma_mbar, multicast_mask) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
@ -258,14 +231,14 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {tma_desc_, tma_load_mbar, multicast_mask};
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask}};
}
// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {*new_tma_desc, tma_load_mbar, multicast_mask};
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask}};
}
// Generate the TMA coord tensor
@ -286,6 +259,27 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_MULTICAST_OP>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_MULTICAST arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint16_t // multicast mask
> const opargs_;
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_STORE //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
@ -293,6 +287,68 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
// The executable SM90_TMA_STORE with tma_desc
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_STORE arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor
void const* const desc_ptr = &(traits.tma_desc_);
void const* const src_ptr = cute::raw_pointer_cast(src.data());
auto dst_coord = dst.data().coord_;
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr);
#endif
return detail::copy_explode<SM90_TMA_STORE>(make_tuple(desc_ptr, src_ptr), seq<0,1>{},
dst_coord, tuple_seq<decltype(dst_coord)>{});
}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_REDUCE_ADD //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
// The executable SM90_TMA_REDUCE_ADD with tma_desc
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_REDUCE_ADD, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
@ -304,7 +360,7 @@ struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_STORE arguments
// SM90_TMA_REDUCE_ADD arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
@ -339,7 +395,7 @@ struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr);
#endif
SM90_TMA_STORE::copy(&tma_desc_,
SM90_TMA_REDUCE_ADD::copy(&tma_desc_,
src_ptr, get<Is>(dst_coord)...);
}
@ -354,8 +410,8 @@ struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_REDUCE_ADD");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor
traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq<decltype(dst.data().coord_)>{});
}
@ -383,6 +439,13 @@ struct Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, OpArgs...>
// 0: uint64_t* bulk_load_memory_barrier
cute::tuple<OpArgs...> bulk_load_mbar_;
// Record the memory barrier for the instruction
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, uint64_t*>
with(uint64_t& bulk_mbar) const {
return {{&bulk_mbar}};
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
@ -395,15 +458,29 @@ struct Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, OpArgs...>
"Extra arguments not set. Set .with() before use.");
static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_BULK_COPY_G2S");
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_BULK_COPY_G2S");
SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), *get<0>(traits.bulk_load_mbar_),
SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), get<0>(traits.bulk_load_mbar_),
raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8));
}
};
// Record the memory barrier for the instruction
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, uint64_t*>
with(uint64_t& bulk_mbar) const {
return {{&bulk_mbar}};
template <class NumBitsPerTMA, class... Args>
struct Copy_Traits<SM90_BULK_COPY_G2S::PREFETCH, NumBitsPerTMA, Args...>
: Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA>
{
template <class... CopyArgs>
CUTE_HOST_DEVICE
Copy_Traits(Copy_Traits<CopyArgs...> const& traits) {}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_BULK_PREFETCH");
SM90_BULK_COPY_G2S::PREFETCH::copy(raw_pointer_cast(src.data()), int32_t(NumBitsPerTMA::value / 8));
}
};
@ -653,7 +730,7 @@ template <class GEngine, class GLayout,
CUTE_HOST_DEVICE constexpr
void
fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType
TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s)
TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s)
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
{
@ -663,7 +740,7 @@ fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, /
using TmaInternalType = typename GEngine::value_type;
constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value;
static_assert(TmaRank >= tma_rank);
auto gmem_shape = shape(gtensor);
auto gmem_stride = stride(gtensor);
// Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides
@ -703,12 +780,12 @@ template <class GEngine, class GLayout,
class ShapeT, size_t TmaRank>
CUTE_HOST_DEVICE constexpr
void
fill_tma_gmem_shape_stride(Copy_Traits<Op,Bits,Aux> const& tma_traits,
fill_tma_gmem_shape_stride(Copy_Traits<Op,Bits,Aux> const& tma_traits,
Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
{
return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}),
return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}),
gmem_prob_shape, gmem_prob_stride);
}
@ -824,7 +901,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
// Construct the descriptor
//
TmaDescriptor tma_desc = {0};
TmaDescriptor tma_desc{};
//
// TMA general info
@ -897,7 +974,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
if constexpr (decltype(rank<j>(tma_gmem_basis_stride) == Int<1>{})::value) {
return E<j>{}; // Return TMA Coord basis -- known scale of Int<1>{}
} else {
int32_t scale = ceil_div(int32_t(di * sizeof_bits_v<TmaInternalType> / cute::max(gmem_prob_stride[j], 16)), 8);
int32_t scale = ceil_div(int32_t(di * sizeof_bits_v<TmaInternalType> / cute::max(gmem_prob_stride[j], uint64_t{16})), 8);
return E<j>{} * scale; // Return TMA Coord basis -- with a dynamic scale factor
}
}
@ -948,7 +1025,7 @@ make_tma_copy_atom(CopyOp,
// Construct the Copy_Traits
//
constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits<TmaInternalType>::value;
constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits_v<TmaInternalType>;
using Traits = Copy_Traits<CopyOp, cute::C<num_bits_per_tma>, decltype(aux_params)>;
using Atom = Copy_Atom<Traits, typename GEngine::value_type>;
@ -1105,6 +1182,14 @@ make_tma_copy(CopyOp const& copy_op,
CTA_Tiler const& cta_tiler,
Cluster_Size const& cluster_size)
{
if constexpr (cute::is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL> ||
cute::is_same_v<CopyOp, SM90_TMA_STORE_IM2COL>) {
return make_im2col_tma_copy(copy_op,
gtensor,
slayout,
cta_tiler,
cluster_size);
} else {
auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler);
auto cta_t_tile = make_layout(cluster_size);
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
@ -1112,6 +1197,7 @@ make_tma_copy(CopyOp const& copy_op,
return detail::make_tma_copy_tiled<TmaType>(copy_op,
gtensor, slayout,
cta_t_tile, cta_v_tile);
}
}
// Explicit defaulting
@ -1179,9 +1265,11 @@ auto
tma_partition(Copy_Atom<Args...> const& copy_atom,
CtaCoord const& cta_coord,
Layout<TShape,TStride> const& cta_layout, // T: CTA coord -> logical multicast id
Tensor<SEngine,SLayout> const& stensor, // SMEM Tensor (TMATile, Iter)
Tensor<GEngine,GLayout> const& gtensor) // GMEM Tensor (TMATile, Iter)
Tensor<SEngine,SLayout> const& stensor, // SMEM Tensor (TMATile, Rest...)
Tensor<GEngine,GLayout> const& gtensor) // GMEM Tensor (TMATile, Rest...)
{
CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(gtensor));
// Invert the smem to get the largest contiguous vector in the smem layout
Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor)));
// Scale that up to cover all of the smem_coords
@ -1189,14 +1277,19 @@ tma_partition(Copy_Atom<Args...> const& copy_atom,
// Factor out the single-instrucion portion
Layout tma_layout_v = make_layout(Int<Copy_Atom<Args...>::NumValSrc>{});
Layout layout_V = logical_divide(layout_v, tma_layout_v);
auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v));
// Append with _ until we cover all Rest... modes
auto glayout_V = append<rank_v<decltype(gtensor)>>(layout_V, _);
auto slayout_V = append<rank_v<decltype(stensor)>>(layout_V, _);
// Transform tile mode and coalesce
Tensor gtensor_v = coalesce(gtensor.compose(layout_V, _), Shape<Shape<_1,_1>,_1>{}); // ((TMA,TMA_Iter),Iter)
Tensor stensor_v = coalesce(stensor.compose(layout_V, _), Shape<Shape<_1,_1>,_1>{}); // ((TMA,TMA_Iter),Iter)
Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape<Shape<_1,_1>>{}); // ((TMA,TMA_Iter), Rest...)
Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape<Shape<_1,_1>>{}); // ((TMA,TMA_Iter), Rest...)
#if 0
if (thread0()) {
print("gtensor : "); print(gtensor); print("\n");
print("stensor : "); print(stensor); print("\n");
print("layout_V : "); print(layout_V); print("\n");
print("gtensor_v : "); print(gtensor_v); print("\n");
print("stensor_v : "); print(stensor_v); print("\n");
@ -1205,11 +1298,15 @@ tma_partition(Copy_Atom<Args...> const& copy_atom,
// Restride the cta-into-tma-instr layout
Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout);
Layout tma_layout_tv = make_layout(tma_layout_t, tma_layout_v);
auto tma_layout_tv = make_tile(make_tile(make_layout(tma_layout_t, tma_layout_v), _));
// Append with _ until we cover all Rest... modes
auto gtma_layout_tv = append<rank_v<decltype(gtensor)>>(tma_layout_tv, _);
auto stma_layout_tv = append<rank_v<decltype(stensor)>>(tma_layout_tv, _);
// Transform TMA mode
Tensor gtensor_tv = gtensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter)
Tensor stensor_tv = stensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter)
Tensor gtensor_tv = gtensor_v.compose(gtma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...)
Tensor stensor_tv = stensor_v.compose(stma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...)
#if 0
if (thread0()) {
@ -1219,9 +1316,11 @@ tma_partition(Copy_Atom<Args...> const& copy_atom,
}
#endif
// Slice and group Frg,TMA_Iter and return
auto c = make_coord(make_coord(make_coord(cta_coord, _), _), _);
return cute::make_tuple(group_modes<0,2>(gtensor_tv(c)), group_modes<0,2>(stensor_tv(c)));
auto c = make_coord(make_coord(make_coord(cta_coord, _), _));
auto c_s = append<rank_v<decltype(stensor_tv)>>(c, _);
auto c_g = append<rank_v<decltype(gtensor_tv)>>(c, _);
return cute::make_tuple(group_modes<0,2>(gtensor_tv(c_g)), group_modes<0,2>(stensor_tv(c_s)));
}
} // end namespace cute

View File

@ -35,7 +35,6 @@
#include <cute/arch/mma.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/tensor.hpp>
#include <cute/util/type_traits.hpp>
@ -78,7 +77,7 @@ struct MMA_Atom<MMA_Traits<Args...>>
CUTE_HOST_DEVICE
auto
with(TraitsArgs&&... args) const {
auto traits = Traits::with(std::forward<TraitsArgs>(args)...);
auto traits = Traits::with(static_cast<TraitsArgs&&>(args)...);
return MMA_Atom<decltype(traits)>{traits};
}
@ -157,7 +156,7 @@ struct MMA_Atom<MMA_Traits<Args...>>
// If the intended FrgTypeA is a view (of the current tensor), forward the whole
static_assert(is_same<ValTypeA, typename remove_cvref_t<ATensor>::value_type>::value
, "Expecting ValTypeA type");
return make_tensor<FrgTypeA>(std::forward<ATensor>(atensor));
return make_tensor<FrgTypeA>(static_cast<ATensor&&>(atensor));
} else {
// Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout
return make_fragment_like<FrgTypeA>(atensor);
@ -179,7 +178,7 @@ struct MMA_Atom<MMA_Traits<Args...>>
// If the intended FrgTypeB is a view (of the current tensor), forward the whole
static_assert(is_same<ValTypeB, typename remove_cvref_t<BTensor>::value_type>::value
, "Expecting ValTypeB type");
return make_tensor<FrgTypeB>(std::forward<BTensor>(btensor));
return make_tensor<FrgTypeB>(static_cast<BTensor&&>(btensor));
} else {
// Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout
return make_fragment_like<FrgTypeB>(btensor);
@ -213,7 +212,7 @@ struct TiledMMA : MMA_Atom
static_assert( rank_v<AtomLayoutMNK> == 3, "TiledMMA requires rank-3 AtomLayoutMNK");
static_assert( rank_v<PermutationMNK> == 3, "TiledMMA requires rank-3 PermutationMNK");
static_assert( is_tile<PermutationMNK>::value, "TiledMMA requires independent permutations of MNK.");
static_assert( is_tuple<PermutationMNK>::value, "TiledMMA requires independent permutations of MNK.");
static_assert(is_static<PermutationMNK>::value, "TiledMMA requires static permutations of MNK.");
using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
@ -391,7 +390,7 @@ struct TiledMMA : MMA_Atom
} else {
return cute::max(core_size, perm_size);
}
CUTE_GCC_UNREACHABLE;
}
@ -517,7 +516,7 @@ struct ThrMMA : TiledMMA
auto
partition_C(CTensor&& ctensor) const
{
auto thr_tensor = make_tensor(std::forward<CTensor>(ctensor).data(), this->thrfrg_C(ctensor.layout()));
auto thr_tensor = make_tensor(static_cast<CTensor&&>(ctensor).data(), this->thrfrg_C(ctensor.layout()));
auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_)));
return thr_tensor(thr_vmn, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
@ -528,7 +527,7 @@ struct ThrMMA : TiledMMA
auto
partition_A(ATensor&& atensor) const
{
auto thr_tensor = make_tensor(std::forward<ATensor>(atensor).data(), this->thrfrg_A(atensor.layout()));
auto thr_tensor = make_tensor(static_cast<ATensor&&>(atensor).data(), this->thrfrg_A(atensor.layout()));
auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_)));
return thr_tensor(thr_vmk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
@ -539,7 +538,7 @@ struct ThrMMA : TiledMMA
auto
partition_B(BTensor&& btensor) const
{
auto thr_tensor = make_tensor(std::forward<BTensor>(btensor).data(), this->thrfrg_B(btensor.layout()));
auto thr_tensor = make_tensor(static_cast<BTensor&&>(btensor).data(), this->thrfrg_B(btensor.layout()));
auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_)));
return thr_tensor(thr_vnk, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));
@ -744,7 +743,15 @@ print(ThrMMA<TiledMMA, ThrVMNK> const& thr_mma)
template <class... Args>
CUTE_HOST_DEVICE
auto
void
print_latex(MMA_Atom<Args...> const& mma_atom)
{
print_latex(make_tiled_mma(mma_atom));
}
template <class... Args>
CUTE_HOST_DEVICE
void
print_latex(TiledMMA<Args...> const& mma)
{
auto layout_and_thrid_C = mma.get_layoutC_MN();
@ -764,7 +771,7 @@ print_latex(TiledMMA<Args...> const& mma)
layoutB_NK, thrID_B);
}
// MNK MMA Layout to console printer -- 8-value color coded by thread
// MNK MMA Layout to console printer
template <class LayoutC, class ThrIDC,
class LayoutA, class ThrIDA,
class LayoutB, class ThrIDB>

View File

@ -32,12 +32,8 @@
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/layout.hpp>
#include <cute/numeric/integer_subbyte.hpp>
#include <cutlass/numeric_types.h>
#include <cute/numeric/numeric_types.hpp>
namespace cute
{

View File

@ -98,16 +98,16 @@
# endif
#endif
#ifdef _MSC_VER
#if defined(_MSC_VER)
// Provides support for alternative operators 'and', 'or', and 'not'
#include <iso646.h>
# include <iso646.h>
#endif // _MSC_VER
#if defined(__CUDACC_RTC__)
#define CUTE_STL_NAMESPACE cuda::std
#define CUTE_STL_NAMESPACE_IS_CUDA_STD
# define CUTE_STL_NAMESPACE cuda::std
# define CUTE_STL_NAMESPACE_IS_CUDA_STD
#else
#define CUTE_STL_NAMESPACE std
# define CUTE_STL_NAMESPACE std
#endif
//
@ -115,9 +115,9 @@
//
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
# include <cuda/std/cassert>
#else
#include <cassert>
# include <cassert>
#endif
#define CUTE_STATIC_V(x) decltype(x)::value
@ -125,10 +125,11 @@
#define CUTE_STATIC_ASSERT static_assert
#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__)
// Fail and print a message. Typically used for notification of a compiler misconfiguration.
#if defined(__CUDA_ARCH__)
# define CUTE_RUNTIME_ASSERT(x) __brkpt()
# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __brkpt()
#else
# define CUTE_RUNTIME_ASSERT(x) assert(0 && x)
# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x)
#endif
//
@ -136,9 +137,9 @@
//
#if !defined(__CUDACC_RTC__)
#include <cstdio>
#include <iostream>
#include <iomanip>
# include <cstdio>
# include <iostream>
# include <iomanip>
#endif
//
@ -151,13 +152,8 @@
// Basic types
//
#include <cute/numeric/int.hpp>
#include <cute/numeric/real.hpp>
#include <cute/numeric/half.hpp>
#include <cute/numeric/float8.hpp>
#include <cute/numeric/bfloat.hpp>
#include <cute/numeric/tfloat.hpp>
#include <cute/numeric/complex.hpp>
#include <cute/numeric/numeric_types.hpp>
//
// Debugging utilities
//

View File

@ -32,7 +32,7 @@
#include <cute/config.hpp>
#include <cute/numeric/int.hpp>
#include <cute/numeric/numeric_types.hpp>
#include <cute/numeric/math.hpp>
namespace cute

View File

@ -355,7 +355,7 @@ void clear(array<T,N>& a)
a.fill(T(0));
}
template <typename T, size_t N>
template <class T, size_t N>
CUTE_HOST_DEVICE constexpr
void fill(array<T,N>& a, T const& value)
{
@ -370,14 +370,14 @@ void swap(array<T,N>& a, array<T,N>& b)
}
/// @return A cute::array of the elements of @c t in reverse order.
template <typename T, size_t N>
CUTE_HOST_DEVICE constexpr cute::array<T, N>
reverse(cute::array<T, N> const& t) {
template <class T, size_t N>
CUTE_HOST_DEVICE constexpr
cute::array<T,N> reverse(cute::array<T,N> const& t)
{
if constexpr (N == 0u) {
return t;
}
else {
cute::array<T, N> t_r{};
} else {
cute::array<T,N> t_r{};
for (size_t k = 0; k < N; ++k) {
t_r[k] = t[N - k - 1];
}
@ -422,7 +422,7 @@ CUTE_HOST_DEVICE constexpr
T&& get(array<T,N>&& a)
{
static_assert(I < N, "Index out of range");
return std::move(a[I]);
return cute::move(a[I]);
}
} // end namespace cute
@ -442,12 +442,12 @@ struct tuple_element<I, cute::array<T,N>>
};
template <class T, size_t N>
struct tuple_size<const cute::array<T,N>>
struct tuple_size<cute::array<T,N> const>
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
{};
template <size_t I, class T, size_t N>
struct tuple_element<I, const cute::array<T,N>>
struct tuple_element<I, cute::array<T,N> const>
{
using type = T;
};
@ -462,7 +462,7 @@ namespace std
template <class... _Tp>
struct tuple_size;
template<size_t _Ip, class... _Tp>
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif
@ -478,12 +478,12 @@ struct tuple_element<I, cute::array<T,N>>
};
template <class T, size_t N>
struct tuple_size<const cute::array<T,N>>
struct tuple_size<cute::array<T,N> const>
: CUTE_STL_NAMESPACE::integral_constant<size_t, N>
{};
template <size_t I, class T, size_t N>
struct tuple_element<I, const cute::array<T,N>>
struct tuple_element<I, cute::array<T,N> const>
{
using type = T;
};

View File

@ -37,29 +37,20 @@
#include <cute/config.hpp>
#include <cute/numeric/int.hpp> // sizeof_bits
#include <cute/numeric/numeric_types.hpp>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
template <class T>
struct is_subbyte {
static constexpr bool value = sizeof_bits_v<T> < 8;
};
template <class T>
constexpr bool is_subbyte_v = is_subbyte<T>::value;
//
// Underlying subbyte storage type
//
template <class T>
using subbyte_storage_type_t = conditional_t<(sizeof_bits_v<T> <= 8), uint8_t,
conditional_t<(sizeof_bits_v<T> <= 16), uint16_t,
conditional_t<(sizeof_bits_v<T> <= 32), uint32_t,
conditional_t<(sizeof_bits_v<T> <= 64), uint64_t,
conditional_t<(sizeof_bits_v<T> <= 128), uint128_t,
using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v<T> <= 8), uint8_t,
conditional_t<(cute::sizeof_bits_v<T> <= 16), uint16_t,
conditional_t<(cute::sizeof_bits_v<T> <= 32), uint32_t,
conditional_t<(cute::sizeof_bits_v<T> <= 64), uint64_t,
conditional_t<(cute::sizeof_bits_v<T> <= 128), uint128_t,
T>>>>>;
template <class T> struct subbyte_iterator;
@ -183,6 +174,11 @@ public:
operator element_type() const {
return get();
}
// Address
subbyte_iterator<T> operator&() const {
return {ptr_, idx_};
}
};
//
@ -314,7 +310,7 @@ public:
CUTE_HOST_DEVICE constexpr friend
auto recast_ptr(subbyte_iterator const& x) {
using NewT = conditional_t<(is_const_v<T>), NewT_ const, NewT_>;
if constexpr (is_subbyte<NewT>::value) { // Making subbyte_iter, preserve the subbyte idx
if constexpr (cute::is_subbyte_v<NewT>) { // Making subbyte_iter, preserve the subbyte idx
return subbyte_iterator<NewT>(x.ptr_, x.idx_);
} else { // Not subbyte, assume/assert subbyte idx 0
return reinterpret_cast<NewT*>(raw_pointer_cast(x));
@ -323,7 +319,7 @@ public:
}
CUTE_HOST_DEVICE friend void print(subbyte_iterator x) {
printf("subptr[%db](%p.%u)", int(sizeof_bits<T>::value), x.ptr_, x.idx_);
printf("subptr[%db](%p.%u)", int(sizeof_bits_v<T>), x.ptr_, x.idx_);
}
};
@ -369,8 +365,8 @@ private:
public:
CUTE_HOST_DEVICE constexpr
array_subbyte() {}
constexpr
array_subbyte() = default;
CUTE_HOST_DEVICE constexpr
array_subbyte(array_subbyte const& x) {
@ -562,7 +558,7 @@ CUTE_HOST_DEVICE constexpr
T&& get(array_subbyte<T,N>&& a)
{
static_assert(I < N, "Index out of range");
return std::move(a[I]);
return cute::move(a[I]);
}
} // end namespace cute
@ -608,7 +604,7 @@ namespace std
template <class... _Tp>
struct tuple_size;
template<size_t _Ip, class... _Tp>
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif

View File

@ -37,7 +37,7 @@
#include <cute/config.hpp>
#include <cute/numeric/int.hpp> // uint_bit_t
#include <cute/numeric/numeric_types.hpp> // uint_bit_t
namespace cute
{

View File

@ -96,11 +96,11 @@ uint32_t&& get(dim3&& a)
{
static_assert(I < 3, "Index out of range");
if constexpr (I == 0) {
return std::move(a.x);
return cute::move(a.x);
} else if constexpr (I == 1) {
return std::move(a.y);
return cute::move(a.y);
} else if constexpr (I == 2) {
return std::move(a.z);
return cute::move(a.z);
}
CUTE_GCC_UNREACHABLE;
@ -162,11 +162,11 @@ uint32_t&& get(uint3&& a)
{
static_assert(I < 3, "Index out of range");
if constexpr (I == 0) {
return std::move(a.x);
return cute::move(a.x);
} else if constexpr (I == 1) {
return std::move(a.y);
return cute::move(a.y);
} else if constexpr (I == 2) {
return std::move(a.z);
return cute::move(a.z);
}
CUTE_GCC_UNREACHABLE;

View File

@ -126,18 +126,14 @@ CUTE_HOST_DEVICE constexpr T& getv(EBO<N, T, false>& x)
template <size_t N, class T>
CUTE_HOST_DEVICE constexpr T&& getv(EBO<N, T, false>&& x)
{ return static_cast<T&&>(x.t_); }
{ return cute::move(x.t_); }
template <class IdxSeq, class... T>
struct TupleBase;
// Base class of cute::tuple.
// It inherits from EBO<i, t> for each (i, t) in (I..., T...).
// The actual storage (for nonempty t) lives in the base classes.
// index_sequence is a way to wrap up a sequence of zero or more
// compile-time integer values in a single type.
// We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice,
// as the type alias TupleBase below indicates.
// Base class of cute::tuple binds each element to an index
// by inheriting from EBO<i, t> for each (i, t) in (I..., T...).
// The storage (for nonempty t) lives in the base classes.
template <size_t... I, class... T>
struct TupleBase<index_sequence<I...>, T...>
: EBO<I,T>...
@ -169,11 +165,6 @@ struct TupleBase<index_sequence<I...>, T...>
//
// Inheriting from the above alias TupleBase
// causes MSVC 2022 build errors when assigning one tuple to another:
//
// illegal member initialization:
// 'TupleBase< /* template arguments */ >' is not a base or member
//
// Not using the alias or any kind of alias fixed the errors.
// In summary: this is verbose as a work-around for MSVC build errors.
template <class... T>
struct tuple : detail::TupleBase<make_index_sequence<sizeof...(T)>, T...>
@ -365,10 +356,10 @@ tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4,
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)..., get<I4>(t4)...);
}
template<class T0, class T1>
template <class T0, class T1>
struct tuple_cat_static;
template<class... T0s, class... T1s>
template <class... T0s, class... T1s>
struct tuple_cat_static<tuple<T0s...>, tuple<T1s...>> {
using type = tuple<T0s..., T1s...>;
};
@ -630,11 +621,8 @@ template <class Tuple, size_t... Is>
CUTE_HOST_DEVICE void print_tuple(Tuple const& t,
index_sequence<Is...>, char s = '(', char e = ')')
{
using eat = int[];
using cute::print;
(void) eat {(print(s), 0),
(print(Is == 0 ? "" : ","), print(get<Is>(t)), 0)...,
(print(e), 0)};
((void(print(Is == 0 ? s : ',')), void(print(get<Is>(t)))), ...); print(e);
}
#if !defined(__CUDACC_RTC__)
@ -642,11 +630,8 @@ template <class Tuple, std::size_t... Is>
CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t,
index_sequence<Is...>, char s = '(', char e = ')')
{
using eat = int[];
(void) eat {(void(os << s), 0),
(void(os << (Is == 0 ? "" : ",") << get<Is>(t)), 0)...,
(void(os << e), 0)};
return os;
(void(os << (Is == 0 ? s : ',') << get<Is>(t)), ...);
return os << e;
}
#endif // !defined(__CUDACC_RTC__)
@ -707,7 +692,7 @@ namespace std
template <class... _Tp>
struct tuple_size;
template<size_t _Ip, class... _Tp>
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif

View File

@ -108,7 +108,7 @@ namespace std
template <class... _Tp>
struct tuple_size;
template<size_t _Ip, class... _Tp>
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif

View File

@ -218,7 +218,7 @@ static constexpr int depth_v = depth_t<Tuple>::value;
// product
//
// Implementation of product (see below) as a function object
// Implementation of product as a function object
struct Product
{
template <class IntTuple>
@ -232,7 +232,7 @@ struct Product
} else {
return cute::transform_apply(a, Product{}, multiplies_unary_lfold{});
}
} else {
} else if constexpr (cute::is_integral<IntTuple>::value) {
return a;
}
@ -248,7 +248,7 @@ CUTE_HOST_DEVICE constexpr
auto
product_each(Tuple const& t)
{
return transform(wrap(t), [](auto const& x) { return product(x); });
return transform(wrap(t), product);
}
// Take the product of Tuple at the leaves of TupleG
@ -394,7 +394,7 @@ shape_div(IntTupleA const& a, IntTupleB const& b)
static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure");
return C<shape_div(IntTupleA::value, IntTupleB::value)>{};
} else { // int int
//assert(a % b == 0 || b % a == 0); // Wave dynamic assertion
//assert(a % b == 0 || b % a == 0); // Waive dynamic assertion
return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero
}
@ -855,7 +855,10 @@ elem_geq(T const& t, U const& u) {
return !elem_less(t, u);
}
namespace detail {
/** Increment a (dynamic) coord lexicographically within a shape
* @pre is_congruent<Coord,Shape>::value
* \code
* auto shape = make_shape(1,2,make_shape(2,3),3);
*
@ -866,44 +869,26 @@ elem_geq(T const& t, U const& u) {
* assert(i == size(shape));
* \endcode
*/
template <class Coord, class Shape>
template <int I = 0, class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape);
namespace detail {
template <class Coord, class Shape, int I0, int... Is>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape, seq<I0,Is...>)
increment(Coord& coord, Shape const& shape)
{
cute::increment(get<I0>(coord), get<I0>(shape));
if constexpr (sizeof...(Is) != 0) {
if (back(get<I0>(coord)) == back(get<I0>(shape))) {
back(get<I0>(coord)) = 0;
increment(coord, shape, seq<Is...>{});
if constexpr (is_integral<Coord>::value) {
++coord;
} else {
increment(get<I>(coord), get<I>(shape));
if constexpr (I+1 < tuple_size<Coord>::value) {
if (back(get<I>(coord)) == back(get<I>(shape))) {
back(get<I>(coord)) = 0;
increment<I+1>(coord, shape);
}
}
}
}
} // end namespace detail
template <class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape)
{
if constexpr (is_integral<Coord>::value && is_integral<Shape>::value) {
++coord;
} else if constexpr (is_tuple<Coord>::value && is_tuple<Shape>::value) {
static_assert(tuple_size<Coord>::value == tuple_size<Shape>::value, "Mismatched ranks");
detail::increment(coord, shape, tuple_seq<Coord>{});
} else {
static_assert(sizeof(Coord) == 0, "Invalid parameters");
}
}
struct ForwardCoordIteratorSentinal
{};
@ -918,7 +903,7 @@ struct ForwardCoordIterator
Coord const& operator*() const { return coord; }
CUTE_HOST_DEVICE constexpr
ForwardCoordIterator& operator++() { increment(coord, shape); return *this; }
ForwardCoordIterator& operator++() { detail::increment(coord, shape); return *this; }
// Sentinel for the end of the implied range
CUTE_HOST_DEVICE constexpr

View File

@ -56,6 +56,9 @@ using Step = cute::tuple<Strides...>;
template <class... Coords>
using Coord = cute::tuple<Coords...>;
template <class... Layouts>
using Tile = cute::tuple<Layouts...>;
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Shape<Ts...>
@ -80,7 +83,17 @@ Coord<Ts...>
make_coord(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Tile<Ts...>
make_tile(Ts const&... t)
{
return {t...};
}
//
// Layout
//
template <class Shape, class Stride = LayoutLeft::Apply<Shape> >
struct Layout
@ -366,59 +379,56 @@ make_layout(Shape const& shape, GenRowMajor)
return make_layout(shape, compact_row_major(shape));
}
// Follow the same ordering induced by the strides, but make the layout compact
//
// Advanced Layout constructions
//
// Make a compact layout with shape @a shape and strides following the order induced by @a order.
// Dynamic values in @a order are ignored, considered large, and considered ordered from left to right.
// Example:
// make_ordered_layout(Shape<_2,_2,_2,_2>{}, Step<_0,_2,_3,_1>{})
// -> (_2,_2,_2,_2):(_1,_4,_8,_2)
// make_ordered_layout(make_shape(2,3,4,5), make_step(Int<2>{}, 67, 42, Int<50>{}))
// -> (2,3,4,5):(_1,10,30,2)
template <class Shape, class Order>
CUTE_HOST_DEVICE constexpr
auto
make_ordered_layout(Shape const& shape, Order const& order)
{
static_assert(is_static<Order>::value);
return make_layout(shape, compact_order(shape, order));
}
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
make_ordered_layout(Layout<Shape,Stride> const& layout)
{
return make_ordered_layout(layout.shape(), layout.stride());
}
// Make a layout of the same shape that is either ordered or colmajor depending on staticness
// Make a compact layout with the same shape as @a layout
// and strides following the order induced by @a layout.stride().
// Static-0 strides in the input @a layout are preserved in the output.
// Example:
// make_layout_like(Layout<Shape<_2,_2,_2,_2>, Stride<_0,_2,_4,_1>>{})
// -> (_2,_2,_2,_2):(_0,_2,_4,_1)
// make_layout_like(make_layout(make_shape(2,3,4,5), make_stride(Int<0>{},42,Int<1>{},Int<0>{})))
// -> (2,3,4,5):(_0,4,_1,_0)
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
make_layout_like(Layout<Shape,Stride> const& layout)
{
auto any_zero = any_of(layout.stride(), [](auto d) { return is_constant<0, decltype(d)>{}; });
if constexpr (any_zero) {
// If there are static-0 strides, then make a col-major layout that keeps those 0s
return make_layout(layout.shape(),
compact_col_major(filter_zeros(layout.stride(), layout.shape())));
} else
if constexpr (is_static<Shape>::value && is_static<Stride>::value) {
// If the layout is fully static, then make a layout that follows the same order as the strides
// Assumes the strides are unique
return make_ordered_layout(layout.shape(), layout.stride());
} else {
return make_layout(layout.shape());
}
CUTE_GCC_UNREACHABLE;
return make_layout(layout.shape(),
compact_order(filter_zeros(layout.stride(), layout.shape()), layout.stride()));
}
//
// Make a layout of the same shape,
// with mode-0 being colmajor then following the mode order in layout
//
// Make a compact layout with the same shape as @a layout
// and strides following the order induced by @a layout.stride(),
// except mode-0 is always stride-1 and generated column-major.
// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms
// so this generates the 0th mode with LayoutLeft regardless of the reference layout.
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Layout<Shape,Stride> const& layout)
{
constexpr int R = Layout<Shape,Stride>::rank;
if constexpr (R > 1 && is_static<Shape>::value && is_static<Stride>::value) {
return tiled_product(make_layout(shape<0>(layout)), make_ordered_layout(take<1,R>(layout)));
if constexpr (R > 1 && is_static<Shape>::value) {
return tiled_product(make_layout(shape<0>(layout)),
make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride())));
} else {
return make_layout(layout.shape());
}
@ -458,11 +468,11 @@ CUTE_HOST_DEVICE constexpr
auto
get(Layout<Shape,Stride> const& layout)
{
return make_layout(get<Is...>(layout.shape()),
return make_layout(get<Is...>(layout.shape()),
get<Is...>(layout.stride()));
}
// Return a new layout with only the modes in the range [B,E)
// Return a new layout with only the modes in the range [B,E)
template <int B, int E, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
@ -470,7 +480,7 @@ take(Layout<Shape,Stride> const& layout)
{
static_assert(B < E, "take: empty range error");
static_assert(0 <= B && E <= Layout<Shape,Stride>::rank, "take: range out of bounds");
return make_layout(take<B,E>(layout.shape()),
return make_layout(take<B,E>(layout.shape()),
take<B,E>(layout.stride()));
}
@ -490,7 +500,7 @@ CUTE_HOST_DEVICE constexpr
auto
flatten(Layout<Shape,Stride> const& layout)
{
return make_layout(flatten(layout.shape()),
return make_layout(flatten(layout.shape()),
flatten(layout.stride()));
}
@ -1376,6 +1386,23 @@ logical_divide(Layout<LShape,LStride> const& layout,
CUTE_GCC_UNREACHABLE;
}
// Generalization of ceil_div for Layout lhs
// is effectively the "rest mode" of logical_divide.
// Occurs in the calculation of gridDim, for example, for generalized tilers
// Example:
// dim3 gridDim(size(ceil_div(problem_shape_M, cta_tiler_M)),
// size(ceil_div(problem_shape_N, cta_tiler_N)));
// This does not consider compositional acceptance, so it may be the case that
// ceil_div produces a result while logical_divide (and friends) do not.
template <class Target, class TShape, class TStride>
CUTE_HOST_DEVICE constexpr
auto
ceil_div(Target const& target,
Layout<TShape,TStride> const& tiler)
{
return complement(tiler, size(target));
}
//
// Convenience operator
// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y))
@ -1425,7 +1452,6 @@ flat_divide(Layout<LShape,LStride> const& layout,
// Logical product
//
// @post compatible()
template <class LShape, class LStride,
class TShape, class TStride>
CUTE_HOST_DEVICE constexpr
@ -1501,7 +1527,7 @@ flat_product(Layout<LShape,LStride> const& block,
//
// Rank-sensitive products
//
//
// blocked_product -- Reproduce a block over a tiler.
// Think of every element of "tiler" as a "block"
@ -1517,7 +1543,7 @@ blocked_product(Layout<TShape,TStride> const& block,
constexpr int R = cute::max(rank_v<TShape>, rank_v<UShape>);
auto result = logical_product(append<R>(block), append<R>(tiler));
return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat<R>(Int<1>{}));
}
@ -1545,7 +1571,7 @@ raked_product(Layout<TShape,TStride> const& block,
// @param block The layout to repeat
// @param trg_shape The target shape of the result
// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with.
// Defaults to GenColMajor, so @a layout will repeat
// Defaults to GenColMajor, so @a layout will repeat
// across the first mode first, the second mode second, etc
// E.g. Step<_2,_1,_3> will cause @a layout to repeat
// across the second mode first, the first mode second, and the third mode last.
@ -1659,7 +1685,7 @@ recast_layout(Layout<Shape,Stride> const& layout)
else if constexpr (scale::num == 1) {
return downcast<scale::den>(layout);
}
else if constexpr (scale::den == 1) {
else if constexpr (scale::den == 1) {
return upcast<scale::num>(layout);
}
else {

View File

@ -178,6 +178,16 @@ struct ComposedLayout : private cute::tuple<LayoutA, Offset, LayoutB> // EBO fo
tile(Layouts const&... layouts) const {
return tiled_divide(*this, make_tile(layouts...));
}
// Equality, return a static or dynamic boolean
template <class... Args>
CUTE_HOST_DEVICE constexpr
auto
operator==(ComposedLayout<Args...> const& other) const {
return this->layout_a() == other.layout_a() &&
this->layout_b() == other.layout_b() &&
this->offset() == other.offset();
}
};
template <class A, class O, class B>

View File

@ -63,6 +63,9 @@ struct ArithmeticTuple : tuple<T...>
template <class... T>
struct is_tuple<ArithmeticTuple<T...>> : true_type {};
template <class... Ts>
struct is_flat<ArithmeticTuple<Ts...>> : is_flat<tuple<Ts...>> {};
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
@ -108,16 +111,45 @@ template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
return t + ArithmeticTuple<U...>(u);
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
return ArithmeticTuple<T...>(t) + u;
}
// Subtraction
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator-(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator-(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
return t - ArithmeticTuple<U...>(u);
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator-(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
return ArithmeticTuple<T...>(t) - u;
}
// Negation
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
operator-(ArithmeticTuple<T...> const& t) {
return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
//
@ -128,7 +160,7 @@ template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<U...> const&
operator+(C<t>, ArithmeticTuple<U...> const& u) {
static_assert(t == 0, "Artihmetic tuple op+ error!");
static_assert(t == 0, "Arithmetic tuple op+ error!");
return u;
}
@ -136,7 +168,23 @@ template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<T...> const&
operator+(ArithmeticTuple<T...> const& t, C<u>) {
static_assert(u == 0, "Artihmetic tuple op+ error!");
static_assert(u == 0, "Arithmetic tuple op+ error!");
return t;
}
template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<U...> const&
operator-(C<t>, ArithmeticTuple<U...> const& u) {
static_assert(t == 0, "Arithmetic tuple op- error!");
return -u;
}
template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<T...> const&
operator-(ArithmeticTuple<T...> const& t, C<u>) {
static_assert(u == 0, "Arithmetic tuple op- error!");
return t;
}
@ -531,7 +579,7 @@ namespace std
template <class... _Tp>
struct tuple_size;
template<size_t _Ip, class... _Tp>
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif

View File

@ -30,9 +30,9 @@
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cutlass/complex.h>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/numeric_types.hpp>
namespace cute
{
@ -44,6 +44,9 @@ using cutlass::real;
using cutlass::imag;
using cutlass::conj;
template <class T>
static constexpr auto is_complex_v = is_complex<T>::value;
/// Fused multiply-add for complex numbers
template <class T>
CUTE_HOST_DEVICE constexpr

View File

@ -1,41 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <vector_types.h>
#include <cutlass/numeric_types.h>
namespace cute {
using cutlass::half_t;
} // end namespace cute

View File

@ -36,8 +36,7 @@
#include <cstdint>
#endif
#include <cute/numeric/integer_subbyte.hpp>
#include <cute/numeric/uint128.hpp>
#include <cutlass/numeric_types.h>
namespace cute
{
@ -46,16 +45,16 @@ namespace cute
// Signed integers
//
using int2_t = cute::int2b_t;
using int4_t = cute::int4b_t;
using int8_t = CUTE_STL_NAMESPACE::int8_t;
using int16_t = CUTE_STL_NAMESPACE::int16_t;
using int32_t = CUTE_STL_NAMESPACE::int32_t;
using int64_t = CUTE_STL_NAMESPACE::int64_t;
using int2_t = cutlass::int2b_t;
using int4_t = cutlass::int4b_t;
using CUTE_STL_NAMESPACE::int8_t;
using CUTE_STL_NAMESPACE::int16_t;
using CUTE_STL_NAMESPACE::int32_t;
using CUTE_STL_NAMESPACE::int64_t;
template <int N> struct int_bit;
template <> struct int_bit< 2> { using type = cute::int2b_t; };
template <> struct int_bit< 4> { using type = cute::int4b_t; };
template <> struct int_bit< 2> { using type = cutlass::int2b_t; };
template <> struct int_bit< 4> { using type = cutlass::int4b_t; };
template <> struct int_bit< 8> { using type = int8_t; };
template <> struct int_bit< 16> { using type = int16_t; };
template <> struct int_bit< 32> { using type = int32_t; };
@ -74,24 +73,24 @@ using int_byte_t = typename int_byte<N>::type;
// Unsigned integers
//
using uint1_t = cute::uint1b_t;
using uint2_t = cute::uint2b_t;
using uint4_t = cute::uint4b_t;
using uint8_t = CUTE_STL_NAMESPACE::uint8_t;
using uint16_t = CUTE_STL_NAMESPACE::uint16_t;
using uint32_t = CUTE_STL_NAMESPACE::uint32_t;
using uint64_t = CUTE_STL_NAMESPACE::uint64_t;
using uint128_t = cute::uint128_t;
using uint1_t = cutlass::uint1b_t;
using uint2_t = cutlass::uint2b_t;
using uint4_t = cutlass::uint4b_t;
using CUTE_STL_NAMESPACE::uint8_t;
using CUTE_STL_NAMESPACE::uint16_t;
using CUTE_STL_NAMESPACE::uint32_t;
using CUTE_STL_NAMESPACE::uint64_t;
using cutlass::uint128_t;
template <int N> struct uint_bit;
template <> struct uint_bit< 1> { using type = cute::uint1b_t; };
template <> struct uint_bit< 2> { using type = cute::uint2b_t; };
template <> struct uint_bit< 4> { using type = cute::uint4b_t; };
template <> struct uint_bit< 1> { using type = cutlass::uint1b_t; };
template <> struct uint_bit< 2> { using type = cutlass::uint2b_t; };
template <> struct uint_bit< 4> { using type = cutlass::uint4b_t; };
template <> struct uint_bit< 8> { using type = uint8_t; };
template <> struct uint_bit< 16> { using type = uint16_t; };
template <> struct uint_bit< 32> { using type = uint32_t; };
template <> struct uint_bit< 64> { using type = uint64_t; };
template <> struct uint_bit<128> { using type = cute::uint128_t; };
template <> struct uint_bit<128> { using type = cutlass::uint128_t; };
template <int N>
using uint_bit_t = typename uint_bit<N>::type;
@ -102,50 +101,4 @@ using uint_byte = uint_bit<8*N>;
template <int N>
using uint_byte_t = typename uint_byte<N>::type;
//
// sizeof_bytes
//
template <class T>
struct sizeof_bytes {
static constexpr size_t value = sizeof(T);
};
template <class T>
static constexpr int sizeof_bytes_v = sizeof_bytes<T>::value;
//
// sizeof_bits
//
template <class T>
struct sizeof_bits {
static constexpr size_t value = sizeof(T) * 8;
};
template <class T>
struct sizeof_bits<T const>: sizeof_bits<T> {};
template <>
struct sizeof_bits<void> {
static constexpr size_t value = 0;
};
template <>
struct sizeof_bits<bool> {
static constexpr size_t value = 1;
};
template <int Bits, bool Signed>
struct sizeof_bits<integer_subbyte<Bits,Signed>> {
static constexpr size_t value = Bits;
};
template <int Bits, bool Signed>
struct sizeof_bits<cutlass::integer_subbyte<Bits,Signed>> {
static constexpr size_t value = Bits;
};
template <class T>
static constexpr int sizeof_bits_v = sizeof_bits<T>::value;
} // namespace cute

View File

@ -1,235 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#endif
#include <cutlass/integer_subbyte.h>
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
namespace cute {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, bool Signed = true>
struct integer_subbyte
{
/// Storage type
using Storage = uint8_t;
/// Number of bits
static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte");
/// External type
using xint_t = typename conditional<Signed, int, unsigned>::type;
/// Bitmask for truncation from larger integers
static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1);
/// Bitmask for the sign bit
static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1));
//
// Data members
//
Storage storage;
//
// Methods
//
/// No operation
CUTE_HOST_DEVICE constexpr
integer_subbyte() {}
/// Conversion from integer type
CUTE_HOST_DEVICE constexpr
integer_subbyte(int value) // NOTE: Sign extension?
: storage(reinterpret_cast<Storage const&>(value) & bits_mask_) {}
CUTE_HOST_DEVICE constexpr
integer_subbyte(unsigned value)
: storage(reinterpret_cast<Storage const&>(value) & bits_mask_) {}
/// Convert to int or unsigned
CUTE_HOST_DEVICE constexpr
operator xint_t() const {
if (sign_mask_ & storage) { // Sign extend
return xint_t(storage) | ~xint_t(bits_mask_);
} else {
return xint_t(storage);
}
}
/// Equality
CUTE_HOST_DEVICE constexpr
bool operator==(integer_subbyte const& rhs) const {
return storage == rhs.storage;
}
/// Inequality
CUTE_HOST_DEVICE constexpr
bool operator!=(integer_subbyte const& rhs) const {
return storage != rhs.storage;
}
/// Less than or equal
CUTE_HOST_DEVICE constexpr
bool operator<=(integer_subbyte const& rhs) const {
if (sign_mask_ & storage) {
return !(rhs.storage < storage);
} else {
return storage <= rhs.storage;
}
}
/// Less than
CUTE_HOST_DEVICE constexpr
bool operator<(integer_subbyte const& rhs) const {
if (sign_mask_ & storage) {
return !(rhs.storage <= storage);
} else {
return storage < rhs.storage;
}
}
/// Greater than or equal
CUTE_HOST_DEVICE constexpr
bool operator>=(integer_subbyte const& rhs) const {
return !(*this < rhs);
}
/// Greater than
CUTE_HOST_DEVICE constexpr
bool operator>(integer_subbyte const& rhs) const {
return !(*this <= rhs);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// 1-bit unsigned integer type
using uint1b_t = integer_subbyte<1, false>;
/// 2-bit integer type
using int2b_t = integer_subbyte<2, true>;
/// 2-bit unsigned integer type
using uint2b_t = integer_subbyte<2, false>;
/// 4-bit integer type
using int4b_t = integer_subbyte<4, true>;
/// 4-bit unsigned integer type
using uint4b_t = integer_subbyte<4, false>;
/// 1-bit binary type
using bin1_t = bool;
} // namespace cute
///////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(__CUDACC_RTC__)
#include <limits>
namespace CUTE_STL_NAMESPACE {
template <>
struct numeric_limits<cute::uint1b_t> {
CUTE_HOST_DEVICE static constexpr
cute::uint1b_t const lowest() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint1b_t const min() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint1b_t const max() noexcept { return 1; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = false;
};
template <>
struct numeric_limits<cute::int2b_t> {
CUTE_HOST_DEVICE static constexpr
cute::int2b_t lowest() noexcept { return -2; }
CUTE_HOST_DEVICE static constexpr
cute::int2b_t min() noexcept { return -2; }
CUTE_HOST_DEVICE static constexpr
cute::int2b_t max() noexcept { return 1; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = true;
};
template <>
struct numeric_limits<cute::uint2b_t> {
CUTE_HOST_DEVICE static constexpr
cute::uint2b_t const lowest() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint2b_t const min() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint2b_t const max() noexcept { return 3; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = false;
};
template <>
struct numeric_limits<cute::int4b_t> {
CUTE_HOST_DEVICE static constexpr
cute::int4b_t lowest() noexcept { return -8; }
CUTE_HOST_DEVICE static constexpr
cute::int4b_t min() noexcept { return -8; }
CUTE_HOST_DEVICE static constexpr
cute::int4b_t max() noexcept { return 7; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = true;
};
template <>
struct numeric_limits<cute::uint4b_t> {
CUTE_HOST_DEVICE static constexpr
cute::uint4b_t const lowest() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint4b_t const min() noexcept { return 0; }
CUTE_HOST_DEVICE static constexpr
cute::uint4b_t const max() noexcept { return 15; }
static constexpr bool is_integer = true;
static constexpr bool is_signed = false;
};
} // namespace std
#endif // !defined(__CUDACC_RTC__)

View File

@ -443,4 +443,35 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, C<t> const&) {
}
#endif
namespace detail {
// parse_int_digits takes a variadic number of digits and converts them into an int
template <class... Ts>
constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits)
{
if constexpr (sizeof...(Ts) == 0) {
return 10 * result + digit;
} else {
return parse_int_digits(10 * result + digit, digits...);
}
}
} // end namespace detail
// This user-defined literal operator allows cute::constant written as literals. For example,
//
// auto var = 32_c;
//
// var has type cute::constant<int,32>.
//
template <char... digits>
constexpr cute::constant<int,detail::parse_int_digits(0, (digits - '0')...)> operator "" _c()
{
static_assert((('0' <= digits && digits <= '9') && ...),
"Expected 0 <= digit <= 9 for each digit of the integer.");
return {};
}
} // end namespace cute

View File

@ -130,6 +130,10 @@ nratio(R<a,b>, R<c,d>) {
return {};
}
//
// Operators
//
template <auto a, auto b, auto x, auto y>
CUTE_HOST_DEVICE constexpr
typename R<a*x,b*y>::type
@ -227,14 +231,14 @@ abs(R<a,b>) {
template <auto a, auto b>
CUTE_HOST_DEVICE constexpr
auto
int32_t
log_2(R<a,b>) {
static_assert(R<a,b>::num > 0);
static_assert(R<a,b>::den > 0);
return log_2(static_cast<uint32_t>(R<a,b>::num)) - log_2(static_cast<uint32_t>(R<a,b>::den));
}
// @return A non-reduced ratio cute::R of the Trait0::value / Trait1::value
template <class Trait0, class Trait1>
CUTE_HOST_DEVICE constexpr
auto

View File

@ -316,11 +316,11 @@ safe_div(T const& t, U const& u) {
template <class T>
CUTE_HOST_DEVICE constexpr
auto
int32_t
log_2(T x) {
assert(x > 0);
static_assert(is_unsigned<T>::value, "Only to be used for unsigned integral types.");
return bit_width(x) - 1;
return static_cast<int32_t>(bit_width(x)) - 1;
}
} // namespace cute

View File

@ -30,14 +30,46 @@
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <vector_types.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_size.h>
#include <cute/numeric/int.hpp>
#include <cute/numeric/real.hpp>
namespace cute {
template <typename T>
struct sizeof_bits : public cutlass::sizeof_bits<T> {};
// DO NOT change auto to int, sizeof_bits<sparse_elem> use integral_ratio instead of int
template <class T>
static constexpr auto sizeof_bits_v = sizeof_bits<T>::value;
using cutlass::bits_to_bytes;
using cutlass::is_subbyte;
template <class T>
static constexpr auto is_subbyte_v = is_subbyte<T>::value;
using cutlass::half_t;
using cutlass::bfloat16_t;
using cutlass::tfloat32_t;
// Umbrella floating-point 8-bit data type : type_erased_dynamic_float8_t
// This umbrella datatype can be enabled when a user provides a specific
// datatype in runtime argument list.
using cutlass::type_erased_dynamic_float8_t;
using cutlass::float_e4m3_t;
using cutlass::float_e5m2_t;
using cutlass::uint1b_t;
using cutlass::int2b_t;
using cutlass::uint2b_t;
using cutlass::int4b_t;
using cutlass::uint4b_t;
using cutlass::bin1_t;
} // end namespace cute

View File

@ -1,259 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#include <cstdlib>
#include <cmath>
#include <type_traits>
#include <stdexcept>
#endif
#include <cute/config.hpp>
/// Optionally enable GCC's built-in type
#if defined(__x86_64) && !defined(__CUDA_ARCH__)
# if defined(__GNUC__) && 0
# define CUTE_UINT128_NATIVE
# elif defined(_MSC_VER)
# define CUTE_INT128_ARITHMETIC
# include <intrin.h>
# endif
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cute {
/////////////////////////////////////////////////////////////////////////////////////////////////
///! Unsigned 128b integer type
struct alignas(16) uint128_t
{
/// Size of one part of the uint's storage in bits
static constexpr int storage_bits_ = 64;
struct hilo
{
uint64_t lo;
uint64_t hi;
};
// Use a union to store either low and high parts or, if present, a built-in 128b integer type.
union
{
struct hilo hilo_;
#if defined(CUTE_UINT128_NATIVE)
unsigned __int128 native;
#endif // defined(CUTE_UINT128_NATIVE)
};
//
// Methods
//
/// Default ctor
CUTE_HOST_DEVICE constexpr
uint128_t() : hilo_{0, 0} {}
/// Constructor from uint64
CUTE_HOST_DEVICE constexpr
uint128_t(uint64_t lo_) : hilo_{lo_, 0} {}
/// Constructor from two 64b unsigned integers
CUTE_HOST_DEVICE constexpr
uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {}
/// Optional constructor from native value
#if defined(CUTE_UINT128_NATIVE)
uint128_t(unsigned __int128 value) : native(value) { }
#endif
/// Lossily cast to uint64
CUTE_HOST_DEVICE constexpr
explicit operator uint64_t() const
{
return hilo_.lo;
}
template <class Dummy = bool>
CUTE_HOST_DEVICE constexpr
static void exception()
{
//static_assert(sizeof(Dummy) == 0, "Not implemented exception!");
//abort();
//printf("uint128 not implemented!\n");
}
/// Add
CUTE_HOST_DEVICE constexpr
uint128_t operator+(uint128_t const& rhs) const
{
uint128_t y;
#if defined(CUTE_UINT128_NATIVE)
y.native = native + rhs.native;
#else
y.hilo_.lo = hilo_.lo + rhs.hilo_.lo;
y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo));
#endif
return y;
}
/// Subtract
CUTE_HOST_DEVICE constexpr
uint128_t operator-(uint128_t const& rhs) const
{
uint128_t y;
#if defined(CUTE_UINT128_NATIVE)
y.native = native - rhs.native;
#else
y.hilo_.lo = hilo_.lo - rhs.hilo_.lo;
y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo);
#endif
return y;
}
/// Multiply by unsigned 64b integer yielding 128b integer
CUTE_HOST_DEVICE constexpr
uint128_t operator*(uint64_t const& rhs) const
{
uint128_t y;
#if defined(CUTE_UINT128_NATIVE)
y.native = native * rhs;
#elif defined(CUTE_INT128_ARITHMETIC)
// Multiply by the low part
y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi);
// Add the high part and ignore the overflow
uint64_t overflow;
y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow);
#else
exception();
#endif
return y;
}
/// Divide 128b operation by 64b operation yielding a 64b quotient
CUTE_HOST_DEVICE constexpr
uint64_t operator/(uint64_t const& divisor) const
{
uint64_t quotient = 0;
#if defined(CUTE_UINT128_NATIVE)
quotient = uint64_t(native / divisor);
#elif defined(CUTE_INT128_ARITHMETIC)
// implemented using MSVC's arithmetic intrinsics
uint64_t remainder = 0;
quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
#else
exception();
#endif
return quotient;
}
/// Divide 128b operation by 64b operation yielding a 64b quotient
CUTE_HOST_DEVICE constexpr
uint64_t operator%(uint64_t const& divisor) const
{
uint64_t remainder = 0;
#if defined(CUTE_UINT128_NATIVE)
remainder = uint64_t(native % divisor);
#elif defined(CUTE_INT128_ARITHMETIC)
// implemented using MSVC's arithmetic intrinsics
(void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
#else
exception();
#endif
return remainder;
}
/// Computes the quotient and remainder in a single method.
CUTE_HOST_DEVICE constexpr
uint64_t divmod(uint64_t &remainder, uint64_t divisor) const
{
uint64_t quotient = 0;
#if defined(CUTE_UINT128_NATIVE)
quotient = uint64_t(native / divisor);
remainder = uint64_t(native % divisor);
#elif defined(CUTE_INT128_ARITHMETIC)
// implemented using MSVC's arithmetic intrinsics
quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder);
#else
exception();
#endif
return quotient;
}
/// Left-shifts a 128b unsigned integer
CUTE_HOST_DEVICE constexpr
uint128_t operator<<(int sh) const
{
if (sh == 0) {
return *this;
}
else if (sh >= storage_bits_) {
return uint128_t(0, hilo_.lo << (sh - storage_bits_));
}
else {
return uint128_t(
(hilo_.lo << sh),
(hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh))
);
}
}
/// Right-shifts a 128b unsigned integer
CUTE_HOST_DEVICE constexpr
uint128_t operator>>(int sh) const
{
if (sh == 0) {
return *this;
}
else if (sh >= storage_bits_) {
return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0);
}
else {
return uint128_t(
(hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)),
(hilo_.hi >> sh)
);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -33,7 +33,7 @@
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/int.hpp> // sizeof_bits
#include <cute/numeric/numeric_types.hpp> // sizeof_bits
#include <cute/numeric/math.hpp>
#include <cute/numeric/integral_constant.hpp>
@ -58,7 +58,7 @@ CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void* ptr)
{
if constexpr (is_subbyte<NewT>::value) {
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT>(ptr);
} else {
return reinterpret_cast<NewT*>(ptr);
@ -71,7 +71,7 @@ CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void const* ptr)
{
if constexpr (is_subbyte<NewT>::value) {
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT const>(ptr);
} else {
return reinterpret_cast<NewT const*>(ptr);

View File

@ -33,7 +33,7 @@
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/int.hpp> // sizeof_bits
#include <cute/numeric/numeric_types.hpp> // sizeof_bits
namespace cute
{

View File

@ -109,7 +109,7 @@ as_position_independent_swizzle_tensor(Tensor&& tensor)
} else {
#if !defined(NDEBUG)
{
uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(std::forward<Tensor>(tensor).data()));
uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(static_cast<Tensor&&>(tensor).data()));
uint32_t mask = ((uint32_t(1) << SwizzleFn::num_base) - 1) | SwizzleFn::swizzle_code;
assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle
}
@ -118,7 +118,7 @@ as_position_independent_swizzle_tensor(Tensor&& tensor)
// Recast swizzle from acting on byte-addressed pointers to elements of type-T
auto new_swizzle = recast_layout<uint8_t, T>(SwizzleFn{});
// Strip off everything and create a new smem_ptr for type-T
auto new_ptr = make_smem_ptr<T>(raw_pointer_cast(std::forward<Tensor>(tensor).data()));
auto new_ptr = make_smem_ptr<T>(raw_pointer_cast(static_cast<Tensor&&>(tensor).data()));
return make_tensor(new_ptr, composition(new_swizzle, Int<0>{}, tensor.layout()));
}
CUTE_GCC_UNREACHABLE;

Some files were not shown because too many files have changed in this diff Show More