import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
from matplotlib import pyplot as plt
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from scipy.ndimage import rotate
# 假设这是你的 U-Net 模型定义
from nets.unet import Unet as UNet
# 导入 VGG16 模型定义
from nets.vgg import VGG16


# 加载模型
def load_model(model_path, device, backbone='vgg'):
    if backbone == 'vgg':
        model = UNet(num_classes=2, pretrained=False, backbone='vgg')
    elif backbone == 'resnet50':
        model = UNet(num_classes=2, pretrained=False, backbone='resnet50')
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    model.to(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    print(f"Loaded state_dict keys: {state_dict.keys()}")
    model.load_state_dict(state_dict)
    print(f"Model's state_dict keys: {model.state_dict().keys()}")
    model.eval()
    return model


# 对 K 空间数据进行旋转
def rotate_kspace_via_image(kspace_data, angle):
    # 原始重建
    image = fftshift(ifft2(kspace_data, (256, 256)), axes=1)

    # 复数图像旋转（保持相位信息）
    rotated_image = rotate(image.real, angle, reshape=False) + 1j * rotate(image.imag, angle, reshape=False)

    # 生成新 k 空间
    new_kspace = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_image)))
    return new_kspace


# 从 K 空间数据重建图像
def reconstruct_image(kspace_data):
    im = fftshift(ifft2(kspace_data, (256, 256)), axes=0)  # 仅在垂直方向平移
    # 取模得到实值图像
    magnitude_image = np.abs(im)
    # 归一化图像
    normalized_image = (magnitude_image - magnitude_image.min()) / (magnitude_image.max() - magnitude_image.min())
    # 转换为 8 位无符号整数
    uint8_image = (normalized_image * 255).astype(np.uint8)
    # 转换为 RGB 图像
    rgb_image = cv2.cvtColor(uint8_image, cv2.COLOR_GRAY2RGB)
    return rgb_image


# 使用分水岭算法分割粘连目标
def watershed_segmentation(mask, original_image):
    # 形态学操作，先开运算去除小的噪声，再膨胀操作扩大前景区域
    kernel = np.ones((3, 3), np.uint8)
    opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
    sure_bg = cv2.dilate(opening, kernel, iterations=3)

    # 距离变换找到前景的确定区域
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
    ret, sure_fg = cv2.threshold(dist_transform, 0.7 * dist_transform.max(), 255, 0)

    # 找到未知区域
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(sure_bg, sure_fg)

    # 标记连通区域
    ret, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown == 255] = 0

    # 转换原始图像为 8 位无符号整数类型
    original_image = np.uint8(original_image)

    # 应用分水岭算法
    markers = cv2.watershed(original_image, markers)
    return markers


# 从原图中分割出预测信息，处理粘连目标并分别显示
#def segment_image(original_image, mask):
#    print(f"Original image shape: {np.array(original_image).shape}")
#    print(f"Mask shape: {mask.shape}")
#    original_image = np.array(original_image)

    # 处理粘连目标
#    markers = watershed_segmentation(mask * 255, original_image)

#    unique_markers = np.unique(markers)
#    unique_markers = unique_markers[unique_markers > 1]  # 排除背景和边界

#    segmented_images = []
#    for marker in unique_markers:
#        target_mask = (markers == marker).astype(np.uint8)
#        rows, cols = np.where(target_mask == 1)
#        if len(rows) == 0 or len(cols) == 0:
#            continue
#        y1, y2 = np.min(rows), np.max(rows)
#        x1, x2 = np.min(cols), np.max(cols)
#        # 裁剪出目标区域
#        cropped_image = original_image[y1:y2 + 1, x1:x2 + 1]
#        segmented_images.append(cropped_image)

