Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: resolve assert_stmt and config_entry correctly #924

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 81 additions & 47 deletions kclvm/sema/src/advanced_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,47 +396,12 @@ mod tests {
],
),
(
"src/advanced_resolver/test_data/pkg/pkg.k"
"src/advanced_resolver/test_data/import_test/d.k"
Peefy marked this conversation as resolved.
Show resolved Hide resolved
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
vec![
(1, 7, 1, 11, "Name".to_string(), SymbolKind::Schema),
(2, 4, 2, 8, "name".to_string(), SymbolKind::Attribute),
(4, 7, 4, 13, "Person".to_string(), SymbolKind::Schema),
(5, 4, 5, 8, "name".to_string(), SymbolKind::Attribute),
(5, 10, 5, 14, "Name".to_string(), SymbolKind::Unresolved),
(
1,
7,
1,
11,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Schema,
),
(5, 17, 5, 21, "Name".to_string(), SymbolKind::Unresolved),
(
1,
7,
1,
11,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Schema,
),
(5, 23, 5, 27, "name".to_string(), SymbolKind::Unresolved),
(
2,
4,
2,
8,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Attribute,
),
(1, 7, 1, 13, "Parent".to_string(), SymbolKind::Schema),
(2, 4, 2, 8, "age1".to_string(), SymbolKind::Attribute),
],
),
(
Expand Down Expand Up @@ -1068,6 +1033,40 @@ mod tests {
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Schema,
),
(34, 4, 34, 8, "name".to_string(), SymbolKind::Unresolved),
(
5,
4,
5,
8,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Attribute,
),
(34, 9, 34, 13, "name".to_string(), SymbolKind::Unresolved),
(
2,
4,
2,
8,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Attribute,
),
(37, 0, 37, 1, "x".to_string(), SymbolKind::Value),
(38, 16, 38, 17, "x".to_string(), SymbolKind::Unresolved),
(
37,
0,
37,
1,
"src/advanced_resolver/test_data/schema_symbols.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Value,
),
],
),
(
Expand All @@ -1089,12 +1088,47 @@ mod tests {
],
),
(
"src/advanced_resolver/test_data/import_test/d.k"
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
vec![
(1, 7, 1, 13, "Parent".to_string(), SymbolKind::Schema),
(2, 4, 2, 8, "age1".to_string(), SymbolKind::Attribute),
(1, 7, 1, 11, "Name".to_string(), SymbolKind::Schema),
(2, 4, 2, 8, "name".to_string(), SymbolKind::Attribute),
(4, 7, 4, 13, "Person".to_string(), SymbolKind::Schema),
(5, 4, 5, 8, "name".to_string(), SymbolKind::Attribute),
(5, 10, 5, 14, "Name".to_string(), SymbolKind::Unresolved),
(
1,
7,
1,
11,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Schema,
),
(5, 17, 5, 21, "Name".to_string(), SymbolKind::Unresolved),
(
1,
7,
1,
11,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Schema,
),
(5, 23, 5, 27, "name".to_string(), SymbolKind::Unresolved),
(
2,
4,
2,
8,
"src/advanced_resolver/test_data/pkg/pkg.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
SymbolKind::Attribute,
),
],
),
(
Expand Down Expand Up @@ -1202,7 +1236,7 @@ mod tests {
"src/advanced_resolver/test_data/schema_symbols.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
34_u64,
35_u64,
5_u64,
Some((33, 13, 33, 19, "Person".to_string(), SymbolKind::Unresolved)),
),
Expand Down Expand Up @@ -1264,7 +1298,7 @@ mod tests {
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
17_u64,
26_u64,
9_usize,
10_usize,
),
// __main__.Main schema expr scope
(
Expand All @@ -1273,7 +1307,7 @@ mod tests {
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
30,
41,
9,
10,
),
// pkg.Person schema expr scope
(
Expand All @@ -1282,16 +1316,16 @@ mod tests {
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
33,
21,
5,
6,
),
// __main__ package scope
(
"src/advanced_resolver/test_data/schema_symbols.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
34,
36,
31,
4,
5,
),
// import_test.a.Person expr scope
(
Expand Down
80 changes: 50 additions & 30 deletions kclvm/sema/src/advanced_resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
if let Some(if_cond) = &assert_stmt.if_cond {
self.expr(if_cond);
}
if let Some(msg) = &assert_stmt.msg {
self.expr(msg);
}
None
}

Expand Down Expand Up @@ -762,38 +765,38 @@ impl<'ctx> AdvancedResolver<'ctx> {
.get_scopes_mut()
.add_ref_to_scope(cur_scope, first_unresolved_ref);
}
if names.len() > 1 {
let mut parent_ty = self.ctx.node_ty_map.get(&first_name.id)?;

let mut parent_ty = self.ctx.node_ty_map.get(&first_name.id)?;

for index in 1..names.len() {
let name = names.get(index).unwrap();
let def_symbol_ref = self.gs.get_symbols().get_type_attribute(
&parent_ty,
&name.node,
self.get_current_module_info(),
)?;

let (start_pos, end_pos): Range = name.get_span_pos();
let ast_id = name.id.clone();
let mut unresolved =
UnresolvedSymbol::new(name.node.clone(), start_pos, end_pos, None);
unresolved.def = Some(def_symbol_ref);
let unresolved_ref = self
.gs
.get_symbols_mut()
.alloc_unresolved_symbol(unresolved, &ast_id);

let cur_scope = *self.ctx.scopes.last().unwrap();
self.gs
.get_scopes_mut()
.add_ref_to_scope(cur_scope, unresolved_ref);
for index in 1..names.len() {
let name = names.get(index).unwrap();
let def_symbol_ref = self.gs.get_symbols().get_type_attribute(
&parent_ty,
&name.node,
self.get_current_module_info(),
)?;

parent_ty = self.ctx.node_ty_map.get(&name.id)?;
if index == names.len() - 1 {
return Some(unresolved_ref);
let (start_pos, end_pos): Range = name.get_span_pos();
let ast_id = name.id.clone();
let mut unresolved =
UnresolvedSymbol::new(name.node.clone(), start_pos, end_pos, None);
unresolved.def = Some(def_symbol_ref);
let unresolved_ref = self
.gs
.get_symbols_mut()
.alloc_unresolved_symbol(unresolved, &ast_id);

let cur_scope = *self.ctx.scopes.last().unwrap();
self.gs
.get_scopes_mut()
.add_ref_to_scope(cur_scope, unresolved_ref);

parent_ty = self.ctx.node_ty_map.get(&name.id)?;
if index == names.len() - 1 {
return Some(unresolved_ref);
}
}
}

Some(symbol_ref)
}
None => {
Expand Down Expand Up @@ -975,17 +978,34 @@ impl<'ctx> AdvancedResolver<'ctx> {
let cur_scope = self.ctx.scopes.last().unwrap();
self.gs
.get_scopes_mut()
.set_owner_to_scope(*cur_scope, owner)
.set_owner_to_scope(*cur_scope, owner);
}

for entry in entries.iter() {
if let Some(key) = &entry.node.key {
self.ctx.maybe_def = true;
self.expr(key);
if let Some(key_symbol_ref) = self.expr(key) {
self.set_config_scope_owner(key_symbol_ref);
}
self.ctx.maybe_def = false;
}
self.expr(&entry.node.value);
}
self.leave_scope()
}

pub(crate) fn set_config_scope_owner(&mut self, key_symbol_ref: SymbolRef) {
let symbols = self.gs.get_symbols();

if let Some(def_symbol_ref) = symbols.get_symbol(key_symbol_ref).unwrap().get_definition() {
if let Some(def_ast_id) = symbols.symbols_info.symbol_ref_map.get(&def_symbol_ref) {
if let Some(def_ty) = self.ctx.node_ty_map.get(def_ast_id) {
if def_ty.is_schema() {
self.ctx.current_schema_symbol =
self.gs.get_symbols().get_type_symbol(&def_ty, None);
}
}
}
}
}
}
7 changes: 6 additions & 1 deletion kclvm/sema/src/advanced_resolver/test_data/schema_symbols.k
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,9 @@ p = Main{
age = b._b + a._person?.age
}

person = pkg.Person {}
person = pkg.Person {
name.name = ""
}

x = "123"
assert True, "${x}456"
25 changes: 19 additions & 6 deletions kclvm/sema/src/core/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct SymbolDB {
pub(crate) fully_qualified_name_map: IndexMap<String, SymbolRef>,
pub(crate) schema_builtin_symbols: IndexMap<SymbolRef, IndexMap<String, SymbolRef>>,
pub(crate) ast_id_map: IndexMap<AstIndex, SymbolRef>,
pub(crate) symbol_ty_map: IndexMap<SymbolRef, Arc<Type>>,
pub(crate) symbol_ref_map: IndexMap<SymbolRef, AstIndex>,
}

impl KCLSymbolData {
Expand Down Expand Up @@ -319,11 +319,6 @@ impl KCLSymbolData {
}
}

pub fn add_symbol_info(&mut self, symbol: SymbolRef, ty: Arc<Type>, ast_id: AstIndex) {
self.symbols_info.ast_id_map.insert(ast_id, symbol);
self.symbols_info.symbol_ty_map.insert(symbol, ty);
}

pub fn get_symbol_by_ast_index(&self, id: &AstIndex) -> Option<SymbolRef> {
self.symbols_info.ast_id_map.get(id).cloned()
}
Expand Down Expand Up @@ -434,6 +429,9 @@ impl KCLSymbolData {
self.symbols_info
.ast_id_map
.insert(ast_id.clone(), symbol_ref);
self.symbols_info
.symbol_ref_map
.insert(symbol_ref, ast_id.clone());
self.schemas.get_mut(symbol_id).unwrap().id = Some(symbol_ref);
symbol_ref
}
Expand All @@ -451,6 +449,9 @@ impl KCLSymbolData {
self.symbols_info
.ast_id_map
.insert(ast_id.clone(), symbol_ref);
self.symbols_info
.symbol_ref_map
.insert(symbol_ref, ast_id.clone());
self.unresolved.get_mut(symbol_id).unwrap().id = Some(symbol_ref);
symbol_ref
}
Expand All @@ -468,6 +469,9 @@ impl KCLSymbolData {
self.symbols_info
.ast_id_map
.insert(ast_id.clone(), symbol_ref);
self.symbols_info
.symbol_ref_map
.insert(symbol_ref, ast_id.clone());
self.type_aliases.get_mut(symbol_id).unwrap().id = Some(symbol_ref);
symbol_ref
}
Expand All @@ -481,6 +485,9 @@ impl KCLSymbolData {
self.symbols_info
.ast_id_map
.insert(ast_id.clone(), symbol_ref);
self.symbols_info
.symbol_ref_map
.insert(symbol_ref, ast_id.clone());
self.rules.get_mut(symbol_id).unwrap().id = Some(symbol_ref);
symbol_ref
}
Expand All @@ -498,6 +505,9 @@ impl KCLSymbolData {
self.symbols_info
.ast_id_map
.insert(ast_id.clone(), symbol_ref);
self.symbols_info
.symbol_ref_map
.insert(symbol_ref, ast_id.clone());
self.attributes.get_mut(symbol_id).unwrap().id = Some(symbol_ref);
symbol_ref
}
Expand All @@ -511,6 +521,9 @@ impl KCLSymbolData {
self.symbols_info
.ast_id_map
.insert(ast_id.clone(), symbol_ref);
self.symbols_info
.symbol_ref_map
.insert(symbol_ref, ast_id.clone());
self.values.get_mut(symbol_id).unwrap().id = Some(symbol_ref);
symbol_ref
}
Expand Down
Loading
Loading