@@ -98,28 +98,23 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
9898 if (t.has_data()) {
9999 constexpr std::array<size_t, 2> block_shape{1, 32};
100100 const std::array<size_t, 2> expected{
101- DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
102- DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])
103- };
104- NVTE_CHECK(t.scale_inv.shape.size() == 2
105- && t.scale_inv.shape[0] == expected[0]
106- && t.scale_inv.shape[1] == expected[1],
107- "Tensor \"", name,
108- "\" has invalid scale_inv shape (expected ", expected,
101+ DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
102+ DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])};
103+ NVTE_CHECK(t.scale_inv.shape.size() == 2 && t.scale_inv.shape[0] == expected[0] &&
104+ t.scale_inv.shape[1] == expected[1],
105+ "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected,
109106 ", got ", t.scale_inv.shape, ")");
110107 }
111108 if (t.has_columnwise_data()) {
112109 constexpr std::array<size_t, 2> block_shape{32, 1};
113110 const std::array<size_t, 2> expected{
114- DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[1]),
115- DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[0])
116- };
117- NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2
118- && t.columnwise_scale_inv.shape[0] == expected[0]
119- && t.columnwise_scale_inv.shape[1] == expected[1],
120- "Tensor \"", name,
121- "\" has invalid columnwise_scale_inv shape (expected ", expected,
122- ", got ", t.scale_inv.shape, ")");
111+ DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[1]),
112+ DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[0])};
113+ NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2 &&
114+ t.columnwise_scale_inv.shape[0] == expected[0] &&
115+ t.columnwise_scale_inv.shape[1] == expected[1],
116+ "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ",
117+ expected, ", got ", t.columnwise_scale_inv.shape, ")");
123118 }
124119 } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
125120 const auto [first_dim, last_dim] = t.flat_2d_dims();
@@ -128,29 +123,24 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
128123 constexpr std::array<size_t, 2> block_shape{1, 16};
129124 constexpr std::array<size_t, 2> block_alignment{128, 4};
130125 const std::array<size_t, 2> expected{
131- DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
132- DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])
133- };
134- NVTE_CHECK(t.scale_inv.shape.size() == 2
135- && t.scale_inv.shape[0] == expected[0]
136- && t.scale_inv.shape[1] == expected[1],
137- "Tensor \"", name,
138- "\" has invalid scale_inv shape (expected ", expected,
126+ DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[0]), block_alignment[0]),
127+ DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[1]), block_alignment[1])};
128+ NVTE_CHECK(t.scale_inv.shape.size() == 2 && t.scale_inv.shape[0] == expected[0] &&
129+ t.scale_inv.shape[1] == expected[1],
130+ "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected,
139131 ", got ", t.scale_inv.shape, ")");
140132 }
141133 if (t.has_columnwise_data()) {
142134 constexpr std::array<size_t, 2> block_shape{1, 16};
143135 constexpr std::array<size_t, 2> block_alignment{128, 4};
144136 const std::array<size_t, 2> expected{
145- DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[0]), block_alignment[0]),
146- DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[1]), block_alignment[1])
147- };
148- NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2
149- && t.columnwise_scale_inv.shape[0] == expected[0]
150- && t.columnwise_scale_inv.shape[1] == expected[1],
151- "Tensor \"", name,
152- "\" has invalid columnwise_scale_inv shape (expected ", expected,
153- ", got ", t.scale_inv.shape, ")");
137+ DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_shape[0]), block_alignment[0]),
138+ DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_shape[1]), block_alignment[1])};
139+ NVTE_CHECK(t.columnwise_scale_inv.shape.size() == 2 &&
140+ t.columnwise_scale_inv.shape[0] == expected[0] &&
141+ t.columnwise_scale_inv.shape[1] == expected[1],
142+ "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ",
143+ expected, ", got ", t.columnwise_scale_inv.shape, ")");
154144 }
155145 }
156146 }
0 commit comments