-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
33 lines (22 loc) · 950 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
Maintainer: Gabriel Dias ([email protected])
Mateus Oliveira ([email protected])
"""
import torch
import torch.nn as nn
class RangeMAELoss(nn.Module):
def __init__(self):
super(RangeMAELoss, self).__init__()
def forward(self, x, y, ppm):
gaba_min_ind = torch.argmin(ppm[ppm >= 3.2])
gaba_max_ind = torch.argmin(ppm[ppm >= 2.8])
glx_min_ind = torch.argmin(ppm[ppm >= 3.95])
glx_max_ind = torch.argmin(ppm[ppm >= 3.55])
gaba_x = x[:, gaba_min_ind:gaba_max_ind]
gaba_y = y[:, gaba_min_ind:gaba_max_ind]
glx_x = x[:, glx_min_ind:glx_max_ind]
glx_y = y[:, glx_min_ind:glx_max_ind]
gaba_mae = torch.abs(gaba_x - gaba_y).mean(dim=1).mean(dim=0)
glx_mae = torch.abs(glx_x - glx_y).mean(dim=1).mean(dim=0)
global_mae = torch.abs(x - y).mean(dim=1).mean(dim=0)
return (gaba_mae * 8 + glx_mae + global_mae) / 10