Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: local var set value in the internal scope #1444

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions kclvm/compiler/src/codegen/llvm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub struct Scope<'ctx> {
pub schema_scalar_idx: RefCell<usize>,
/// Scope normal variables
pub variables: RefCell<IndexMap<String, PointerValue<'ctx>>>,
/// Scope normal initialized variables
pub uninitialized: RefCell<IndexSet<String>>,
/// Scope closures referenced by internal scope.
pub closures: RefCell<IndexMap<String, PointerValue<'ctx>>>,
/// Potential arguments in the current scope, such as schema/lambda arguments.
Expand Down Expand Up @@ -1717,40 +1719,42 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
level
}

/// Append a variable or update the existed local variable.
pub fn add_or_update_local_variable(&self, name: &str, value: BasicValueEnum<'ctx>) {
/// Append a variable or update the existed closure variable within the current scope.
pub fn add_or_update_local_variable_within_scope(
&self,
name: &str,
value: Option<BasicValueEnum<'ctx>>,
) {
let current_pkgpath = self.current_pkgpath();
let mut pkg_scopes = self.pkg_scopes.borrow_mut();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let mut existed = false;
// Query the variable in all scopes.
for i in 0..scopes.len() {
let index = scopes.len() - i - 1;
let variables_mut = scopes[index].variables.borrow_mut();
let index = scopes.len() - 1;
if let Some(scope) = scopes.last_mut() {
let mut variables_mut = scope.variables.borrow_mut();
let mut uninitialized = scope.uninitialized.borrow_mut();
if value.is_none() {
uninitialized.insert(name.to_string());
} else {
uninitialized.remove(name);
}
match variables_mut.get(&name.to_string()) {
// If the local variable is found, store the new value for the variable.
// We cannot update rule/lambda/schema arguments because they are read-only.
Some(ptr)
if index > GLOBAL_LEVEL
&& !self.local_vars.borrow().contains(name)
&& !scopes[index].arguments.borrow().contains(name) =>
{
self.builder.build_store(*ptr, value);
existed = true;
Some(ptr) if index > GLOBAL_LEVEL => {
if let Some(value) = value {
self.builder.build_store(*ptr, value);
}
}
_ => {}
}
}
// If not found, alloc a new variable.
if !existed {
let ptr = self.builder.build_alloca(self.value_ptr_type(), name);
self.builder.build_store(ptr, value);
// Store the value for the variable and add the variable into the current scope.
if let Some(last) = scopes.last_mut() {
let mut variables = last.variables.borrow_mut();
variables.insert(name.to_string(), ptr);
}
_ => {
let ptr = self.builder.build_alloca(self.value_ptr_type(), name);
if let Some(value) = value {
self.builder.build_store(ptr, value);
}
// Store the value for the variable and add the variable into the current scope.
variables_mut.insert(name.to_string(), ptr);
}
};
}
}

Expand Down Expand Up @@ -2001,6 +2005,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
for i in 0..scopes_len {
let index = scopes_len - i - 1;
let variables = scopes[index].variables.borrow();
// Skip uninitialized pointer value, which may cause NPE.
let uninitialized = scopes[index].uninitialized.borrow();
if let Some(var) = variables.get(&name.to_string()) {
// Closure vars, 2 denotes the builtin scope and the global scope, here is a closure scope.
let value = if i >= 1 && i < scopes_len - 2 {
Expand Down Expand Up @@ -2065,6 +2071,9 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
],
)
} else {
if uninitialized.contains(name) {
continue;
}
self.builder.build_load(*var, name)
}
};
Expand Down
46 changes: 46 additions & 0 deletions kclvm/compiler/src/codegen/llvm/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,52 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
}
}

pub(crate) fn emit_config_if_entry_expr_vars(
&self,
config_if_entry_expr: &'ctx ast::ConfigIfEntryExpr,
) {
self.emit_config_entries_vars(&config_if_entry_expr.items);
if let Some(orelse) = &config_if_entry_expr.orelse {
// Config expr or config if entry expr.
if let ast::Expr::Config(config_expr) = &orelse.node {
self.emit_config_entries_vars(&config_expr.items);
} else if let ast::Expr::ConfigIfEntry(config_if_entry_expr) = &orelse.node {
self.emit_config_if_entry_expr_vars(config_if_entry_expr);
}
}
}

