[mlir][sparse][pytaco] test cleanup
removed obsoleted TODO removed strange Fp precision for coordinates lined up meta data testing code for readability Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D119377
This commit is contained in:
parent
3915789503
commit
6195a25487
|
@ -1,12 +1,12 @@
|
||||||
# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
|
# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
|
||||||
2 9
|
2 9
|
||||||
3 3
|
3 3
|
||||||
1.0 1.0 100.0
|
1 1 100
|
||||||
1.0 2.0 107.0
|
1 2 107
|
||||||
1.0 3.0 114.0
|
1 3 114
|
||||||
2.0 1.0 201.0
|
2 1 201
|
||||||
2.0 2.0 216.0
|
2 2 216
|
||||||
2.0 3.0 231.0
|
2 3 231
|
||||||
3.0 1.0 318.0
|
3 1 318
|
||||||
3.0 2.0 342.0
|
3 2 342
|
||||||
3.0 3.0 366.0
|
3 3 366
|
||||||
|
|
|
@ -18,17 +18,13 @@ csr = pt.format([pt.dense, pt.compressed], [0, 1])
|
||||||
# Read matrices A and B from file, infer size of output matrix C.
|
# Read matrices A and B from file, infer size of output matrix C.
|
||||||
A = pt.read(os.path.join(_SCRIPT_PATH, "data/A.mtx"), csr)
|
A = pt.read(os.path.join(_SCRIPT_PATH, "data/A.mtx"), csr)
|
||||||
B = pt.read(os.path.join(_SCRIPT_PATH, "data/B.mtx"), csr)
|
B = pt.read(os.path.join(_SCRIPT_PATH, "data/B.mtx"), csr)
|
||||||
C = pt.tensor((A.shape[0], B.shape[1]), csr)
|
C = pt.tensor([A.shape[0], B.shape[1]], csr)
|
||||||
|
|
||||||
# Define the kernel.
|
# Define the kernel.
|
||||||
i, j, k = pt.get_index_vars(3)
|
i, j, k = pt.get_index_vars(3)
|
||||||
C[i, j] = A[i, k] * B[k, j]
|
C[i, j] = A[i, k] * B[k, j]
|
||||||
|
|
||||||
# Force evaluation of the kernel by writing out C.
|
# Force evaluation of the kernel by writing out C.
|
||||||
#
|
|
||||||
# TODO: use sparse_tensor.out for output, so that C.tns becomes
|
|
||||||
# a file in extended FROSTT format
|
|
||||||
#
|
|
||||||
with tempfile.TemporaryDirectory() as test_dir:
|
with tempfile.TemporaryDirectory() as test_dir:
|
||||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
|
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
|
||||||
out_file = os.path.join(test_dir, "C.tns")
|
out_file = os.path.join(test_dir, "C.tns")
|
||||||
|
|
|
@ -23,8 +23,8 @@ def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool
|
||||||
_ = expected_f.readline()
|
_ = expected_f.readline()
|
||||||
|
|
||||||
# Compare the two lines of meta data
|
# Compare the two lines of meta data
|
||||||
if actual_f.readline() != expected_f.readline() or actual_f.readline(
|
if (actual_f.readline() != expected_f.readline() or
|
||||||
) != expected_f.readline():
|
actual_f.readline() != expected_f.readline()):
|
||||||
return FALSE
|
return FALSE
|
||||||
|
|
||||||
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
|
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
|
||||||
|
|
Loading…
Reference in New Issue