Skip to content

Commit

Permalink
make progress on issue #295. banking analyzer seems to properly bank …
Browse files Browse the repository at this point in the history
…for dual ported readers, chisel template gives mostly the right answer but something is slightly off
  • Loading branch information
mattfel1 committed Feb 21, 2020
1 parent 71026b9 commit e8f1769
Show file tree
Hide file tree
Showing 15 changed files with 471 additions and 81 deletions.
226 changes: 197 additions & 29 deletions fringe/src/fringe/templates/memory/MemPrimitives.scala

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions fringe/src/fringe/templates/memory/MemType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fringe.templates.memory

sealed trait MemType
object BankedSRAMType extends MemType
object BankedSRAMDualReadType extends MemType
object FFType extends MemType
object FIFOType extends MemType
object LIFOType extends MemType
Expand Down
25 changes: 25 additions & 0 deletions fringe/src/fringe/templates/memory/NBuffers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,31 @@ class NBufMem(np: NBufParams) extends Module {
io.rPort(i).output.zipWithIndex.foreach{case (r, j) => r := chisel3.util.Mux1H(outSel, srams.map{f => f.io.rPort(i).output(j)})}
}

case BankedSRAMDualReadType =>
val srams = (0 until np.numBufs).map{ i =>
val x = Module(new BankedSRAM(np.p.Ds, np.p.bitWidth,
np.p.Ns, np.p.Bs, np.p.Ps,
np.p.WMapping, np.p.RMapping,
np.p.bankingMode, np.p.inits, np.p.syncMem, np.p.fracBits, np.p.numActives, "SRAM"))
x.io <> DontCare
x
}
// Route NBuf IO to SRAM IOs
srams.zipWithIndex.foreach{ case (f,i) =>
np.p.WMapping.zipWithIndex.foreach { case (a,j) =>
val wMask = if (a.port.bufPort.isDefined) {ctrl.io.statesInW(ctrl.lookup(a.port.bufPort.get)) === i.U} else true.B
f.connectBufW(io.wPort(j), j, wMask)
}
np.p.RMapping.zipWithIndex.foreach { case (a,j) =>
val rMask = if (a.port.bufPort.isDefined) {ctrl.io.statesInR(a.port.bufPort.get) === i.U} else true.B
f.connectBufR(io.rPort(j), j, rMask)
}
}

np.p.RMapping.zipWithIndex.collect{case (p, i) if (p.port.bufPort.isDefined) =>
val outSel = (0 until np.numBufs).map{ a => ctrl.io.statesInR(p.port.bufPort.get) === a.U }
io.rPort(i).output.zipWithIndex.foreach{case (r, j) => r := chisel3.util.Mux1H(outSel, srams.map{f => f.io.rPort(i).output(j)})}
}

case FFType =>
val ffs = (0 until np.numBufs).map{ i =>
Expand Down
77 changes: 76 additions & 1 deletion fringe/src/fringe/templates/memory/SRAM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,34 @@ class SRAMVerilogIO[T<:Data](t: T, d: Int) extends Bundle {

}

class SRAMVerilogDualReadIO[T<:Data](t: T, d: Int) extends Bundle {
val clk = Input(Clock())
val raddr0 = Input(UInt({1 max log2Ceil(d)}.W))
val raddr1 = Input(UInt({1 max log2Ceil(d)}.W))
val waddr = Input(UInt({1 max log2Ceil(d)}.W))
val raddrEn0 = Input(Bool())
val raddrEn1 = Input(Bool())
val waddrEn = Input(Bool())
val wen = Input(Bool())
val backpressure = Input(Bool())
val wdata = Input(UInt(t.getWidth.W))
val rdata0 = Output(UInt(t.getWidth.W))
val rdata1 = Output(UInt(t.getWidth.W))

override def cloneType = (new SRAMVerilogDualReadIO(t, d)).asInstanceOf[this.type] // See chisel3 bug 358

}

abstract class SRAMBlackBox[T<:Data](params: Map[String,Param]) extends BlackBox(params) {
val io: SRAMVerilogIO[T]
}


class SRAMVerilogDualRead[T<:Data](val t: T, val d: Int) extends BlackBox(
Map("DWIDTH" -> IntParam(t.getWidth), "WORDS" -> IntParam(d), "AWIDTH" -> IntParam({1 max log2Ceil(d)})))
{
override val io = IO(new SRAMVerilogDualReadIO(t, d))
}

class SRAMVerilogSim[T<:Data](val t: T, val d: Int) extends BlackBox(
Map("DWIDTH" -> IntParam(t.getWidth), "WORDS" -> IntParam(d), "AWIDTH" -> IntParam({1 max log2Ceil(d)})))
Expand Down Expand Up @@ -76,11 +99,30 @@ class GenericRAMIO[T<:Data](t: T, d: Int) extends Bundle {
}
}

