Skip to content

Commit 9468261

Browse files
committed
fix
1 parent 38f8cdf commit 9468261

5 files changed

Lines changed: 8 additions & 4 deletions

File tree

backends/gcu/common/gcu_funcs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
namespace custom_kernel {
2424
// using Tensor = phi::DenseTensor;
2525
// using Context = phi::CustomContext;
26+
// using DataType = phi::DataType;
27+
// using DataLayout = phi::DataLayout;
2628
/**
2729
* CPU -> GCU
2830
* GCU -> CPU

backends/gcu/custom_engine/custom_engine_interface.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ C_Status CustomEngineOpLower(C_CustomEngineLowerParams* lower_param) {
240240
pir::Block* op_block = &(region.front());
241241

242242
// process subgraph block
243-
paddle::dialect::ProcessBlock(
243+
pir::ProcessBlock(
244244
*place, sub_graph_block, op_block, ctx, map_op_pair, map_value_pair);
245245

246246
if (VLOG_IS_ON(3)) {

backends/gcu/tests/unittests/test_cross.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def prepare_data(self):
6969
def forward_with_dtype(self, dtype):
7070
x = paddle.to_tensor(self.data_x, dtype=dtype)
7171
y = paddle.to_tensor(self.data_y, dtype=dtype)
72+
if self.axis is None:
73+
return paddle.cross(x, y)
7274
return paddle.cross(x, y, self.axis)
7375

7476
def forward(self):

backends/gcu/tests/unittests/test_layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def test_check_output(
9999
self.with_weight = with_weight
100100
self.with_bias = with_bias
101101
self.weight_bias_shape = [np.prod(self.normalized_shape)]
102-
rtol = 1e-5
103-
atol = 1e-5
102+
rtol = 1e-4
103+
atol = 1e-4
104104
if dtype == np.float16:
105105
rtol = 1e-3
106106
atol = 1e-3

backends/gcu/tests/unittests/test_rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def forward(self):
8484
self.input_size,
8585
self.hidden_size,
8686
self.num_layers,
87-
self.direction,
87+
direction=self.direction,
8888
time_major=False,
8989
weight_ih_attr=weight_ih_attr,
9090
weight_hh_attr=weight_hh_attr,

0 commit comments

Comments
 (0)