-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathctc_decode_operation.cpp
More file actions
121 lines (104 loc) · 5.25 KB
/
ctc_decode_operation.cpp
File metadata and controls
121 lines (104 loc) · 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Copyright (c) 2019 Shahrzad Shirzad
// Copyright (c) 2019 Hartmut Kaiser
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
#include <phylanx/phylanx.hpp>
#include <hpx/hpx_main.hpp>
#include <hpx/util/lightweight_test.hpp>
#include <cstdint>
#include <string>
#include <utility>
#if defined(PHYLANX_HAVE_BLAZE_TENSOR)
///////////////////////////////////////////////////////////////////////////////
void test_ctc_decode_operation_1()
{
blaze::DynamicTensor<double> arg1{
{{0.2, 0.2, 0.6}, {0.4, 0.3, 0.3}}, {{0.7, 0.15, 0.15}, {0., 0., 0.}}};
blaze::DynamicVector<std::int64_t> arg2{2, 1};
phylanx::execution_tree::primitive y_pred =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<double>(arg1));
phylanx::execution_tree::primitive input_length =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::int64_t>(arg2));
phylanx::execution_tree::primitive greedy =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::uint8_t>(1));
phylanx::execution_tree::primitive beam_width =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::int64_t>(10));
phylanx::execution_tree::primitive top_paths =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::int64_t>(10));
phylanx::execution_tree::primitive ctc_decode =
phylanx::execution_tree::primitives::create_ctc_decode_operation(
hpx::find_here(),
phylanx::execution_tree::primitive_arguments_type{std::move(y_pred),
std::move(input_length), std::move(greedy),
std::move(beam_width), std::move(top_paths)});
hpx::future<phylanx::execution_tree::primitive_argument_type> f =
ctc_decode.eval();
auto result = phylanx::execution_tree::extract_list_value(f.get());
auto it = result.begin();
auto first = phylanx::execution_tree::extract_list_value(*it);
blaze::DynamicMatrix<double> expected_decoded_dense{{0.}, {0.}};
blaze::DynamicMatrix<double> expected_log_prob{{1.42711636}, {0.35667494}};
HPX_TEST_EQ(
phylanx::ir::node_data<double>(std::move(expected_decoded_dense)),
phylanx::execution_tree::extract_numeric_value(*first.begin()));
HPX_TEST(
allclose(phylanx::ir::node_data<double>(std::move(expected_log_prob)),
phylanx::execution_tree::extract_numeric_value(*++it)));
}
void test_ctc_decode_operation_2()
{
blaze::DynamicTensor<double> arg1{
{{1., 0., 0., 0.}, {0., 0., 0.4, 0.6}, {0., 0., 0.4, 0.6},
{0., 0.9, 0.1, 0.}, {0., 0., 0., 0.}, {0., 0., 0., 0.}},
{{0.1, 0.9, 0., 0.}, {0., 0.9, 0.1, 0.}, {0., 0., 0.1, 0.9},
{0., 0.9, 0.1, 0.1}, {0.9, 0.1, 0., 0.}, {0., 0., 0., 0.}}};
blaze::DynamicVector<std::int64_t> arg2{4, 5};
phylanx::execution_tree::primitive y_pred =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<double>(arg1));
phylanx::execution_tree::primitive input_length =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::int64_t>(arg2));
phylanx::execution_tree::primitive greedy =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::uint8_t>(1));
phylanx::execution_tree::primitive beam_width =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::int64_t>(10));
phylanx::execution_tree::primitive top_paths =
phylanx::execution_tree::primitives::create_variable(
hpx::find_here(), phylanx::ir::node_data<std::int64_t>(10));
phylanx::execution_tree::primitive ctc_decode =
phylanx::execution_tree::primitives::create_ctc_decode_operation(
hpx::find_here(),
phylanx::execution_tree::primitive_arguments_type{std::move(y_pred),
std::move(input_length), std::move(greedy),
std::move(beam_width), std::move(top_paths)});
hpx::future<phylanx::execution_tree::primitive_argument_type> f =
ctc_decode.eval();
auto result = phylanx::execution_tree::extract_list_value(f.get());
auto it = result.begin();
auto first = phylanx::execution_tree::extract_list_value(*it);
blaze::DynamicMatrix<double> expected_decoded_dense{
{0., 1., -1.}, {1., 1., 0.}};
blaze::DynamicMatrix<double> expected_log_prob{{1.12701166}, {0.52680272}};
HPX_TEST_EQ(
phylanx::ir::node_data<double>(std::move(expected_decoded_dense)),
phylanx::execution_tree::extract_numeric_value(*first.begin()));
HPX_TEST(
allclose(phylanx::ir::node_data<double>(std::move(expected_log_prob)),
phylanx::execution_tree::extract_numeric_value(*++it)));
}
int main(int argc, char* argv[])
{
test_ctc_decode_operation_1();
test_ctc_decode_operation_2();
return hpx::util::report_errors();
}
#endif