import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
from matplotlib import pyplot as plt
import logging
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from scipy.ndimage import rotate
# 假设这是你的 U-Net 模型定义
from nets.unet import Unet as UNet
# 导入 VGG16 模型定义
from nets.vgg import VGG16




# 加载模型
def load_model(model_path, device, backbone='vgg'):
    if backbone == 'vgg':
        model = UNet(num_classes=4, pretrained=False, backbone='vgg')
    elif backbone =='resnet50':
        model = UNet(num_classes=4, pretrained=False, backbone='resnet50')
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    model.to(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    print(f"Loaded state_dict keys: {state_dict.keys()}")
    model.load_state_dict(state_dict)
    print(f"Model's state_dict keys: {model.state_dict().keys()}")
    model.eval()
    return model
# 对 K 空间数据进行旋转
def rotate_kspace_via_image(kspace_data, angle):
    # 原始重建
    image = fftshift(ifft2(kspace_data, (256, 256)), axes=1)

    # 复数图像旋转（保持相位信息）
    rotated_image = rotate(image.real, angle, reshape=False) + 1j * rotate(image.imag, angle, reshape=False)

    # 生成新 k 空间
    new_kspace = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_image)))
    return new_kspace


# 从 K 空间数据重建图像
def reconstruct_image(kspace_data):
    im = fftshift(ifft2(kspace_data, (256, 256)), axes=0)  # 
    # 取模得到实值图像
    magnitude_image = np.abs(im)
    # 归一化图像
    normalized_image = (magnitude_image - magnitude_image.min()) / (magnitude_image.max() - magnitude_image.min())
    # 转换为 8 位无符号整数
    uint8_image = (normalized_image * 255).astype(np.uint8)
    # 转换为 RGB 图像
    rgb_image = cv2.cvtColor(uint8_image, cv2.COLOR_GRAY2RGB)
    return rgb_image


