Skip to content

Commit

Permalink
compiler: Support record spread
Browse files Browse the repository at this point in the history
Thsi feels even messier than in the interpreter but I am not really sure
how else to do this. Maybe with some better static analysis that
computes record layouts at compile-time.
  • Loading branch information
tekknolagi committed Jun 5, 2024
1 parent 3f5223a commit 8328c54
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
12 changes: 11 additions & 1 deletion compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,19 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En
if isinstance(pattern, Record):
self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}")
updates = {}
seen_key_indices: list[int] = []
for key, pattern_value in pattern.data.items():
assert not isinstance(pattern_value, Spread), "record spread not yet supported"
if isinstance(pattern_value, Spread):
use_spread = True
if pattern_value.name:
num_seen_keys = len(seen_key_indices)
self._emit(
f"size_t seen_keys[{num_seen_keys}] = {{ {', '.join(map(str, seen_key_indices))} }};"
)
updates[pattern_value.name] = self._mktemp(f"record_rest({arg}, seen_keys, {num_seen_keys})")
break
key_idx = self.record_key(key)
seen_key_indices.append(key_idx)
record_value = self._mktemp(f"record_get({arg}, {key_idx})")
self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}")
updates.update(self.try_match(env, record_value, pattern_value, fallthrough))
Expand Down
9 changes: 7 additions & 2 deletions compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,17 @@ def test_match_list(self) -> None:
def test_match_list_spread(self) -> None:
self.assertEqual(self._run("f [4, 5] . f = | [_, ...xs] -> xs"), "[5]\n")

def test_match_list_spread_empty(self) -> None:
self.assertEqual(self._run("f [4] . f = | [_, ...xs] -> xs"), "[]\n")

def test_match_record(self) -> None:
self.assertEqual(self._run("f {a = 4, b = 5} . f = | {a = 1, b = 2} -> 3 | {a = 4, b = 5} -> 6"), "6\n")

@unittest.skip("TODO")
def test_match_record_spread(self) -> None:
self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "[5]\n")
self.assertEqual(self._run("f {a=1, b=2, c=3} . f = | {a=1, ...rest} -> rest"), "{b = 2, c = 3}\n")

def test_match_record_spread_empty(self) -> None:
self.assertEqual(self._run("f {a=1} . f = | {a=1, ...rest} -> rest"), "{}\n")

def test_match_hole(self) -> None:
self.assertEqual(self._run("f () . f = | 1 -> 3 | () -> 4"), "4\n")
Expand Down
32 changes: 32 additions & 0 deletions runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ struct object* record_get(struct object* record, size_t key) {
return NULL;
}

size_t record_num_fields(struct object* record) {
return as_record(record)->size;
}

bool is_string(struct object* obj) {
if (is_small_string(obj)) {
return true;
Expand Down Expand Up @@ -603,6 +607,34 @@ struct object* list_cons(struct object* item, struct object* list) {
return result;
}

bool array_contains(size_t* haystack, size_t size, size_t needle) {
for (size_t i = 0; i < size; i++) {
if (haystack[i] == needle) {
return true;
}
}
return false;
}

struct object* record_rest(struct object* record, size_t* exclude,
size_t num_excluded) {
// NB: This is used in a match expression so it is assumed that all of the
// key indices in the exclude array are present in the record and that there
// are no duplicates in either the record or the exclude array.
HANDLES();
GC_PROTECT(record);
size_t num_keys = record_num_fields(record);
size_t num_result_keys = num_keys - num_excluded;
struct object* result = mkrecord(heap, num_result_keys);
for (size_t src = 0, dst = 0; dst < num_result_keys; src++) {
struct record_field field = as_record(record)->fields[src];
if (!array_contains(exclude, num_excluded, field.key)) {
record_set(result, dst++, field);
}
}
return result;
}

struct object* heap_string_concat(struct object* a, struct object* b) {
uword a_size = string_length(a);
uword b_size = string_length(b);
Expand Down

0 comments on commit 8328c54

Please sign in to comment.