diff --git a/lib/source/pl/core/ast/ast_node_type_application.cpp b/lib/source/pl/core/ast/ast_node_type_application.cpp index a2215d30..b185ff90 100644 --- a/lib/source/pl/core/ast/ast_node_type_application.cpp +++ b/lib/source/pl/core/ast/ast_node_type_application.cpp @@ -127,9 +127,20 @@ namespace pl::core::ast { ast::ASTNode* type = this->getType().get(); if (auto typDecl = dynamic_cast(type); typDecl != nullptr) { + std::vector> templateArgs(this->m_templateArguments.size()); + for (size_t i = 0; i < this->m_templateArguments.size(); i++) { + auto &templateArgument = this->m_templateArguments[i]; + if (auto typeApp = dynamic_cast(templateArgument.get()); typeApp != nullptr) { + templateArgs[i] = templateArgument->evaluate(evaluator); + } + } + + evaluator->setCurrentTemplateArguments(std::move(templateArgs)); return typDecl->getTypeDefinition(evaluator); } else if(auto builtinType = dynamic_cast(type); builtinType != nullptr) { return builtinType; + } else if(auto typeApp = dynamic_cast(type); typeApp != nullptr) { + return typeApp->getTypeDefinition(evaluator); } return nullptr; diff --git a/lib/source/pl/core/ast/ast_node_type_decl.cpp b/lib/source/pl/core/ast/ast_node_type_decl.cpp index e037eedd..68282a5a 100644 --- a/lib/source/pl/core/ast/ast_node_type_decl.cpp +++ b/lib/source/pl/core/ast/ast_node_type_decl.cpp @@ -146,10 +146,31 @@ namespace pl::core::ast { const ASTNode* ASTNodeTypeDecl::getTypeDefinition(Evaluator *evaluator) const { if (m_type == nullptr) err::E0004.throwError(fmt::format("Cannot use incomplete type '{}' before it has been defined.", this->m_name), "Try defining this type further up in your code before trying to instantiate it.", this->getLocation()); - else if (auto typeApp = dynamic_cast(m_type.get()); typeApp != nullptr) + else if (auto typeApp = dynamic_cast(m_type.get()); typeApp != nullptr) { + evaluator->pushTypeTemplateParameters(); + ON_SCOPE_EXIT { + evaluator->popTypeTemplateParameters(); + }; + + auto& templateArguments = evaluator->getCurrentTemplateArguments(); + std::vector> templatePatterns; + for (size_t i = 0; i < this->m_templateParameters.size(); i++) { + auto &templateParameter = this->m_templateParameters[i]; + + if (i >= templateArguments.size()) { + break; + } + + if (templateParameter->isType()) { + auto& argument = templateArguments[i]; + evaluator->getTypeTemplateParameters().emplace_back(std::move(argument)); + } + } + return typeApp->getTypeDefinition(evaluator); - else + } else { return m_type.get(); + } } [[nodiscard]] const std::string ASTNodeTypeDecl::getTypeName() const { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4e64663b..7ba8d8d6 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -35,6 +35,7 @@ set(AVAILABLE_TESTS TemplateParametersScope TypeNameOf CustomBuiltInType + Using ) diff --git a/tests/include/test_patterns/test_pattern_using.hpp b/tests/include/test_patterns/test_pattern_using.hpp new file mode 100644 index 00000000..f85458e1 --- /dev/null +++ b/tests/include/test_patterns/test_pattern_using.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "test_pattern.hpp" + +namespace pl::test { + + class TestPatternUsing : public TestPattern { + public: + TestPatternUsing(core::Evaluator *evaluator) : TestPattern(evaluator, "Using") { + } + ~TestPatternUsing() override = default; + + [[nodiscard]] std::string getSourceCode() const override { + return R"( + using US = T; + US u; + US v = 64; + u = 64; + std::assert(u == 64 && v == 64, "u,v should be 64"); + + using UST = US; + UST ust; + UST vst = 16; + ust = 16; + std::assert(ust == 16 && vst == 16, "ust,vst should be 16"); + + struct USS { + US us = 16; + }; + USS uss; + std::assert(uss.us == 16, "us should be 16"); + + USS> ussus; + std::assert(ussus.us == 16, "ussus should be 16"); + + US us2[2]; + US us3[2] @ 0; + std::assert(us3[0] == 137, "us3[0] should be 137"); + + namespace A { + using US = T; + } + A::US us4; + A::US us5 @ 0; + us4 = us5; + std::assert(us5 == 137 && us4 == 137, "us4, us5 should be 137"); + )"; + } + }; + +} \ No newline at end of file diff --git a/tests/source/tests.cpp b/tests/source/tests.cpp index 36fb4e42..c52e7e46 100644 --- a/tests/source/tests.cpp +++ b/tests/source/tests.cpp @@ -30,6 +30,7 @@ #include "test_patterns/test_pattern_template_parameters_scope.hpp" #include "test_patterns/test_pattern_typenameof.hpp" #include "test_patterns/test_pattern_custom_builtin_type.hpp" +#include "test_patterns/test_pattern_using.hpp" static pl::core::Evaluator s_evaluator; @@ -65,4 +66,5 @@ std::array Tests = { TEST(TemplateParametersScope), TEST(TypeNameOf), TEST(CustomBuiltinType), + TEST(Using), };