49 lines
1.2 KiB
Python
49 lines
1.2 KiB
Python
|
from torch import nn
|
||
|
|
||
|
|
||
|
class CTCHead(nn.Module):
|
||
|
def __init__(self,
|
||
|
in_channels,
|
||
|
out_channels=6625,
|
||
|
fc_decay=0.0004,
|
||
|
mid_channels=None,
|
||
|
return_feats=False,
|
||
|
**kwargs):
|
||
|
super(CTCHead, self).__init__()
|
||
|
if mid_channels is None:
|
||
|
self.fc = nn.Linear(
|
||
|
in_channels,
|
||
|
out_channels,
|
||
|
bias=True,)
|
||
|
else:
|
||
|
self.fc1 = nn.Linear(
|
||
|
in_channels,
|
||
|
mid_channels,
|
||
|
bias=True,
|
||
|
)
|
||
|
self.fc2 = nn.Linear(
|
||
|
mid_channels,
|
||
|
out_channels,
|
||
|
bias=True,
|
||
|
)
|
||
|
|
||
|
self.out_channels = out_channels
|
||
|
self.mid_channels = mid_channels
|
||
|
self.return_feats = return_feats
|
||
|
|
||
|
def forward(self, x, labels=None):
|
||
|
if self.mid_channels is None:
|
||
|
predicts = self.fc(x)
|
||
|
else:
|
||
|
x = self.fc1(x)
|
||
|
predicts = self.fc2(x)
|
||
|
|
||
|
if self.return_feats:
|
||
|
result = dict()
|
||
|
result['ctc'] = predicts
|
||
|
result['ctc_neck'] = x
|
||
|
else:
|
||
|
result = predicts
|
||
|
|
||
|
return result
|