Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Scala 3.4.0 #76

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ jobs:
runs-on: ${{ matrix.os }}
timeout-minutes: 60
steps:
- name: Install sbt
if: contains(runner.os, 'macos')
shell: bash
run: brew install sbt

- name: Ignore line ending differences in git
if: contains(runner.os, 'windows')
shell: bash
Expand Down Expand Up @@ -107,6 +112,10 @@ jobs:
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
- name: Install sbt
if: contains(runner.os, 'macos')
run: brew install sbt

- name: Ignore line ending differences in git
if: contains(runner.os, 'windows')
run: git config --global core.autocrlf false
Expand Down Expand Up @@ -172,6 +181,10 @@ jobs:
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
- name: Install sbt
if: contains(runner.os, 'macos')
run: brew install sbt

- name: Ignore line ending differences in git
if: contains(runner.os, 'windows')
run: git config --global core.autocrlf false
Expand Down Expand Up @@ -208,6 +221,10 @@ jobs:
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
- name: Install sbt
if: contains(runner.os, 'macos')
run: brew install sbt

- name: Ignore line ending differences in git
if: contains(runner.os, 'windows')
run: git config --global core.autocrlf false
Expand Down Expand Up @@ -235,7 +252,7 @@ jobs:

- name: Publish site
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
uses: peaceiris/actions-gh-pages@v3.9.3
uses: peaceiris/actions-gh-pages@v4.0.0
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: site/target/docs/site
Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ val pytorchVersion = "2.1.2"
val cudaVersion = "12.3-8.9"
val openblasVersion = "0.3.26"
val mklVersion = "2024.0"
ThisBuild / scalaVersion := "3.3.1"
ThisBuild / scalaVersion := "3.4.0"
ThisBuild / javaCppVersion := "1.5.10"
ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots")

Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ type DTypeOrDeriveArange[
* rules](https://github.com/pytorch/pytorch/blob/fb6749d977e33b5f463c2d0a1b56a939428105e5/c10/core/ScalarType.h#L423-L444)
*/
type Promoted[T <: DType, U <: DType] <: DType = (T, U) match
case (T, DType) => T
case (T, T) => T
case (U, U) => U
case (Undefined, U) | (T, Undefined) => Undefined
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
case _ => false

/** True if `other` has the same size and elements as this tensor, false otherwise. */
def equal(other: Tensor[D]): Boolean = native.equal(other.native)
infix def equal(other: Tensor[D]): Boolean = native.equal(other.native)

/** Returns the tensor with elements exponentiated. */
def exp: Tensor[D] = fromNative(native.exp())
Expand Down Expand Up @@ -415,7 +415,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def <(other: ScalaType): Tensor[Bool] = lt(other)

def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
infix def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
fromNative(native.matmul(u.native))

def `@`[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] = matmul(u)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | Complex
B <:< (FloatNN | ComplexNN)

/* Evidence that two dtypes are not the same */
type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2]
infix type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2]
6 changes: 3 additions & 3 deletions core/src/main/scala/torch/nn/functional/Convolution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private[torch] trait Convolution {
toArray(padding),
toArray(outputPadding),
groups,
toArray(dilation): _*
toArray(dilation)*
)
)

Expand Down Expand Up @@ -151,7 +151,7 @@ private[torch] trait Convolution {
toArray(padding),
toArray(outputPadding),
groups,
toArray(dilation): _*
toArray(dilation)*
)
)