#    return segmented_images
def segment_image(original_image, mask):
    print(f"Original image shape: {np.array(original_image).shape}")
    print(f"Mask shape: {mask.shape}")
    original_image = np.array(original_image)

    # 处理粘连目标
    markers = watershed_segmentation(mask * 255, original_image)

    unique_markers = np.unique(markers)
    unique_markers = unique_markers[unique_markers > 1]  # 排除背景和边界

    segmented_images = []
    for marker in unique_markers:
        target_mask = (markers == marker).astype(np.uint8)
        rows, cols = np.where(target_mask == 1)
        if len(rows) == 0 or len(cols) == 0:
            continue
        y1, y2 = np.min(rows), np.max(rows)
        x1, x2 = np.min(cols), np.max(cols)
        # 裁剪出目标区域
        cropped_image = original_image[y1:y2 + 1, x1:x2 + 1]
        # 创建透明通道
        alpha_channel = target_mask[y1:y2 + 1, x1:x2 + 1] * 255
        # 合并图像和透明通道
        rgba_image = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2RGBA)
        rgba_image[:, :, 3] = alpha_channel
        segmented_images.append(rgba_image)

    return segmented_images
# 预测分割掩码
def predict_mask(model, image, device, original_size):
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
        if model.n_classes > 1:
            mask = output.argmax(dim=1)
        else:
            mask = torch.sigmoid(output) > 0.5
    print(f"Mask shape before resize: {mask.shape}")
    print(f"Mask data range: [{mask.min()}, {mask.max()}]")
    print(f"Mask value counts: {np.unique(mask.cpu().squeeze().numpy(), return_counts=True)}")
    mask = mask.cpu().squeeze().numpy()
    plt.imshow(mask, cmap='gray')
    plt.title("Predicted Mask (Before Resize)")
    plt.show()
    mask = cv2.resize(mask.astype(np.uint8), original_size, interpolation=cv2.INTER_NEAREST)
    return mask

# 主函数
def main(model_path, input_directory, backbone='vgg'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = load_model(model_path, device, backbone)

    # 获取输入目录下的所有 .npy 文件
    file_names = [f for f in os.listdir(input_directory) if f.endswith('.npy')]

    for file_name in file_names:
        # 读取 k 空间数据
        ksp = np.load(os.path.join(input_directory, file_name), allow_pickle=True)

        # 重建图像
        original_image = reconstruct_image(ksp)

        # 这里简单假设可以使用一个固定的角度旋转 K 空间，你可以根据实际情况修改
        #angle = 0
        #new_kspace = rotate_kspace_via_image(ksp, angle)
        #rotated_image = reconstruct_image(new_kspace)

        # 预处理图像（这里简单复用之前的函数，可能需要根据实际情况调整）
        input_image = torch.from_numpy(original_image.transpose((2, 0, 1)) / 255.0).unsqueeze(0).float()

        # 预测分割掩码
        original_size = (original_image.shape[1], original_image.shape[0])
        mask = predict_mask(model, input_image, device, original_size)

        # 从原图中分割出预测信息
        segmented_images = segment_image(original_image, mask)

        # 创建保存图像的文件夹
        output_folder = 'out_img'
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # 保存分割后的各个目标，放大图像
        scale_factor = 1.2  # 放大比例，可根据需要调整
        for i, img in enumerate(segmented_images):
            # 放大图像
            new_width = int(img.shape[1] * scale_factor)
            new_height = int(img.shape[0] * scale_factor)
            resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

            output_path = os.path.join(output_folder, f'{os.path.splitext(file_name)[0]}_segmented_image_{i + 1}.jpg')
            cv2.imwrite(output_path, cv2.cvtColor(resized_img, cv2.COLOR_RGB2BGR))

        # 显示原图
        cv2.imshow('Original Image', cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))

        # 显示分割后的各个目标
        for i, img in enumerate(segmented_images):
            cv2.imshow(f'Segmented Image {i + 1}', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

        cv2.waitKey(0)
        cv2.destroyAllWindows()




if __name__ == "__main__":
    model_path = r'D:\Users\Desktop\item\unet-pytorch\model_data\best_epoch_weights.pth'  # 替换为你的模型路径
    input_directory = r'D:\Users\Desktop\item\unet-pytorch\fid_data'  # 替换为你的 K 空间 .npy 文件所在目录
    backbone = 'vgg'  # 可以根据需要改为 'resnet50'
    main(model_path, input_directory, backbone)    