import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import nnls
from scipy.sparse import csr_matrix


def read_mrd(file_path, header_size=512, data_shape=(384, 448), data_type=np.complex64):
    """
    读取 .mrd 文件并解析为复数浮点型数组。

    :param file_path: .mrd 文件路径
    :param header_size: 文件头的字节大小
    :param data_shape: 数据的维度（例如 K 空间或图像空间的形状）
    :param data_type: 数据类型（默认为 np.complex64,即单精度复数）
    :return: 解析后的复数数组
    """
    # 将二进制数据解析为实数类型（float32 或 float64）
    if data_type == np.complex64:
        float_type = np.float32  # 单精度浮点数
    elif data_type == np.complex128:
        float_type = np.float64  # 双精度浮点数
    else:
        raise ValueError("Unsupported complex data type. Use np.complex64 or np.complex128.")
    # 确定单个复数占用的字节数
    bytes_per_sample = np.dtype(data_type).itemsize  # complex64 -> 8 bytes (4 实部 + 4 虚部)
    float_type_size = np.dtype(float_type).itemsize  # 每个浮点数的字节大小

    # 计算数据部分的字节数
    total_data_points = np.prod(data_shape) * 2  # 每个复数由两个浮点数组成
    data_bytes = total_data_points * float_type_size  # 数据部分的总字节数

    with open(file_path, 'rb') as f:
        # 跳过头部
        f.seek(header_size)

        # 读取数据部分
        raw_data = f.read(data_bytes)
        if len(raw_data) != data_bytes:
            raise ValueError(f"Error: Expected {data_bytes} bytes of data, but got {len(raw_data)} bytes.")

        # 解析为一维数组（实部和虚部交替存储）
        real_imag_array = np.frombuffer(raw_data, dtype=float_type)

        # 将实部和虚部合成为复数
        complex_data = real_imag_array[::2] + 1j * real_imag_array[1::2]

        # 重塑为指定维度
        complex_data = complex_data.reshape(data_shape)
    return complex_data


def omp(A, y, sparsity_level):
    """
    正交匹配追踪算法实现
    :param A: 测量矩阵（采样矩阵与稀疏矩阵的乘积）
    :param y: 采样得到的测量信号
    :param sparsity_level: 信号的稀疏度
    :return: 恢复的稀疏信号
    """
    m, n = A.shape
    r = y.copy()
    omega = []
    x = np.zeros(n)
    for _ in range(sparsity_level):
        correlations = np.abs(A.T @ r)
        lambda_new = np.argmax(correlations)
        omega.append(lambda_new)
        A_omega = A[:, omega]
        A_omega_dense = A_omega.toarray()
        x_omega, _ = nnls(A_omega_dense, np.real(y))
        x[omega] = x_omega
        r = y - A @ x
    return x


# 示例调用
file_path = r"D:\Users\Desktop\egg\Scan2.mrd"
data_shape = (384, 448)  # 替换为实际的 K 空间形状
header_size = 512  # 假设头部大小为 512 字节
data_type = np.complex64  # 假设是单精度复数（complex64）

# 读取复数数据
mrd_data = read_mrd(file_path, header_size=header_size, data_shape=data_shape, data_type=data_type)
np.save('mrd_data1.npy', mrd_data)

# 计算复数数据的幅度（模）
magnitude = np.abs(mrd_data)

# 可视化原始 K 空间图像
plt.imshow(magnitude, cmap='gray', aspect='auto')
plt.title("Original K-Space Image")
plt.colorbar(label="Normalized Intensity")
plt.savefig("pic/originalK.png", dpi=300)  # 保存为 PNG 文件
plt.show()

# 预定义采样矩阵（手动生成稀疏随机矩阵）
sampling_rate = 0.2  # 采样率
n = data_shape[0] * data_shape[1]
m = int(sampling_rate * n)
density = 0.1  # 稀疏矩阵的非零元素密度
nnz = int(m * n * density)  # 非零元素数量

# 手动生成随机索引
rows = np.random.randint(0, m, nnz)
cols = np.random.randint(0, n, nnz)
data = np.random.randn(nnz)

# 创建稀疏矩阵
Phi = csr_matrix((data, (rows, cols)), shape=(m, n))

# 逐列计算稀疏矩阵的范数
col_norms = np.sqrt(np.asarray(Phi.power(2).sum(axis=0))).flatten()
# 避免除零错误
col_norms[col_norms == 0] = 1
# 逐列归一化稀疏矩阵
Phi = csr_matrix((Phi.data / col_norms[Phi.indices], Phi.indices, Phi.indptr), shape=Phi.shape)

# 完整的 k 空间数据向量化
kspace_full = mrd_data.flatten()

# 定义傅里叶变换操作函数
def fft_operation(x):
    x_reshaped = x.reshape(data_shape)
    return np.fft.fft2(x_reshaped) / np.sqrt(n)

# 压缩感知采样
def phi_psi_product(x):
    fft_x = fft_operation(x)
    fft_x_flat = fft_x.flatten()
    return Phi @ fft_x_flat

y = phi_psi_product(kspace_full)

# 稀疏度
sparsity_level = 50

# 定义 A 矩阵的作用（避免显式构建）
def A_operator(x):
    return phi_psi_product(x)

# 通过迭代算法（OMP）恢复稀疏信号
def modified_omp(A_operator, y, sparsity_level, n):
    r = y.copy()
    omega = []
    x = np.zeros(n)
    for _ in range(sparsity_level):
        def correlation_operator(i):
            basis = np.zeros(n)
            basis[i] = 1
            A_basis = A_operator(basis)
            return np.abs(np.dot(A_basis.conj(), r))
        correlations = np.array([correlation_operator(i) for i in range(n)])
        lambda_new = np.argmax(correlations)
        omega.append(lambda_new)
        A_omega = np.array([A_operator(np.eye(n)[i]) for i in omega]).T
        x_omega, _ = nnls(A_omega, np.real(y))
        x[omega] = x_omega
        Ax = np.array([A_operator(x[i] * np.eye(n)[i]) for i in range(n)]).sum(axis=0)
        r = y - Ax
    return x

recovered_sparse_signal = modified_omp(A_operator, y, sparsity_level, n)

# 从恢复的稀疏信号得到 k 空间信息
recovered_k_space_vector = recovered_sparse_signal
recovered_k_space = recovered_k_space_vector.reshape(data_shape)

# 从 k 空间重建图像
reconstructed_image = np.abs(np.fft.ifft2(recovered_k_space))

# 可视化重建图像
plt.imshow(reconstructed_image, cmap='gray')
plt.colorbar()
plt.title("Reconstructed Image (Compressed Sensing)")
plt.savefig("pic/reconstructed_cs.png", dpi=300)  # 保存为 PNG 文件
plt.show()

