Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of Inheritance and Polymorphic functions #2801

Merged
merged 2 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,9 @@ RUN(NAME class_01 LABELS cpython llvm llvm_jit)
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
RUN(NAME class_03 LABELS cpython llvm llvm_jit)
RUN(NAME class_04 LABELS cpython llvm llvm_jit)
RUN(NAME class_05 LABELS cpython llvm llvm_jit)
RUN(NAME class_06 LABELS cpython llvm llvm_jit)


# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
37 changes: 37 additions & 0 deletions integration_tests/class_05.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from lpython import i32

class Animal:
def __init__(self:"Animal"):
self.species: str = "Generic Animal"
self.age: i32 = 0
self.is_domestic: bool = True

class Dog(Animal):
def __init__(self:"Dog", name:str, age:i32):
super().__init__()
self.species: str = "Dog"
self.name: str = name
self.age: i32 = age

class Cat(Animal):
def __init__(self:"Cat", name: str, age: i32):
super().__init__()
self.species: str = "Cat"
self.name:str = name
self.age: i32 = age

def main():
dog: Dog = Dog("Buddy", 5)
cat: Cat = Cat("Whiskers", 3)
op1: str = str(dog.name+" is a "+str(dog.age)+"-year-old "+dog.species+".")
print(op1)
assert op1 == "Buddy is a 5-year-old Dog."
print(dog.is_domestic)
assert dog.is_domestic == True
op2: str = str(cat.name+ " is a "+ str(cat.age)+ "-year-old "+ cat.species+ ".")
print(op2)
assert op2 == "Whiskers is a 3-year-old Cat."
print(cat.is_domestic)
assert cat.is_domestic == True

main()
36 changes: 36 additions & 0 deletions integration_tests/class_06.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from lpython import i32

class Base():
def __init__(self:"Base"):
self.x : i32 = 10

def get_x(self:"Base")->i32:
print(self.x)
return self.x

#Testing polymorphic fn calls
def get_x_static(d: Base)->i32:
print(d.x)
return d.x

class Derived(Base):
def __init__(self: "Derived"):
super().__init__()
self.y : i32 = 20

def get_y(self:"Derived")->i32:
print(self.y)
return self.y


def main():
d : Derived = Derived()
x : i32 = get_x_static(d)
assert x == 10
# Testing parent method call using der obj
x = d.get_x()
assert x == 10
y: i32 = d.get_y()
assert y == 20

main()
2 changes: 1 addition & 1 deletion src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ ttype
| Array(ttype type, dimension* dims, array_physical_type physical_type)
| FunctionType(ttype* arg_types, ttype? return_var_type, abi abi, deftype deftype, string? bindc_name, bool elemental, bool pure, bool module, bool inline, bool static, symbol* restrictions, bool is_restriction)

cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray
cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray | DerivedToBase
storage_type = Default | Save | Parameter
access = Public | Private
intent = Local | In | Out | InOut | ReturnVar | Unspecified
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/casting_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ namespace LCompilers::CastingUtil {
{ASR::ttypeType::Complex, ASR::cast_kindType::ComplexToComplex},
{ASR::ttypeType::Real, ASR::cast_kindType::RealToReal},
{ASR::ttypeType::Integer, ASR::cast_kindType::IntegerToInteger},
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger}
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger},
{ASR::ttypeType::StructType, ASR::cast_kindType::DerivedToBase}
};

