import cv2
import numpy as np
import os

# 全局变量
drawing = False  # 是否正在绘制
points = []      # 存储多边形的顶点
germ_mask = None  # 存储手动标记的胚盘掩码

def fit_ellipse_to_mask(mask, min_area=1000):
    """对掩码进行椭圆拟合，返回拟合后的掩码"""
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours: return mask
    largest_contour = max(contours, key=lambda c: cv2.contourArea(c))
    if cv2.contourArea(largest_contour) < min_area: return mask
    if len(largest_contour) >= 5:
        ellipse = cv2.fitEllipse(largest_contour)
        fitted_mask = np.zeros_like(mask)
        cv2.ellipse(fitted_mask, ellipse, 255, -1)
        return fitted_mask
    else:
        return mask

def filter_small_regions(mask, min_area=570):
    """过滤掩码中面积小于指定阈值的区域"""
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    filtered_mask = np.zeros_like(mask)
    
    for contour in contours:
        area = cv2.contourArea(contour)
        if area >= min_area:
            cv2.drawContours(filtered_mask, [contour], -1, 255, -1)
    
    return filtered_mask

def keep_center_region(mask, min_area=1800):
    """保留掩码中最靠近中心的区域"""
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours: return mask
    h, w = mask.shape
    img_center = (w//2, h//2)
    best_contour = None
    min_distance = float('inf')
    for cnt in contours:
        if cv2.contourArea(cnt) < min_area: continue
        M = cv2.moments(cnt)
        if M["m00"] == 0: continue
        cx, cy = int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])
        distance = np.sqrt((cx-img_center[0])**2 + (cy-img_center[1])**2)
        if distance < min_distance:
            min_distance = distance
            best_contour = cnt
    if best_contour is None: return mask
    filtered_mask = np.zeros_like(mask)
    cv2.drawContours(filtered_mask, [best_contour], -1, 255, -1)
    return filtered_mask

