def get_edges(self
, t
):
torchvision
.transforms
.ToPILImage
()(t
[0].cpu
()).show
()
edge
= self
.ByteTensor
(t
.size
()).zero_
()
edge
[:, :, :, 1:] = edge
[:, :, :, 1:] | (t
[:, :, :, 1:] != t
[:, :, :, :-1])
edge
[:, :, :, :-1] = edge
[:, :, :, :-1] | (t
[:, :, :, 1:] != t
[:, :, :, :-1])
edge
[:, :, 1:, :] = edge
[:, :, 1:, :] | (t
[:, :, 1:, :] != t
[:, :, :-1, :])
edge
[:, :, :-1, :] = edge
[:, :, :-1, :] | (t
[:, :, 1:, :] != t
[:, :, :-1, :])
torchvision
.transforms
.ToPILImage
()(edge
.float()[0].cpu
()).show
()
return edge
.float()
转载请注明原文地址: https://lol.8miu.com/read-33035.html