Expand Down Expand Up @@ -179,7 +179,7 @@ private[torch] trait Convolution {
toArray(padding),
toArray(outputPadding),
groups,
toArray(dilation): _*
toArray(dilation)*
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
// TODO: not in Python code. Note other modules retain index, so we have repeats
this.register(module)(using Name(index.toString()))
// TODO: make modules list mutable?
ModuleList(all: _*)
ModuleList(all*)

/** Appends a given module to the end of the list.
*
Expand All @@ -94,7 +94,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
this.register(module)(using Name(index.toString()))
val all = modules.appended(module)
// TODO: make modules list mutable?
ModuleList(all: _*)
ModuleList(all*)

/** Appends modules from a Python iterable to the end of the list.
*
Expand All @@ -115,7 +115,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
this.register(module)(using Name(index.toString()))
)
// TODO: make modules list mutable?
ModuleList(all: _*)
ModuleList(all*)

override def hasBias(): Boolean = modules.exists(_.hasBias())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class LayerNorm[ParamType <: FloatNN | ComplexNN: Default](
) extends HasWeight[ParamType]
with TensorModule[ParamType]:

private val shape: LongVector = LongVector(normalizedShape.map(_.toLong): _*)
private val shape: LongVector = LongVector(normalizedShape.map(_.toLong)*)
private val options: LayerNormOptions = LayerNormOptions(shape)
options.eps().put(eps)
options.elementwise_affine().put(elementWiseAffine)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/ops/ReductionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,6 @@ private[torch] trait ReductionOps {
val nativeDim = dim.toArray
fromNative(
if nativeDim.isEmpty then torchNative.count_nonzero(input.native)
else torchNative.count_nonzero(input.native, nativeDim: _*)
else torchNative.count_nonzero(input.native, nativeDim*)
)
}
2 changes: 1 addition & 1 deletion core/src/test/scala/TrainingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package torch
import torch.data.*

class TraininSuite extends munit.FunSuite {
class TrainingSuite extends munit.FunSuite {
test("training") {

val xTrain = torch.arange(end = 10, dtype = float32) // .reshape(10, 1)
Expand Down
6 changes: 3 additions & 3 deletions core/src/test/scala/torch/ops/RandomSamplingOpsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class RandomSamplingOpsSuite extends TensorCheckSuite {

val g1 = torch.Generator()
g1.manualSeed(0)
val t1 = torch.randint(high = 100, Seq(2, 2), generator = g1)
val t2 = torch.randint(high = 100, Seq(2, 2), generator = g1)
val t1 = torch.randint(high = 100, size = Seq(2, 2), generator = g1)
val t2 = torch.randint(high = 100, size = Seq(2, 2), generator = g1)
assertNotEquals(t1, t2)

val g2 = torch.Generator()
g2.manualSeed(0)
val t3 = torch.randint(high = 100, Seq(2, 2), generator = g2)
val t3 = torch.randint(high = 100, size = Seq(2, 2), generator = g2)
assertEquals(t1, t3)

}
Expand Down
6 changes: 3 additions & 3 deletions examples/src/main/scala/gpt/V2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ object V2:
*/

// here are all the unique characters that occur in this text
val chars = SortedSet(text: _*)
val chars = SortedSet(text*)
println(s"chars = ${chars.mkString(", ")}")
val vocab_size = chars.size
println(s"vocab_size = $vocab_size")
Expand Down Expand Up @@ -413,7 +413,7 @@ object V2:
Utils.register_i(this, Head_2(nEmbed, headSize, blockSize), i)
}
// val hs = 0 until numHeads map{ i => Utils.register_i(this, Head(nEmbed, headSize, blockSize, dropout), i) }
val heads = register(nn.ModuleList(hs: _*))
val heads = register(nn.ModuleList(hs*))
// TODO: BUG - self.proj = nn.Linear(head_size * num_heads, n_embd)
val proj = register(nn.Linear(headSize * numHeads, nEmbed))
// val proj = register( nn.Linear(nEmbed, nEmbed) )
Expand Down Expand Up @@ -603,7 +603,7 @@ object V2:
val token_embedding_table = register(nn.Embedding(vocabSize, nEmbed))
val position_embedding_table = register(nn.Embedding(blockSize, nEmbed))
val blocks_i = 0 until nBlocks map { i => Block(nEmbed, nHead, blockSize, vocabSize, dropout) }
val blocks = register(nn.Sequential(blocks_i: _*))
val blocks = register(nn.Sequential(blocks_i*))
val ln_f = register(nn.LayerNorm(Seq(nEmbed)))
val lm_head = register(nn.Linear(nEmbed, vocabSize))

Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.9.8
sbt.version=1.10.0
6 changes: 3 additions & 3 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")
addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.3")
addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0")
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.6.5")
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.6.5")
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.7.1")
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.7.1")
3 changes: 1 addition & 2 deletions vision/src/main/scala/torchvision/transforms/presets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package transforms

import com.sksamuel.scrimage.ImmutableImage
import com.sksamuel.scrimage.ScaleMethod
import torch.Tensor
import torch.Float32
import torch.{Float32, Tensor}
import torchvision.transforms.functional.toTensor

object Presets:
Expand Down