# 预处理图像
def preprocess_image(image, scale_factor=0.5):
    print(f"Original image size: {image.size}")  # 这里修改为获取PIL Image的size属性
    # 调整图像尺寸为能被 32 整除（U-Net 通常需要）
    width = int(image.width * scale_factor)
    height = int(image.height * scale_factor)
    width = (width // 32) * 32
    height = (height // 32) * 32
    image = image.resize((width, height))
    print(f"Resized image size: {image.size}")
    image = np.array(image)
    print(f"Image shape after converting to numpy: {image.shape}")
    image = image.transpose((2, 0, 1))
    print(f"Image shape after transpose: {image.shape}")
    image = image / 255.0
    print(f"Image data range after normalization: [{image.min()}, {image.max()}]")
    plt.imshow(image.transpose((1, 2, 0)))
    plt.title("Preprocessed Image")
    plt.show()
    image = torch.from_numpy(image).unsqueeze(0).float()
    return image


# 预测分割掩码
def predict_mask(model, image, device, original_size):
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
        if model.n_classes > 1:
            mask = output.argmax(dim=1)
        else:
            mask = torch.sigmoid(output) > 0.5
    print(f"Mask shape before resize: {mask.shape}")
    print(f"Mask data range: [{mask.min()}, {mask.max()}]")
    print(f"Mask value counts: {np.unique(mask.cpu().squeeze().numpy(), return_counts=True)}")
    mask = mask.cpu().squeeze().numpy()
    plt.imshow(mask, cmap='gray')
    plt.title("Predicted Mask (Before Resize)")
    plt.show()
    mask = cv2.resize(mask.astype(np.uint8), original_size, interpolation=cv2.INTER_NEAREST)
    return mask


# 从原图中分割出预测信息，将所有目标分开显示，并添加椭圆拟合逻辑
def segment_image(original_image, mask):
    print(f"Original image shape: {np.array(original_image).shape}")
    print(f"Mask shape: {mask.shape}")
    original_image = np.array(original_image)

    segmented_images = []
    num_labels, labels = cv2.connectedComponents(mask)
    for label in range(1, num_labels):
        target_mask = (labels == label).astype(np.uint8)
        rows, cols = np.where(target_mask == 1)
        if len(rows) == 0 or len(cols) == 0:
            continue

        # 计算目标物体的边界
        y1, y2 = np.min(rows), np.max(rows)
        x1, x2 = np.min(cols), np.max(cols)

        # 计算目标物体的高度和宽度
        height = y2 - y1
        width = x2 - x1
        if height < 10 or width < 10:
            continue
        # 动态调整裁剪范围，这里以 20% 为例
        padding_y = int(height * 0.2)
        padding_x = int(width * 0.2)
        #padding_y = int(0)
        #padding_x = int(0)
        # 计算新的裁剪坐标
        new_y1 = max(0, y1 - padding_y)
        new_y2 = min(original_image.shape[0], y2 + padding_y)
        new_x1 = max(0, x1 - padding_x)
        new_x2 = min(original_image.shape[1], x2 + padding_x)

        print(f"Cropping coordinates for object {label}: y1={new_y1}, y2={new_y2}, x1={new_x1}, x2={new_x2}")
        # 裁剪出目标区域
        cropped_image = original_image[new_y1:new_y2, new_x1:new_x2]

        # 创建透明背景的图像
        alpha = np.ones((cropped_image.shape[0], cropped_image.shape[1]), dtype=np.uint8) * 255
        alpha[target_mask[new_y1:new_y2, new_x1:new_x2] == 0] = 0
        cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2RGBA)
        cropped_image[:, :, 3] = alpha

        # 添加椭圆拟合逻辑
        contours, _ = cv2.findContours(target_mask[new_y1:new_y2, new_x1:new_x2], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if len(contours) > 0:
            ellipse = cv2.fitEllipse(contours[0])
            cv2.ellipse(cropped_image, ellipse, (255, 0, 0), 1)  # 绘制椭圆

        segmented_images.append(cropped_image)

    return segmented_images


# 主函数
def main(model_path, input_directory, backbone='vgg'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = load_model(model_path, device, backbone)

    # 获取输入目录下的所有图像文件（这里假设是常见的图像格式，如jpg, png等）
    file_names = [f for f in os.listdir(input_directory) if f.endswith('.npy')]

    for file_name in file_names:
        # 读取 k 空间数据
        ksp = np.load(os.path.join(input_directory, file_name), allow_pickle=True)

        # 重建图像
        original_image = reconstruct_image(ksp)
        cv2.imshow('Reconstructed Image', cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR))
        cv2.waitKey(0)
        # 这里简单假设可以使用一个固定的角度旋转 K 空间，你可以根据实际情况修改
        # angle = 0
        # new_kspace = rotate_kspace_via_image(ksp, angle)
        # rotated_image = reconstruct_image(new_kspace)

        # 预处理图像（这里简单复用之前的函数，可能需要根据实际情况调整）
        input_image = torch.from_numpy(original_image.transpose((2, 0, 1)) / 255.0).unsqueeze(0).float()

        # 预测分割掩码
        original_size = (original_image.shape[1], original_image.shape[0])
        mask = predict_mask(model, input_image, device, original_size)


        # 处理粘连情况
        input_images = np.array(original_image)
        mask_threshold = 0.5
        scale = 0.5
        n_classes = 2
        bilinear = False

        img = Image.fromarray((np.abs(input_images) * 255).astype(np.uint8))
        new_mask = predict_mask(model, torch.from_numpy(input_images.transpose((2, 0, 1)) / 255.0).unsqueeze(0).float().to(device), device, original_size)
        # 防止有黑斑
        img = Image.fromarray((np.clip(np.abs(input_images), 0, 1) * 255).astype(np.uint8))

        # 保存mask
        fig = plt.figure(figsize=(new_mask.shape[1] / 100, new_mask.shape[0] / 100), dpi=130)
        plt.imshow(new_mask, cmap='gray')
        plt.axis('off')
        plt.savefig('./mask.png', bbox_inches='tight', pad_inches=0, transparent=True)
        plt.close(fig)

        # 从原图中分割出预测信息
        segmented_images = segment_image(original_image, new_mask)

        # 创建保存图像的文件夹
        output_folder = 'outim'
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # 保存分割后的各个目标，放大图像
        scale_factor = 1  # 放大比例，可根据需要调整
        for i, img in enumerate(segmented_images):
            print(f"Shape of segmented image {i + 1}: {img.shape}")
            # 放大图像
            new_width = int(img.shape[1] * scale_factor)
            new_height = int(img.shape[0] * scale_factor)
            resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

            output_path = os.path.join(output_folder, f'{os.path.splitext(file_name)[0]}_segmented_image_{i + 1}.png')
            cv2.imwrite(output_path, resized_img)

        # 显示原图
        cv2.imshow('Original Image', cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR))

        # 显示分割后的各个目标
        for i, img in enumerate(segmented_images):
            cv2.imshow(f'Segmented Image {i + 1}', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

        cv2.waitKey(0)
        cv2.destroyAllWindows()


if __name__ == "__main__":
    model_path = r'D:\Users\Desktop\item\unet-pytorch\model_data\EGG_WHITE.pth'  # 替换为你的模型路径
    input_directory =r'D:\Users\Desktop\item\unet-pytorch\fid_data\fid'  # 替换为你的图像文件所在目录
    backbone = 'vgg'  # 可以根据需要改为'resnet50'
    main(model_path, input_directory, backbone)
