import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
from matplotlib import pyplot as plt

# 假设这是你的 U-Net 模型定义
from nets.unet import Unet as UNet
# 导入 VGG16 模型定义
from nets.vgg import VGG16
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from scipy.ndimage import rotate
from scipy.fftpack import dct, idct

def fft2c(img):
    """Centered 2D FFT."""
    return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(img)))
def ifft2c(kspace):
    """Centered 2D IFFT."""
    return np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace)))
def dct2(a):
    """二维离散余弦变换"""
    return dct(dct(a.T, norm='ortho').T, norm='ortho')
def idct2(a):
    """二维离散余弦逆变换"""
    return idct(idct(a.T, norm='ortho').T, norm='ortho')
def soft_threshold(x, lam):
    """软阈值操作"""
    return np.sign(x) * np.maximum(np.abs(x) - lam, 0)
def admm_mri_reconstruction(kspace_sampled, mask, lam, rho=1.0, num_iters=100):
    """
    使用交替方向乘子法进行 MRI 重建
    :param kspace_sampled: 采样的 k 空间数据
    :param mask: 采样掩码
    :param lam: 正则化参数
    :param rho: 惩罚参数
    :param num_iters: 迭代次数
    :return: 重建的图像
    """
    # 初始重建（零填充）
    #img_reconstructed = ifft2c(kspace_sampled)
    img_reconstructed = fftshift(ifft2(kspace_sampled, (256, 256)), axes=0)
    z = dct2(img_reconstructed)
    u = np.zeros_like(z)

    for i in range(num_iters):
        # 更新 x
        kspace_x = kspace_sampled * mask + fft2c(ifft2c(idct2(z - u))) * (1 - mask)
        img_reconstructed = ifft2c(kspace_x)

        # 更新 z
        sparse_coeff = dct2(img_reconstructed) + u
        z = soft_threshold(sparse_coeff, lam / rho)

        # 更新 u
        u = u + (dct2(img_reconstructed) - z)

        if (i + 1) % 10 == 0 or i == 0:
            print(f"Iteration {i + 1}/{num_iters}")

    return np.abs(img_reconstructed)
# 加载模型
def load_model_Egg(model_path, device, backbone='vgg'):
    if backbone == 'vgg':
        model = UNet(num_classes=3, pretrained=False, backbone='vgg')
    elif backbone =='resnet50':
        model = UNet(num_classes=3, pretrained=False, backbone='resnet50')
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    model.to(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    print(f"Loaded state_dict keys: {state_dict.keys()}")
    model.load_state_dict(state_dict)
    print(f"Model's state_dict keys: {model.state_dict().keys()}")
    model.eval()
    return model
def load_model_Ger(model_path, device, backbone='vgg'):
    if backbone == 'vgg':
        model = UNet(num_classes=2, pretrained=False, backbone='vgg')
    elif backbone =='resnet50':
        model = UNet(num_classes=2, pretrained=False, backbone='resnet50')
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    model.to(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    print(f"Loaded state_dict keys: {state_dict.keys()}")
    model.load_state_dict(state_dict)
    print(f"Model's state_dict keys: {model.state_dict().keys()}")
    model.eval()
    return model
# 对 K 空间数据进行旋转
def rotate_kspace_via_image(kspace_data, angle):
    # 原始重建
    image = fftshift(ifft2(kspace_data, (256, 256)), axes=1)

    # 复数图像旋转（保持相位信息）
    rotated_image = rotate(image.real, angle, reshape=False) + 1j * rotate(image.imag, angle, reshape=False)

    # 生成新 k 空间
    new_kspace = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_image)))
    return new_kspace

# 从 K 空间数据重建图像
def reconstruct_image(kspace_data):
    im = fftshift(ifft2(kspace_data, (256, 256)), axes=0)  # 
    # 取模得到实值图像
    magnitude_image = np.abs(im)
    # 归一化图像
    normalized_image = (magnitude_image - magnitude_image.min()) / (magnitude_image.max() - magnitude_image.min())
    # 转换为 8 位无符号整数
    uint8_image = (normalized_image * 255).astype(np.uint8)
    # 转换为 RGB 图像
    rgb_image = cv2.cvtColor(uint8_image, cv2.COLOR_GRAY2RGB)
    return rgb_image

