Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: DGEMM workunits #146

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/main_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf
python -m pip install --upgrade numpy mypy==1.0.1 cmake pytest pybind11 scikit-build patchelf scipy
- name: Install pykokkos-base
run: |
cd /tmp
Expand All @@ -39,4 +39,5 @@ jobs:
mypy pykokkos
- name: run tests
run: |
export OMP_NUM_THREADS=4
python runtests.py
168 changes: 99 additions & 69 deletions benchmarks/dgemm_compare.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,111 @@
"""
Compare DGEMM performance with SciPy
(i.e., a wheel with OpenBLAS 0.3.18)
Record DGEMM performance.
"""

import os
import shutil
import time
import argparse
import socket

import pykokkos as pk
from pykokkos.linalg.l3_blas import dgemm as pk_dgemm

import numpy as np
from numpy.testing import assert_allclose
from scipy.linalg.blas import dgemm as scipy_dgemm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm


def setup_data(mode):
rng = np.random.default_rng(18898787)
a = rng.random((square_matrix_width, square_matrix_width)).astype(float)
b = rng.random((square_matrix_width, square_matrix_width)).astype(float)
if "pykokkos" in mode:
view_a = pk.View([square_matrix_width, square_matrix_width], dtype=pk.float64)
view_b = pk.View([square_matrix_width, square_matrix_width], dtype=pk.float64)
view_a[:] = a
view_b[:] = b
return view_a, view_b
else:
return a, b


def time_dgemm(expected, mode, league_size=4, tile_width=2):
start = time.perf_counter()
if mode == "pykokkos_no_tiling":
actual = pk_dgemm(alpha, a, b, beta=0.0, view_c=None)
elif mode == "pykokkos_with_tiling":
actual = pk_dgemm(alpha, a, b, beta=0.0, view_c=None, league_size=4, tile_width=2)
elif mode == "scipy":
actual = scipy_dgemm(alpha, a, b)
else:
raise ValueError(f"Unknown timing mode: {mode}")
# include check for correctness inside the
# timer code block to prevent i.e., async GPU
# execution; just be careful to select matrix sizes
# large enough that the assertion isn't slower than the
# DGEMM
assert_allclose(actual, expected)
end = time.perf_counter()
dgemm_time_sec = end - start
return dgemm_time_sec


if __name__ == "__main__":
import timeit
num_repeats = 50
results = {"PyKokkos": {},
"SciPy": {}}
alpha, a, b, c, beta = (3.6,
np.array([[8, 7, 1, 200, 55.3],
[99.2, 1.11, 2.02, 17.7, 900.2],
[5.01, 15.21, 22.07, 1.09, 22.22],
[1, 2, 3, 4, 5]], dtype=np.float64),
np.array([[9, 0, 2, 19],
[77, 100, 4, 19],
[1, 500, 9, 19],
[226.68, 11.61, 12.12, 19],
[17.7, 200.10, 301.17, 20]], dtype=np.float64),
np.ones((4, 4)) * 3.3,
4.3)
for system_size in ["small", "medium", "large"]:
print("-" * 20)
print(f"system size: {system_size}")

if system_size == "medium":
a_new = np.tile(a, (10, 0))
b_new = np.tile(b, (0, 10))
c_new = np.ones((40, 40)) * 3.3
elif system_size == "large":
a_new = np.tile(a, (40, 0))
b_new = np.tile(b, (0, 40))
c_new = np.ones((160, 160)) * 3.3
else:
a_new = a
b_new = b
c_new = c

