Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yihonglyu committed Nov 15, 2024
1 parent a23d2cc commit def160e
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
#include "test/util/include/asserts.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test_utils.h"

Expand Down Expand Up @@ -3800,6 +3801,50 @@ TEST(TransposeOptimizerTests, TestCast) {
/*opset_version*/ {15, 18});
}

//#ifndef DISABLE_CONTRIB_OPS

TEST(TransposeOptimizerTests, TestQLinearSoftmax) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<uint8_t>(builder, std::nullopt, {1, 384, 384, 21}, 0, 255);
auto* transpose_1_out_0 = builder.MakeIntermediate();
auto* input_x_scale = builder.MakeScalarInitializer<float>(0.5086354613304138);
auto* input_x_zero_point = builder.MakeScalarInitializer<uint8_t>(74);
auto* input_y_scale = builder.MakeScalarInitializer<float>(0.003921568859368563);
auto* input_y_zero_point = builder.MakeScalarInitializer<uint8_t>(0);
auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate();
auto* transpose_2_out_0 = builder.MakeOutput();

auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0});
transpose_1.AddAttribute("perm", std::vector<int64_t>{0, 3, 1, 2});
auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax",
{transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point},
{qlinearsoftmax_1_out_0}, kMSDomain);
qlinearsoftmax_1.AddAttribute("axis", static_cast<int64_t>(1));
qlinearsoftmax_1.AddAttribute("opset", static_cast<int64_t>(13));
auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0});
transpose_2.AddAttribute("perm", std::vector<int64_t>{0, 2, 3, 1});
};

auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) {
int transpose_cost = EstimateTransposeCost(session.GetGraph());
EXPECT_EQ(transpose_cost, 0);
};

TransformerTester(build_test_case_1,
check_optimized_graph_1,
TransformerLevel::Level2,
TransformerLevel::Level3,
/*opset_version*/ 13,
/*per_sample_tolerance*/ 0.0,
/*relative_per_sample_tolerance*/ 0.0,
/*transformer*/ nullptr,
/*add_session_options*/ {},
/*disabled_optimizers*/ {},
/*ep*/ DefaultCpuExecutionProvider());
}

//#endif // DISABLE_CONTRIB_OPS

TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<float>(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0);
Expand Down

0 comments on commit def160e

Please sign in to comment.