def detect_egg_region(frame, min_area=10000, max_area=100000, visualize=True):
    """选择最大轮廓并进行椭圆拟合，将椭圆外像素置为30并可视化"""
    # 图像预处理（保持原有逻辑不变）
    lab = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
    cl = clahe.apply(l)
    limg = cv2.merge((cl,a,b))
    enhanced_frame = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
    
    gray = cv2.cvtColor(enhanced_frame, cv2.COLOR_BGR2GRAY)
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                 cv2.THRESH_BINARY_INV, 11, 2)
    blur = cv2.GaussianBlur(thresh, (5, 5), 0)
    edges = cv2.Canny(blur, 20, 80)
    
    # 形态学操作闭合边缘
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
    closed_edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel)
    
    # 寻找轮廓并筛选最大轮廓
    contours, _ = cv2.findContours(closed_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return np.full_like(frame, 37), None, edges  # 返回全30图像
    
    egg_contour = max(contours, key=lambda c: cv2.contourArea(c))
    area = cv2.contourArea(egg_contour)
    
    if not (min_area < area < max_area):
        return np.full_like(frame, 37), None, edges  # 返回全30图像
    
    # 对最大轮廓进行椭圆拟合（需至少5个顶点）
    if len(egg_contour) >= 5:
        ellipse = cv2.fitEllipse(egg_contour)  # 拟合椭圆
        # 创建椭圆掩码
        egg_mask = np.zeros_like(gray)
        cv2.ellipse(egg_mask, ellipse, 255, -1)  # -1 表示填充椭圆内部
    else:
        # 轮廓顶点不足时，使用原轮廓创建掩码
        egg_mask = np.zeros_like(gray)
        cv2.drawContours(egg_mask, [egg_contour], -1, 255, -1)
    
    # 形态学闭操作优化掩码
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
    egg_mask = cv2.morphologyEx(egg_mask, cv2.MORPH_CLOSE, kernel)
    
    # 将椭圆外的像素置为30
    masked_frame = frame.copy()
    masked_frame[egg_mask == 0] = 37
    
    # 可视化：显示椭圆拟合结果
    if visualize:
        cv2.imshow("Egg Ellipse Mask", egg_mask)  # 显示椭圆掩码
        cv2.imshow("Masked Frame with Ellipse", masked_frame)
    
    return masked_frame

def manual_germ_marking(frame, frame_num):
    """手动绘制闭合多边形作为胚盘掩码"""
    global drawing, points, germ_mask
    
    drawing = False
    points = []
    frame_copy = frame.copy()
    temp_img = frame.copy()
    germ_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
    
    def mouse_callback(event, x, y, flags, param):
        global drawing, points, germ_mask
        
        if event == cv2.EVENT_LBUTTONDOWN:
            drawing = True
            points.append((x, y))
            print(f"添加顶点: ({x}, {y})")
            
        elif event == cv2.EVENT_MOUSEMOVE and drawing:
            # 更新临时图像，显示当前线段
            temp_img = frame_copy.copy()
            if len(points) > 0:
                cv2.line(temp_img, points[-1], (x, y), (0, 255, 255), 2)
                # 显示所有已绘制的线段
                for i in range(len(points) - 1):
                    cv2.line(temp_img, points[i], points[i+1], (0, 255, 255), 2)
            cv2.imshow(param, temp_img)
            
        elif event == cv2.EVENT_LBUTTONUP:
            drawing = False
            # 绘制最后一条线段
            temp_img = frame_copy.copy()
            if len(points) > 0:
                cv2.line(temp_img, points[-1], (x, y), (0, 255, 255), 2)
                # 显示所有已绘制的线段
                for i in range(len(points) - 1):
                    cv2.line(temp_img, points[i], points[i+1], (0, 255, 255), 2)
            cv2.imshow(param, temp_img)
    
    window_name = f'Manual Germ Marking - Frame {frame_num}'
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.setMouseCallback(window_name, mouse_callback, param=window_name)
    
    print(f"\n=== 手动标记第 {frame_num} 帧 ===")
    print("操作说明：")
    print("1. 点击鼠标左键添加多边形顶点")
    print("2. 按 'c' 键闭合多边形并生成掩码")
    print("3. 按 'r' 键重置绘制")
    print("4. 按 'q' 键取消标记并使用自动检测结果")
    
    while True:
        # 显示当前绘制状态
        display_img = temp_img.copy()
        if len(points) > 1:
            # 显示所有已绘制的线段
            for i in range(len(points) - 1):
                cv2.line(display_img, points[i], points[i+1], (0, 255, 255), 2)
            # 显示闭合线段（临时连接最后一点和第一点）
            cv2.line(display_img, points[-1], points[0], (0, 255, 0), 1)
        
        cv2.imshow(window_name, display_img)
        key = cv2.waitKey(1) & 0xFF
        
        if key == ord('c'):  # 闭合多边形并生成掩码
            if len(points) >= 3:  # 至少需要3个点形成闭合多边形
                # 将点列表转换为numpy数组
                contour = np.array(points, np.int32)
                contour = contour.reshape((-1, 1, 2))
                
                # 绘制并填充多边形
                cv2.drawContours(germ_mask, [contour], -1, 255, -1)
                
                # 在原图上显示最终掩码效果
                result = frame.copy()
                result[germ_mask > 0] = (0, 255, 255)  # 黄色显示胚盘区域
                cv2.imshow(window_name, result)
                print("掩码生成完成，按其他继续...")
                cv2.waitKey(0)
                cv2.destroyWindow(window_name)
                return germ_mask                
            else:
                print("至少需要3个顶点才能闭合多边形！")

        elif key == ord('r'):  # 重置绘制
            print("重置绘制...")
            points = []
            temp_img = frame.copy()
            germ_mask = np.zeros_like(germ_mask)


        elif key == ord('q') or key == 27:  # 取消标记
            print("取消标记，使用自动检测结果")
            cv2.destroyWindow(window_name)
            return None

def process_frame(frame, 
                  background_thresh=40,  
                  egg_white_thresh=80,   
                  yolk_thresh=200,      
                  germ_thresh=180,      
                  alpha=0.2,
                  visualize=True,
                  use_manual_germ=False,
                  manual_germ_mask=None):
    """使用固定阈值处理有蛋清的图像"""
    # 转换为灰度图
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    
    # 预处理：高斯模糊减少噪点
    gray = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # 增强对比度
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)
    
    # 创建背景掩码（最暗区域）
    mask_background = cv2.threshold(gray, background_thresh, 255, cv2.THRESH_BINARY_INV)[1]
    
    # 固定阈值分割（严格按照：背景 < 蛋清 < 蛋黄 < 胚盘）
    mask_egg_white = cv2.inRange(gray, background_thresh + 1, egg_white_thresh)
    mask_yolk = cv2.inRange(gray, egg_white_thresh + 1, yolk_thresh)
    mask_germ = cv2.threshold(gray, germ_thresh, 255, cv2.THRESH_BINARY)[1]
    mask_germ = filter_small_regions(mask_germ)
    
    # 形态学优化
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
    mask_egg_white = cv2.morphologyEx(mask_egg_white, cv2.MORPH_CLOSE, kernel, iterations=2)
    mask_yolk = cv2.morphologyEx(mask_yolk, cv2.MORPH_CLOSE, kernel)
    
    # 确保掩码之间没有重叠区域
    mask_yolk = cv2.bitwise_and(mask_yolk, cv2.bitwise_not(mask_egg_white))
    
    # 只保留蛋清中心区域
    mask_egg_white_center = keep_center_region(mask_egg_white, min_area=5000)
    
    # 椭圆拟合
    mask_egg_white_fit = fit_ellipse_to_mask(mask_egg_white_center, min_area=10000)
    mask_yolk_fit = fit_ellipse_to_mask(mask_yolk, min_area=4000)
    #mask_yolk_fit = mask_yolk
    # 胚盘处理：如果使用手动标记，则优先使用手动标记的掩码
    if use_manual_germ and manual_germ_mask is not None:
        mask_germ_filtered = manual_germ_mask
    else:
        mask_germ_filtered = keep_center_region(mask_germ, min_area=300)
    
    # 创建彩色掩码
    color_mask = np.zeros_like(frame)
    color_mask[mask_egg_white_fit > 0] = (0, 0, 255)  # 红色蛋清
    color_mask[mask_yolk_fit > 0] = (0, 255, 0)      # 绿色蛋黄
    color_mask[mask_germ_filtered > 0] = (0, 255, 255)  # 黄色胚盘
    
    # 确保背景区域完全为黑色
    color_mask[mask_background > 0] = (0, 0, 0)
    
    # 应用透明度
    result = cv2.addWeighted(frame, 1 - alpha, color_mask, alpha, 0)
    
    # 可视化
    if visualize:
        cv2.imshow("Processed Frame", result)
    
    return result

