import numpy as np
import matplotlib.pyplot as plt
from scipy.fftpack import dct, idct
import os
import cv2
from scipy.fft import fft2, ifft2, fftshift, ifftshift


def fft2c(img):
    """Centered 2D FFT."""
    return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(img)))


def ifft2c(kspace):
    """Centered 2D IFFT."""
    return np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace)))


def dct2(a):
    """二维离散余弦变换"""
    return dct(dct(a.T, norm='ortho').T, norm='ortho')


def idct2(a):
    """二维离散余弦逆变换"""
    return idct(idct(a.T, norm='ortho').T, norm='ortho')


def generate_sampling_mask(shape, sampling_rate):
    """生成随机采样掩码"""
    num_samples = int(sampling_rate * shape[0] * shape[1])
    mask = np.zeros(shape, dtype=np.float32)
    indices = np.random.choice(shape[0] * shape[1], num_samples, replace=False)
    mask.flat[indices] = 1
    return mask
def soft_threshold(x, lam):
    """软阈值操作"""
    return np.sign(x) * np.maximum(np.abs(x) - lam, 0)


def reconstruct_image(kspace_data):
    im = fftshift(ifft2(kspace_data, (256, 256)), axes=0)
    # 取模得到实值图像
    magnitude_image = np.abs(im)
    # 归一化图像
    normalized_image = (magnitude_image - magnitude_image.min()) / (magnitude_image.max() - magnitude_image.min())
    return normalized_image




def cg_mri_reconstruction(kspace_sampled, mask, num_iters=200):
    # 对 kspace_sampled 进行去噪处理（这里使用中值滤波示例）
    #kspace_sampled = np.abs(kspace_sampled)  # 取模，因为中值滤波一般处理实值数据
    #kspace_sampled = cv2.GaussianBlur(kspace_sampled.astype(np.float32), (3, 3), 0.1)  # 减小标准差为 0.1
    #kspace_sampled = kspace_sampled.astype(np.complex128)  # 转换回复数类型

    # 初始化重建图像
    img_reconstructed = reconstruct_image(kspace_sampled)

    # 显示初始化图像
    plt.imshow(np.abs(img_reconstructed), cmap='gray')
    plt.title('Initial Reconstructed MRI Image')
    plt.show()
    plt.imsave('initial_reconstructed_image.png', np.abs(img_reconstructed), cmap='gray')

    def A(x):
        # 正向算子：从图像到采样 k 空间
        kspace = fft2(x)
        return kspace * mask

    def At(y):
        # 伴随算子：从采样 k 空间到图像
        image = ifft2(y)
        return image

    # 初始残差
    r = kspace_sampled - A(img_reconstructed)
    p = At(r)

    epsilon = 1e-10  # 一个很小的正数，用于避免除零错误

    for i in range(num_iters):
        Ap = A(p)
        denominator_alpha = np.sum(np.conj(Ap) * Ap)
        print(f"Iteration {i + 1}, denominator_alpha: {denominator_alpha}")  # 添加打印语句
        if np.abs(denominator_alpha) < epsilon:
            denominator_alpha = epsilon
        alpha = np.sum(np.conj(r) * r) / denominator_alpha
        img_reconstructed = img_reconstructed + alpha * p
        r_new = r - alpha * Ap
        denominator_beta = np.sum(np.conj(r_new) * r_new)  # 这里修正为使用新的残差计算分母
        print(f"Iteration {i + 1}, denominator_beta: {denominator_beta}")  # 添加打印语句
        if np.abs(denominator_beta) < epsilon:
            denominator_beta = epsilon
        beta = np.sum(np.conj(r_new) * r_new) / denominator_beta
        p = At(r_new) + beta * p
        r = r_new

        # 打印残差的范数，观察残差是否正常更新
        print(f"Iteration {i + 1}, residual norm: {np.linalg.norm(r)}") 

        # 可视化中间结果
        if (i + 1) % 100 == 0:
            plt.imshow(np.abs(img_reconstructed), cmap='gray')
            plt.title(f'Reconstructed MRI Image at Iteration {i + 1}')
            plt.show()

    final_img = np.abs(img_reconstructed)

    return final_img
def admm_mri_reconstruction(kspace_sampled, mask, lam, rho=1.0, num_iters=100):
    """
    使用交替方向乘子法进行 MRI 重建
    :param kspace_sampled: 采样的 k 空间数据
    :param mask: 采样掩码
    :param lam: 正则化参数
    :param rho: 惩罚参数
    :param num_iters: 迭代次数
    :return: 重建的图像
    """
    img_reconstructed = reconstruct_image(kspace_sampled)

    # 显示初始化图像
    plt.imshow(np.abs(img_reconstructed), cmap='gray')
    plt.title('Initial Reconstructed MRI Image')
    plt.show()
    plt.imsave('initial_reconstructed_image.png', np.abs(img_reconstructed), cmap='gray')
    # 初始重建（零填充）
    #img_reconstructed = ifft2c(kspace_sampled)
    z = dct2(img_reconstructed)
    u = np.zeros_like(z)

    for i in range(num_iters):
        # 更新 x
        kspace_x = kspace_sampled * mask + fft2c(ifft2c(idct2(z - u))) * (1 - mask)
        img_reconstructed = ifft2c(kspace_x)

        # 更新 z
        sparse_coeff = dct2(img_reconstructed) + u
        z = soft_threshold(sparse_coeff, lam / rho)

        # 更新 u
        u = u + (dct2(img_reconstructed) - z)

        if (i + 1) % 10 == 0 or i == 0:
            print(f"Iteration {i + 1}/{num_iters}")

    return np.abs(img_reconstructed)



