update for GPU usage
Browse files
annotator/hed/__init__.py
CHANGED
|
@@ -87,10 +87,11 @@ class SOFT_HEDdetector:
|
|
| 87 |
if not os.path.exists(modelpath):
|
| 88 |
from basicsr.utils.download_util import load_file_from_url
|
| 89 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
| 90 |
-
self.netNetwork = ControlNetHED_Apache2().float().
|
| 91 |
-
self.netNetwork.load_state_dict(torch.load(modelpath))
|
| 92 |
|
| 93 |
def __call__(self, input_image, safe=False, threshold=200):
|
|
|
|
| 94 |
assert input_image.ndim == 3
|
| 95 |
H, W, C = input_image.shape
|
| 96 |
with torch.no_grad():
|
|
|
|
| 87 |
if not os.path.exists(modelpath):
|
| 88 |
from basicsr.utils.download_util import load_file_from_url
|
| 89 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
| 90 |
+
self.netNetwork = ControlNetHED_Apache2().float().eval()
|
| 91 |
+
self.netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
|
| 92 |
|
| 93 |
def __call__(self, input_image, safe=False, threshold=200):
|
| 94 |
+
self.netNetwork.cuda()
|
| 95 |
assert input_image.ndim == 3
|
| 96 |
H, W, C = input_image.shape
|
| 97 |
with torch.no_grad():
|