Skip to content

Commit

Permalink
refactor: simplify AST structure for some enum nodes (#991)
Browse files Browse the repository at this point in the history
* refactor: simplify ast structure for some enum nodes

Signed-off-by: peefy <[email protected]>

* fix: parser and ast crate test cases

Signed-off-by: peefy <[email protected]>

* fix: ast structure test cases in tools

Signed-off-by: peefy <[email protected]>

---------

Signed-off-by: peefy <[email protected]>
  • Loading branch information
Peefy authored Jan 17, 2024
1 parent 8c60381 commit 51c9514
Show file tree
Hide file tree
Showing 139 changed files with 8,948 additions and 9,616 deletions.
36 changes: 17 additions & 19 deletions kclvm/ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ impl Module {

/// A statement
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type")]
pub enum Stmt {
TypeAlias(TypeAliasStmt),
Expr(ExprStmt),
Expand Down Expand Up @@ -643,7 +643,7 @@ pub struct SchemaIndexSignature {
pub struct SchemaAttr {
pub doc: String,
pub name: NodeRef<String>,
pub op: Option<BinOrAugOp>,
pub op: Option<AugOp>,
pub value: Option<NodeRef<Expr>>,
pub is_optional: bool,
pub decorators: Vec<NodeRef<CallExpr>>,
Expand All @@ -670,7 +670,7 @@ pub struct RuleStmt {

/// A expression
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type")]
pub enum Expr {
Identifier(Identifier),
Unary(UnaryExpr),
Expand Down Expand Up @@ -791,7 +791,7 @@ pub struct UnaryExpr {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BinaryExpr {
pub left: NodeRef<Expr>,
pub op: BinOrCmpOp,
pub op: BinOp,
pub right: NodeRef<Expr>,
}

Expand Down Expand Up @@ -1136,7 +1136,7 @@ pub struct Compare {
/// """long string literal"""
/// ```
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type")]
pub enum Literal {
Number(NumberLit),
String(StringLit),
Expand Down Expand Up @@ -1200,7 +1200,7 @@ impl NumberBinarySuffix {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type", content = "value")]
pub enum NumberLitValue {
Int(i64),
Float(f64),
Expand Down Expand Up @@ -1566,20 +1566,12 @@ impl CmpOp {

/// BinOrCmpOp is the set of all binary and comparison operators in KCL.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type")]
pub enum BinOrCmpOp {
Bin(BinOp),
Cmp(CmpOp),
}

/// BinOrAugOp is the set of all binary and argument operators in KCL.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
pub enum BinOrAugOp {
Bin(BinOp),
Aug(AugOp),
}

/// ExprContext represents the location information of the AST node.
/// The left side of the assignment symbol represents `Store`,
/// and the right side represents `Load`.
Expand All @@ -1591,7 +1583,7 @@ pub enum ExprContext {

/// A expression
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type", content = "value")]
pub enum Type {
Any,
Named(Identifier),
Expand Down Expand Up @@ -1634,14 +1626,20 @@ pub struct UnionType {
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", content = "data")]
#[serde(tag = "type", content = "value")]
pub enum LiteralType {
Bool(bool),
Int(i64, Option<NumberBinarySuffix>), // value + suffix
Int(IntLiteralType), // value + suffix
Float(f64),
Str(String),
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct IntLiteralType {
pub value: i64,
pub suffix: Option<NumberBinarySuffix>,
}

impl ToString for Type {
fn to_string(&self) -> String {
fn to_str(typ: &Type, w: &mut String) {
Expand Down Expand Up @@ -1697,7 +1695,7 @@ impl ToString for Type {
w.push_str("False");
}
}
LiteralType::Int(v, suffix) => {
LiteralType::Int(IntLiteralType { value: v, suffix }) => {
if let Some(suffix) = suffix {
w.push_str(&format!("{}{}", v, suffix.value()));
} else {
Expand Down
2 changes: 1 addition & 1 deletion kclvm/ast/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ fn get_dummy_assign_binary_ast() -> ast::Node<ast::AssignStmt> {
))],
value: Box::new(ast::Node::new(
ast::Expr::Binary(ast::BinaryExpr {
op: ast::BinOrCmpOp::Bin(ast::BinOp::Add),
op: ast::BinOp::Add,
left: Box::new(ast::Node::new(
ast::Expr::Identifier(ast::Identifier {
names: vec![Node::dummy_node(String::from("a"))],
Expand Down
10 changes: 2 additions & 8 deletions kclvm/ast_pretty/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,7 @@ impl<'p, 'ctx> MutSelfTypedResultWalker<'ctx> for Printer<'p> {
self.write(": ");
self.write(&schema_attr.ty.node.to_string());
if let Some(op) = &schema_attr.op {
let symbol = match op {
ast::BinOrAugOp::Bin(bin_op) => bin_op.symbol(),
ast::BinOrAugOp::Aug(aug_op) => aug_op.symbol(),
};
let symbol = op.symbol();
self.write_space();
self.write(symbol);
self.write_space();
Expand Down Expand Up @@ -382,10 +379,7 @@ impl<'p, 'ctx> MutSelfTypedResultWalker<'ctx> for Printer<'p> {
}

fn walk_binary_expr(&mut self, binary_expr: &'ctx ast::BinaryExpr) -> Self::Result {
let symbol = match &binary_expr.op {
ast::BinOrCmpOp::Bin(bin_op) => bin_op.symbol(),
ast::BinOrCmpOp::Cmp(cmp_op) => cmp_op.symbol(),
};
let symbol = binary_expr.op.symbol();
self.expr(&binary_expr.left);
self.write_space();
self.write(symbol);
Expand Down
68 changes: 20 additions & 48 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
if let Some(op) = &schema_attr.op {
match op {
// Union
ast::BinOrAugOp::Aug(ast::AugOp::BitOr) => {
ast::AugOp::BitOr => {
let org_value = self.build_call(
&ApiFunc::kclvm_dict_get_value.name(),
&[
Expand Down Expand Up @@ -1588,7 +1588,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
if let Some(op) = &schema_attr.op {
match op {
// Union
ast::BinOrAugOp::Aug(ast::AugOp::BitOr) => {
ast::AugOp::BitOr => {
let org_value = self.build_call(
&ApiFunc::kclvm_dict_get_value.name(),
&[
Expand Down Expand Up @@ -1668,11 +1668,8 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {

fn walk_binary_expr(&self, binary_expr: &'ctx ast::BinaryExpr) -> Self::Result {
check_backtrack_stop!(self);
let is_logic_op = matches!(
binary_expr.op,
ast::BinOrCmpOp::Bin(ast::BinOp::And) | ast::BinOrCmpOp::Bin(ast::BinOp::Or)
);
let is_membership_as_op = matches!(binary_expr.op, ast::BinOrCmpOp::Bin(ast::BinOp::As));
let is_logic_op = matches!(binary_expr.op, ast::BinOp::And | ast::BinOp::Or);
let is_membership_as_op = matches!(binary_expr.op, ast::BinOp::As);
if !is_logic_op {
let left_value = self
.walk_expr(&binary_expr.left)
Expand All @@ -1690,50 +1687,25 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
.expect(kcl_error::COMPILE_ERROR_MSG)
};
let value = match binary_expr.op {
ast::BinOrCmpOp::Bin(ast::BinOp::Add) => self.add(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::Sub) => self.sub(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::Mul) => self.mul(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::Div) => self.div(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::FloorDiv) => {
self.floor_div(left_value, right_value)
}
ast::BinOrCmpOp::Bin(ast::BinOp::Mod) => self.r#mod(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::Pow) => self.pow(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::LShift) => {
self.bit_lshift(left_value, right_value)
}
ast::BinOrCmpOp::Bin(ast::BinOp::RShift) => {
self.bit_rshift(left_value, right_value)
}
ast::BinOrCmpOp::Bin(ast::BinOp::BitAnd) => self.bit_and(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::BitOr) => self.bit_or(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::BitXor) => self.bit_xor(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::And) => self.logic_and(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::Or) => self.logic_or(left_value, right_value),
ast::BinOrCmpOp::Bin(ast::BinOp::As) => self.r#as(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::Eq) => self.cmp_equal_to(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::NotEq) => {
self.cmp_not_equal_to(left_value, right_value)
}
ast::BinOrCmpOp::Cmp(ast::CmpOp::Gt) => {
self.cmp_greater_than(left_value, right_value)
}
ast::BinOrCmpOp::Cmp(ast::CmpOp::GtE) => {
self.cmp_greater_than_or_equal(left_value, right_value)
}
ast::BinOrCmpOp::Cmp(ast::CmpOp::Lt) => self.cmp_less_than(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::LtE) => {
self.cmp_less_than_or_equal(left_value, right_value)
}
ast::BinOrCmpOp::Cmp(ast::CmpOp::Is) => self.is(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::IsNot) => self.is_not(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::Not) => self.is_not(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::NotIn) => self.not_in(left_value, right_value),
ast::BinOrCmpOp::Cmp(ast::CmpOp::In) => self.r#in(left_value, right_value),
ast::BinOp::Add => self.add(left_value, right_value),
ast::BinOp::Sub => self.sub(left_value, right_value),
ast::BinOp::Mul => self.mul(left_value, right_value),
ast::BinOp::Div => self.div(left_value, right_value),
ast::BinOp::FloorDiv => self.floor_div(left_value, right_value),
ast::BinOp::Mod => self.r#mod(left_value, right_value),
ast::BinOp::Pow => self.pow(left_value, right_value),
ast::BinOp::LShift => self.bit_lshift(left_value, right_value),
ast::BinOp::RShift => self.bit_rshift(left_value, right_value),
ast::BinOp::BitAnd => self.bit_and(left_value, right_value),
ast::BinOp::BitOr => self.bit_or(left_value, right_value),
ast::BinOp::BitXor => self.bit_xor(left_value, right_value),
ast::BinOp::And => self.logic_and(left_value, right_value),
ast::BinOp::Or => self.logic_or(left_value, right_value),
ast::BinOp::As => self.r#as(left_value, right_value),
};
Ok(value)
} else {
let jump_if_false = matches!(binary_expr.op, ast::BinOrCmpOp::Bin(ast::BinOp::And));
let jump_if_false = matches!(binary_expr.op, ast::BinOp::And);
let start_block = self.append_block("");
let value_block = self.append_block("");
let end_block = self.append_block("");
Expand Down
52 changes: 28 additions & 24 deletions kclvm/parser/src/parser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,33 +210,37 @@ impl<'a> Parser<'a> {

let y = self.do_parse_simple_expr(oprec);

// compare: a == b == c
if let BinOrCmpOp::Cmp(cmp_op) = op.clone() {
if cmp_expr.ops.is_empty() {
cmp_expr.left = x.clone();
match op {
// compare: a == b == c
BinOrCmpOp::Cmp(cmp_op) => {
if cmp_expr.ops.is_empty() {
cmp_expr.left = x.clone();
}
cmp_expr.ops.push(cmp_op);
cmp_expr.comparators.push(y);
continue;
}
cmp_expr.ops.push(cmp_op);
cmp_expr.comparators.push(y);
continue;
}
// binary a + b
BinOrCmpOp::Bin(bin_op) => {
if !cmp_expr.ops.is_empty() {
x = Box::new(Node::node(
Expr::Compare(cmp_expr.clone()),
self.sess.struct_token_loc(token, self.prev_token),
));
cmp_expr.ops = Vec::new();
cmp_expr.comparators = Vec::new();
}

if !cmp_expr.ops.is_empty() {
x = Box::new(Node::node(
Expr::Compare(cmp_expr.clone()),
self.sess.struct_token_loc(token, self.prev_token),
));
cmp_expr.ops = Vec::new();
cmp_expr.comparators = Vec::new();
x = Box::new(Node::node(
Expr::Binary(BinaryExpr {
left: x,
op: bin_op,
right: y,
}),
self.sess.struct_token_loc(token, self.prev_token),
));
}
}

x = Box::new(Node::node(
Expr::Binary(BinaryExpr {
left: x,
op,
right: y,
}),
self.sess.struct_token_loc(token, self.prev_token),
));
}
}

Expand Down
8 changes: 4 additions & 4 deletions kclvm/parser/src/parser/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl<'a> Parser<'a> {
doc: "".to_string(),
name: node_ref!(target.get_name(), targets[0].pos()),
ty: ty.unwrap(),
op: Some(BinOrAugOp::Aug(aug_op)),
op: Some(aug_op),
value: Some(value),
is_optional: false,
decorators: Vec::new(),
Expand Down Expand Up @@ -1058,7 +1058,7 @@ impl<'a> Parser<'a> {
assign.targets[0].pos()
),
ty: assign.ty.unwrap(),
op: Some(BinOrAugOp::Aug(AugOp::Assign)),
op: Some(AugOp::Assign),
value: Some(assign.value),
is_optional: false,
decorators: Vec::new(),
Expand Down Expand Up @@ -1215,10 +1215,10 @@ impl<'a> Parser<'a> {

let op = if self.token.kind == TokenKind::Assign {
self.bump_token(TokenKind::Assign);
Some(BinOrAugOp::Aug(AugOp::Assign))
Some(AugOp::Assign)
} else if let TokenKind::BinOpEq(x) = self.token.kind {
self.bump_token(self.token.kind);
Some(BinOrAugOp::Aug(x.into()))
Some(x.into())
} else {
None
};
Expand Down
10 changes: 8 additions & 2 deletions kclvm/parser/src/parser/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,15 @@ impl<'a> Parser<'a> {
let v = lit.symbol.as_str().parse::<i64>().unwrap();
if let Some(suffix) = lit.suffix {
let x = ast::NumberBinarySuffix::try_from(suffix.as_str().as_str());
ast::LiteralType::Int(v, Some(x.unwrap()))
ast::LiteralType::Int(ast::IntLiteralType {
value: v,
suffix: Some(x.unwrap()),
})
} else {
ast::LiteralType::Int(v, None)
ast::LiteralType::Int(ast::IntLiteralType {
value: v,
suffix: None,
})
}
}
token::LitKind::Float => {
Expand Down
2 changes: 1 addition & 1 deletion kclvm/parser/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ pub(crate) fn parsing_type_string(src: &str) -> String {
let stream = parse_token_streams(sess, src, new_byte_pos(0));
let mut parser = Parser::new(sess, stream);
let typ = parser.parse_type_annotation();
format!("{typ:?}\n")
format!("{typ:#?}\n")
})
}

Expand Down
Loading

0 comments on commit 51c9514

Please sign in to comment.