| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 
 | import random
 
 import cv2
 import gradio as gr
 import matplotlib.pyplot as plt
 import numpy as np
 import torch
 from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
 
 
 def setup_seed(seed=33):
 """
 设置随机种子函数,采用固定的随机种子使得结果可复现
 seed:种子值,int
 """
 torch.manual_seed(seed)
 torch.cuda.manual_seed_all(seed)
 np.random.seed(seed)
 random.seed(seed)
 torch.backends.cudnn.benchmark = (
 False
 )
 torch.backends.cudnn.deterministic = True
 
 
 def show_anns(anns, image):
 if len(anns) == 0:
 return image
 sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
 img = np.zeros(
 (
 sorted_anns[0]["segmentation"].shape[0],
 sorted_anns[0]["segmentation"].shape[1],
 3,
 ),
 dtype=np.uint8,
 )
 
 for ann in sorted_anns:
 m = ann["segmentation"]
 color_mask = np.random.choice(range(256), size=3)
 img[m] = color_mask
 
 return cv2.add(image, img)
 
 
 sam_checkpoint = "/disk1/datasets/models/sam/sam_vit_h_4b8939.pth"
 model_type = "vit_h"
 device = "cuda"
 sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
 sam.to(device=device)
 
 
 def segment_anything(image, points_per_side, pred_iou_thresh, seed, sam=sam):
 """
 使用SAM分割自动一副图像,并给出分割后的效果图,在gradio上显示
 """
 
 
 setup_seed(seed)
 mask_generator = SamAutomaticMaskGenerator(
 sam, points_per_side=points_per_side, pred_iou_thresh=pred_iou_thresh
 )
 masks = mask_generator.generate(image)
 seg_res_img = show_anns(masks, image)
 return seg_res_img, len(masks)
 
 
 interface = gr.Interface(
 fn=segment_anything,
 inputs=[
 gr.components.Image(label="输入图像", height=500),
 gr.Slider(16, 128),
 gr.Slider(0, 1, step=0.01),
 gr.Slider(1, 999),
 ],
 outputs=[
 gr.components.Image(label="分割结果", height=500, interactive=True),
 gr.components.Number(label="分割数"),
 ],
 examples=[
 ["./images/girl.jpg", 32, 0.86, 31],
 ["./images/zdt.png", 64, 0.86, 33],
 ["./images/green wormcopy.jpg", 64, 0.86, 33],
 ],
 ).queue(concurrency_count=5)
 
 
 
 interface.launch(
 share=False,
 server_name="0.0.0.0",
 server_port=7860,
 
 favicon_path="./images/icon.ico",
 )
 
 |