view_a = pk.from_numpy(a_new)
view_b = pk.from_numpy(b_new)
view_c = pk.from_numpy(c_new)
pk_dgemm_time_sec = timeit.timeit("pk_dgemm(alpha, view_a, view_b, beta, view_c)",
globals=globals(),
number=num_repeats)
results["PyKokkos"][system_size] = pk_dgemm_time_sec
print(f"PyKokkos DGEMM execution time (s) for {num_repeats} repeats: {pk_dgemm_time_sec}")
scipy_dgemm_time_sec = timeit.timeit("scipy_dgemm(alpha, a_new, b_new, beta, c_new)",
globals=globals(),
number=num_repeats)
results["SciPy"][system_size] = scipy_dgemm_time_sec
print(f"SciPy DGEMM execution time (s) for {num_repeats} repeats: {scipy_dgemm_time_sec}")
ratio = pk_dgemm_time_sec / scipy_dgemm_time_sec
if ratio == 1:
print("PyKokkos DGEMM timing is identical to SciPy")
elif ratio > 1:
print(f"PyKokkos DGEMM timing is slower than SciPy with ratio: {ratio:.2f} fold")
else:
print(f"PyKokkos DGEMM timing is faster than SciPy with ratio: {ratio:.2f} fold")
print("-" * 20)
df = pd.DataFrame.from_dict(results)
print("df:\n", df)
fig, ax = plt.subplots()
df.plot.bar(ax=ax,
rot=0,
logy=True,
xlabel="Problem Size",
ylabel=f"log of time (s) for {num_repeats} repeats",
title="DGEMM Performance Comparison with timeit")
fig.savefig("DGEMM_perf_compare.png", dpi=300)
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--num-global-repeats', default=5)
parser.add_argument('-m', '--mode', default="scipy")
parser.add_argument('-p', '--power-of-two', default=10)
parser.add_argument('-w', '--tile-width', default=2)
parser.add_argument('-l', '--league-size', default=4)
parser.add_argument('-s', '--space', default="OpenMP")
args = parser.parse_args()
hostname = socket.gethostname()

if args.space == "OpenMP":
space = pk.ExecutionSpace.OpenMP
elif args.space == "Cuda":
space = pk.ExecutionSpace.Cuda
else:
raise ValueError(f"Invalid execution space specified: {args.space}")
pk.set_default_space(space)


num_global_repeats = int(args.num_global_repeats)
square_matrix_width = 2 ** int(args.power_of_two)


num_threads = os.environ.get("OMP_NUM_THREADS")
if num_threads is None:
raise ValueError("must set OMP_NUM_THREADS for benchmarks!")

space_name = str(space).split(".")[1]
scenario_name = f"{hostname}_dgemm_{args.mode}_{num_threads}_OMP_threads_{space_name}_execution_space_{square_matrix_width}_square_matrix_width_{args.league_size}_league_size"

cwd = os.getcwd()
shutil.rmtree(os.path.join(cwd, "pk_cpp"),
ignore_errors=True)

df = pd.DataFrame(np.full(shape=(num_global_repeats, 2), fill_value=np.nan),
columns=["scenario", "time (s)"])
df["scenario"] = df["scenario"].astype(str)
print("df before trials:\n", df)

alpha = 1.0
a, b = setup_data(mode=args.mode)
expected = scipy_dgemm(alpha, a, b)
counter = 0
for global_repeat in tqdm(range(1, num_global_repeats + 1)):
dgemm_time_sec = time_dgemm(expected, mode=args.mode, league_size=args.league_size, tile_width=args.tile_width)
df.iloc[counter, 0] = f"{scenario_name}"
df.iloc[counter, 1] = dgemm_time_sec
counter += 1

print("df after trials:\n", df)

filename = f"{scenario_name}.parquet.gzip"
df.to_parquet(filename,
engine="pyarrow",
compression="gzip")
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ ignore_errors = True

[mypy-pykokkos.lib.ufunc_workunits]
ignore_errors = True

[mypy-pykokkos.linalg.workunits]
ignore_errors = True
55 changes: 47 additions & 8 deletions pykokkos/linalg/l3_blas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import Optional

import pykokkos as pk
from pykokkos.linalg import workunits

# Level 3 BLAS functions

