[mlir][sparse][pytaco] test cleanup
removed obsoleted TODO removed strange Fp precision for coordinates lined up meta data testing code for readability Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D119377
This commit is contained in:
parent
3915789503
commit
6195a25487
|
@ -1,12 +1,12 @@
|
|||
# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
|
||||
2 9
|
||||
3 3
|
||||
1.0 1.0 100.0
|
||||
1.0 2.0 107.0
|
||||
1.0 3.0 114.0
|
||||
2.0 1.0 201.0
|
||||
2.0 2.0 216.0
|
||||
2.0 3.0 231.0
|
||||
3.0 1.0 318.0
|
||||
3.0 2.0 342.0
|
||||
3.0 3.0 366.0
|
||||
1 1 100
|
||||
1 2 107
|
||||
1 3 114
|
||||
2 1 201
|
||||
2 2 216
|
||||
2 3 231
|
||||
3 1 318
|
||||
3 2 342
|
||||
3 3 366
|
||||
|
|
|
@ -18,17 +18,13 @@ csr = pt.format([pt.dense, pt.compressed], [0, 1])
|
|||
# Read matrices A and B from file, infer size of output matrix C.
|
||||
A = pt.read(os.path.join(_SCRIPT_PATH, "data/A.mtx"), csr)
|
||||
B = pt.read(os.path.join(_SCRIPT_PATH, "data/B.mtx"), csr)
|
||||
C = pt.tensor((A.shape[0], B.shape[1]), csr)
|
||||
C = pt.tensor([A.shape[0], B.shape[1]], csr)
|
||||
|
||||
# Define the kernel.
|
||||
i, j, k = pt.get_index_vars(3)
|
||||
C[i, j] = A[i, k] * B[k, j]
|
||||
|
||||
# Force evaluation of the kernel by writing out C.
|
||||
#
|
||||
# TODO: use sparse_tensor.out for output, so that C.tns becomes
|
||||
# a file in extended FROSTT format
|
||||
#
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
|
||||
out_file = os.path.join(test_dir, "C.tns")
|
||||
|
|
|
@ -23,8 +23,8 @@ def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool
|
|||
_ = expected_f.readline()
|
||||
|
||||
# Compare the two lines of meta data
|
||||
if actual_f.readline() != expected_f.readline() or actual_f.readline(
|
||||
) != expected_f.readline():
|
||||
if (actual_f.readline() != expected_f.readline() or
|
||||
actual_f.readline() != expected_f.readline()):
|
||||
return FALSE
|
||||
|
||||
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
|
||||
|
|
Loading…
Reference in New Issue