import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
from matplotlib import pyplot as plt
# 假设这是你的 U-Net 模型定义
from nets.unet import Unet as UNet
# 导入 VGG16 模型定义
from nets.vgg import VGG16


# 加载模型
def load_model(model_path, device, backbone='vgg'):
    if backbone == 'vgg':
        model = UNet(num_classes=4, pretrained=False, backbone='vgg')
    elif backbone =='resnet50':
        model = UNet(num_classes=4, pretrained=False, backbone='resnet50')
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    model.to(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    print(f"Loaded state_dict keys: {state_dict.keys()}")
    model.load_state_dict(state_dict)
    print(f"Model's state_dict keys: {model.state_dict().keys()}")
    model.eval()
    return model


# 预处理图像
def preprocess_image(image, scale_factor=0.5):
    print(f"Original image size: {image.size}")
    # 调整图像尺寸为能被 32 整除（U-Net 通常需要）
    width = int(image.width * scale_factor)
    height = int(image.height * scale_factor)
    width = (width // 32) * 32
    height = (height // 32) * 32
    image = image.resize((width, height))
    print(f"Resized image size: {image.size}")
    image = np.array(image)
    print(f"Image shape after converting to numpy: {image.shape}")
    image = image.transpose((2, 0, 1))
    print(f"Image shape after transpose: {image.shape}")
    image = image / 255.0
    print(f"Image data range after normalization: [{image.min()}, {image.max()}]")
    plt.imshow(image.transpose((1, 2, 0)))
    plt.title("Preprocessed Image")
    plt.show()
    image = torch.from_numpy(image).unsqueeze(0).float()
    return image


# 预测分割掩码
def predict_mask(model, image, device, original_size):
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
        if model.n_classes > 1:
            mask = output.argmax(dim=1)
        else:
            mask = torch.sigmoid(output) > 0.5
    print(f"Mask shape before resize: {mask.shape}")
    print(f"Mask data range: [{mask.min()}, {mask.max()}]")
    print(f"Mask value counts: {np.unique(mask.cpu().squeeze().numpy(), return_counts=True)}")
    mask = mask.cpu().squeeze().numpy()
    plt.imshow(mask, cmap='gray')
    plt.title("Predicted Mask (Before Resize)")
    plt.show()
    mask = cv2.resize(mask.astype(np.uint8), original_size, interpolation=cv2.INTER_NEAREST)
    return mask


# 从原图中分割出预测信息，将所有目标分开显示
def segment_image(original_image, mask):
    print(f"Original image shape: {np.array(original_image).shape}")
    print(f"Mask shape: {mask.shape}")
    original_image = np.array(original_image)

    segmented_images = []
    num_labels, labels = cv2.connectedComponents(mask)
    for label in range(1, num_labels):
        target_mask = (labels == label).astype(np.uint8)
        rows, cols = np.where(target_mask == 1)
        if len(rows) == 0 or len(cols) == 0:
            continue

        # 计算目标物体的边界
        y1, y2 = np.min(rows), np.max(rows)
        x1, x2 = np.min(cols), np.max(cols)

        # 计算目标物体的高度和宽度
        height = y2 - y1
        width = x2 - x1

        # 动态调整裁剪范围，这里以 20% 为例
        padding_y = int(height * 0.1)
        padding_x = int(width * 0.1)
        #padding_y = int(0)
        #padding_x = int(0)
        # 计算新的裁剪坐标
        new_y1 = max(0, y1 - padding_y)
        new_y2 = min(original_image.shape[0], y2 + padding_y)
        new_x1 = max(0, x1 - padding_x)
        new_x2 = min(original_image.shape[1], x2 + padding_x)

        print(f"Cropping coordinates for object {label}: y1={new_y1}, y2={new_y2}, x1={new_x1}, x2={new_x2}")
        # 裁剪出目标区域
        cropped_image = original_image[new_y1:new_y2, new_x1:new_x2]

        # 创建透明背景的图像
        alpha = np.ones((cropped_image.shape[0], cropped_image.shape[1]), dtype=np.uint8) * 255
        alpha[target_mask[new_y1:new_y2, new_x1:new_x2] == 0] = 0
        cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2RGBA)
        cropped_image[:, :, 3] = alpha

        segmented_images.append(cropped_image)

    return segmented_images


# 主函数
def main(model_path, image_path, backbone='vgg'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = load_model(model_path, device, backbone)

    # 打开图像
    original_image = Image.open(image_path).convert('RGB')
    original_size = (original_image.width, original_image.height)  # 获取原始图像的尺寸

    # 预处理图像
    input_image = preprocess_image(original_image)

    # 预测分割掩码
    mask = predict_mask(model, input_image, device, original_size)

    # 从原图中分割出预测信息
    segmented_images = segment_image(original_image, mask)

    # 创建保存图像的文件夹
    output_folder = 'out_img'
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # 保存分割后的各个目标，放大图像
    scale_factor = 1  # 放大比例，可根据需要调整
    for i, img in enumerate(segmented_images):
        print(f"Shape of segmented image {i + 1}: {img.shape}")
        # 放大图像
        new_width = int(img.shape[1] * scale_factor)
        new_height = int(img.shape[0] * scale_factor)
        resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

        output_path = os.path.join(output_folder, f'segmented_image_{i + 1}.png')  # 保存为 PNG 格式以支持透明背景
        cv2.imwrite(output_path, resized_img)

    # 显示原图
    cv2.imshow('Original Image', cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR))

    # 显示分割后的各个目标
    for i, img in enumerate(segmented_images):
        cv2.imshow(f'Segmented Image {i + 1}', cv2.cvtColor(img, cv2.COLOR_RGBA2BGR))

    cv2.waitKey(0)
    cv2.destroyAllWindows()


if __name__ == "__main__":
    model_path = r'D:\Users\Desktop\item\unet-pytorch\model_data\Egg white_300.pth'  # 替换为你的模型路径
    image_path = r'D:\Users\Desktop\item\unet-pytorch\img\rgb_1744715221.jpg'  # 替换为你的图像路径
    backbone = 'vgg'  # 可以根据需要改为'resnet50'
    main(model_path, image_path, backbone)
