import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift, ifftshift
import os
import pywt
from scipy.fftpack import dct, idct


def fft2c(img):
    """Centered 2D FFT."""
    return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(img)))


def ifft2c(kspace):
    """Centered 2D IFFT."""
    return np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace)))


def generate_sampling_mask(shape, sampling_rate):
    """生成随机采样掩码"""
    num_samples = int(sampling_rate * shape[0] * shape[1])
    mask = np.zeros(shape, dtype=np.float32)
    indices = np.random.choice(shape[0] * shape[1], num_samples, replace=False)
    mask.flat[indices] = 1
    return mask


def dct2(a):
    """二维离散余弦变换"""
    return dct(dct(a.T, norm='ortho').T, norm='ortho')


def soft_threshold(x, lam):
    """软阈值操作"""
    return np.sign(x) * np.maximum(np.abs(x) - lam, 0)


def idct2(a):
    """二维离散余弦逆变换"""
    return idct(idct(a.T, norm='ortho').T, norm='ortho')


def reconstruct_image(kspace_data):
    im = fftshift(ifft2(kspace_data, (256, 256)), axes=0)
    # 取模得到实值图像
    magnitude_image = np.abs(im)
    # 归一化图像
    normalized_image = (magnitude_image - magnitude_image.min()) / (magnitude_image.max() - magnitude_image.min())
    return normalized_image


def ista_mri_reconstruction(kspace_sampled, mask, lam, num_iters=100, step_size=1.0, wavelet='db4'):
    # 初始重建（零填充）
    img_reconstructed = reconstruct_image(kspace_sampled)

    for i in range(num_iters):
        # 将当前图像转换到稀疏域（使用离散小波变换 DWT）
        coeffs = pywt.wavedec2(img_reconstructed, wavelet, level=3)
        new_coeffs = []
        for c in coeffs:
            if isinstance(c, tuple):
                c = tuple([soft_threshold(x, lam * step_size) for x in c])
            else:
                c = soft_threshold(c, lam * step_size)
            new_coeffs.append(c)
        # 将稀疏系数转换回图像域
        img_temp = pywt.waverec2(new_coeffs, wavelet)
        # 数据一致性步骤
        kspace_reconstructed = fft2c(img_temp)
        kspace_reconstructed = kspace_sampled * mask + kspace_reconstructed * (1 - mask)
        img_reconstructed = reconstruct_image(kspace_reconstructed)

        if (i + 1) % 10 == 0 or i == 0:
            print(f"Iteration {i + 1}/{num_iters}")

    return img_reconstructed


def main(kspace_path, mask_path):
    # 设置 FISTA 参数
    lam = 0.05  # 正则化参数，可以根据需要调整
    num_iters = 500  # 迭代次数
    rho = 2.0  # 惩罚参数
    step_size = 0.05  # 步长，可以根据需要调整
    tv_alpha = 0.1  # 总变分约束强度，可以根据需要调整
    wavelets = ['haar']

    kspace_full = np.load(kspace_path)
    mask = np.load(mask_path)

    # 调整形状使其匹配
    min_shape_0 = min(kspace_full.shape[0], mask.shape[0])
    min_shape_1 = min(kspace_full.shape[1], mask.shape[1])
    kspace_full = kspace_full[:min_shape_0, :min_shape_1]
    mask = mask[:min_shape_0, :min_shape_1]

    # 检查kspace_full和mask的形状是否匹配
    if kspace_full.shape != mask.shape:
        print(f"警告: 文件 {mask_path} 的形状与k空间数据不匹配,程序退出。")
        return

    # 采样k空间
    kspace_sampled = kspace_full * mask

    for wavelet in wavelets:
        # 重建图像
        img_reconstructed = ista_mri_reconstruction(kspace_sampled, mask, lam, num_iters, step_size)

        # 构造保存的PNG文件名，包含小波基名称
        base_name = os.path.splitext(os.path.basename(mask_path))[0]
        current_dir = os.getcwd()
        png_filename = f"{base_name}_{wavelet}_ista_reconstructed.png"
        png_path = os.path.join(current_dir, png_filename)
        plt.imshow(img_reconstructed, cmap='gray')  # 设置为灰度图显示
        plt.title(f'Reconstructed MRI Image with {wavelet} wavelet')
        plt.show()
        plt.imsave(png_path, np.abs(img_reconstructed), cmap='gray')


if __name__ == "__main__":
    # 在这里直接定义路径
    kspace_path = r'D:\Users\Desktop\egg\mrd_data1.npy'
    mask_path = r'D:\Users\Desktop\egg\line_128.npy'
    main(kspace_path, mask_path)