def main():
    input_path = "no\Y-2-竖.gif"
    cap = cv2.VideoCapture(input_path)
    
    if not cap.isOpened():
        print(f"Error: 无法打开文件 {input_path}")
        return
    
    # 获取原始文件名和扩展名
    base_name = os.path.basename(input_path)
    file_name, file_ext = os.path.splitext(base_name)
    
    # 构建输出文件名
    output_name = f"processed_{file_name}.mp4"
    
    fps = cap.get(cv2.CAP_PROP_FPS) or 10
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    out = cv2.VideoWriter(output_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    
    frame_count = 0
    
    print("\n=== 处理视频 ===")
    print("1. 按 'm' 键为当前帧添加手动标记")
    print("2. 按 'a' 键自动处理所有剩余帧")
    print("3. 按 'q' 键退出程序")
    
    # 自动处理模式（默认为False，允许用户交互）
    auto_process = False
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        cv2.imshow("Original Frame", frame)
        new_frame = detect_egg_region(frame, min_area=10000, max_area=100000, visualize=True)
        
        # 初始化手动标记相关变量
        use_manual_germ = False
        manual_germ_mask = None
        
        # 如果不是自动模式，允许用户交互
        if not auto_process:
            print(f"\n处理第 {frame_count} 帧")
            print("按 'm' 键为当前帧添加手动标记，按 'a' 键自动处理所有剩余帧")
            
            # 等待用户按键
            key = cv2.waitKey(0) & 0xFF
            
            if key == ord('m'):  # 手动标记当前帧
                use_manual_germ = True
                manual_germ_mask = manual_germ_marking(new_frame, frame_count)
                
            elif key == ord('a'):  # 自动处理所有剩余帧
                print("切换到自动处理模式")
                auto_process = True
                
            elif key == ord('q'):  # 退出程序
                print("程序已退出")
                break
        
        processed_frame = process_frame(new_frame, 
                                        background_thresh=40,
                                        egg_white_thresh=80,
                                        yolk_thresh=200,
                                        germ_thresh=172,
                                        alpha=0.2,
                                        visualize=True,
                                        use_manual_germ=use_manual_germ,
                                        manual_germ_mask=manual_germ_mask)
        
        out.write(processed_frame)
        
        frame_count += 1
        if frame_count % 10 == 0:
            print(f"已处理 {frame_count} 帧")
    
    cap.release()
    out.release()
    cv2.destroyAllWindows()
    print(f"处理完成，输出文件: {output_name}")

if __name__ == "__main__":
    main()