int get_type_priority(ASR::ttypeType type) {
Expand Down
5 changes: 5 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7725,6 +7725,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = LLVM::CreateLoad(*builder, list_api->get_pointer_to_list_data(tmp));
break;
}
case (ASR::cast_kindType::DerivedToBase) : {
this->visit_expr(*x.m_arg);
tmp = llvm_utils->create_gep(tmp, 0);
break;
}
default : throw CodeGenError("Cast kind not implemented");
}
}
Expand Down
151 changes: 121 additions & 30 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::call_arg_t c_arg;
c_arg.loc = args[i].loc;
c_arg.m_value = args[i].m_value;
cast_helper(m_args[i], c_arg.m_value, true);
ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]);
ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value);
if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(
l_type->m_derived_type));
ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(
r_type->m_derived_type));
if ( ASRUtils::is_derived_type_similar(l2_type, r2_type) ) {
cast_helper(m_args[i], c_arg.m_value, true, true);
check_type_equality = false;
} else {
cast_helper(m_args[i], c_arg.m_value, true);
}
} else {
cast_helper(m_args[i], c_arg.m_value, true);
}
if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) {
std::string ltype = ASRUtils::type_to_str_python(left_type);
std::string rtype = ASRUtils::type_to_str_python(right_type);
Expand Down Expand Up @@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
std::string obj_name = x.m_args.m_args->m_arg;
for(size_t i = 0; i < x.n_body; i++) {
std::string var_name;
if (! AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
throw SemanticError("Only AnnAssign implemented in __init__ ",
x.m_body[i]->base.loc);
if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
continue;
}
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if(AST::is_a<AST::Attribute_t>(*ann_assign.m_target)){
Expand Down Expand Up @@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope->add_symbol(x_m_name, class_type);
}
} else {
if( x.n_bases > 0 ) {
throw SemanticError("Inheritance in classes isn't supported yet.",
ASR::symbol_t* parent = nullptr;
if( x.n_bases > 1 ) {
throw SemanticError("Multiple inheritance in classes isn't supported yet.",
x.base.base.loc);
}
else if (x.n_bases == 1) {
std::string b_name = "";
if ( AST::is_a<AST::Name_t>(*x.m_bases[0]) ) {
b_name = AST::down_cast<AST::Name_t>(x.m_bases[0])->m_id;
} else {
throw SemanticError("Expected a Name here", x.base.base.loc);
}
parent = current_scope->resolve_symbol(b_name);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*parent));
}
SymbolTable *parent_scope = current_scope;
if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) {
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*sym));
Expand All @@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
init_self_type(*f, sym, x.base.base.loc);
if ( std::string(f->m_name) == std::string("__init__") ) {
this->visit_init_body(*f);
this->visit_init_body(*f, st->m_parent, x.m_body[i]->base.loc);
} else {
this->visit_stmt(*x.m_body[i]);
}
Expand Down Expand Up @@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
member_names.p, member_names.size(), member_fn_names.p,
member_fn_names.size(), class_abi, ASR::accessType::Public,
false, false, member_init.p, member_init.size(),
nullptr, nullptr));
nullptr, parent));
parent_scope->add_symbol(x.m_name, class_sym);
visit_ClassMembers(x, member_names, member_fn_names,
struct_dependencies, member_init, false, class_abi, true);
Expand Down Expand Up @@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope = parent_scope;
}

virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;
virtual void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) = 0;

void add_name(const Location &loc) {
std::string var_name = "__name__";
Expand Down Expand Up @@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
// Implement visit_Global for Symbol Table visitor.
void visit_Global(const AST::Global_t &/*x*/) {}

void visit_init_body (const AST::FunctionDef_t &/*x*/) {
void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) {
//Implemented in BodyVisitor
}

Expand Down Expand Up @@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = asr;
}

void visit_init_body (const AST::FunctionDef_t &x) {
void visit_init_body (const AST::FunctionDef_t &x, ASR::symbol_t* parent_sym, const Location loc) {
SymbolTable *old_scope = current_scope;
ASR::symbol_t *t = current_scope->get_symbol("__init__");
if ( t==nullptr ) {
Expand All @@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
throw SemanticError("__init__ is not a function", x.base.base.loc);
}
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
current_scope = f->m_symtab;
//Transform statements into correct format
Vec<AST::stmt_t*> new_body;
new_body.reserve(al, 1);
Vec<AST::stmt_t*> body;
body.reserve(al, 1);
ASR::stmt_t* super_call_stmt = nullptr;
for (size_t i=0; i<x.n_body; i++) {
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if ( ann_assign.m_value != nullptr ) {
Vec<AST::expr_t*>target;
target.reserve(al, 1);
target.push_back(al, ann_assign.m_target);
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
target.p, 1, ann_assign.m_value, nullptr);
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
new_body.push_back(al, assgn);
if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
if ( ann_assign.m_value != nullptr ) {
Vec<AST::expr_t*>target;
target.reserve(al, 1);
target.push_back(al, ann_assign.m_target);
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
target.p, 1, ann_assign.m_value, nullptr);
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
body.push_back(al, assgn);
}
} else if (AST::is_a<AST::Expr_t>(*x.m_body[i]) &&
AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value))) {
AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value);