pub(crate) fn emit_config_entries_vars(&self, items: &'ctx [ast::NodeRef<ast::ConfigEntry>]) {
for item in items {
if let ast::Expr::ConfigIfEntry(config_if_entry_expr) = &item.node.value.node {
self.emit_config_if_entry_expr_vars(config_if_entry_expr);
}
if let Some(key) = &item.node.key {
let optional_name = match &key.node {
ast::Expr::Identifier(identifier) => Some(identifier.names[0].node.clone()),
ast::Expr::StringLit(string_lit) => Some(string_lit.value.clone()),
ast::Expr::Subscript(subscript) => {
let mut name = None;
if let ast::Expr::Identifier(identifier) = &subscript.value.node {
if let Some(index_node) = &subscript.index {
if let ast::Expr::NumberLit(number) = &index_node.node {
if let ast::NumberLitValue::Int(_) = number.value {
name = Some(identifier.names[0].node.clone())
}
}
}
}
name
}
_ => None,
};
if let Some(name) = &optional_name {
self.add_or_update_local_variable_within_scope(name, None);
}
}
}
}

/// Compile AST Modules, which requires traversing three times.
/// 1. scan all possible global variables and allocate undefined values to global pointers.
/// 2. build all user-defined schema/rule types.
Expand Down
18 changes: 13 additions & 5 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
let else_block = self.append_block("");
let end_block = self.append_block("");
let is_truth = self.value_is_truthy(cond);
self.emit_config_if_entry_expr_vars(config_if_entry_expr);
let tpe = self.value_ptr_type();
self.cond_br(is_truth, then_block, else_block);
self.builder.position_at_end(then_block);
Expand All @@ -1992,7 +1993,13 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
self.br(end_block);
self.builder.position_at_end(else_block);
let else_value = if let Some(orelse) = &config_if_entry_expr.orelse {
self.walk_expr(orelse).expect(kcl_error::COMPILE_ERROR_MSG)
// Config expr or config if entry expr.
if let ast::Expr::Config(config_expr) = &orelse.node {
self.walk_config_entries(&config_expr.items)
.expect(kcl_error::COMPILE_ERROR_MSG)
} else {
self.walk_expr(orelse).expect(kcl_error::COMPILE_ERROR_MSG)
}
} else {
self.none_value()
};
Expand Down Expand Up @@ -2076,7 +2083,10 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {

fn walk_config_expr(&self, config_expr: &'ctx ast::ConfigExpr) -> Self::Result {
check_backtrack_stop!(self);
self.walk_config_entries(&config_expr.items)
self.enter_scope();
let result = self.walk_config_entries(&config_expr.items);
self.leave_scope();
result
}

fn walk_check_expr(&self, check_expr: &'ctx ast::CheckExpr) -> Self::Result {
Expand Down Expand Up @@ -2830,7 +2840,6 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
items: &'ctx [NodeRef<ConfigEntry>],
) -> CompileResult<'ctx> {
let config_value = self.dict_value();
self.enter_scope();
for item in items {
let value = self.walk_expr(&item.node.value)?;
if let Some(key) = &item.node.key {
Expand Down Expand Up @@ -2869,7 +2878,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
if let Some(name) = &optional_name {
let value =
self.dict_get(config_value, self.native_global_string(name, "").into());
self.add_or_update_local_variable(name, value);
self.add_or_update_local_variable_within_scope(name, Some(value));
}
} else {
// If the key does not exist, execute the logic of unpacking expression `**expr` here.
Expand All @@ -2879,7 +2888,6 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
);
}
}
self.leave_scope();
Ok(config_value)
}
}
16 changes: 11 additions & 5 deletions kclvm/evaluator/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,12 @@ impl<'ctx> TypedResultWalker<'ctx> for Evaluator<'ctx> {
Ok(if is_truth {
self.walk_config_entries(&config_if_entry_expr.items)?
} else if let Some(orelse) = &config_if_entry_expr.orelse {
self.walk_expr(orelse)?
// Config expr or config if entry expr.
if let ast::Expr::Config(config_expr) = &orelse.node {
self.walk_config_entries(&config_expr.items)?
} else {
self.walk_expr(orelse)?
}
} else {
self.none_value()
})
Expand Down Expand Up @@ -927,7 +932,10 @@ impl<'ctx> TypedResultWalker<'ctx> for Evaluator<'ctx> {

