mirror of https://github.com/Qiskit/qiskit.git
Add dedicated functions for memory marginalization (#8051)
* Add dedicated functions for memory marginalization This commit adds dedicated functions for memory marginalization. Previously, the marginal_counts() function had support for marginalizing memory in a Results object, but this can be inefficient especially if your memory list is outside a Results object. The new functions added in this commit are implemented in Rust and multithreaded. Additionally the marginal_counts() function is updated to use the same inner Rust functions. * Fix rustfmt * Add missing test file * Fix typos Co-authored-by: Daniel J. Egger <38065505+eggerdj@users.noreply.github.com> * Remove unused import * Increate default parallel_threshold to 1000 * Add support for different measurement levels * Update docstring * Add release note * Expand unit tests * Fix rustfmt * Apply suggestions from code review Co-authored-by: Kevin Hartman <kevin@hart.mn> * Use a lookup table instead of a match statement Co-authored-by: Daniel J. Egger <38065505+eggerdj@users.noreply.github.com> Co-authored-by: Kevin Hartman <kevin@hart.mn> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
5a6ec94699
commit
8ee4ac80ec
|
@ -182,6 +182,17 @@ dependencies = [
|
|||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.2"
|
||||
|
@ -288,6 +299,7 @@ dependencies = [
|
|||
"hashbrown",
|
||||
"indoc",
|
||||
"libc",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"parking_lot",
|
||||
"pyo3-build-config",
|
||||
|
@ -346,7 +358,9 @@ dependencies = [
|
|||
"ahash",
|
||||
"hashbrown",
|
||||
"indexmap",
|
||||
"lazy_static",
|
||||
"ndarray",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"numpy",
|
||||
"pyo3",
|
||||
|
|
|
@ -18,10 +18,12 @@ rand_distr = "0.4.3"
|
|||
indexmap = "1.9.1"
|
||||
ahash = "0.7.6"
|
||||
num-complex = "0.4"
|
||||
num-bigint = "0.4"
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
[dependencies.pyo3]
|
||||
version = "0.16.5"
|
||||
features = ["extension-module", "hashbrown", "num-complex"]
|
||||
features = ["extension-module", "hashbrown", "num-complex", "num-bigint"]
|
||||
|
||||
[dependencies.ndarray]
|
||||
version = "^0.15.0"
|
||||
|
|
|
@ -25,6 +25,7 @@ Experiment Results (:mod:`qiskit.result`)
|
|||
Counts
|
||||
marginal_counts
|
||||
marginal_distribution
|
||||
marginal_memory
|
||||
|
||||
Distributions
|
||||
=============
|
||||
|
@ -50,6 +51,7 @@ from .result import Result
|
|||
from .exceptions import ResultError
|
||||
from .utils import marginal_counts
|
||||
from .utils import marginal_distribution
|
||||
from .utils import marginal_memory
|
||||
from .counts import Counts
|
||||
|
||||
from .distributions.probability import ProbDistribution
|
||||
|
|
|
@ -18,13 +18,15 @@ from typing import List, Union, Optional, Dict
|
|||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qiskit.exceptions import QiskitError
|
||||
from qiskit.result.result import Result
|
||||
from qiskit.result.counts import Counts
|
||||
from qiskit.result.distributions.probability import ProbDistribution
|
||||
from qiskit.result.distributions.quasi import QuasiDistribution
|
||||
|
||||
from qiskit.result.postprocess import _bin_to_hex, _hex_to_bin
|
||||
from qiskit.result.postprocess import _bin_to_hex
|
||||
|
||||
# pylint: disable=import-error, no-name-in-module
|
||||
from qiskit._accelerate import results as results_rs
|
||||
|
@ -88,12 +90,9 @@ def marginal_counts(
|
|||
sorted_indices = sorted(
|
||||
indices, reverse=True
|
||||
) # same convention as for the counts
|
||||
bit_strings = [_hex_to_bin(s) for s in experiment_result.data.memory]
|
||||
marginal_bit_strings = [
|
||||
"".join([s[-idx - 1] for idx in sorted_indices if idx < len(s)]) or "0"
|
||||
for s in bit_strings
|
||||
]
|
||||
experiment_result.data.memory = [_bin_to_hex(s) for s in marginal_bit_strings]
|
||||
experiment_result.data.memory = results_rs.marginal_memory(
|
||||
experiment_result.data.memory, sorted_indices, return_hex=True
|
||||
)
|
||||
return result
|
||||
else:
|
||||
marg_counts = _marginalize(result, indices)
|
||||
|
@ -128,14 +127,85 @@ def _adjust_creg_sizes(creg_sizes, indices):
|
|||
return new_creg_sizes
|
||||
|
||||
|
||||
def marginal_memory(
|
||||
memory: Union[List[str], np.ndarray],
|
||||
indices: Optional[List[int]] = None,
|
||||
int_return: bool = False,
|
||||
hex_return: bool = False,
|
||||
avg_data: bool = False,
|
||||
parallel_threshold: int = 1000,
|
||||
) -> Union[List[str], np.ndarray]:
|
||||
"""Marginalize shot memory
|
||||
|
||||
This function is multithreaded and will launch a thread pool with threads equal to the number
|
||||
of CPUs by default. You can tune the number of threads with the ``RAYON_NUM_THREADS``
|
||||
environment variable. For example, setting ``RAYON_NUM_THREADS=4`` would limit the thread pool
|
||||
to 4 threads.
|
||||
|
||||
Args:
|
||||
memory: The input memory list, this is either a list of hexadecimal strings to be marginalized
|
||||
representing measure level 2 memory or a numpy array representing level 0 measurement
|
||||
memory (single or avg) or level 1 measurement memory (single or avg).
|
||||
indices: The bit positions of interest to marginalize over. If
|
||||
``None`` (default), do not marginalize at all.
|
||||
int_return: If set to ``True`` the output will be a list of integers.
|
||||
By default the return type is a bit string. This and ``hex_return``
|
||||
are mutually exclusive and can not be specified at the same time. This option only has an
|
||||
effect with memory level 2.
|
||||
hex_return: If set to ``True`` the output will be a list of hexadecimal
|
||||
strings. By default the return type is a bit string. This and
|
||||
``int_return`` are mutually exclusive and can not be specified
|
||||
at the same time. This option only has an effect with memory level 2.
|
||||
avg_data: If a 2 dimensional numpy array is passed in for ``memory`` this can be set to
|
||||
``True`` to indicate it's a avg level 0 data instead of level 1
|
||||
single data.
|
||||
parallel_threshold: The number of elements in ``memory`` to start running in multiple
|
||||
threads. If ``len(memory)`` is >= this value, the function will run in multiple
|
||||
threads. By default this is set to 1000.
|
||||
|
||||
Returns:
|
||||
marginal_memory: The list of marginalized memory
|
||||
|
||||
Raises:
|
||||
ValueError: if both ``int_return`` and ``hex_return`` are set to ``True``
|
||||
"""
|
||||
if int_return and hex_return:
|
||||
raise ValueError("Either int_return or hex_return can be specified but not both")
|
||||
|
||||
if isinstance(memory, np.ndarray):
|
||||
if int_return:
|
||||
raise ValueError("int_return option only works with memory list input")
|
||||
if hex_return:
|
||||
raise ValueError("hex_return option only works with memory list input")
|
||||
if indices is None:
|
||||
return memory.copy()
|
||||
if memory.ndim == 1:
|
||||
return results_rs.marginal_measure_level_1_avg(memory, indices)
|
||||
if memory.ndim == 2:
|
||||
if avg_data:
|
||||
return results_rs.marginal_measure_level_0_avg(memory, indices)
|
||||
else:
|
||||
return results_rs.marginal_measure_level_1(memory, indices)
|
||||
if memory.ndim == 3:
|
||||
return results_rs.marginal_measure_level_0(memory, indices)
|
||||
raise ValueError("Invalid input memory array")
|
||||
return results_rs.marginal_memory(
|
||||
memory,
|
||||
indices,
|
||||
return_int=int_return,
|
||||
return_hex=hex_return,
|
||||
parallel_threshold=parallel_threshold,
|
||||
)
|
||||
|
||||
|
||||
def marginal_distribution(
|
||||
counts: dict, indices: Optional[List[int]] = None, format_marginal: bool = False
|
||||
) -> Dict[str, int]:
|
||||
"""Marginalize counts from an experiment over some indices of interest.
|
||||
|
||||
Unlike :func:`~.marginal_counts` this function respects the order of
|
||||
the input ``indices``. If the input ``indices`` list is specified, the order
|
||||
the bit indices will be the output order of the bitstrings
|
||||
the input ``indices``. If the input ``indices`` list is specified then the order
|
||||
the bit indices are specified will be the output order of the bitstrings
|
||||
in the marginalized output.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
---
|
||||
features:
|
||||
- |
|
||||
Added a new function :func:`~.marginal_memory` which is used to marginalize
|
||||
shot memory arrays. Provided with the shot memory array and the indices
|
||||
of interest the function will return a maginized shot memory array. This
|
||||
function differs from the memory support in the :func:`~.marginal_counts`
|
||||
method which only works on the ``memory`` field in a :class:`~.Results`
|
||||
object.
|
|
@ -10,6 +10,9 @@
|
|||
// copyright notice, and modified files need to carry a notice indicating
|
||||
// that they have been altered from the originals.
|
||||
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
|
||||
use std::env;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// This code is part of Qiskit.
|
||||
//
|
||||
// (C) Copyright IBM 2022
|
||||
//
|
||||
// This code is licensed under the Apache License, Version 2.0. You may
|
||||
// obtain a copy of this license in the LICENSE.txt file in the root directory
|
||||
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
||||
//
|
||||
// Any modifications or derivative works of this code must retain this
|
||||
// copyright notice, and modified files need to carry a notice indicating
|
||||
// that they have been altered from the originals.
|
||||
|
||||
lazy_static! {
|
||||
static ref HEX_TO_BIN_LUT: [&'static str; 256] = {
|
||||
let mut lookup = [""; 256];
|
||||
lookup[b'0' as usize] = "0000";
|
||||
lookup[b'1' as usize] = "0001";
|
||||
lookup[b'2' as usize] = "0010";
|
||||
lookup[b'3' as usize] = "0011";
|
||||
lookup[b'4' as usize] = "0100";
|
||||
lookup[b'5' as usize] = "0101";
|
||||
lookup[b'6' as usize] = "0110";
|
||||
lookup[b'7' as usize] = "0111";
|
||||
lookup[b'8' as usize] = "1000";
|
||||
lookup[b'9' as usize] = "1001";
|
||||
lookup[b'A' as usize] = "1010";
|
||||
lookup[b'B' as usize] = "1011";
|
||||
lookup[b'C' as usize] = "1100";
|
||||
lookup[b'D' as usize] = "1101";
|
||||
lookup[b'E' as usize] = "1110";
|
||||
lookup[b'F' as usize] = "1111";
|
||||
lookup[b'a' as usize] = "1010";
|
||||
lookup[b'b' as usize] = "1011";
|
||||
lookup[b'c' as usize] = "1100";
|
||||
lookup[b'd' as usize] = "1101";
|
||||
lookup[b'e' as usize] = "1110";
|
||||
lookup[b'f' as usize] = "1111";
|
||||
lookup
|
||||
};
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn hex_to_bin(hex: &str) -> String {
|
||||
hex[2..]
|
||||
.chars()
|
||||
.map(|c| HEX_TO_BIN_LUT[c as usize])
|
||||
.collect()
|
||||
}
|
|
@ -10,8 +10,16 @@
|
|||
// copyright notice, and modified files need to carry a notice indicating
|
||||
// that they have been altered from the originals.
|
||||
|
||||
use super::converters::hex_to_bin;
|
||||
use crate::getenv_use_multiple_threads;
|
||||
use hashbrown::HashMap;
|
||||
use ndarray::prelude::*;
|
||||
use num_bigint::BigUint;
|
||||
use num_complex::Complex64;
|
||||
use numpy::IntoPyArray;
|
||||
use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3};
|
||||
use pyo3::prelude::*;
|
||||
use rayon::prelude::*;
|
||||
|
||||
fn marginalize<T: std::ops::AddAssign + Copy>(
|
||||
counts: HashMap<String, T>,
|
||||
|
@ -73,3 +81,140 @@ pub fn marginal_distribution(
|
|||
) -> HashMap<String, f64> {
|
||||
marginalize(counts, indices)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn map_memory(
|
||||
hexstring: &str,
|
||||
indices: &Option<Vec<usize>>,
|
||||
clbit_size: usize,
|
||||
return_hex: bool,
|
||||
) -> String {
|
||||
let out = match indices {
|
||||
Some(indices) => {
|
||||
let bitstring = hex_to_bin(hexstring);
|
||||
let bit_array = bitstring.as_bytes();
|
||||
indices
|
||||
.iter()
|
||||
.map(|bit| {
|
||||
let index = clbit_size - *bit - 1;
|
||||
match bit_array.get(index) {
|
||||
Some(bit) => *bit as char,
|
||||
None => '0',
|
||||
}
|
||||
})
|
||||
.rev()
|
||||
.collect()
|
||||
}
|
||||
None => hex_to_bin(hexstring),
|
||||
};
|
||||
if return_hex {
|
||||
format!("0x{:x}", BigUint::parse_bytes(out.as_bytes(), 2).unwrap())
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction(
|
||||
return_int = "false",
|
||||
return_hex = "false",
|
||||
parallel_threshold = "1000"
|
||||
)]
|
||||
pub fn marginal_memory(
|
||||
py: Python,
|
||||
memory: Vec<String>,
|
||||
indices: Option<Vec<usize>>,
|
||||
return_int: bool,
|
||||
return_hex: bool,
|
||||
parallel_threshold: usize,
|
||||
) -> PyResult<PyObject> {
|
||||
let run_in_parallel = getenv_use_multiple_threads();
|
||||
let first_elem = memory.get(0);
|
||||
if first_elem.is_none() {
|
||||
let res: Vec<String> = Vec::new();
|
||||
return Ok(res.to_object(py));
|
||||
}
|
||||
|
||||
let clbit_size = hex_to_bin(first_elem.unwrap()).len();
|
||||
|
||||
let out_mem: Vec<String> = if memory.len() < parallel_threshold || !run_in_parallel {
|
||||
memory
|
||||
.iter()
|
||||
.map(|x| map_memory(x, &indices, clbit_size, return_hex))
|
||||
.collect()
|
||||
} else {
|
||||
memory
|
||||
.par_iter()
|
||||
.map(|x| map_memory(x, &indices, clbit_size, return_hex))
|
||||
.collect()
|
||||
};
|
||||
if return_int {
|
||||
if out_mem.len() < parallel_threshold || !run_in_parallel {
|
||||
Ok(out_mem
|
||||
.iter()
|
||||
.map(|x| BigUint::parse_bytes(x.as_bytes(), 2).unwrap())
|
||||
.collect::<Vec<BigUint>>()
|
||||
.to_object(py))
|
||||
} else {
|
||||
Ok(out_mem
|
||||
.par_iter()
|
||||
.map(|x| BigUint::parse_bytes(x.as_bytes(), 2).unwrap())
|
||||
.collect::<Vec<BigUint>>()
|
||||
.to_object(py))
|
||||
}
|
||||
} else {
|
||||
Ok(out_mem.to_object(py))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn marginal_measure_level_0(
|
||||
py: Python,
|
||||
memory: PyReadonlyArray3<Complex64>,
|
||||
indices: Vec<usize>,
|
||||
) -> PyObject {
|
||||
let mem_arr: ArrayView3<Complex64> = memory.as_array();
|
||||
let input_shape = mem_arr.shape();
|
||||
let new_shape = [input_shape[0], indices.len(), input_shape[2]];
|
||||
let out_arr: Array3<Complex64> =
|
||||
Array3::from_shape_fn(new_shape, |(i, j, k)| mem_arr[[i, indices[j], k]]);
|
||||
out_arr.into_pyarray(py).into()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn marginal_measure_level_0_avg(
|
||||
py: Python,
|
||||
memory: PyReadonlyArray2<Complex64>,
|
||||
indices: Vec<usize>,
|
||||
) -> PyObject {
|
||||
let mem_arr: ArrayView2<Complex64> = memory.as_array();
|
||||
let input_shape = mem_arr.shape();
|
||||
let new_shape = [indices.len(), input_shape[1]];
|
||||
let out_arr: Array2<Complex64> =
|
||||
Array2::from_shape_fn(new_shape, |(i, j)| mem_arr[[indices[i], j]]);
|
||||
out_arr.into_pyarray(py).into()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn marginal_measure_level_1(
|
||||
py: Python,
|
||||
memory: PyReadonlyArray2<Complex64>,
|
||||
indices: Vec<usize>,
|
||||
) -> PyObject {
|
||||
let mem_arr: ArrayView2<Complex64> = memory.as_array();
|
||||
let input_shape = mem_arr.shape();
|
||||
let new_shape = [input_shape[0], indices.len()];
|
||||
let out_arr: Array2<Complex64> =
|
||||
Array2::from_shape_fn(new_shape, |(i, j)| mem_arr[[i, indices[j]]]);
|
||||
out_arr.into_pyarray(py).into()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn marginal_measure_level_1_avg(
|
||||
py: Python,
|
||||
memory: PyReadonlyArray1<Complex64>,
|
||||
indices: Vec<usize>,
|
||||
) -> PyResult<PyObject> {
|
||||
let mem_arr: &[Complex64] = memory.as_slice()?;
|
||||
let out_arr: Vec<Complex64> = indices.into_iter().map(|idx| mem_arr[idx]).collect();
|
||||
Ok(out_arr.into_pyarray(py).into())
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// copyright notice, and modified files need to carry a notice indicating
|
||||
// that they have been altered from the originals.
|
||||
|
||||
pub mod converters;
|
||||
pub mod marginalization;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
|
@ -19,5 +20,14 @@ use pyo3::wrap_pyfunction;
|
|||
pub fn results(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_counts))?;
|
||||
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_distribution))?;
|
||||
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_memory))?;
|
||||
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_measure_level_0))?;
|
||||
m.add_wrapped(wrap_pyfunction!(
|
||||
marginalization::marginal_measure_level_0_avg
|
||||
))?;
|
||||
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_measure_level_1))?;
|
||||
m.add_wrapped(wrap_pyfunction!(
|
||||
marginalization::marginal_measure_level_1_avg
|
||||
))?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# This code is part of Qiskit.
|
||||
#
|
||||
# (C) Copyright IBM 2017, 2019.
|
||||
#
|
||||
# This code is licensed under the Apache License, Version 2.0. You may
|
||||
# obtain a copy of this license in the LICENSE.txt file in the root directory
|
||||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
||||
#
|
||||
# Any modifications or derivative works of this code must retain this
|
||||
# copyright notice, and modified files need to carry a notice indicating
|
||||
# that they have been altered from the originals.
|
||||
|
||||
"""Test marginal_memory() function."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qiskit.test import QiskitTestCase
|
||||
from qiskit.result import marginal_memory
|
||||
|
||||
|
||||
class TestMarginalMemory(QiskitTestCase):
|
||||
"""Result operations methods."""
|
||||
|
||||
def test_marginalize_memory(self):
|
||||
"""Test that memory marginalizes correctly."""
|
||||
memory = [hex(ii) for ii in range(8)]
|
||||
res = marginal_memory(memory, indices=[0])
|
||||
self.assertEqual(res, [bin(ii % 2)[2:] for ii in range(8)])
|
||||
|
||||
def test_marginalize_memory_int(self):
|
||||
"""Test that memory marginalizes correctly int output."""
|
||||
memory = [hex(ii) for ii in range(8)]
|
||||
res = marginal_memory(memory, indices=[0], int_return=True)
|
||||
self.assertEqual(res, [ii % 2 for ii in range(8)])
|
||||
|
||||
def test_marginalize_memory_hex(self):
|
||||
"""Test that memory marginalizes correctly hex output."""
|
||||
memory = [hex(ii) for ii in range(8)]
|
||||
res = marginal_memory(memory, indices=[0], hex_return=True)
|
||||
self.assertEqual(res, [hex(ii % 2) for ii in range(8)])
|
||||
|
||||
def test_marginal_counts_result_memory_indices_None(self):
|
||||
"""Test that a memory marginalizes correctly with indices=None."""
|
||||
memory = [hex(ii) for ii in range(8)]
|
||||
res = marginal_memory(memory, hex_return=True)
|
||||
self.assertEqual(res, memory)
|
||||
|
||||
def test_marginalize_memory_in_parallel(self):
|
||||
"""Test that memory marginalizes correctly multithreaded."""
|
||||
memory = [hex(ii) for ii in range(15)]
|
||||
res = marginal_memory(memory, indices=[0], parallel_threshold=1)
|
||||
self.assertEqual(res, [bin(ii % 2)[2:] for ii in range(15)])
|
||||
|
||||
def test_error_on_multiple_return_types(self):
|
||||
"""Test that ValueError raised if multiple return types are requested."""
|
||||
with self.assertRaises(ValueError):
|
||||
marginal_memory([], int_return=True, hex_return=True)
|
||||
|
||||
def test_memory_level_0(self):
|
||||
"""Test that a single level 0 measurement data is correctly marginalized."""
|
||||
memory = np.asarray(
|
||||
[
|
||||
# qubit 0 qubit 1 qubit 2
|
||||
[
|
||||
[-12974255.0, -28106672.0],
|
||||
[15848939.0, -53271096.0],
|
||||
[-18731048.0, -56490604.0],
|
||||
], # shot 1
|
||||
[
|
||||
[-18346508.0, -26587824.0],
|
||||
[-12065728.0, -44948360.0],
|
||||
[14035275.0, -65373000.0],
|
||||
], # shot 2
|
||||
[
|
||||
[12802274.0, -20436864.0],
|
||||
[-15967512.0, -37575556.0],
|
||||
[15201290.0, -65182832.0],
|
||||
], # ...
|
||||
[[-9187660.0, -22197716.0], [-17028016.0, -49578552.0], [13526576.0, -61017756.0]],
|
||||
[[7006214.0, -32555228.0], [16144743.0, -33563124.0], [-23524160.0, -66919196.0]],
|
||||
],
|
||||
dtype=complex,
|
||||
)
|
||||
result = marginal_memory(memory, [0, 2])
|
||||
expected = np.asarray(
|
||||
[
|
||||
[[-12974255.0, -28106672.0], [-18731048.0, -56490604.0]], # shot 1
|
||||
[[-18346508.0, -26587824.0], [14035275.0, -65373000.0]], # shot 2
|
||||
[[12802274.0, -20436864.0], [15201290.0, -65182832.0]], # ...
|
||||
[[-9187660.0, -22197716.0], [13526576.0, -61017756.0]],
|
||||
[[7006214.0, -32555228.0], [-23524160.0, -66919196.0]],
|
||||
],
|
||||
dtype=complex,
|
||||
)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_memory_level_0_avg(self):
|
||||
"""Test that avg level 0 measurement data is correctly marginalized."""
|
||||
memory = np.asarray(
|
||||
[[-1059254.375, -26266612.0], [-9012669.0, -41877468.0], [6027076.0, -54875060.0]],
|
||||
dtype=complex,
|
||||
)
|
||||
result = marginal_memory(memory, [0, 2], avg_data=True)
|
||||
expected = np.asarray(
|
||||
[[-1059254.375, -26266612.0], [6027076.0, -54875060.0]], dtype=complex
|
||||
)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_memory_level_1(self):
|
||||
"""Test that a memory level 1 single data is correctly marginalized."""
|
||||
memory = np.array([[1.0j, 1.0, 0.5 + 0.5j], [0.5 + 0.5j, 1.0, 1.0j]], dtype=complex)
|
||||
result = marginal_memory(memory, [0, 2])
|
||||
expected = np.array([[1.0j, 0.5 + 0.5j], [0.5 + 0.5j, 1.0j]], dtype=complex)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_memory_level_1_avg(self):
|
||||
"""Test that avg memory level 1 data is correctly marginalized."""
|
||||
memory = np.array([1.0j, 1.0, 0.5 + 0.5j], dtype=complex)
|
||||
result = marginal_memory(memory, [0, 1])
|
||||
expected = np.array([1.0j, 1.0], dtype=complex)
|
||||
np.testing.assert_array_equal(result, expected)
|
Loading…
Reference in New Issue