Skip to content

Commit

Permalink
disable bf16 and half
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Oct 3, 2024
1 parent 94b5701 commit e3aa2bc
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions example/microbm/microbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
random.seed(42)

instructions = ["fneg", "fadd", "fsub", "fmul", "fdiv", "fcmp", "fptrunc", "fpext"]
functions = ["sin", "cos", "tan", "exp", "log", "sqrt", "expm1", "log1p", "cbrt", "pow", "fabs", "hypot", "fmuladd"]
functions = ["sin", "cos", "tan", "exp", "log", "sqrt", "expm1", "log1p", "cbrt", "pow", "fabs", "hypot", "fma"]

precisions = ["bf16", "half", "float", "double", "fp80", "fp128"]
precisions = ["float", "double", "fp80", "fp128"]
# precisions = ["bf16", "half", "float", "double", "fp80", "fp128"]
iterations = 1000000000
unrolled = 32
unrolled = 8

precision_to_llvm_type = {
"double": "double",
Expand Down Expand Up @@ -517,13 +518,13 @@ def compile_and_run(ll_filename, executable):
print(f"An error occurred during compilation or execution: {e}")
raise

finally:
if os.path.exists(asm_filename):
os.remove(asm_filename)
print(f"Cleaned up assembly file: {asm_filename}")
# finally:
# if os.path.exists(asm_filename):
# os.remove(asm_filename)
# print(f"Cleaned up assembly file: {asm_filename}")


csv_file = "cm.csv"
csv_file = "results.csv"

with open(csv_file, "w", newline="") as csvfile:
fieldnames = ["instruction", "precision", "cost"]
Expand Down Expand Up @@ -554,9 +555,13 @@ def compile_and_run(ll_filename, executable):
if src_rank is None:
continue
if instr == "fptrunc":
dst_precisions = [p for p in precisions_ordered if precision_ranks[p] < src_rank]
dst_precisions = [
p for p in precisions_ordered if p in precisions and precision_ranks[p] < src_rank
]
else:
dst_precisions = [p for p in precisions_ordered if precision_ranks[p] > src_rank]
dst_precisions = [
p for p in precisions_ordered if p in precisions and precision_ranks[p] > src_rank
]
for dst_precision in dst_precisions:
if (src_precision == "half" and dst_precision == "bf16") or (
src_precision == "bf16" and dst_precision == "half"
Expand Down

0 comments on commit e3aa2bc

Please sign in to comment.