CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_ref.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/coord.h"
35 
36 namespace cutlass {
37 
39 
44 template <int Rank>
46 public:
48  static int const kRank = Rank;
49 
51  static int const kStrideRank = Rank;
52 
54  using Index = int32_t;
55 
57  using LongIndex = int64_t;
58 
61 
64 
65 private:
66 
67  //
68  // Data members
69  //
70 
72  Stride stride_;
73 
74 public:
75 
76  //
77  // Methods
78  //
79 
81  IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { }
82 
85  LongIndex operator()(Coord<Rank> const &coord) const {
86  return coord.dot(stride_);
87  }
88 
91  Stride stride() const {
92  return stride_;
93  }
94 
97  Stride & stride() {
98  return stride_;
99  }
100 
103  LongIndex capacity(TensorCoord const &size) const {
104  int idx = stride_.max_dim_index();
105  return stride_[idx] * size[idx];
106  }
107 };
108 
110 
111 /* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
112  and layout within memory. A TensorRef combines a pointer and a Layout concept
113 
114  Examples:
115 
116  (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h)
117 
118  1. Column-major matrix may be represented as a rank=2 tensor:
119 
120  TensorRef<float, layout::ColumnMajor> A(ptr_A, ldm);
121 
122  2. Row-major matrix may be represented as a rank=2 tensor:
123 
124  TensorRef<float, layout::RowMajor> B(ptr_A, ldm);
125 
126  3. An interleaved matrix may be represented as a rank=2 tensor:
127 
128  TensorRef<int8_t, layout::ColumnMajorInterleaved<32> > C;
129 
130  4. A helper exists to define a TensorRef for a contiguous matrix whose layout
131  is not known at compile time.
132 
133  int ldm; // leading dimension
134  layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor
135 
136 
137  TensorRef<int, layout::ContiguousMatrix> E(ptr_E, {ldm, kind});
138 
139 */
140 template <
142  typename Element_,
144  typename Layout_
145 >
146 class TensorRef {
147  public:
149  using Element = Element_;
150 
152  using Layout = Layout_;
153 
155  using Reference = typename platform::conditional<
157  Element &,
159  >::type;
160 
162  static int const kRank = Layout::kRank;
163 
165  using Index = typename Layout::Index;
166 
168  using LongIndex = typename Layout::LongIndex;
169 
171  using TensorCoord = typename Layout::TensorCoord;
172 
174  using Stride = typename Layout::Stride;
175 
177  using ConstTensorRef = TensorRef<
180 
183  typename platform::remove_const<Element>::type,
185 
189  static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
190 
191  private:
192 
194  Element* ptr_;
195 
197  Layout layout_;
198 
199  public:
200 
201  //
202  // Methods
203  //
204 
208  Element *ptr = nullptr,
209  Layout const &layout = Layout()
210  ):
211  ptr_(ptr), layout_(layout) {
212 
213  }
214 
218  NonConstTensorRef const &ref
219  ):
220  ptr_(ref.data()), layout_(ref.layout()) { }
221 
225  return ConstTensorRef(ptr_, layout_);
226  }
227 
230  return NonConstTensorRef(const_cast<typename platform::remove_const<Element>::type *>(ptr_), layout_);
231  }
232 
235  void reset(Element* ptr = nullptr) {
236  ptr_ = ptr;
237  }
238 
241  void reset(Element* ptr, Layout const &layout) {
242  ptr_ = ptr;
243  layout_ = layout;
244  }
245 
248  bool good() const {
249  return ptr_ != nullptr;
250  }
251 
254  Element * data() const { return ptr_; }
255 
258  Reference data(LongIndex idx) const {
260  (sizeof_bits<Element>::value < 8)>::get(ptr_, idx);
261  }
262 
266  return layout_;
267  }
268 
271  Layout layout() const {
272  return layout_;
273  }
274 
277  Stride stride() const {
278  return layout_.stride();
279  }
280 
284  return layout_.stride();
285  }
286 
289  Index stride(int dim) const {
290  return layout_.stride().at(dim);
291  }
292 
295  Index & stride(int dim) {
296  return layout_.stride().at(dim);
297  }
298 
301  LongIndex offset(TensorCoord const& coord) const {
302  return layout_(coord);
303  }
304 
307  Reference at(TensorCoord const& coord) const {
308  return data(offset(coord));
309  }
310 
313  Reference operator[](TensorCoord const& coord) const {
314  return data(offset(coord));
315  }
316 
320  ptr_ += offset_;
321  return *this;
322  }
323 
327  add_pointer_offset(offset(coord));
328  return *this;
329  }
330 
333  TensorRef operator+(TensorCoord const& b) const {
334  TensorRef result(*this);
335  result.add_coord_offset(b);
336  return result;
337  }
338 
342  add_coord_offset(b);
343  return *this;
344  }
345 
348  TensorRef operator-(TensorCoord const& b) const {
349  TensorRef result(*this);
350  result.add_pointer_offset(-offset(b));
351  return result;
352  }
353 
357  add_pointer_offset(-offset(b));
358  return *this;
359  }
360 };
361 
363 template <
364  typename Element,
365  typename Layout
366 >
368 TensorRef<Element, Layout> make_TensorRef(Element *ptr, Layout const &layout) {
369  return TensorRef<Element, Layout>(ptr, layout);
370 }
371 
373 //
374 // Partial specializations to handle degenerate and sub-byte cases.
375 //
377 
378 template <
379  typename Element,
380  typename Layout
381 >
382 bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment) {
383 
384  int const kStrideRank = Layout::kStrideRank;
385 
386  if (reinterpret_cast<uintptr_t>(ref.data()) % alignment) {
387  return false;
388  }
389 
391  for (int i = 0; i < kStrideRank; ++i) {
392  if (ref.stride(i) % alignment) {
393  return false;
394  }
395  }
396 
397  return true;
398 }
399 
401 
402 } // namespace cutlass
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: tensor_ref.h:91
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE LongIndex operator()(Coord< Rank > const &coord) const
Returns the offset of a coordinate in linear memory.
Definition: tensor_ref.h:85
CUTLASS_HOST_DEVICE Index stride(int dim) const
Returns the layout object&#39;s stride in a given physical dimension.
Definition: tensor_ref.h:289
typename Layout::Stride Stride
Layout&#39;s stride vector.
Definition: tensor_ref.h:174
CUTLASS_HOST_DEVICE Index & stride(int dim)
Returns the layout object&#39;s stride in a given physical dimension.
Definition: tensor_ref.h:295
CUTLASS_HOST_DEVICE Stride & stride()
Returns the layout object&#39;s stride vector.
Definition: tensor_ref.h:283
T type
Definition: platform.h:351
Definition: tensor_ref.h:45
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
Coord< kStrideRank, Index > Stride
Stride vector.
Definition: tensor_ref.h:63
CUTLASS_HOST_DEVICE ConstTensorRef const_ref() const
Returns a reference to constant-valued tensor.
Definition: tensor_ref.h:224
int32_t Index
Index type used for coordinates.
Definition: tensor_ref.h:54
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Layout Layout
Mapping function from logical coordinate to linear memory.
Definition: tensor_ref.h:152
CUTLASS_HOST_DEVICE Stride & stride()
Returns the stride of the layout.
Definition: tensor_ref.h:97
CUTLASS_HOST_DEVICE void reset(Element *ptr, Layout const &layout)
Updates the pointer and layout object.
Definition: tensor_ref.h:241
CUTLASS_HOST_DEVICE TensorRef & operator-=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_ref.h:356
CUTLASS_HOST_DEVICE Reference operator[](TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:313
C++ features that may be otherwise unimplemented for CUDA device functions.
CUTLASS_HOST_DEVICE TensorRef(Element *ptr=nullptr, Layout const &layout=Layout())
Constructs a TensorRef with a pointer and layout object.
Definition: tensor_ref.h:207
CUTLASS_HOST_DEVICE Layout layout() const
Returns the layout object.
Definition: tensor_ref.h:271
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Element Element
Data type of individual access.
Definition: tensor_ref.h:149
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE Reference data(LongIndex idx) const
Returns a reference to the element at a given linear index.
Definition: tensor_ref.h:258
CUTLASS_HOST_DEVICE TensorRef operator-(TensorCoord const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_ref.h:348
CUTLASS_HOST_DEVICE TensorRef(NonConstTensorRef const &ref)
Converting constructor from TensorRef to non-constant data.
Definition: tensor_ref.h:217
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object&#39;s stride vector.
Definition: tensor_ref.h:277
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref.h:171
CUTLASS_HOST_DEVICE bool good() const
Returns true if the TensorRef is non-null.
Definition: tensor_ref.h:248
Defines the size of an element in bits.
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer.
Definition: tensor_ref.h:235
Definition: subbyte_reference.h:557
Definition: tensor_ref.h:146
typename platform::conditional< sizeof_bits< Element >::value >=8, Element &, SubbyteReference< Element > >::type Reference
Reference type to an element.
Definition: tensor_ref.h:159
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
std::conditional (true specialization)
Definition: platform.h:325
#define static_assert(__e, __m)
Definition: platform.h:153
static int const kRank
Logical rank of tensor.
Definition: tensor_ref.h:48
CUTLASS_HOST_DEVICE NonConstTensorRef non_const_ref() const
Definition: tensor_ref.h:229
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:43
int64_t LongIndex
Long index type used for offsets.
Definition: tensor_ref.h:57
CUTLASS_HOST_DEVICE TensorRef< Element, Layout > make_TensorRef(Element *ptr, Layout const &layout)
Constructs a TensorRef, deducing types from arguments.
Definition: tensor_ref.h:368
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
CUTLASS_HOST_DEVICE IdentityTensorLayout(Stride const &stride=Stride())
Definition: tensor_ref.h:81
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
Definition: subbyte_reference.h:294
CUTLASS_HOST_DEVICE Layout & layout()
Returns the layout object.
Definition: tensor_ref.h:265
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
CUTLASS_HOST_DEVICE TensorRef operator+(TensorCoord const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_ref.h:333
CUTLASS_HOST_DEVICE TensorRef & operator+=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_ref.h:341
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Provides a mechanism for packing and unpacking elements smaller than one byte.
static int const kStrideRank
Rank of stride vector.
Definition: tensor_ref.h:51
CUTLASS_HOST_DEVICE LongIndex dot(Coord const &b, LongIndex sum=LongIndex(0)) const
Computes the dot product with anotherCoord object.
Definition: coord.h:246
CUTLASS_HOST_DEVICE int max_dim_index() const
Returns the index of the dimension with greatest value.
Definition: coord.h:130
Basic include for CUTLASS.
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168
CUTLASS_HOST_DEVICE LongIndex capacity(TensorCoord const &size) const
Compute the number of contiguous elements needed to store a tensor with the given size...
Definition: tensor_ref.h:103