import logging
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from UNetWork.utils.data_loading import BasicDataset
from UNetWork.unet import UNet
import cv2
import matplotlib.pyplot as plt
from PIL import Image

class UNetNetwork():
    def __init__(self):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.load_model('./UNetWork/checkpoints/best_epoch_weights.pth', device)
        #self.load_model('./UNetWork/checkpoints/checkpoint_epoch10.pth', 'cpu')
    def draw_ellipse_on_image(self, original_img, mask):
        original_np = np.array(original_img)
        mask_np = np.array(mask, dtype=np.uint8)
        cons = np.zeros((256, 256), dtype=np.uint8)

        # 查找轮廓
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        print('len(contours)='+str(len(contours)))
        if len(contours) == 0:
            return None, None, None

        for contour in contours:
            # if 75 <= len(contour) and len(contour) <= 180:
                print('contour = '+ str(len(contour)))
                ellipse = cv2.fitEllipse(contour)
                # cv2.ellipse(original_np, ellipse, (255, 0, 0), 1)  # 绘制椭圆
                cv2.ellipse(cons, ellipse, (255, 255, 255), 1)

        fig = plt.figure(figsize=(cons.shape[1] / 100, cons.shape[0] / 100), dpi=130)
        plt.imshow(cons, cmap='gray')
        plt.axis('off')
        plt.savefig('./cons.png', bbox_inches='tight', pad_inches=0, transparent=True)
        plt.close(fig)

        contours, _ = cv2.findContours(cons, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        rgb1, rgb2, rgb3 = None, None, None
        for contour in contours:

            # 找到椭圆的最小外接矩形
#            x, y, w, h = cv2.boundingRect(contour)
#            cx = x + 2 / w
#            cy = y + 2 / h
            # 找到椭圆的最小外接矩形
            x, y, w, h = cv2.boundingRect(contour)
            cx = x + w / 2  # 修正此处计算错误
            cy = y + h / 2  # 修正此处计算错误
            # 创建一个带有 alpha 通道的图像
            original_with_alpha = cv2.cvtColor(original_np, cv2.COLOR_BGR2BGRA)

            # 将原始图像中 (x, y) 之内的区域设为不透明，外部区域设为透明
            mask_area = np.zeros(original_with_alpha.shape[:2], dtype=np.uint8)
            cv2.drawContours(mask_area, [contour], -1, 255, thickness=cv2.FILLED)

            # 设置 alpha 通道
            original_with_alpha[:, :, 3] = np.where(mask_area > 0, 255, 0)
            original_with_alpha = original_with_alpha[y:y + h, x:x + w, :]

            # print('cy = ' + str(cy))
            if 90 <= cy and cy <= 200:
                if cx < 95:
                    rgb1 = original_with_alpha
                elif cx < 150:
                    rgb2 = original_with_alpha
                else:
                    rgb3 = original_with_alpha

        return rgb1, rgb2, rgb3

    def predict_img(self, net, full_img, device, scale_factor=1, out_threshold=0.5):
        net.eval()
        img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
        img = img.unsqueeze(0)
        img = img.to(device=device, dtype=torch.float32)

        with torch.no_grad():
            output = net(img).cpu()
            output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
            if net.n_classes > 1:
                mask = output.argmax(dim=1)
            else:
                mask = torch.sigmoid(output) > out_threshold

        return mask[0].long().squeeze().numpy()

#    def load_model(self, model_path, device, n_classes=2, bilinear=False):
#        # self.n_channels =
#        self.net = UNet(n_channels=3, n_classes=n_classes, bilinear=bilinear)
#        self.net.to(device=device)
#        state_dict = torch.load(model_path, map_location=device)
#        self.mask_values = state_dict.pop('mask_values', [0, 1])
#        self.net.load_state_dict(state_dict)
#        # return net, mask_values
    def load_model(self, model_path, device, n_classes=2, bilinear=False):
        # self.n_channels =
        self.net = UNet(n_channels=3, n_classes=n_classes, bilinear=bilinear)
        self.net.to(device=device)
        try:
            state_dict = torch.load(model_path, map_location=device, weights_only=True)  # 处理 FutureWarning
        except TypeError:
            state_dict = torch.load(model_path, map_location=device)
        self.mask_values = state_dict.pop('mask_values', [0, 1])
        self.net.load_state_dict(state_dict)
        # return net, mask_values
    def run_prediction(self, input_images, model_path,
                       mask_threshold=0.5, scale=0.5,
                       n_classes=2, bilinear=False):
        logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logging.info(f'Using device {device}')

        # net, mask_values = self.load_model(model_path, device, n_classes, bilinear)

        input_images = np.stack([input_images] * 3, axis=-1)

        # 防止分割掩膜连在一起
        img = Image.fromarray((np.abs(input_images) * 255).astype(np.uint8))
        mask = self.predict_img(net=self.net, full_img=img, scale_factor=scale, out_threshold=mask_threshold,
                                device=device)
        # 防止有黑斑
        img = Image.fromarray((np.clip(np.abs(input_images), 0, 1) * 255).astype(np.uint8))

        # 保存mask
        fig = plt.figure(figsize=(mask.shape[1] / 100, mask.shape[0] / 100), dpi=130)
        plt.imshow(mask, cmap='gray')
        plt.axis('off')
        plt.savefig('./mask.png', bbox_inches='tight', pad_inches=0, transparent=True)
        plt.close(fig)

        return self.draw_ellipse_on_image(img, mask)
    def kspace2image(self, kspace, norm=None):
        """Convert k-space data to image."""
        return np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(kspace), norm=norm))

    def segment(self, data):
        # with open(path, 'rb') as file:
        #     kspace_sampled = np.fromfile(file, dtype=np.complex64)
        # kspace_sampled = data.reshape((256, 256)).astype(np.complex128)
        # input_images = self.kspace2image(kspace_sampled)
        input_images = data

        model_path = './checkpoints/checkpoint_epoch10.pth'  # Model path
        scale_factor = 0.5  # Scale factor
        mask_threshold = 0.5  # Mask threshold

        rgb1, rgb2, rgb3 = self.run_prediction(input_images, model_path, mask_threshold, scale_factor)
        return [rgb1, rgb2, rgb3]