Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and improve type checking. #28481

Draft
wants to merge 1 commit into
base: mainnet
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/ast/src/functions/core_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use leo_span::{Symbol, sym};

/// A core instruction that maps directly to an AVM bytecode instruction.
#[derive(Clone, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CoreFunction {
BHP256CommitToAddress,
BHP256CommitToField,
Expand Down
2 changes: 1 addition & 1 deletion compiler/ast/src/struct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub struct Composite {

impl PartialEq for Composite {
fn eq(&self, other: &Self) -> bool {
self.identifier == other.identifier
self.identifier == other.identifier && self.external == other.external
}
}

Expand Down
6 changes: 5 additions & 1 deletion compiler/ast/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub struct CompositeType {

impl fmt::Display for CompositeType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{id}", id = self.id)
if let Some(program) = self.program {
write!(f, "{}.aleo/{}", program, self.id)
} else {
write!(f, "{}", self.id)
}
}
}
3 changes: 2 additions & 1 deletion compiler/ast/src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

use crate::Type;

use itertools::Itertools as _;
use serde::{Deserialize, Serialize};
use std::fmt;

Expand Down Expand Up @@ -44,6 +45,6 @@ impl TupleType {

impl fmt::Display for TupleType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({})", self.elements.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
write!(f, "({})", self.elements.iter().format(", "))
}
}
51 changes: 3 additions & 48 deletions compiler/ast/src/types/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use snarkvm::prelude::{
use std::fmt;

/// Explicit type used for defining a variable or expression type
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Type {
/// The `address` type.
Address,
Expand Down Expand Up @@ -61,56 +61,11 @@ pub enum Type {
Unit,
/// Placeholder for a type that could not be resolved or was not well-formed.
/// Will eventually lead to a compile error.
#[default]
Err,
}

impl Type {
///
/// Returns `true` if the self `Type` is equal to the other `Type`.
///
/// Flattens array syntax: `[[u8; 1]; 2] == [u8; (2, 1)] == true`
///
pub fn eq_flat(&self, other: &Self) -> bool {
match (self, other) {
(Type::Address, Type::Address)
| (Type::Boolean, Type::Boolean)
| (Type::Field, Type::Field)
| (Type::Group, Type::Group)
| (Type::Scalar, Type::Scalar)
| (Type::Signature, Type::Signature)
| (Type::String, Type::String)
| (Type::Unit, Type::Unit) => true,
(Type::Array(left), Type::Array(right)) => {
left.element_type().eq_flat(right.element_type()) && left.length() == right.length()
}
(Type::Identifier(left), Type::Identifier(right)) => left.name == right.name,
(Type::Integer(left), Type::Integer(right)) => left.eq(right),
(Type::Mapping(left), Type::Mapping(right)) => {
left.key.eq_flat(&right.key) && left.value.eq_flat(&right.value)
}
(Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
.elements()
.iter()
.zip_eq(right.elements().iter())
.all(|(left_type, right_type)| left_type.eq_flat(right_type)),
(Type::Composite(left), Type::Composite(right)) => {
left.id.name == right.id.name && left.program == right.program
}
(Type::Future(left), Type::Future(right))
if left.inputs.len() == right.inputs.len() && left.location.is_some() && right.location.is_some() =>
{
left.location == right.location
&& left
.inputs()
.iter()
.zip_eq(right.inputs().iter())
.all(|(left_type, right_type)| left_type.eq_flat(right_type))
}
_ => false,
}
}

///
/// Returns `true` if the self `Type` is equal to the other `Type` in all aspects besides composite program of origin.
///
/// In the case of futures, it also makes sure that if both are not explicit, they are equal.
Expand Down Expand Up @@ -194,7 +149,7 @@ impl fmt::Display for Type {
Type::Scalar => write!(f, "scalar"),
Type::Signature => write!(f, "signature"),
Type::String => write!(f, "string"),
Type::Composite(ref struct_type) => write!(f, "{}", struct_type.id.name),
Type::Composite(ref struct_type) => write!(f, "{struct_type}"),
Type::Tuple(ref tuple) => write!(f, "{tuple}"),
Type::Unit => write!(f, "()"),
Type::Err => write!(f, "error"),
Expand Down
2 changes: 1 addition & 1 deletion compiler/parser/src/parser/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<N: Network> ParserContext<'_, N> {
}
}

Ok((Type::Composite(CompositeType { id: ident, program: self.program_name }), ident.span))
Ok((Type::Composite(CompositeType { id: ident, program: None }), ident.span))
d0cd marked this conversation as resolved.
Show resolved Hide resolved
} else if self.token.token == Token::LeftSquare {
// Parse the left bracket.
self.expect(&Token::LeftSquare)?;
Expand Down
4 changes: 4 additions & 0 deletions compiler/passes/src/common/symbol_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ impl SymbolTable {
pub fn lookup_struct(&self, location: Location, main_program: Option<Symbol>) -> Option<&Composite> {
if let Some(struct_) = self.structs.get(&location) {
return Some(struct_);
} else if location.program.is_none() {
if let Some(struct_) = self.structs.get(&Location::new(main_program, location.name)) {
return Some(struct_);
}
} else if location.program == main_program {
if let Some(struct_) = self.structs.get(&Location::new(None, location.name)) {
return Some(struct_);
Expand Down
3 changes: 2 additions & 1 deletion compiler/passes/src/symbol_table_creation/creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ impl<'a> ProgramVisitor<'a> for SymbolTableCreator<'a> {
if !input.is_record && !self.structs.insert(input.name()) {
return self.handler.emit_err::<LeoError>(AstError::shadowed_struct(input.name(), input.span).into());
}
if let Err(err) = self.symbol_table.insert_struct(Location::new(input.external, input.name()), input) {
let program_name = input.external.or(self.program_name);
if let Err(err) = self.symbol_table.insert_struct(Location::new(program_name, input.name()), input) {
self.handler.emit_err(err);
}
}
Expand Down
Loading
Loading