2017-11-04 104 views
0

我试图在PyTorch中的http://anthology.aclweb.org/W16-1617中实现丢失函数。它如下所示:如何在对比余弦损失函数中使用ByteTensor?

enter image description here

我实现损失如下:

class CosineContrastiveLoss(nn.Module): 
    """ 
    Cosine contrastive loss function. 
    Based on: http://anthology.aclweb.org/W16-1617 
    Maintain 0 for match, 1 for not match. 
    If they match, loss is 1/4(1-cos_sim)^2. 
    If they don't, it's cos_sim^2 if cos_sim < margin or 0 otherwise. 
    Margin in the paper is ~0.4. 
    """ 

    def __init__(self, margin=0.4): 
     super(CosineContrastiveLoss, self).__init__() 
     self.margin = margin 

    def forward(self, output1, output2, label): 
     cos_sim = F.cosine_similarity(output1, output2) 
     loss_cos_con = torch.mean((1-label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) + 
            (label) * torch.pow(cos_sim * torch.lt(cos_sim, self.margin), 2)) 
     return loss_cos_con 

但是,我得到一个错误说: TypeError: mul received an invalid combination of arguments - got (torch.cuda.ByteTensor), but expected one of: * (float value) didn't match because some of the arguments have invalid types: (torch.cuda.ByteTensor) * (torch.cuda.FloatTensor other) didn't match because some of the arguments have invalid types: (torch.cuda.ByteTensor)

我知道, torch.lt()返回一个ByteTensor,但是如果我尝试将它强制为一个浮点传感器torch.Tensor.float()我得到AttributeError: module 'torch.autograd.variable' has no attribute 'FloatTensor'

我真的不知道该从哪里出发。我认为在余弦相似张量和基于小于规则的0或1的张量之间进行元素方式的乘法是合乎逻辑的。

回答

1

也许你可以直接尝试float()方法吗? 变量(torch.zeros(5))。float() - 适用于我,例如