diff --git a/riscemu/core/registers.py b/riscemu/core/registers.py index 106f7bc..c0ebb42 100644 --- a/riscemu/core/registers.py +++ b/riscemu/core/registers.py @@ -5,11 +5,11 @@ """ from collections import defaultdict -from typing import Union, Type +from typing import Type from ..helpers import * -from . import Int32, BaseFloat, Float32, Float64 +from . import Int32, BaseFloat class Registers: @@ -222,14 +222,12 @@ def get(self, reg: str, mark_read: bool = True) -> Int32: self.last_read = reg return self.vals[reg] - def get_f(self, reg: str, mark_read: bool = True) -> BaseFloat: + def get_f(self, reg: str) -> BaseFloat: if not self.infinite_regs and reg not in self.float_regs: raise RuntimeError("Invalid float register: {}".format(reg)) - if mark_read: - self.last_read = reg return self.float_vals[reg] - def set_f(self, reg: str, val: Union[float, BaseFloat]): + def set_f(self, reg: str, val: BaseFloat): if not self.infinite_regs and reg not in self.float_regs: raise RuntimeError("Invalid float register: {}".format(reg)) self.float_vals[reg] = self._float_type.bitcast(val) diff --git a/riscemu/core/usermode_cpu.py b/riscemu/core/usermode_cpu.py index 8f1c83f..b792d9e 100644 --- a/riscemu/core/usermode_cpu.py +++ b/riscemu/core/usermode_cpu.py @@ -64,9 +64,7 @@ def step(self, verbose: bool = False): self.cycle += 1 ins = self.mmu.read_ins(self.pc) if verbose: - print( - FMT_CPU + " Running 0x{:08X}:{} {}".format(self.pc, FMT_NONE, ins) - ) + print(FMT_CPU + " 0x{:08X}:{} {}".format(self.pc, FMT_NONE, ins)) self.pc += self.INS_XLEN self.run_instruction(ins) except RiscemuBaseException as ex: diff --git a/snitch/__main__.py b/snitch/__main__.py index b0f28d5..4324b88 100644 --- a/snitch/__main__.py +++ b/snitch/__main__.py @@ -7,19 +7,19 @@ """ import sys -from .regs import StreamingRegs -from .xssr import RV32_Xssr_pseudo +from .xssr import Xssr_pseudo +from .frep import FrepEnabledCpu, Xfrep from riscemu.riscemu_main import RiscemuMain class SnitchMain(RiscemuMain): - def configure_cpu(self): - super().configure_cpu() - self.cpu.regs = StreamingRegs(self.cpu.mmu) + def instantiate_cpu(self): + self.cpu = FrepEnabledCpu(self.selected_ins_sets, self.cfg) + self.configure_cpu() def register_all_isas(self): super().register_all_isas() - self.available_ins_sets.update({"Xssr": RV32_Xssr_pseudo}) + self.available_ins_sets.update({"Xssr": Xssr_pseudo, "Xfrep": Xfrep}) if __name__ == "__main__": diff --git a/snitch/frep.py b/snitch/frep.py new file mode 100644 index 0000000..c8167b3 --- /dev/null +++ b/snitch/frep.py @@ -0,0 +1,103 @@ +from typing import List, Type, Union, Set, Literal + +from riscemu.colors import FMT_CPU, FMT_NONE +from riscemu.config import RunConfig +from riscemu.core import UserModeCPU +from riscemu.instructions import InstructionSet, Instruction, RV32F, RV32D + +from dataclasses import dataclass + +from snitch.regs import StreamingRegs + + +@dataclass(frozen=True) +class FrepState: + rep_count: int + ins_count: int + mode: Literal["inner", "outer"] + + +class FrepEnabledCpu(UserModeCPU): + repeat: Union[FrepState, None] + allowed_ins: Set[str] + + def __init__(self, instruction_sets: List[Type["InstructionSet"]], conf: RunConfig): + super().__init__(instruction_sets, conf) + self.regs = StreamingRegs( + mem=self.mmu, infinite_regs=conf.unlimited_registers, flen=conf.flen + ) + self.repeats = None + # only floating point instructions are allowed inside an frep! + self.allowed_ins = set(x for x, y in RV32F(self).get_instructions()) + if conf.flen > 32: + self.allowed_ins.union(x for x, y in RV32D(self).get_instructions()) + + def step(self, verbose: bool = False): + if self.repeats is None: + super().step(verbose=verbose) + return + # get the spec + spec: FrepState = self.repeats + self.repeats = None + + instructions = [ + self.mmu.read_ins(self.pc + i * self.INS_XLEN) + for i in range(spec.ins_count) + ] + + for ins in instructions: + if ins.name not in self.allowed_ins: + # TODO: wrap in a nicer error type + raise RuntimeError( + "Forbidden instruction inside frep loop: {}".format(ins) + ) + + if verbose: + print( + FMT_CPU + + "┌────── floating point repetition ({}) {} times".format( + spec.mode, spec.rep_count + 1 + ) + ) + for i, ins in enumerate(instructions): + print( + FMT_CPU + + "│ 0x{:08X}:{} {}".format( + self.pc + i * self.INS_XLEN, FMT_NONE, ins + ) + ) + print(FMT_CPU + "└────── end of floating point repetition" + FMT_NONE) + + pc = self.pc + if spec.mode == "outer": + for _ in range(spec.rep_count + 1): + for ins in instructions: + self.run_instruction(ins) + elif spec.mode == "inner": + for ins in instructions: + for _ in range(spec.rep_count + 1): + self.run_instruction(ins) + else: + raise RuntimeError(f"Unknown frep mode: {spec.mode}") + self.cycle += (spec.rep_count + 1) * spec.ins_count + self.pc = pc + (spec.ins_count * self.INS_XLEN) + + +class Xfrep(InstructionSet): + def instruction_frep_o(self, ins: Instruction): + self.frep(ins, "outer") + + def instruction_frep_i(self, ins: Instruction): + self.frep(ins, "inner") + + def frep(self, ins: Instruction, mode: Literal["inner", "outer"]): + assert isinstance(self.cpu, FrepEnabledCpu) + assert len(ins.args) == 4 + assert ins.get_imm(2).abs_value.value == 0, "staggering not supported yet" + assert ins.get_imm(3).abs_value.value == 0, "staggering not supported yet" + + self.cpu.repeats = FrepState( + rep_count=self.regs.get(ins.get_reg(0)).unsigned_value, + ins_count=ins.get_imm(1).abs_value.value, + mode=mode, + ) diff --git a/snitch/regs.py b/snitch/regs.py index e8311c6..ec2cb0e 100644 --- a/snitch/regs.py +++ b/snitch/regs.py @@ -52,6 +52,7 @@ def __init__( mem: MMU, xssr_regs: Tuple[str] = ("ft0", "ft1", "ft2"), infinite_regs: bool = False, + flen: int = 64, ): self.mem = mem self.enabled = False @@ -61,11 +62,11 @@ def __init__( stream_def = StreamDef() self.dm_by_id.append(stream_def) self.streams[reg] = stream_def - super().__init__(infinite_regs) + super().__init__(infinite_regs=infinite_regs, flen=flen) - def get_f(self, reg, mark_read=True) -> "BaseFloat": + def get_f(self, reg) -> "BaseFloat": if not self.enabled or reg not in self.streams: - return super().get_f(reg, mark_read) + return super().get_f(reg) # do the streaming stuff: stream = self.streams[reg] @@ -74,26 +75,21 @@ def get_f(self, reg, mark_read=True) -> "BaseFloat": # TODO: Check overflow # TODO: repetition addr = stream.base + (stream.pos * stream.stride) - val = self.mem.read_float(addr) + val = self._float_type(self.mem.read(addr, self.flen // 8)) # increment pos - print( - "stream: got val {} from addr 0x{:x}, stream {}".format(val, addr, stream) - ) stream.pos += 1 return val - def set_f(self, reg, val: "BaseFloat", mark_set=True) -> bool: + def set_f(self, reg, val: "BaseFloat") -> bool: if not self.enabled or reg not in self.streams: - return super().set_f(reg, mark_set) + return super().set_f(reg, val) stream = self.streams[reg] assert stream.mode is StreamMode.WRITE addr = stream.base + (stream.pos * stream.stride) - self.mem.write(addr, 4, bytearray(val.bytes)) + data = val.bytes + self.mem.write(addr + (self.flen // 8) - len(data), len(data), bytearray(data)) - print( - "stream: wrote val {} into addr 0x{:x}, stream {}".format(val, addr, stream) - ) stream.pos += 1 return True diff --git a/snitch/xssr.py b/snitch/xssr.py index 1c90d1a..982af51 100644 --- a/snitch/xssr.py +++ b/snitch/xssr.py @@ -3,7 +3,7 @@ from .regs import StreamingRegs, StreamDef, StreamMode -class RV32_Xssr_pseudo(InstructionSet): +class Xssr_pseudo(InstructionSet): def instruction_ssr_enable(self, ins: Instruction): self._stream.enabled = True diff --git a/test/filecheck/snitch/frep_only.asm b/test/filecheck/snitch/frep_only.asm new file mode 100644 index 0000000..106af7f --- /dev/null +++ b/test/filecheck/snitch/frep_only.asm @@ -0,0 +1,24 @@ +// RUN: python3 -m snitch %s -o libc -v | filecheck %s + +.text +.globl main +main: + // load constants + li t0, 0 + fcvt.s.w ft0, t0 + li t0, 1 + fcvt.s.w ft1, t0 + + // repeat 100 times + li t0, 99 + frep.i t0, 1, 0, 0 + fadd.s ft0, ft0, ft1 // add one + + // print result to stdout + printf "100 * 1 = {:f32}", ft0 +// CHECK: 100 * 1 = 100.0 + // return 0 + li a0, 0 + ret + +// CHECK-NEXT: [CPU] Program exited with code 0 diff --git a/test/filecheck/snitch/ssr_frep.asm b/test/filecheck/snitch/ssr_frep.asm new file mode 100644 index 0000000..d266c5b --- /dev/null +++ b/test/filecheck/snitch/ssr_frep.asm @@ -0,0 +1,75 @@ +// RUN: python3 -m snitch %s -o libc -v --flen 32 | filecheck %s + +.data + +vec0: +.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000 +vec1: +.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000 +dest: +.space 40 +expected: +.word 0x0, 0x3e800000, 0x3f800000, 0x40100000, 0x40800000, 0x40c80000, 0x41100000, 0x41440000, 0x41800000, 0x41a20000 + +.text +.globl main + +main: + // ssr config + ssr.configure 0, 10, 4 + ssr.configure 1, 10, 4 + ssr.configure 2, 10, 4 + + // ft0 streams from vec0 + la a0, vec0 + ssr.read a0, 0, 0 + + // ft1 streams from vec1 + la a0, vec1 + ssr.read a0, 1, 0 + + // ft2 streams to dest + la a0, dest + ssr.write a0, 2, 0 + + li a0, 9 + // some constant to divide by + li t0, 4 + fcvt.s.w ft3, t0 + ssr.enable + + frep.o a0, 2, 0, 0 + fmul.s ft4, ft0, ft1 // ft3 = vec0[i] * vec1[i] + fdiv.s ft2, ft4, ft3 // dest[i] = ft3 / 4 + + // stop ssr + ssr.disable + + // check values were written correctly: + la t0, dest + la t1, expected + li a0, 36 +loop: + add s0, t0, a0 + add s1, t1, a0 + + // load vec0, vec1 and dest elements + flw ft0, 0(s0) + flw ft1, 0(s1) + + // assert ft0 == ft1 (expected[i] == dest[i]) + feq.s s0, ft0, ft1 + beq zero, s0, fail + + addi a0, a0, -4 + bge a0, zero loop + + li a0, 0 + ret + +fail: + printf "Assertion failure: {} != {} (at {})", ft0, ft1, a0 + li a0, -1 + ret + +// CHECK: [CPU] Program exited with code 0 diff --git a/test/filecheck/snitch/ssr_only.asm b/test/filecheck/snitch/ssr_only.asm new file mode 100644 index 0000000..d10d762 --- /dev/null +++ b/test/filecheck/snitch/ssr_only.asm @@ -0,0 +1,73 @@ +// RUN: python3 -m snitch %s -o libc -v --flen 32 | filecheck %s + +.data + +vec0: +.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000 +vec1: +.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000 +dest: +.space 40 + +.text +.globl main + +main: + // ssr config + ssr.configure 0, 10, 4 + ssr.configure 1, 10, 4 + ssr.configure 2, 10, 4 + + la a0, vec0 + ssr.read a0, 0, 0 + + la a0, vec1 + ssr.read a0, 1, 0 + + la a0, dest + ssr.write a0, 2, 0 + + ssr.enable + + // set up loop + li a0, 10 +loop: + fadd.s ft2, ft0, ft1 + + addi a0, a0, -1 + bne a0, zero, loop + + // end of loop: + ssr.disable + + // check values were written correctly: + la t0, vec0 + la t1, vec1 + la t2, dest + li a0, 36 +loop2: + add s0, t0, a0 + add s1, t1, a0 + add s2, t2, a0 + + // load vec0, vec1 and dest elements + flw ft0, 0(s0) + flw ft1, 0(s1) + flw ft2, 0(s2) + + // assert ft2 == ft1 + ft2 + fadd.s ft3, ft1, ft0 + feq.s s0, ft2, ft3 + beq zero, s0, fail + + addi a0, a0, -4 + bne a0, zero, loop2 + + ret + +fail: + printf "failed {} + {} != {} (at {})", ft0, ft1, ft2, a0 + li a0, -1 + ret + +// CHECK: [CPU] Program exited with code 0 diff --git a/test/filecheck/snitch_simple.asm b/test/filecheck/snitch_simple.asm deleted file mode 100644 index 3cdab19..0000000 --- a/test/filecheck/snitch_simple.asm +++ /dev/null @@ -1,74 +0,0 @@ -// RUN: python3 -m snitch %s -o libc | filecheck %s - -.data - -vec0: -.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000 -vec1: -.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000 -dest: -.space 40 - -.text -.globl main - -main: - // ssr config - ssr.configure 0, 10, 4 - ssr.configure 1, 10, 4 - ssr.configure 2, 10, 4 - - la a0, vec0 - ssr.read a0, 0, 0 - - la a0, vec1 - ssr.read a0, 1, 0 - - la a0, dest - ssr.write a0, 2, 0 - - ssr.enable - - // set up loop - li a0, 100 -loop: - fadd.s ft2, ft0, ft1 - - addi a0, a0, -1 - bne a0, zero, loop - - // end of loop: - ssr.disable - - ret - -//CHECK: stream: got val 0.0 from addr 0x80148, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=0) -//CHECK_NEXT: stream: got val 0.0 from addr 0x80170, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=0) -//CHECK_NEXT: stream: wrote val 0.0 into addr 0x80198, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=0) -//CHECK_NEXT: stream: got val 1.0 from addr 0x8014c, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=1) -//CHECK_NEXT: stream: got val 1.0 from addr 0x80174, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=1) -//CHECK_NEXT: stream: wrote val 2.0 into addr 0x8019c, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=1) -//CHECK_NEXT: stream: got val 2.0 from addr 0x80150, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=2) -//CHECK_NEXT: stream: got val 2.0 from addr 0x80178, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=2) -//CHECK_NEXT: stream: wrote val 4.0 into addr 0x801a0, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=2) -//CHECK_NEXT: stream: got val 3.0 from addr 0x80154, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=3) -//CHECK_NEXT: stream: got val 3.0 from addr 0x8017c, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=3) -//CHECK_NEXT: stream: wrote val 6.0 into addr 0x801a4, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=3) -//CHECK_NEXT: stream: got val 4.0 from addr 0x80158, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=4) -//CHECK_NEXT: stream: got val 4.0 from addr 0x80180, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=4) -//CHECK_NEXT: stream: wrote val 8.0 into addr 0x801a8, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=4) -//CHECK_NEXT: stream: got val 5.0 from addr 0x8015c, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=5) -//CHECK_NEXT: stream: got val 5.0 from addr 0x80184, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=5) -//CHECK_NEXT: stream: wrote val 10.0 into addr 0x801ac, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=5) -//CHECK_NEXT: stream: got val 6.0 from addr 0x80160, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=6) -//CHECK_NEXT: stream: got val 6.0 from addr 0x80188, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=6) -//CHECK_NEXT: stream: wrote val 12.0 into addr 0x801b0, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=6) -//CHECK_NEXT: stream: got val 7.0 from addr 0x80164, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=7) -//CHECK_NEXT: stream: got val 7.0 from addr 0x8018c, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=7) -//CHECK_NEXT: stream: wrote val 14.0 into addr 0x801b4, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=7) -//CHECK_NEXT: stream: got val 8.0 from addr 0x80168, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=8) -//CHECK_NEXT: stream: got val 8.0 from addr 0x80190, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=8) -//CHECK_NEXT: stream: wrote val 16.0 into addr 0x801b8, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=8) -//CHECK_NEXT: stream: got val 9.0 from addr 0x8016c, stream StreamDef(base=524616, bound=10, stride=4, mode=, dim=0, pos=9) -//CHECK_NEXT: stream: got val 9.0 from addr 0x80194, stream StreamDef(base=524656, bound=10, stride=4, mode=, dim=0, pos=9) -//CHECK_NEXT: stream: wrote val 18.0 into addr 0x801bc, stream StreamDef(base=524696, bound=10, stride=4, mode=, dim=0, pos=9)