def ista_mri_reconstruction(kspace_sampled, mask, lam, num_iters=100, step_size=1.0):
    """
    使用迭代软阈值算法ISTA进行MRI重建
    :param kspace_sampled: 采样的k空间数据
    :param mask: 采样掩码
    :param lam: 正则化参数
    :param num_iters: 迭代次数
    :param step_size: 步长
    :return: 重建的图像
    """
    # 初始重建（零填充）
    # img_reconstructed = ifft2c(kspace_sampled)
    # img_reconstructed = np.abs(img_reconstructed)
    img_reconstructed = reconstruct_image(kspace_sampled)
    # 显示初始化图像
    plt.imshow(np.abs(img_reconstructed), cmap='gray')
    plt.title('Initial Reconstructed MRI Image')
    plt.show()
    plt.imsave('initial_reconstructed_image.png', np.abs(img_reconstructed), cmap='gray')

    for i in range(num_iters):
        # 将当前图像转换到稀疏域
        sparse_coeff = dct2(img_reconstructed)

        # 软阈值操作
        sparse_coeff = np.sign(sparse_coeff) * np.maximum(np.abs(sparse_coeff) - lam * step_size, 0)

        # 将稀疏系数转换回图像域
        img_temp = idct2(sparse_coeff)

        # 数据一致性步骤
        kspace_reconstructed = fft2c(img_temp)
        kspace_reconstructed = kspace_sampled * mask + kspace_reconstructed * (1 - mask)
        #img_reconstructed = np.abs(ifft2c(kspace_reconstructed))
        img_reconstructed = reconstruct_image(kspace_reconstructed)
        if (i + 1) % 10 == 0 or i == 0:
            print(f"Iteration {i + 1}/{num_iters}")

    return img_reconstructed


def main():
    # 设置当前工作目录
    current_dir = os.getcwd()

    # 加载完整的k空间数据
    kspace_path = os.path.join(current_dir, 'fid_20250418_163142.npy')
    if not os.path.exists(kspace_path):
        print(f"错误: 找不到文件 {kspace_path}")
        return

    try:
        kspace_full = np.load(kspace_path)
        print(f"成功加载k空间数据: {kspace_path}")
    except Exception as e:
        print(f"错误: 加载k空间数据时出错: {e}")
        return
    # 设置 ISTA 参数
    lam = 0.1  # 正则化参数，可以根据需要调整
    num_iters = 1000  # 迭代次数
    step_size = 0.1  # 步长，可以根据需要调整
    # 设置采样率
    sampling_rate = 0.33 # 可以根据需要调整

    # 生成采样掩码
    mask1 = generate_sampling_mask(kspace_full.shape, sampling_rate)

    # 采样k空间

    # 共轭梯度法重建图像

    # 遍历当前目录下的所有 .npy 文件，排除 'mrd_data.npy'
    for filename in os.listdir(current_dir):
        if filename.lower().endswith('.npy') and filename != 'fid_20250418_163142.npy':
            mask_path = os.path.join(current_dir, filename)
            try:
                mask = np.load(mask_path)
                print(f"处理文件: {mask_path}")
            except Exception as e:
                print(f"错误: 加载文件 {mask_path} 时出错: {e}")
                continue

            # 检查 kspace_full 和 mask 的形状是否匹配
            if kspace_full.shape != mask.shape:
                print(f"警告: 文件 {filename} 的形状与 k 空间数据不匹配，跳过。")
                continue
            kspace_sampled_test = kspace_full * mask1
            img_reconstructed_test = reconstruct_image(kspace_sampled_test)
            # 显示初始化图像
            plt.imshow(np.abs(img_reconstructed_test), cmap='gray')
            plt.title('test Reconstructed MRI Image')
            plt.show()
            plt.imsave('test_reconstructed_image.png', np.abs(img_reconstructed_test), cmap='gray')
            # 采样 k 空间
            kspace_sampled = kspace_full * mask

            # 重建图像
            try:
                #img_reconstructed = admm_mri_reconstruction(kspace_sampled, mask, lam, rho=1.0, num_iters=200)
                img_reconstructed =ista_mri_reconstruction(kspace_sampled, mask, lam, num_iters, step_size)
                print(f"成功重建图像: {filename}")
            except Exception as e:
                print(f"错误: 重建文件 {filename} 时出错: {e}")
                continue

            # 构造保存的 PNG 文件名
            base_name = os.path.splitext(filename)[0]
            #png_filename = f"{base_name}_ista_reconstructed_fid_20250417_103329.png"
            png_filename = f"{base_name}_ista_reconstructed_fid_20250417_103329.png"
            png_path = os.path.join(current_dir, png_filename)
            plt.imshow(img_reconstructed, cmap='gray')  # 设置为灰度图显示
            plt.title('Reconstructed MRI Image')
            #plt.axis('off')  # 取消坐标轴
           # plt.tight_layout(pad=0)
            plt.show()
            #plt.savefig(png_path, dpi=300)  # 保存为 PNG 文件
            plt.imsave(png_path, np.abs(img_reconstructed), cmap='gray')

if __name__ == "__main__":
    main()
