[mlir][sparse][pytaco] test cleanup

removed obsoleted TODO
removed strange Fp precision for coordinates
lined up meta data testing code for readability

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D119377
This commit is contained in:
Aart Bik 2022-02-09 14:23:22 -08:00
parent 3915789503
commit 6195a25487
3 changed files with 12 additions and 16 deletions

View File

@ -1,12 +1,12 @@
# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format # See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
2 9 2 9
3 3 3 3
1.0 1.0 100.0 1 1 100
1.0 2.0 107.0 1 2 107
1.0 3.0 114.0 1 3 114
2.0 1.0 201.0 2 1 201
2.0 2.0 216.0 2 2 216
2.0 3.0 231.0 2 3 231
3.0 1.0 318.0 3 1 318
3.0 2.0 342.0 3 2 342
3.0 3.0 366.0 3 3 366

View File

@ -18,17 +18,13 @@ csr = pt.format([pt.dense, pt.compressed], [0, 1])
# Read matrices A and B from file, infer size of output matrix C. # Read matrices A and B from file, infer size of output matrix C.
A = pt.read(os.path.join(_SCRIPT_PATH, "data/A.mtx"), csr) A = pt.read(os.path.join(_SCRIPT_PATH, "data/A.mtx"), csr)
B = pt.read(os.path.join(_SCRIPT_PATH, "data/B.mtx"), csr) B = pt.read(os.path.join(_SCRIPT_PATH, "data/B.mtx"), csr)
C = pt.tensor((A.shape[0], B.shape[1]), csr) C = pt.tensor([A.shape[0], B.shape[1]], csr)
# Define the kernel. # Define the kernel.
i, j, k = pt.get_index_vars(3) i, j, k = pt.get_index_vars(3)
C[i, j] = A[i, k] * B[k, j] C[i, j] = A[i, k] * B[k, j]
# Force evaluation of the kernel by writing out C. # Force evaluation of the kernel by writing out C.
#
# TODO: use sparse_tensor.out for output, so that C.tns becomes
# a file in extended FROSTT format
#
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns") golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
out_file = os.path.join(test_dir, "C.tns") out_file = os.path.join(test_dir, "C.tns")

View File

@ -23,8 +23,8 @@ def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool
_ = expected_f.readline() _ = expected_f.readline()
# Compare the two lines of meta data # Compare the two lines of meta data
if actual_f.readline() != expected_f.readline() or actual_f.readline( if (actual_f.readline() != expected_f.readline() or
) != expected_f.readline(): actual_f.readline() != expected_f.readline()):
return FALSE return FALSE
actual_data = np.loadtxt(actual, np.float64, skiprows=3) actual_data = np.loadtxt(actual, np.float64, skiprows=3)