-
Notifications
You must be signed in to change notification settings - Fork 2
/
axon.exs
34 lines (25 loc) · 1.55 KB
/
axon.exs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
Mix.install([
{:axon, github: "elixir-nx/axon", branch: "main"},
{:nx_iree, path: "."},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true}
], system_env: %{"NX_IREE_PREFER_PRECOMPILED" => false})
{:ok, drivers} = NxIREE.list_drivers() |> IO.inspect(label: "drivers")
{:ok, [dev | _]} = NxIREE.list_devices("cuda")
flags = ["--iree-hal-target-backends=cuda", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
model =
Axon.input("x", shape: {nil, 3})
|> Axon.dense(8, activation: :relu)
|> Axon.dense(1, activation: :relu)
Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
# Nx.Defn.default_options(compiler: EXLA, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
template = %{"x" => Nx.template({10, 3}, :f32)}
{init_fn, predict_fn} = Axon.build(model, [])
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(Axon.ModelState.empty())])
IO.puts("\n\n\n======= BEGIN predict_compiled_fn =======\n\n\n")
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template])
IO.puts("\n\n\n======= END predict_compiled_fn =======\n\n\n")
IO.puts("\n\n\n======= BEGIN predict_compiled_fn CALL =======\n\n\n")
predict_compiled_fn.(init_params, Nx.iota({10, 3}, type: :f32)) |> dbg()
IO.puts("\n\n\n======= END predict_compiled_fn CALL =======\n\n\n")