# 预处理图像
def preprocess_image(image, scale_factor=0.5):
    print(f"Original image size: {image.size}")  # 这里修改为获取PIL Image的size属性
    # 调整图像尺寸为能被 32 整除（U-Net 通常需要）
    width = int(image.width * scale_factor)
    height = int(image.height * scale_factor)
    width = (width // 32) * 32
    height = (height // 32) * 32
    image = image.resize((width, height))
    print(f"Resized image size: {image.size}")
    image = np.array(image)
    print(f"Image shape after converting to numpy: {image.shape}")
    image = image.transpose((2, 0, 1))
    print(f"Image shape after transpose: {image.shape}")
    image = image / 255.0
    print(f"Image data range after normalization: [{image.min()}, {image.max()}]")
    plt.imshow(image.transpose((1, 2, 0)))
    plt.title("Preprocessed Image")
    plt.show()
    image = torch.from_numpy(image).unsqueeze(0).float()
    return image

# 预测分割掩码
def predict_mask(model, image, device, original_size):
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
        if model.n_classes > 1:
            mask = output.argmax(dim=1)
        else:
            mask = torch.sigmoid(output) > 0.5
    print(f"Mask shape before resize: {mask.shape}")
    print(f"Mask data range: [{mask.min()}, {mask.max()}]")
    print(f"Mask value counts: {np.unique(mask.cpu().squeeze().numpy(), return_counts=True)}")
    mask = mask.cpu().squeeze().numpy()
    #plt.imshow(mask, cmap='gray')
    #plt.title("Predicted Mask (Before Resize)")
    #plt.show()
    mask = cv2.resize(mask.astype(np.uint8), original_size, interpolation=cv2.INTER_NEAREST)
    return mask

# 从原图中分割出预测信息，将所有目标分开显示，并添加椭圆拟合逻辑
def segment_image(original_image, mask):
    print(f"Original image shape: {np.array(original_image).shape}")
    print(f"Mask shape: {mask.shape}")
    original_image = np.array(original_image)

    segmented_images = []
    num_labels, labels = cv2.connectedComponents(mask)
    for label in range(1, num_labels):
        target_mask = (labels == label).astype(np.uint8)
        rows, cols = np.where(target_mask == 1)
        if len(rows) == 0 or len(cols) == 0:
            continue

        # 计算目标物体的边界
        y1, y2 = np.min(rows), np.max(rows)
        x1, x2 = np.min(cols), np.max(cols)

        # 计算目标物体的高度和宽度
        height = y2 - y1
        width = x2 - x1
        if height < 10 or width < 10:
            continue
        # 动态调整裁剪范围，这里以 20% 为例
        padding_y = int(height * 0.2)
        padding_x = int(width * 0.2)
        #padding_y = int(0)
        #padding_x = int(0)
        # 计算新的裁剪坐标
        new_y1 = max(0, y1 - padding_y)
        new_y2 = min(original_image.shape[0], y2 + padding_y)
        new_x1 = max(0, x1 - padding_x)
        new_x2 = min(original_image.shape[1], x2 + padding_x)

        print(f"Cropping coordinates for object {label}: y1={new_y1}, y2={new_y2}, x1={new_x1}, x2={new_x2}")
        # 裁剪出目标区域
        cropped_image = original_image[new_y1:new_y2, new_x1:new_x2]

        # 创建透明背景的图像
        alpha = np.ones((cropped_image.shape[0], cropped_image.shape[1]), dtype=np.uint8) * 255
        alpha[target_mask[new_y1:new_y2, new_x1:new_x2] == 0] = 0
        cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2RGBA)
        cropped_image[:, :, 3] = alpha

        # 添加椭圆拟合逻辑
        contours, _ = cv2.findContours(target_mask[new_y1:new_y2, new_x1:new_x2], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if len(contours) > 0:
            ellipse = cv2.fitEllipse(contours[0])
            cv2.ellipse(cropped_image, ellipse, (255, 0, 0), 1)  # 绘制椭圆

        segmented_images.append(cropped_image)

    return segmented_images
#颜色转换
def cvtColor(image):
    #im2_uint8 = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
    # 创建自定义 LUT
    lut = np.zeros((256, 1, 3), dtype=np.uint8)
    for i in range(256):
        if i < 32:
            # 黑到红渐变
            lut[i, 0, 2] = 255  # 红色通道
            lut[i, 0, 0] = 0    # 蓝色通道
            lut[i, 0, 1] = 0    # 绿色通道
        else:
            # 红到黄渐变（红+绿=黄）
            lut[i, 0, 2] = 255      # 红色通道保持最大
            lut[i, 0, 1] = (i - 128) * 2    # 绿色通道逐渐增加

    # 转换为 BGR 并应用 LUT
    pseudo_color1 = cv2.LUT(cv2.cvtColor(image, cv2.COLOR_GRAY2BGR), lut)
    return pseudo_color1
#融合mask
def creatmask(mask_Egg,mask_Ger):
    # 检查掩码的高度和宽度是否一致
    if mask_Egg.shape[:2] != mask_Ger.shape[:2]:
        raise ValueError("两个掩码的高度和宽度必须一致。")

    # 找到二值掩码中值为1的位置
    positions = np.where(mask_Ger == 1)
    
    # 将非二值掩码中对应位置的值设为3
    mask_Egg[positions] = 3
    return mask_Egg
#在图上绘制掩码
def draw_mask_on_image(original_image, mask):
    color_mapping = {
        1: [255, 255, 255],  # 白
        2: [0, 255, 255],  # 黄
        3: [0, 0, 255]  # 红色表示 3
    }
    overlay = original_image.copy()
    #for value in np.unique(mask):
    #    if value in color_mapping:
    #        color = color_mapping[value]
    #        positions = np.where(mask == value)
           
    #        overlay[positions] = color
    alpha = 0.5 # 设置透明度，0.0 表示完全透明，1.0 表示完全不透明
    for value in np.unique(mask):
        if value in color_mapping:
            color = color_mapping[value]
            positions = np.where(mask == value)
            # 混合颜色
            overlay[positions] = alpha * np.array(color) + (1 - alpha) * overlay[positions]
    return overlay

# 主函数
def main(model_path_Egg, model_path_Ger,input_directory, backbone='vgg'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_Egg = load_model_Egg(model_path_Egg, device, backbone)
    model_Ger = load_model_Ger(model_path_Ger, device, backbone)
    mask=np.load(r'D:\Users\Desktop\item\unet-pytorch\radial_mask.npy')
    # 设置 ADMM 参数
    lam = 0.1  # 正则化参数，可以根据需要调整
    rho = 1.0  # 惩罚参数
    num_iters = 100  # 迭代次数
    # 获取输入目录下的所有图像文件（这里假设是常见的图像格式，如jpg, png等）
    file_names = [f for f in os.listdir(input_directory) if f.endswith('.npy')]

    for file_name in file_names:
        # 读取 k 空间数据
        ksp = np.load(os.path.join(input_directory, file_name), allow_pickle=True)
        new_ksp=ksp*mask
        # 重建图像
        #original_image = reconstruct_image(new_ksp)
        original_image = admm_mri_reconstruction(new_ksp,mask,lam, rho, num_iters)
        original_image_8u = cv2.convertScaleAbs(np.array(original_image))
        
        original_image_np = np.array(original_image)
        if len(original_image_np.shape) == 3 and original_image_np.shape[2] == 3:
        # 如果是 3 通道 RGB 图像，进行颜色转换
            original_image_gray = cv2.cvtColor(original_image_np, cv2.COLOR_RGB2GRAY)
        else:
        # 如果是单通道图像，直接使用该图像
            original_image_gray = original_image_np
        #cv2.imshow('Reconstructed Image', cv2.cvtColor(np.array(original_image_8u), cv2.COLOR_RGB2BGR))
        cv2.imshow('Reconstructed Image', original_image_gray)
        cv2.waitKey(0)

        # 预测分割掩码
        original_size = (original_image.shape[1], original_image.shape[0])

        # 处理粘连情况
        input_images = np.array(original_image)


        img = Image.fromarray((np.abs(input_images) * 255).astype(np.uint8))
        original_image_gray = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2GRAY)
        blue_image=cvtColor(original_image_gray)
        new_mask_Egg = predict_mask(model_Egg, torch.from_numpy(input_images.transpose((2, 0, 1)) / 255.0).unsqueeze(0).float().to(device), device, original_size)
        new_mask_Ger = predict_mask(model_Ger, torch.from_numpy(blue_image.transpose((2, 0, 1)) / 255.0).unsqueeze(0).float().to(device), device, original_size)
        # 防止有黑斑
        img = Image.fromarray((np.clip(np.abs(input_images), 0, 1) * 255).astype(np.uint8))


        
        new_mask=creatmask(new_mask_Egg,new_mask_Ger)
        fig = plt.figure(figsize=(new_mask.shape[1] / 100, new_mask.shape[0] / 100), dpi=130)
        plt.imshow(new_mask, cmap='gray')
        plt.axis('off')
        plt.savefig('./mask.png', bbox_inches='tight', pad_inches=0, transparent=True)
        plt.show()

        # 在原图上绘制融合后的掩码
        image_with_mask = draw_mask_on_image(original_image, new_mask)
        cv2.imshow('output_image_with_alpha', image_with_mask)
        
        # 从原图中分割出预测信息
        segmented_images = segment_image(original_image, new_mask)
        segmented_image_color = segment_image(image_with_mask, new_mask)
        # 创建保存图像的文件夹
        output_folder = 'outim'
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # 保存分割后的各个目标，放大图像
        scale_factor = 1  # 放大比例，可根据需要调整
        for i, img in enumerate(segmented_images):
            print(f"Shape of segmented image {i + 1}: {img.shape}")
            # 放大图像
            new_width = int(img.shape[1] * scale_factor)
            new_height = int(img.shape[0] * scale_factor)
            resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

            output_path = os.path.join(output_folder, f'{os.path.splitext(file_name)[0]}_segmented_image_{i + 1}.png')
            cv2.imwrite(output_path, resized_img)
            
        for i, img in enumerate(segmented_image_color):
            print(f"Shape of segmented image {i + 1}: {img.shape}")
            # 放大图像
            new_width = int(img.shape[1] * scale_factor)
            new_height = int(img.shape[0] * scale_factor)
            resized_img_color = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

            output_path = os.path.join(output_folder, f'{os.path.splitext(file_name)[0]}_segmented_image_color_{i + 1}.png')
            cv2.imwrite(output_path, resized_img_color)
        

        # 显示分割后的各个目标
        for i, img in enumerate(segmented_images):
            cv2.imshow(f'Segmented Image {i + 1}', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        for i, img in enumerate(segmented_image_color):
            #cv2.imshow(f'Segmented Image_color {i + 1}', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            cv2.imshow(f'Segmented Image_color {i + 1}',img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

if __name__ == "__main__":
    model_path_Egg = r'D:\Users\Desktop\item\unet-pytorch\model_data\Egg_white and yolk.pth'  # 替换为你的模型路径
    model_path_Ger = r'D:\Users\Desktop\item\unet-pytorch\model_data\Germinal_disc.pth '  # 替换为你的模型路径
    input_directory =r'D:\Users\Desktop\item\unet-pytorch\fid_data\fid'  # 替换为你的图像文件所在目录
    backbone = 'vgg'  # 可以根据需要改为'resnet50'
    main(model_path_Egg,model_path_Ger, input_directory, backbone)
