Skip to content

Commit

Permalink
Port a test from LFortran (#2782)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanay-man authored Jul 23, 2024
1 parent 542300f commit c5be7c7
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ RUN(NAME lambda_01 LABELS cpython llvm llvm_jit)
RUN(NAME c_mangling LABELS cpython llvm llvm_jit c)
RUN(NAME class_01 LABELS cpython llvm llvm_jit)
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
RUN(NAME class_03 LABELS cpython llvm llvm_jit)

# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
3 changes: 1 addition & 2 deletions integration_tests/class_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def __init__(self:"Character", name:str, health:i32, attack_power:i32):
self.attack_power : i32 = attack_power
self.is_immortal : bool = False

def attack(self:"Character", other:"Character") -> str:
def attack(self:"Character", other:"Character")->str:
other.health -= self.attack_power
return self.name+" attacks "+ other.name+" for "+str(self.attack_power)+" damage."

Expand Down Expand Up @@ -41,4 +41,3 @@ def main():
assert hero.is_alive() == True

main()

24 changes: 24 additions & 0 deletions integration_tests/class_03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from lpython import f64
from math import pi

class Circle:
def __init__(self:"Circle", radius:f64):
self.radius :f64 = radius

def circle_area(self:"Circle")->f64:
return pi * self.radius ** 2.0

def circle_print(self:"Circle"):
area : f64 = self.circle_area()
print("Circle: r = ",str(self.radius)," area = ",str(area))

def main():
c : Circle = Circle(1.0)
c.circle_print()
assert abs(c.circle_area() - 3.141593) <= 1e-6
c.radius = 1.5
c.circle_print()
assert abs(c.circle_area() - 7.068583) < 1e-6

if __name__ == "__main__":
main()
38 changes: 33 additions & 5 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1931,7 +1931,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
return ASRUtils::TYPE(ASR::make_Union_t(al, attr_annotation->base.base.loc, import_struct_member));
} else if ( AST::is_a<AST::ConstantStr_t>(annotation) ) {
AST::ConstantStr_t *n = AST::down_cast<AST::ConstantStr_t>(&annotation);
ASR::symbol_t *sym = current_scope->parent->parent->resolve_symbol(n->m_value);
ASR::symbol_t *sym = current_scope->resolve_symbol(n->m_value);
if ( sym == nullptr || !ASR::is_a<ASR::Struct_t>(*sym) ) {
throw SemanticError("Only Struct implemented for constant"
" str annotation", loc);
Expand Down Expand Up @@ -3300,6 +3300,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
if ( AST::is_a<AST::FunctionDef_t>(*x.m_body[i]) ) {
AST::FunctionDef_t*
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
init_self_type(*f, sym, x.base.base.loc);
if ( std::string(f->m_name) == std::string("__init__") ) {
this->visit_init_body(*f);
} else {
Expand Down Expand Up @@ -3348,6 +3349,30 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}
}

void init_self_type (const AST::FunctionDef_t &x,
ASR::symbol_t* class_sym, Location loc) {
SymbolTable* parent_scope = current_scope;
ASR::symbol_t *t = current_scope->get_symbol(x.m_name);
if (t==nullptr) {
throw SemanticError("Function not found in current symbol table",
x.base.base.loc);
}
if ( !ASR::is_a<ASR::Function_t>(*t) ) {
throw SemanticError("Only functions implemented in classes",
x.base.base.loc);
}
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
current_scope = f->m_symtab;
if ( f->n_args==0 ) {
return;
}
std::string self_name = x.m_args.m_args[0].m_arg;
ASR::symbol_t* sym = current_scope->get_symbol(self_name);
ASR::Variable_t* self_var = ASR::down_cast<ASR::Variable_t>(sym);
self_var->m_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc, class_sym));
current_scope = parent_scope;
}

virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;

void add_name(const Location &loc) {
Expand Down Expand Up @@ -5139,7 +5164,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
f->m_body = body.p;
f->n_body = body.size();
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(f->m_function_signature);
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
f->m_function_signature);
func_type->m_restrictions = rts.p;
func_type->n_restrictions = rts.size();
f->m_dependencies = dependencies.p;
Expand Down Expand Up @@ -8033,10 +8059,12 @@ we will have to use something else.
st = get_struct_member(st, call_name, loc);
} else if ( ASR::is_a<ASR::Variable_t>(*st)) {
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(st);
if (ASR::is_a<ASR::StructType_t>(*var->m_type)) {
if (ASR::is_a<ASR::StructType_t>(*var->m_type) ||
ASR::is_a<ASR::Class_t>(*var->m_type) ) {
//TODO: Correct Class and ClassType
// call to struct member function
// modifying args to pass the object as self
ASR::StructType_t* var_struct = ASR::down_cast<ASR::StructType_t>(var->m_type);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
ASR::call_arg_t self_arg;
self_arg.loc = args[0].loc;
Expand All @@ -8045,7 +8073,7 @@ we will have to use something else.
for (size_t i=0; i<args.n; i++) {
new_args.push_back(al, args[i]);
}
st = get_struct_member(var_struct->m_derived_type, call_name, loc);
st = get_struct_member(der, call_name, loc);
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
return;
} else {
Expand Down

0 comments on commit c5be7c7

Please sign in to comment.