if ( !AST::is_a<AST::Attribute_t>(*(c->m_func))
|| !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value)) ) {
body.push_back(al, x.m_body[i]);
continue;
}
AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value);
std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func)->m_attr;
if ( AST::is_a<AST::Name_t>(*(super_call->m_func)) &&
std::string(AST::down_cast<AST::Name_t>(super_call->m_func)->m_id)=="super" &&
attr == "__init__") {
if (parent_sym == nullptr) {
throw SemanticError("The class doesn't have a base class",loc);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, 1);
parse_args(*super_call,args);
ASR::call_arg_t first_arg;
first_arg.loc = loc;
ASR::symbol_t* self_sym = current_scope->get_symbol("self");
first_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al,loc,self_sym));
ASR::ttype_t* target_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc,parent_sym));
cast_helper(target_type, first_arg.m_value, x.base.base.loc, true);
Vec<ASR::call_arg_t> args_w_first; args_w_first.reserve(al,1);
args_w_first.push_back(al, first_arg);
for( size_t i = 0; i < args.size(); i++ ) {
args_w_first.push_back(al,args[i]);
}
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(parent_sym,call_name,loc);
super_call_stmt = ASRUtils::STMT(
ASR::make_SubroutineCall_t(al, loc, call_sym, call_sym, args_w_first.p,
args_w_first.size(), nullptr));
}
} else {
body.push_back(al, x.m_body[i]);
}
}
current_scope = f->m_symtab;
tanay-man marked this conversation as resolved.
Show resolved Hide resolved
Vec<ASR::stmt_t*> body;
body.reserve(al, x.n_body);
Vec<ASR::stmt_t*> body_asr;
body_asr.reserve(al, x.n_body);
if ( super_call_stmt ) {
body_asr.push_back(al, super_call_stmt);
}
Vec<ASR::symbol_t*> rts;
rts.reserve(al, 4);
dependencies.clear(al);
transform_stmts(body, new_body.n, new_body.p);
transform_stmts(body_asr, body.n, body.p);
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
f->m_body = body.p;
f->n_body = body.size();
f->m_body = body_asr.p;
f->n_body = body_asr.size();
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
f->m_function_signature);
func_type->m_restrictions = rts.p;
Expand Down Expand Up @@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) {
member_found = std::string(der_type->m_members[i]) == member_name;
}
if( !member_found ) {
if( !member_found && !der_type->m_parent ) {
throw SemanticError("No member " + member_name +
" found in " + std::string(der_type->m_name),
loc);
} else if ( !member_found && der_type->m_parent ) {
ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent));
visit_AttributeUtil(parent_type,attr_char,t,loc);
return;
}
ASR::expr_t *val = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);
Expand Down Expand Up @@ -8064,7 +8141,8 @@ we will have to use something else.
//TODO: Correct Class and ClassType
// call to struct member function
// modifying args to pass the object as self
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
ASR::symbol_t* der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
ASR::call_arg_t self_arg;
self_arg.loc = args[0].loc;
Expand All @@ -8073,7 +8151,20 @@ we will have to use something else.
for (size_t i=0; i<args.n; i++) {
new_args.push_back(al, args[i]);
}
st = get_struct_member(der, call_name, loc);
if ( der->m_symtab->get_symbol(call_name) ) {
st = get_struct_member(der_sym, call_name, loc);
} else if ( der->m_parent ) {
ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent);
if ( !parent->m_symtab->get_symbol(call_name) ) {
throw SemanticError("Method not found in the class "+ std::string(der->m_name) +
" or it's parents",loc);
} else {
st = get_struct_member(der->m_parent, call_name, loc);
}
} else {
throw SemanticError("Method not found in the class "+std::string(der->m_name)+
" or it's parents",loc);
}
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
return;
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-structs_09-f3ffe08.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
"stdout": null,
"stdout_hash": null,
"stderr": "asr-structs_09-f3ffe08.stderr",
"stderr_hash": "f59ab2d213f6423e0a891e43d5a19e83d4405391b1c7bf481b4b939e",
"stderr_hash": "14119a0bc6420ad242b99395d457f2092014d96d2a1ac81d376c649d",
"returncode": 2
}
Loading
Loading