#[inline]
fn walk_config_expr(&self, config_expr: &'ctx ast::ConfigExpr) -> Self::Result {
self.walk_config_entries(&config_expr.items)
self.enter_scope();
let result = self.walk_config_entries(&config_expr.items);
self.leave_scope();
result
}

fn walk_check_expr(&self, check_expr: &'ctx ast::CheckExpr) -> Self::Result {
Expand Down Expand Up @@ -1509,7 +1517,6 @@ impl<'ctx> Evaluator<'ctx> {

pub(crate) fn walk_config_entries(&self, items: &'ctx [NodeRef<ConfigEntry>]) -> EvalResult {
let mut config_value = self.dict_value();
self.enter_scope();
for item in items {
let value = self.walk_expr(&item.node.value)?;
if let Some(key) = &item.node.key {
Expand Down Expand Up @@ -1547,14 +1554,13 @@ impl<'ctx> Evaluator<'ctx> {
);
if let Some(name) = &optional_name {
let value = self.dict_get_value(&config_value, name);
self.add_or_update_local_variable(name, value);
self.add_or_update_local_variable_within_scope(name, value);
}
} else {
// If the key does not exist, execute the logic of unpacking expression `**expr` here.
config_value.dict_insert_unpack(&mut self.runtime_ctx.borrow_mut(), &value)
}
}
self.leave_scope();
Ok(config_value)
}
}
31 changes: 7 additions & 24 deletions kclvm/evaluator/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,34 +301,17 @@ impl<'ctx> Evaluator<'ctx> {
level
}

/// Append a variable or update the existed local variable.
pub fn add_or_update_local_variable(&self, name: &str, value: ValueRef) {
/// Append a variable or update the existed local variable within the current scope.
pub(crate) fn add_or_update_local_variable_within_scope(&self, name: &str, value: ValueRef) {
let current_pkgpath = self.current_pkgpath();
let is_local_var = self.is_local_var(name);
let pkg_scopes = &mut self.pkg_scopes.borrow_mut();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let mut existed = false;
// Query the variable in all scopes.
for i in 0..scopes.len() {
let index = scopes.len() - i - 1;
let is_argument = scopes[index].arguments.contains(name);
let variables_mut = &mut scopes[index].variables;
match variables_mut.get(&name.to_string()) {
// If the local variable is found, store the new value for the variable.
// We cannot update rule/lambda/schema arguments because they are read-only.
Some(_) if index > GLOBAL_LEVEL && !is_local_var && !is_argument => {
variables_mut.insert(name.to_string(), value.clone());
existed = true;
}
_ => {}
}
}
// If not found, alloc a new variable.
if !existed {
// Store the value for the variable and add the variable into the current scope.
if let Some(last) = scopes.last_mut() {
last.variables.insert(name.to_string(), value);
let index = scopes.len() - 1;
if let Some(scope) = scopes.last_mut() {
let variables_mut = &mut scope.variables;
if index > GLOBAL_LEVEL {
variables_mut.insert(name.to_string(), value.clone());
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_15/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
render = lambda {
a = {
foo: "bar"
}
b = {
foo2: "bar2"
a: {
b: "c"
}
}
c = [a, b]
}

out = render()
a = {
foo: "bar"
}
b = {
foo2: "bar2"
a: {
b: "c"
}
}
c = [a, b]
16 changes: 16 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_15/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
out:
- foo: bar
- foo2: bar2
a:
b: c
a:
foo: bar
b:
foo2: bar2
a:
b: c
c:
- foo: bar
- foo2: bar2
a:
b: c
9 changes: 9 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_16/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
level0 = {
name = "apple"
level1_a = {
name = "orange"
}
level1_b = {
name = "pine" + name
}
}
6 changes: 6 additions & 0 deletions test/grammar/datatype/dict/mutual_ref_16/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
level0:
name: apple
level1_a:
name: orange
level1_b:
name: pineapple
Loading