class GenericRAMDualReadIO[T<:Data](t: T, d: Int) extends Bundle {
val addrWidth = {1 max log2Ceil(d)}
val raddr0 = Input(UInt(addrWidth.W))
val raddr1 = Input(UInt(addrWidth.W))
val wen = Input(Bool())
val waddr = Input(UInt(addrWidth.W))
val wdata = Input(t.cloneType)
val rdata0 = Output(t.cloneType)
val rdata1 = Output(t.cloneType)
val backpressure = Input(Bool())

override def cloneType: this.type = {
new GenericRAMDualReadIO(t, d).asInstanceOf[this.type]
}
}

abstract class GenericRAM[T<:Data](val t: T, val d: Int) extends Module {
val addrWidth = {1 max log2Ceil(d)}
val io: GenericRAMIO[T]
}

abstract class GenericRAMDualRead[T<:Data](val t: T, val d: Int) extends Module {
val addrWidth = {1 max log2Ceil(d)}
val io: GenericRAMDualReadIO[T]
}
class FFRAM[T<:Data](override val t: T, override val d: Int) extends GenericRAM(t, d) {
class FFRAMIO[T<:Data](t: T, d: Int) extends GenericRAMIO(t, d) {
class Bank[T<:Data](t: T, d: Int) extends Bundle {
Expand Down Expand Up @@ -154,3 +196,36 @@ class SRAM[T<:Data](override val t: T, override val d: Int, val resourceType: St
}
}


class SRAMDualRead[T<:Data](override val t: T, override val d: Int, val resourceType: String) extends GenericRAMDualRead(t, d) {
val io = IO(new GenericRAMDualReadIO(t, d))

// Customize SRAM here
// TODO: Still needs some cleanup
globals.target match {
// case _:AWS_F1 | _:Zynq | _:ZCU | _:Arria10 | _:KCU1500 | _:CXP =>
case _ =>
val mem = Module(new SRAMVerilogDualRead(t, d))

mem.io.clk := clock
mem.io.raddr0 := io.raddr0
mem.io.raddr1 := io.raddr1
mem.io.wen := io.wen
mem.io.waddr := io.waddr
mem.io.wdata := io.wdata.asUInt()
mem.io.backpressure := io.backpressure
mem.io.raddrEn0 := true.B
mem.io.raddrEn1 := true.B
mem.io.waddrEn := true.B

// Implement WRITE_FIRST logic here
// equality register
val equalReg0 = RegNext(io.wen & (io.raddr0 === io.waddr), false.B)
val equalReg1 = RegNext(io.wen & (io.raddr1 === io.waddr), false.B)
val wdataReg = RegNext(io.wdata.asUInt, 0.U)
io.rdata0 := Mux(equalReg0, wdataReg.asUInt, mem.io.rdata0).asTypeOf(t)
io.rdata1 := Mux(equalReg1, wdataReg.asUInt, mem.io.rdata1).asTypeOf(t)

}
}

20 changes: 10 additions & 10 deletions poly/src/poly/SparseMatrix.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@ package poly

import utils.implicits.collections._

case class SparseMatrix[K](rows: Seq[SparseVector[K]]) {
case class SparseMatrix[K](rows: Seq[SparseVector[K]], isReader: Boolean = false) {
def keys: Set[K] = rows.map(_.cols.keySet).fold(Set.empty)(_++_)
def replaceKeys(keySwap: Map[K,(K,Int)]): SparseMatrix[K] = {
val rows2 = rows.map{r =>
val cols2 = r.cols.map{case (k,v) => (keySwap.getOrElse(k,(k,0))._1 -> v)}
val offset = r.cols.collect{case (k,_) if (keySwap.contains(k)) => keySwap(k)._2}.sum
SparseVector[K](cols2, r.c + offset, r.allIters)
}
SparseMatrix[K](rows2)
SparseMatrix[K](rows2, isReader)
}

def prependBlankRow: SparseMatrix[K] = {
val rows2 = Seq(SparseVector[K](Map[K,Int](), 0, Map[K,Seq[K]]())) ++ rows
SparseMatrix[K](rows2)
SparseMatrix[K](rows2, isReader)
}

def sliceDims(dims: Seq[Int]): SparseMatrix[K] = SparseMatrix[K](dims.map{i => rows(i) })
def sliceDims(dims: Seq[Int]): SparseMatrix[K] = SparseMatrix[K](dims.map{i => rows(i) }, isReader)

def map(f: SparseVector[K] => SparseVector[K]): SparseMatrix[K] = SparseMatrix[K](rows.map(f))
def map(f: SparseVector[K] => SparseVector[K]): SparseMatrix[K] = SparseMatrix[K](rows.map(f), isReader)
def zip(that: SparseMatrix[K])(func: (Int,Int) => Int): SparseMatrix[K] = {
val rows2 = this.rows.zip(that.rows).map{case (v1,v2) => v1.zip(v2)(func) }
SparseMatrix[K](rows2)
SparseMatrix[K](rows2, isReader)
}
def unary_-(): SparseMatrix[K] = this.map{row => -row}
def +(that: SparseMatrix[K]): SparseMatrix[K] = this.zip(that){_+_}
Expand All @@ -35,13 +35,13 @@ case class SparseMatrix[K](rows: Seq[SparseVector[K]]) {
val stepsize = r.cols.collect{case (k,v) if (k == key) => v}.headOption.getOrElse(0)
SparseVector[K](r.cols, r.c + value * stepsize, r.allIters)
}
SparseMatrix[K](rows2)
SparseMatrix[K](rows2, isReader)
}
def incrementConst(value: Int): SparseMatrix[K] = {
val rows2 = this.rows.map{r =>
SparseVector[K](r.cols, r.c + value, r.allIters)
}
SparseMatrix[K](rows2)
SparseMatrix[K](rows2, isReader)
}
private def combs(lol: List[List[SparseVector[K]]]): List[List[SparseVector[K]]] = lol match {
case Nil => List(Nil)
Expand All @@ -67,7 +67,7 @@ case class SparseMatrix[K](rows: Seq[SparseVector[K]]) {
}
}

combs(rowOptions.toList).map{sm => SparseMatrix[K](sm)}
combs(rowOptions.toList).map{sm => SparseMatrix[K](sm, isReader)}
}
def asConstraintEqlZero = ConstraintMatrix(rows.map(_.asConstraintEqlZero).toSet)
def asConstraintGeqZero = ConstraintMatrix(rows.map(_.asConstraintGeqZero).toSet)
Expand All @@ -81,7 +81,7 @@ case class SparseMatrix[K](rows: Seq[SparseVector[K]]) {
val rowStrs = rows.map{row => header.map{k => row(k).toString } :+ row.c.toString :+ row.mod.toString}
val entries = (header.map(_.toString) :+ "c" :+ "mod") +: rowStrs
val maxCol = entries.flatMap(_.map(_.length)).maxOrElse(0)
entries.map{row => row.map{x => " "*(maxCol - x.length + 1) + x }.mkString(" ") }.mkString("\n")
entries.map{row => row.map{x => " "*(maxCol - x.length + 1) + x }.mkString(" ") }.mkString("\n") + {if (isReader) "rd" else "wr"}
}
}
object SparseMatrix {
Expand Down
5 changes: 3 additions & 2 deletions src/spatial/codegen/chiselgen/ChiselGenMem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ trait ChiselGenMem extends ChiselGenCommon {

if (!mem.isNBuffered && name == "LineBuffer") throw new Exception(s"Cannot create non-buffered line buffer! Make sure $mem has readers and writers bound by a pipelined LCA, or else turn it into an SRAM")

val templateName = if (!mem.isNBuffered && name != "LineBuffer") s"$name("
val dualsfx = if (mem.isDualPortedRead) "DualRead" else ""
val templateName = if (!mem.isNBuffered && name != "LineBuffer") s"$name$dualsfx("
else {
if (name == "SRAM") appPropertyStats += HasNBufSRAM
mem.swappers.zipWithIndex.foreach{case (node, port) =>
bufMapping += (node -> {bufMapping.getOrElse(node, List[BufMapping]()) ++ List(BufMapping(mem, port))})
}
s"NBufMem(${name}Type, "
s"NBufMem(${name}${dualsfx}Type, "
}
if (mem.broadcastsAnyRead) appPropertyStats += HasBroadcastRead

Expand Down
3 changes: 3 additions & 0 deletions src/spatial/lang/SRAM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ abstract class SRAM[A:Bits,C[T]](implicit val evMem: C[A] <:< SRAM[A,C]) extends
*/
def mustmerge: C[A] = { this.isMustMerge = true; me }

def dualportedread: C[A] = { this.isDualPortedRead = true; me}
def dualportedwrite: C[A] = { throw new Exception(s"Memories with Dual Write Ports are currently not supported. They can be implemented pretty easily, but we have not needed them yet.")}

def nohierarchical: C[A] = {throw new Exception(s".nohierarchical has been deprecated. Please use .flat instead")}
def noflat: C[A] = {throw new Exception(s".noflat has been deprecated. Please use .hierarchical instead")}
def nobank: C[A] = {throw new Exception(s".nobank has been deprecated. Please use .fullfission instead")}
Expand Down
9 changes: 5 additions & 4 deletions src/spatial/metadata/access/AffineData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ import emul.ResidualGenerator._
case class AccessMatrix(
access: Sym[_],
matrix: SparseMatrix[Idx],
unroll: Seq[Int]
unroll: Seq[Int],
isReader: scala.Boolean = false
) {
def keys: Set[Idx] = matrix.keys
@stateful def parent: Ctrl = access.parent
def substituteKeys(keySwap: Map[Idx,(Idx,Int)]): AccessMatrix = {
keySwap.foreach{case (old, (swp,_)) => swp.domain = old.domain.replaceKeys(keySwap)}
val matrix2 = matrix.replaceKeys(keySwap)
AccessMatrix(access, matrix2, unroll)
AccessMatrix(access, matrix2, unroll, this.isReader)
}

/** True if there exists a reachable multi-dimensional index I such that a(I) = b(I).
Expand All @@ -47,7 +48,7 @@ case class AccessMatrix(
def intersects(b: AccessMatrix)(implicit isl: ISL): Boolean = isl.intersects(this.matrix, b.matrix)

override def toString: String = {
stm(access) + " {" + unroll.mkString(",") + "}\n" + matrix.toString
stm(access) + " " + {if (isReader) "rd" else "wr"} + "{" + unroll.mkString(",") + "}\n" + matrix.toString
}

def short: String = s"$access {${unroll.mkString(",")}}"
Expand All @@ -65,7 +66,7 @@ case class AccessMatrix(
val c2 = r.c * alpha
SparseVector[Idx](vec, c2, r.allIters)
}
SparseMatrix[Idx](rows2)
SparseMatrix[Idx](rows2, b.isReader)
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/spatial/metadata/memory/BankingData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,14 @@ case class NoFlatBank(flag: Boolean) extends Data[NoFlatBank](SetBy.User)
*/
case class MustMerge(flag: Boolean) extends Data[MustMerge](SetBy.User)

/** Flag set by the user to let compiler assume memory has dual read ports.
*
* Getter: sym.isDualPortedRead
* Setter: sym.isDualPortedRead = (true | false)
* Default: false
*/
case class DualPortedRead(flag: Boolean) extends Data[DualPortedRead](SetBy.User)

/** Flag set by the user to ensure an SRAM will merge the buffers, in cases
where you have metapipelined access such as pre-load, accumulate, store.
*
Expand Down
3 changes: 3 additions & 0 deletions src/spatial/metadata/memory/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ package object memory {
def isMustMerge: Boolean = metadata[MustMerge](s).exists(_.flag)
def isMustMerge_=(flag: Boolean): Unit = metadata.add(s, MustMerge(flag))

def isDualPortedRead: Boolean = metadata[DualPortedRead](s).exists(_.flag)
def isDualPortedRead_=(flag: Boolean): Unit = metadata.add(s, DualPortedRead(flag))

def isFullFission: Boolean = metadata[OnlyDuplicate](s).exists(_.flag)
def isFullFission_=(flag: Boolean): Unit = metadata.add(s, OnlyDuplicate(flag))

Expand Down
4 changes: 2 additions & 2 deletions src/spatial/traversal/AccessAnalyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ case class AccessAnalyzer(IR: State) extends Traversal with AccessExpansion {
dbgs(s" Access pattern: ")
pattern.zipWithIndex.foreach{case (p,d) => dbgs(s" [$d] $p") }

val matrices = getUnrolledMatrices(mem,access,addr,pattern,Nil)
val matrices = getUnrolledMatrices(mem,access,addr,pattern,Nil,access.isReader)
access.accessPattern = pattern
access.affineMatrices = matrices

Expand Down Expand Up @@ -252,7 +252,7 @@ case class AccessAnalyzer(IR: State) extends Traversal with AccessExpansion {
dbgs(s" Access pattern: ")
pattern.zipWithIndex.foreach{case (p,d) => dbgs(s" [$d] $p") }

val matrices = pattern.flatMap{p => getUnrolledMatrices(mem, access, Nil, Seq(p), Nil)}
val matrices = pattern.flatMap{p => getUnrolledMatrices(mem, access, Nil, Seq(p), Nil, access.isReader)}
access.accessPattern = pattern
access.affineMatrices = matrices

Expand Down
11 changes: 6 additions & 5 deletions src/spatial/traversal/AccessExpansion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ trait AccessExpansion {
def domain(x: Idx): ConstraintMatrix[Idx] = getOrAddDomain(x)


def getAccessCompactMatrix(access: Sym[_], addr: Seq[Idx], pattern: Seq[AddressPattern]): SparseMatrix[Idx] = {
def getAccessCompactMatrix(access: Sym[_], addr: Seq[Idx], pattern: Seq[AddressPattern], isReader: Boolean = false): SparseMatrix[Idx] = {
val rows = pattern.zipWithIndex.map{case (ap,d) =>
ap.toSparseVector{() => addr.indexOrElse(d, nextRand()) }
}
val matrix = SparseMatrix[Idx](rows)
val matrix = SparseMatrix[Idx](rows, isReader)
matrix.keys.foreach{x => getOrAddDomain(x) }
matrix
}
Expand All @@ -68,7 +68,8 @@ trait AccessExpansion {
access: Sym[_],
addr: Seq[Idx],
pattern: Seq[AddressPattern],
vecID: Seq[Int] = Nil
vecID: Seq[Int] = Nil,
isReader: scala.Boolean = false
): Seq[AccessMatrix] = {
val is = accessIterators(access, mem)
val ps = is.map(_.ctrParOr1)
Expand All @@ -77,7 +78,7 @@ trait AccessExpansion {
dbgs(" Iterators: " + is.indices.map{i => s"${is(i)} (par: ${ps(i)}, start: ${starts(i)})"}.mkString(", "))

val iMap = is.zipWithIndex.toMap
val matrix = getAccessCompactMatrix(access, addr, pattern)
val matrix = getAccessCompactMatrix(access, addr, pattern, isReader)

multiLoop(ps).map{uid: Seq[Int] =>
val mat = matrix.map{vec: SparseVector[Idx] =>
Expand Down Expand Up @@ -119,7 +120,7 @@ trait AccessExpansion {
val uI = components.collectAsMap{case (_,x,_,i) if !i.contains(x) => (x, i): (Idx, Seq[Idx]) }
SparseVector[Idx](xs.zip(as).toMap, c, uI)
}
val amat = AccessMatrix(access, mat, uid ++ vecID)
val amat = AccessMatrix(access, mat, uid ++ vecID, isReader)
amat.keys.foreach{x => getOrAddDomain(x) }
amat
}.toSeq
Expand Down
Loading

0 comments on commit e8f1769

Please sign in to comment.