import numpy as np
import matplotlib.pyplot as plt
from scipy.fftpack import dct, idct
import os
from scipy.optimize import minimize


def fft2c(img):
    """Centered 2D FFT."""
    return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(img)))


def ifft2c(kspace):
    """Centered 2D IFFT."""
    return np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace)))


def generate_sampling_mask(shape, sampling_rate):
    """生成随机采样掩码"""
    num_samples = int(sampling_rate * shape[0] * shape[1])
    mask = np.zeros(shape, dtype=np.float32)
    indices = np.random.choice(shape[0] * shape[1], num_samples, replace=False)
    mask.flat[indices] = 1
    return mask


def dct2(a):
    """二维离散余弦变换"""
    return dct(dct(a.T, norm='ortho').T, norm='ortho')


def idct2(a):
    """二维离散余弦逆变换"""
    return idct(idct(a.T, norm='ortho').T, norm='ortho')


def bpdn_mri_reconstruction(kspace_sampled, mask, lam, num_iters=100):
    """
    使用基追踪去噪算法进行 MRI 重建
    :param kspace_sampled: 采样的 k 空间数据
    :param mask: 采样掩码
    :param lam: 正则化参数
    :param num_iters: 最大迭代次数
    :return: 重建的图像
    """
    # 初始猜测
    img_initial = ifft2c(kspace_sampled)

    def objective(x):
        """目标函数：数据拟合项 + 正则化项"""
        img = x.reshape(mask.shape)
        kspace_reconstructed = fft2c(img)
        data_fidelity = np.linalg.norm(kspace_sampled * mask - kspace_reconstructed * mask) ** 2
        sparse_coeff = dct2(img)
        regularization = lam * np.linalg.norm(sparse_coeff.flatten(), 1)
        return data_fidelity + regularization

    # 优化求解
    result = minimize(objective, img_initial.flatten(), method='L-BFGS-B', options={'maxiter': num_iters})
    img_reconstructed = result.x.reshape(mask.shape)
    return np.abs(img_reconstructed)


# 下采样函数
def downsample(data, factor):
    return data[::factor, ::factor]


def main():
    # 设置当前工作目录
    current_dir = os.getcwd()

    # 加载完整的 k 空间数据
    kspace_path = os.path.join(current_dir, 'mrd_data1.npy')
    if not os.path.exists(kspace_path):
        print(f"错误: 找不到文件 {kspace_path}")
        return

    try:
        kspace_full = np.load(kspace_path)
        # 进一步增大下采样因子，可根据实际情况调整
        downsample_factor = 4
        kspace_full = downsample(kspace_full, downsample_factor)
        print(f"成功加载并下采样 k 空间数据: {kspace_path}")
    except Exception as e:
        print(f"错误: 加载 k 空间数据时出错: {e}")
        return

    # 设置 BPDN 参数
    lam = 0.1  # 正则化参数，可以根据需要调整
    num_iters = 100  # 最大迭代次数

    # 遍历当前目录下的所有 .npy 文件，排除 'mrd_data.npy'
    for filename in os.listdir(current_dir):
        if filename.lower().endswith('.npy') and filename != 'mrd_data.npy':
            mask_path = os.path.join(current_dir, filename)
            try:
                mask = np.load(mask_path)
                # 对掩码也进行下采样
                mask = downsample(mask, downsample_factor)
                print(f"处理文件: {mask_path}")
            except Exception as e:
                print(f"错误: 加载文件 {mask_path} 时出错: {e}")
                continue

            # 检查 kspace_full 和 mask 的形状是否匹配
            if kspace_full.shape != mask.shape:
                print(f"警告: 文件 {filename} 的形状与 k 空间数据不匹配，跳过。")
                continue

            # 采样 k 空间
            kspace_sampled = kspace_full * mask

            # 重建图像
            try:
                img_reconstructed = bpdn_mri_reconstruction(kspace_sampled, mask, lam, num_iters)
                print(f"成功重建图像: {filename}")
            except Exception as e:
                print(f"错误: 使用 BPDN 重建文件 {filename} 时出错: {e}")
                continue

            # 构造保存的 PNG 文件名
            base_name = os.path.splitext(filename)[0]
            png_filename = f"{base_name}_BPDN_reconstructed.png"
            png_path = os.path.join(current_dir, png_filename)
            plt.imshow(img_reconstructed)
            plt.axis('off')  # 取消坐标轴
            plt.tight_layout(pad=0)
            plt.savefig(png_path, dpi=300)  # 保存为 PNG 文件


if __name__ == "__main__":
    main()
    
    