diff --git a/examples/language-features/option.qnt b/examples/language-features/option.qnt index cbecf8b6a..1c0555aaf 100644 --- a/examples/language-features/option.qnt +++ b/examples/language-features/option.qnt @@ -1,8 +1,7 @@ module option { - // a demonstration of discriminated unions, specifying an option type. + // A demonstration of sum types, specifying an option type. // An option type for values. - // This type declaration is not required. It only defines an alias. type Vote_option = | None | Some(int) diff --git a/quint/src/cliCommands.ts b/quint/src/cliCommands.ts index a069d7292..e7b1eb2aa 100644 --- a/quint/src/cliCommands.ts +++ b/quint/src/cliCommands.ts @@ -622,7 +622,7 @@ export async function verifySpec(prev: TypecheckedStage): Promise analyzeInc(verifying, verifying.table, def)) + analyzeInc(verifying, verifying.table, extraDefs) // Flatten modules, replacing instances, imports and exports with their definitions const { flattenedModules, flattenedTable, flattenedAnalysis } = flattenModules( diff --git a/quint/src/graphics.ts b/quint/src/graphics.ts index 0de449d75..52804c81b 100644 --- a/quint/src/graphics.ts +++ b/quint/src/graphics.ts @@ -101,6 +101,20 @@ export function prettyQuintEx(ex: QuintEx): Doc { return nary(text('{'), kvs, text('}'), line()) } + case 'variant': { + const labelExpr = ex.args[0] + assert(labelExpr.kind === 'str', 'malformed variant operator') + const label = richtext(chalk.green, labelExpr.value) + + const valueExpr = ex.args[1] + const value = + valueExpr.kind === 'app' && valueExpr.opcode === 'Rec' && valueExpr.args.length === 0 + ? [] // A payload with the empty record is shown as a bare label + : [text('('), prettyQuintEx(valueExpr), text(')')] + + return group([label, ...value]) + } + default: // instead of throwing, show it in red return richtext(chalk.red, `unsupported operator: ${ex.opcode}(...)`) diff --git a/quint/src/parsing/quintParserFrontend.ts b/quint/src/parsing/quintParserFrontend.ts index 8d4729242..3b210fef5 100644 --- a/quint/src/parsing/quintParserFrontend.ts +++ b/quint/src/parsing/quintParserFrontend.ts @@ -67,10 +67,10 @@ export interface ParserPhase3 extends ParserPhase2 { export interface ParserPhase4 extends ParserPhase3 {} /** - * The result of parsing an expression or unit. + * The result of parsing an expression or collection of declarations. */ export type ExpressionOrDeclarationParseResult = - | { kind: 'declaration'; decl: QuintDeclaration } + | { kind: 'declaration'; decls: QuintDeclaration[] } | { kind: 'expr'; expr: QuintEx } | { kind: 'none' } | { kind: 'error'; errors: QuintError[] } @@ -327,8 +327,8 @@ export function parse( export function parseDefOrThrow(text: string, idGen?: IdGenerator, sourceMap?: SourceMap): QuintDef { const result = parseExpressionOrDeclaration(text, '', idGen ?? newIdGenerator(), sourceMap ?? new Map()) - if (result.kind === 'declaration' && isDef(result.decl)) { - return result.decl + if (result.kind === 'declaration' && isDef(result.decls[0])) { + return result.decls[0] } else { const msg = result.kind === 'error' ? result.errors.join('\n') : `Expected a definition, got ${result.kind}` throw new Error(`${msg}, parsing ${text}`) @@ -383,8 +383,9 @@ class ExpressionOrDeclarationListener extends ToIrListener { exitDeclarationOrExpr(ctx: p.DeclarationOrExprContext) { if (ctx.declaration()) { - const decl = this.declarationStack[this.declarationStack.length - 1] - this.result = { kind: 'declaration', decl } + const prevDecls = this.result?.kind === 'declaration' ? this.result.decls : [] + const decls = this.declarationStack + this.result = { kind: 'declaration', decls: [...prevDecls, ...decls] } } else if (ctx.expr()) { const expr = this.exprStack[this.exprStack.length - 1] this.result = { kind: 'expr', expr } diff --git a/quint/src/quintAnalyzer.ts b/quint/src/quintAnalyzer.ts index a5ad5508f..5d64b36ef 100644 --- a/quint/src/quintAnalyzer.ts +++ b/quint/src/quintAnalyzer.ts @@ -57,10 +57,10 @@ export function analyzeModules(lookupTable: LookupTable, quintModules: QuintModu export function analyzeInc( analysisOutput: AnalysisOutput, lookupTable: LookupTable, - declaration: QuintDeclaration + declarations: QuintDeclaration[] ): AnalysisResult { const analyzer = new QuintAnalyzer(lookupTable, analysisOutput) - analyzer.analyzeDeclaration(declaration) + analyzer.analyzeDeclarations(declarations) return analyzer.getResult() } @@ -94,11 +94,7 @@ class QuintAnalyzer { this.analyzeDeclarations(module.declarations) } - analyzeDeclaration(decl: QuintDeclaration): void { - this.analyzeDeclarations([decl]) - } - - private analyzeDeclarations(decls: QuintDeclaration[]): void { + analyzeDeclarations(decls: QuintDeclaration[]): void { const [typeErrMap, types] = this.typeInferrer.inferTypes(decls) const [effectErrMap, effects] = this.effectInferrer.inferEffects(decls) const updatesErrMap = this.multipleUpdatesChecker.checkEffects([...effects.values()]) diff --git a/quint/src/repl.ts b/quint/src/repl.ts index b9b0e958f..c2bc6c355 100644 --- a/quint/src/repl.ts +++ b/quint/src/repl.ts @@ -20,7 +20,7 @@ import { FlatModule, QuintDef, QuintEx } from './ir/quintIr' import { CompilationContext, CompilationState, - compileDecl, + compileDecls, compileExpr, compileFromCode, contextNameLookup, @@ -582,7 +582,7 @@ function tryEval(out: writer, state: ReplState, newInput: string): boolean { } if (parseResult.kind === 'declaration') { // compile the module and add it to history if everything worked - const context = compileDecl(state.compilationState, state.evaluationState, state.rng, parseResult.decl) + const context = compileDecls(state.compilationState, state.evaluationState, state.rng, parseResult.decls) if ( context.evaluationState.context.size === 0 || diff --git a/quint/src/runtime/compile.ts b/quint/src/runtime/compile.ts index a0b54cd0f..fbec5aeca 100644 --- a/quint/src/runtime/compile.ts +++ b/quint/src/runtime/compile.ts @@ -184,7 +184,7 @@ export function compileExpr( // Hence, we have to compile it via an auxilliary definition. const def: QuintDef = { kind: 'def', qualifier: 'action', name: inputDefName, expr, id: state.idGen.nextId() } - return compileDecl(state, evaluationState, rng, def) + return compileDecls(state, evaluationState, rng, [def]) } /** @@ -195,15 +195,15 @@ export function compileExpr( * @param state - The current compilation state * @param evaluationState - The current evaluation state * @param rng - The random number generator - * @param decl - The Quint declaration to be compiled + * @param decls - The Quint declarations to be compiled * * @returns A compilation context with the compiled definition or its errors */ -export function compileDecl( +export function compileDecls( state: CompilationState, evaluationState: EvaluationState, rng: Rng, - decl: QuintDeclaration + decls: QuintDeclaration[] ): CompilationContext { if (state.originalModules.length === 0 || state.modules.length === 0) { throw new Error('No modules in state') @@ -213,7 +213,7 @@ export function compileDecl( // ensuring the original object is not modified const originalModules = state.originalModules.map(m => { if (m.name === state.mainName) { - return { ...m, declarations: [...m.declarations, decl] } + return { ...m, declarations: [...m.declarations, ...decls] } } return m }) @@ -233,7 +233,7 @@ export function compileDecl( return errorContextFromMessage(evaluationState.listener)({ errors, sourceMap: state.sourceMap }) } - const [analysisErrors, analysisOutput] = analyzeInc(state.analysisOutput, table, decl) + const [analysisErrors, analysisOutput] = analyzeInc(state.analysisOutput, table, decls) const { flattenedModules: flatModules, diff --git a/quint/src/runtime/impl/compilerImpl.ts b/quint/src/runtime/impl/compilerImpl.ts index 00c9fd051..91f56b03c 100644 --- a/quint/src/runtime/impl/compilerImpl.ts +++ b/quint/src/runtime/impl/compilerImpl.ts @@ -32,11 +32,12 @@ import { ExecutionListener } from '../trace' import * as ir from '../../ir/quintIr' -import { RuntimeValue, rv } from './runtimeValue' +import { RuntimeValue, RuntimeValueLambda, RuntimeValueVariant, rv } from './runtimeValue' import { ErrorCode, QuintError } from '../../quintError' import { inputDefName, lastTraceName } from '../compile' import { unreachable } from '../../util' +import { chunk } from 'lodash' // Internal names in the compiler, which have special treatment. // For some reason, if we replace 'q::input' with inputDefName, everything breaks. @@ -696,6 +697,41 @@ export class CompilerVisitor implements IRVisitor { }) break + case 'variant': + // Construct a variant of a sum type. + this.applyFun(app.id, 2, (labelName, value) => just(rv.mkVariant(labelName.toStr(), value))) + break + + case 'matchVariant': + this.applyFun(app.id, app.args.length, (variantExpr, ...cases) => { + // Type checking ensures that this is a variant expression + assert(variantExpr instanceof RuntimeValueVariant, 'invalid value in match expression') + const label = variantExpr.label + const value = variantExpr.value + + // Find the eliminator marked with the variant's label + let result: Maybe | undefined + for (const [caseLabel, caseElim] of chunk(cases, 2)) { + const caseLabelStr = caseLabel.toStr() + if (caseLabelStr === '_') { + // The wilcard case ignores the value. + // NOTE: This SHOULD be a nullary lambda, but by this point the compiler + // has already converted it into a value. Confusing! + result = just(caseElim as RuntimeValueLambda) + } else if (caseLabelStr === label) { + // Type checking ensures the second item of each case is a lambda + const eliminator = caseElim as RuntimeValueLambda + result = eliminator.eval([just(value)]).map(r => r as RuntimeValue) + break + } + } + // Type checking ensures we have cases for every possible variant of a sum type. + assert(result, 'non-exhaustive match expression') + + return result + }) + break + case 'Set': // Construct a set from an array of values. this.applyFun(app.id, app.args.length, (...values: RuntimeValue[]) => just(rv.mkSet(values))) @@ -930,8 +966,6 @@ export class CompilerVisitor implements IRVisitor { break // builtin operators that are not handled by REPL - case 'variant': // TODO: https://github.com/informalsystems/quint/issues/1033 - case 'matchVariant': // TODO: https://github.com/informalsystems/quint/issues/1033 case 'orKeep': case 'mustChange': case 'weakFair': diff --git a/quint/src/runtime/impl/runtimeValue.ts b/quint/src/runtime/impl/runtimeValue.ts index 7c69bf4a6..a7ecfde25 100644 --- a/quint/src/runtime/impl/runtimeValue.ts +++ b/quint/src/runtime/impl/runtimeValue.ts @@ -137,6 +137,16 @@ export const rv = { return new RuntimeValueRecord(OrderedMap(elems).sortBy((_v, k) => k)) }, + /** + * Make a runtime value that represents a variant value of a sum type. + * + * @param label a string reperenting the variant's label + * @param value the value held by the variant + * @return a new runtime value that represents the variant + */ + mkVariant: (label: string, value: RuntimeValue): RuntimeValue => { + return new RuntimeValueVariant(label, value) + }, /** * Make a runtime value that represents a map. * @@ -582,6 +592,10 @@ abstract class RuntimeValueBase implements RuntimeValue { if (this instanceof RuntimeValueRecord && other instanceof RuntimeValueRecord) { return this.map.equals(other.map) } + if (this instanceof RuntimeValueVariant && other instanceof RuntimeValueVariant) { + return this.label === other.label && this.value.equals(other.value) + } + if (this instanceof RuntimeValueSet && other instanceof RuntimeValueSet) { return immutableIs(this.set, other.set) } @@ -811,6 +825,29 @@ class RuntimeValueRecord extends RuntimeValueBase implements RuntimeValue { } } +export class RuntimeValueVariant extends RuntimeValueBase implements RuntimeValue { + label: string + value: RuntimeValue + + constructor(label: string, value: RuntimeValue) { + super(false) // Not a "set-like" value + this.label = label + this.value = value + } + + hashCode() { + return hash(this.value) + this.value.hashCode() + } + + toQuintEx(gen: IdGenerator): QuintEx { + return { + id: gen.nextId(), + kind: 'app', + opcode: 'variant', + args: [{ id: gen.nextId(), kind: 'str', value: this.label }, this.value.toQuintEx(gen)], + } + } +} /** * A set of runtime values represented via an immutable Map. * This is an internal class. @@ -1490,7 +1527,7 @@ class RuntimeValueInfSet extends RuntimeValueBase implements RuntimeValue { * * RuntimeValueLambda cannot be compared with other values. */ -class RuntimeValueLambda extends RuntimeValueBase implements RuntimeValue, Callable { +export class RuntimeValueLambda extends RuntimeValueBase implements RuntimeValue, Callable { nparams: number callable: Callable diff --git a/quint/test/runtime/compile.test.ts b/quint/test/runtime/compile.test.ts index 4723c7579..28d9a2ebd 100644 --- a/quint/test/runtime/compile.test.ts +++ b/quint/test/runtime/compile.test.ts @@ -9,7 +9,7 @@ import { CompilationContext, CompilationState, compile, - compileDecl, + compileDecls, compileExpr, compileFromCode, contextNameLookup, @@ -29,8 +29,12 @@ const idGen = newIdGenerator() // Compile an expression, evaluate it, convert to QuintEx, then to a string, // compare the result. This is the easiest path to test the results. -function assertResultAsString(input: string, expected: string | undefined) { - const moduleText = `module __runtime { val ${inputDefName} = ${input} }` +// +// @param evalContext optional textual representation of context that may hold definitions which +// `input` depends on. This content will be wrapped in a module and imported unqualified +// before the input is evaluated. If not supplied, the context is empty. +function assertResultAsString(input: string, expected: string | undefined, evalContext: string = '') { + const moduleText = `module contextM { ${evalContext} } module __runtime { import contextM.*\n val ${inputDefName} = ${input} }` const mockLookupPath = stringSourceResolver(new Map()).lookupPath('/', './mock') const context = compileFromCode(idGen, moduleText, '__runtime', mockLookupPath, noExecutionListener, newRng().next) @@ -837,6 +841,26 @@ describe('compiling specs to runtime values', () => { }) }) + describe('compile over sum types', () => { + it('can compile construction of sum type variants', () => { + const context = 'type T = Some(int) | None' + assertResultAsString('Some(40 + 2)', 'variant("Some", 42)', context) + assertResultAsString('None', 'variant("None", Rec())', context) + }) + + it('can compile elimination of sum type variants via match', () => { + const context = 'type T = Some(int) | None' + assertResultAsString('match Some(40 + 2) { Some(x) => x | None => 0 }', '42', context) + assertResultAsString('match None { Some(x) => x | None => 0 }', '0', context) + }) + + it('can compile elimination of sum type variants via match using default', () => { + const context = 'type T = Some(int) | None' + // We can hit the fallback case + assertResultAsString('match None { Some(x) => x | _ => 3 }', '3', context) + }) + }) + describe('compile over maps', () => { it('mapBy constructor', () => { assertResultAsString('3.to(5).mapBy(i => 2 * i)', 'Map(Tup(3, 6), Tup(4, 8), Tup(5, 10))') @@ -1051,12 +1075,12 @@ describe('incremental compilation', () => { compilationState.idGen, compilationState.sourceMap ) - const def = parsed.kind === 'declaration' ? parsed.decl : undefined - const context = compileDecl(compilationState, evaluationState, dummyRng, def!) + const defs = parsed.kind === 'declaration' ? parsed.decls : undefined + const context = compileDecls(compilationState, evaluationState, dummyRng, defs!) - assert.deepEqual(context.compilationState.analysisOutput.types.get(def!.id)?.type, { kind: 'int', id: 3n }) + assert.deepEqual(context.compilationState.analysisOutput.types.get(defs![0].id)?.type, { kind: 'int', id: 3n }) - const computable = context.evaluationState?.context.get(kindName('callable', def!.id))! + const computable = context.evaluationState?.context.get(kindName('callable', defs![0].id))! assertComputableAsString(computable, '3') }) @@ -1072,8 +1096,8 @@ describe('incremental compilation', () => { compilationState.idGen, compilationState.sourceMap ) - const decl = parsed.kind === 'declaration' ? parsed.decl : undefined - const context = compileDecl(compilationState, evaluationState, dummyRng, decl!) + const decls = parsed.kind === 'declaration' ? parsed.decls : [] + const context = compileDecls(compilationState, evaluationState, dummyRng, decls) assert.sameDeepMembers(context.syntaxErrors, [ { @@ -1084,5 +1108,45 @@ describe('incremental compilation', () => { }, ]) }) + + it('can complile type alias declarations', () => { + const { compilationState, evaluationState } = compileModules('module m {}', 'm') + const parsed = parseExpressionOrDeclaration( + 'type T = int', + 'test.qnt', + compilationState.idGen, + compilationState.sourceMap + ) + const decls = parsed.kind === 'declaration' ? parsed.decls : [] + const context = compileDecls(compilationState, evaluationState, dummyRng, decls) + + const typeDecl = decls[0] + assert(typeDecl.kind === 'typedef') + assert(typeDecl.name === 'T') + assert(typeDecl.type!.kind === 'int') + + assert.sameDeepMembers(context.syntaxErrors, []) + }) + + it('can compile sum type declarations', () => { + const { compilationState, evaluationState } = compileModules('module m {}', 'm') + const parsed = parseExpressionOrDeclaration( + 'type T = A(int) | B(str) | C', + 'test.qnt', + compilationState.idGen, + compilationState.sourceMap + ) + const decls = parsed.kind === 'declaration' ? parsed.decls : [] + const context = compileDecls(compilationState, evaluationState, dummyRng, decls) + + assert(decls.find(t => t.kind === 'typedef' && t.name === 'T')) + // Sum type declarations are expanded to add an + // operator declaration for each constructor: + assert(decls.find(t => t.kind === 'def' && t.name === 'A')) + assert(decls.find(t => t.kind === 'def' && t.name === 'B')) + assert(decls.find(t => t.kind === 'def' && t.name === 'C')) + + assert.sameDeepMembers(context.syntaxErrors, []) + }) }) })