diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 612fc32cc3..f397765d18 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -835,6 +835,7 @@ RUN(NAME lambda_01 LABELS cpython llvm llvm_jit) RUN(NAME c_mangling LABELS cpython llvm llvm_jit c) RUN(NAME class_01 LABELS cpython llvm llvm_jit) RUN(NAME class_02 LABELS cpython llvm llvm_jit) +RUN(NAME class_03 LABELS cpython llvm llvm_jit) # callback_04 is to test emulation. So just run with cpython RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython) diff --git a/integration_tests/class_02.py b/integration_tests/class_02.py index 11325b1c06..94d92a9ec6 100644 --- a/integration_tests/class_02.py +++ b/integration_tests/class_02.py @@ -6,7 +6,7 @@ def __init__(self:"Character", name:str, health:i32, attack_power:i32): self.attack_power : i32 = attack_power self.is_immortal : bool = False - def attack(self:"Character", other:"Character") -> str: + def attack(self:"Character", other:"Character")->str: other.health -= self.attack_power return self.name+" attacks "+ other.name+" for "+str(self.attack_power)+" damage." @@ -41,4 +41,3 @@ def main(): assert hero.is_alive() == True main() - \ No newline at end of file diff --git a/integration_tests/class_03.py b/integration_tests/class_03.py new file mode 100644 index 0000000000..8e4d9eded6 --- /dev/null +++ b/integration_tests/class_03.py @@ -0,0 +1,24 @@ +from lpython import f64 +from math import pi + +class Circle: + def __init__(self:"Circle", radius:f64): + self.radius :f64 = radius + + def circle_area(self:"Circle")->f64: + return pi * self.radius ** 2.0 + + def circle_print(self:"Circle"): + area : f64 = self.circle_area() + print("Circle: r = ",str(self.radius)," area = ",str(area)) + +def main(): + c : Circle = Circle(1.0) + c.circle_print() + assert abs(c.circle_area() - 3.141593) <= 1e-6 + c.radius = 1.5 + c.circle_print() + assert abs(c.circle_area() - 7.068583) < 1e-6 + +if __name__ == "__main__": + main() diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 97a9b38fbd..4715cf8857 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1931,7 +1931,7 @@ class CommonVisitor : public AST::BaseVisitor { return ASRUtils::TYPE(ASR::make_Union_t(al, attr_annotation->base.base.loc, import_struct_member)); } else if ( AST::is_a(annotation) ) { AST::ConstantStr_t *n = AST::down_cast(&annotation); - ASR::symbol_t *sym = current_scope->parent->parent->resolve_symbol(n->m_value); + ASR::symbol_t *sym = current_scope->resolve_symbol(n->m_value); if ( sym == nullptr || !ASR::is_a(*sym) ) { throw SemanticError("Only Struct implemented for constant" " str annotation", loc); @@ -3300,6 +3300,7 @@ class CommonVisitor : public AST::BaseVisitor { if ( AST::is_a(*x.m_body[i]) ) { AST::FunctionDef_t* f = AST::down_cast(x.m_body[i]); + init_self_type(*f, sym, x.base.base.loc); if ( std::string(f->m_name) == std::string("__init__") ) { this->visit_init_body(*f); } else { @@ -3348,6 +3349,30 @@ class CommonVisitor : public AST::BaseVisitor { } } + void init_self_type (const AST::FunctionDef_t &x, + ASR::symbol_t* class_sym, Location loc) { + SymbolTable* parent_scope = current_scope; + ASR::symbol_t *t = current_scope->get_symbol(x.m_name); + if (t==nullptr) { + throw SemanticError("Function not found in current symbol table", + x.base.base.loc); + } + if ( !ASR::is_a(*t) ) { + throw SemanticError("Only functions implemented in classes", + x.base.base.loc); + } + ASR::Function_t *f = ASR::down_cast(t); + current_scope = f->m_symtab; + if ( f->n_args==0 ) { + return; + } + std::string self_name = x.m_args.m_args[0].m_arg; + ASR::symbol_t* sym = current_scope->get_symbol(self_name); + ASR::Variable_t* self_var = ASR::down_cast(sym); + self_var->m_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc, class_sym)); + current_scope = parent_scope; + } + virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0; void add_name(const Location &loc) { @@ -5139,7 +5164,8 @@ class BodyVisitor : public CommonVisitor { for (const auto &rt: rt_vec) { rts.push_back(al, rt); } f->m_body = body.p; f->n_body = body.size(); - ASR::FunctionType_t* func_type = ASR::down_cast(f->m_function_signature); + ASR::FunctionType_t* func_type = ASR::down_cast( + f->m_function_signature); func_type->m_restrictions = rts.p; func_type->n_restrictions = rts.size(); f->m_dependencies = dependencies.p; @@ -8033,10 +8059,12 @@ we will have to use something else. st = get_struct_member(st, call_name, loc); } else if ( ASR::is_a(*st)) { ASR::Variable_t* var = ASR::down_cast(st); - if (ASR::is_a(*var->m_type)) { + if (ASR::is_a(*var->m_type) || + ASR::is_a(*var->m_type) ) { + //TODO: Correct Class and ClassType // call to struct member function // modifying args to pass the object as self - ASR::StructType_t* var_struct = ASR::down_cast(var->m_type); + ASR::symbol_t* der = ASR::down_cast(var->m_type)->m_derived_type; Vec new_args; new_args.reserve(al, args.n + 1); ASR::call_arg_t self_arg; self_arg.loc = args[0].loc; @@ -8045,7 +8073,7 @@ we will have to use something else. for (size_t i=0; im_derived_type, call_name, loc); + st = get_struct_member(der, call_name, loc); tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc); return; } else {