def dgemm(alpha: float,
view_a,
view_b,
beta: float = 0.0,
view_c = None):
view_c = None,
# TODO: league_size support is pretty limited/confusing
# at the moment...
league_size: int = 4,
tile_width: Optional[int] = None):
"""
Double precision floating point genernal matrix multiplication (GEMM).

Expand All @@ -20,6 +27,8 @@ def dgemm(alpha: float,
Shape (k, n)
beta: float, optional
view_c: pykokkos view of type double, optional
tile_width: int, optional
Number of elements along a dimension of the square tiles.

Returns
-------
Expand All @@ -45,12 +54,42 @@ def dgemm(alpha: float,

C = pk.View([view_a.shape[0], view_b.shape[1]], dtype=pk.double)

for m in range(view_a.shape[0]):
for n in range(view_b.shape[1]):
for k in range(k_a):
subresult = view_a[m, k] * view_b[k, n] * alpha
C[m, n] += float(subresult) # type: ignore
if view_c is not None:
C[m, n] += (view_c[m, n] * beta) # type: ignore
if not tile_width:
if view_c is None:
pk.parallel_for(view_a.shape[0],
workunits.dgemm_impl_no_view_c,
k_a=k_a,
alpha=alpha,
view_a=view_a,
view_b=view_b,
out=C)
else:
pk.parallel_for(view_a.shape[0],
workunits.dgemm_impl_view_c,
k_a=k_a,
alpha=alpha,
beta=beta,
view_a=view_a,
view_b=view_b,
view_c=view_c,
out=C)
else:
# limited tiling support--only (some) convenient powers of two
# allowed for now...
# limited league size support for now as well...
if league_size == 1:
slide_factor = 0
else:
slide_factor = int(league_size / 4)

pk.parallel_for("tiled_matmul",
pk.TeamPolicy(league_size=league_size,
team_size=tile_width ** 2),
workunits.dgemm_impl_tiled_no_view_c,
k_a=k_a,
alpha=alpha,
view_a=view_a,
view_b=view_b,
out=C,
slide_factor=slide_factor)
return C
91 changes: 91 additions & 0 deletions pykokkos/linalg/workunits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pykokkos as pk


@pk.workunit
def dgemm_impl_view_c(tid: int,
k_a: int,
alpha: float,
beta: float,
view_a: pk.View2D[pk.double],
view_b: pk.View2D[pk.double],
view_c: pk.View2D[pk.double],
out: pk.View2D[pk.double]):
for n in range(view_b.extent(1)):
for k in range(k_a):
out[tid][n] += float(view_a[tid][k] * view_b[k][n] * alpha)
out[tid][n] += (view_c[tid][n] * beta)


@pk.workunit
def dgemm_impl_no_view_c(tid: int,
k_a: int,
alpha: float,
view_a: pk.View2D[pk.double],
view_b: pk.View2D[pk.double],
out: pk.View2D[pk.double]):
for n in range(view_b.extent(1)):
for k in range(k_a):
out[tid][n] += float(view_a[tid][k] * view_b[k][n] * alpha)


@pk.workunit
def dgemm_impl_tiled_no_view_c(team_member: pk.TeamMember,
k_a: int,
alpha: float,
view_a: pk.View2D[pk.double],
view_b: pk.View2D[pk.double],
out: pk.View2D[pk.double],
slide_factor: int):
# early attempt at tiled matrix multiplication in PyKokkos

# for now, let's assume a 2x2 tiling arrangement and
# that `view_a`, `view_b`, and `out` views are all 4 x 4 matrices
width: int = out.extent(1)

# start off by getting a global thread id
global_tid: int = team_member.league_rank() * team_member.team_size() + team_member.team_rank()

# TODO: I have no idea how to get 2D scratch memory views?
scratch_mem_a: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), team_member.team_size())
scratch_mem_b: pk.ScratchView1D[float] = pk.ScratchView1D(team_member.team_scratch(0), team_member.team_size())
# in a 4 x 4 matrix with 2 x 2 tiling the leagues
# and teams have matching row/col assignment approaches
bx: int = team_member.league_rank() / 2
by: int = 0
if team_member.league_rank() % 2 != 0:
by = 1
tx: int = team_member.team_rank() / 2
ty: int = 0
if team_member.team_rank() % 2 != 0:
ty = 1
tmp: float = 0
col: int = by * 2 + ty
row: int = bx * 2 + tx

# these variables are a bit silly--can we not get
# 2D scratch memory indexing?
a_index: int = 0
b_index: int = 0

# TODO: league size support is limited for now, probably
# only some convenient factors of the total matrix size
slide_size: int = 0
if slide_factor == 0:
slide_size = 2
else:
slide_size = 4 * slide_factor
for row_factor in range(0, width, slide_size):
for col_factor in range(0, width, slide_size):
tmp = 0
for i in range(width / 2):
scratch_mem_a[team_member.team_rank()] = view_a[row + row_factor][i * 2 + ty]
scratch_mem_b[team_member.team_rank()] = view_b[i * 2 + tx][col + col_factor]
team_member.team_barrier()

for k in range(2):
a_index = k + ((team_member.team_rank() // 2) * 2)
b_index = ty + (k * 2)
tmp += scratch_mem_a[a_index] * scratch_mem_b[b_index]
team_member.team_barrier()

out[row + row_factor][col + col_factor] = tmp
Loading