-
Notifications
You must be signed in to change notification settings - Fork 23
/
ia_sam_manager.py
117 lines (102 loc) · 4.75 KB
/
ia_sam_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import platform
import torch
from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
from ia_check_versions import ia_check_versions
from ia_config import IAConfig
from ia_devices import devices
from ia_logging import ia_logging
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
from mobile_sam import SamPredictor as SamPredictorMobile
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
"""Get SAM mask generator.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamAutomaticMaskGenerator or None: SAM mask generator
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
points_per_batch = 32
elif "FastSAM" in os.path.basename(sam_checkpoint):
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
sam_model_registry_local = fast_sam_model_registry
SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
points_per_batch = None
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
points_per_batch = 64
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
points_per_batch = 64
pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if IAConfig.global_args.get("sam_cpu", False):
ia_logging.info("SAM is running on CPU... (the option has been selected)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_mask_generator = SamAutomaticMaskGeneratorLocal(
model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
else:
sam_mask_generator = None
return sam_mask_generator
def get_sam_predictor(sam_checkpoint):
"""Get SAM predictor.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamPredictor or None: SAM predictor
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamPredictorLocal = SamPredictorHQ
elif "FastSAM" in os.path.basename(sam_checkpoint):
raise NotImplementedError("FastSAM predictor is not implemented yet.")
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamPredictorLocal = SamPredictorMobile
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamPredictorLocal = SamPredictor
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if IAConfig.global_args.get("sam_cpu", False):
ia_logging.info("SAM is running on CPU... (the option has been selected)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_predictor = SamPredictorLocal(sam)
else:
sam_predictor = None
return sam_predictor