Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1 #1837

Open
Jzone-solo-Legend opened this issue Jun 21, 2024 · 20 comments
Labels
bug Something isn't working

Comments

@Jzone-solo-Legend
Copy link

Jzone-solo-Legend commented Jun 21, 2024

Notice: In order to resolve issues more efficiently, please raise issue following the template.
(注意:为了更加高效率解决您遇到的问题,请按照模板提问,补充细节)

🐛 Bug

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run examples/industrial_data_pretraining/paraformer/finetune.sh
  2. See:
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/paraformer/model.py", line 208, in forward
    loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/paraformer/model.py", line 309, in _calc_att_loss
    sematic_embeds, decoder_out_1st = self.sampler(
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/paraformer/model.py", line 350, in sampler
    decoder_outs = self.decoder(
    File "/home/wxh/anaconda3/envs/funasr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/paraformer/decoder.py", line 397, in forward
    x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
    File "/home/wxh/anaconda3/envs/funasr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/transformer/utils/repeat.py", line 32, in forward
    args = m(*args)
    File "/home/wxh/anaconda3/envs/funasr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/paraformer/decoder.py", line 106, in forward
    x, _ = self.self_attn(tgt, tgt_mask)
    File "/home/wxh/anaconda3/envs/funasr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/mnt/home/home/wxh/czl/fun-asr/funasr/models/sanm/attention.py", line 518, in forward
    inputs = inputs * mask
    RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1

Code sample

Expected behavior

Environment

  • Linux
  • FunASR Version 1.0.27
  • ModelScope Version 1.15.0
  • PyTorch Version 2.0.0+cu118
  • How you installed funasr pip3 install -e ./
  • Python version: 3.8.0
  • GPU nvidia 3090
  • CUDA/cuDNN version cuda11.8
  • Docker version None
  • Any other relevant information:

Additional context

@Jzone-solo-Legend Jzone-solo-Legend added the bug Something isn't working label Jun 21, 2024
@dtlzhuangz
Copy link
Contributor

Could you show me your cif_v1 code? Or git log to show your commit id?

@Jzone-solo-Legend
Copy link
Author

Could you show me your cif_v1 code? Or git log to show your commit id?

I just pulled the latest branch (several minutes ago), and now it can work, tks!

@dtlzhuangz
Copy link
Contributor

Could you try modify the cif via https://github.com/modelscope/FunASR/pull/1811/files . Because I found my bugfix commit has been discarded in 45d7aa9

@dtlzhuangz
Copy link
Contributor

Could you try modify the cif via https://github.com/modelscope/FunASR/pull/1811/files . Because I found my bugfix commit has been discarded in 45d7aa9

@LauraGPT Hello. Could it be possible that my bugfix commit was accidentally removed, leading to this issue? Because I cannot see my bug fix in the main code. It will lead to both training and onnx inference bug.

@LauraGPT
Copy link
Collaborator

ecause I cannot see my bug fix in the main code. It will lead to both training and onnx inference bug.

Sorry, it is my mistake. I have fixed it: 93f9a42

@FastSchnell
Copy link

Linux
FunASR Version 最新main
ModelScope Version 最新版
PyTorch Version 2.3.1
Python version: 3.11.5
GPU nvidia 3090
Built on Thu_Nov_18_09:45:30_PST_2021
Cuda compilation tools, release 11.5, V11.5.119
Build cuda_11.5.r11.5/compiler.30672275_0

funasr_error

还是有报错

@LauraGPT
Copy link
Collaborator

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0

funasr_error

还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0

funasr_error

还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

@dtlzhuangz
Copy link
Contributor

dtlzhuangz commented Jun 24, 2024

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

Hello, 现在main分支训练代码调用的还是原始版本的cif,而不是cif_v1。 @FastSchnell ,可以把更具体的报错信息,启动命令还有cif_predictor.py的文件给我吗?如果可以的话把报错的那条数据发一下。

@dtlzhuangz
Copy link
Contributor

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

Hello, 现在main分支训练代码调用的还是原始版本的cif,而不是cif_v1。 @FastSchnell ,可以把更具体的报错信息,启动命令还有cif_predictor.py的文件给我吗?如果可以的话把报错的那条数据发一下。

建议在 https://github.com/modelscope/FunASR/blob/main/funasr/models/paraformer/cif_predictor.py#L248 前打上断点,保存hidden和alphas

@LauraGPT
Copy link
Collaborator

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

Hello, 现在main分支训练代码调用的还是原始版本的cif,而不是cif_v1。 @FastSchnell ,可以把更具体的报错信息,启动命令还有cif_predictor.py的文件给我吗?如果可以的话把报错的那条数据发一下。

抱歉,不知道是怎么被我覆盖的,请问,方便再pr一下么

@FastSchnell
Copy link

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

Hello, 现在main分支训练代码调用的还是原始版本的cif,而不是cif_v1。 @FastSchnell ,可以把更具体的报错信息,启动命令还有cif_predictor.py的文件给我吗?如果可以的话把报错的那条数据发一下。

请你问你是哪个分支,我可以用你的版本试试,或者你的repo,发我一下

@dtlzhuangz
Copy link
Contributor

dtlzhuangz commented Jun 25, 2024

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

Hello, 现在main分支训练代码调用的还是原始版本的cif,而不是cif_v1。 @FastSchnell ,可以把更具体的报错信息,启动命令还有cif_predictor.py的文件给我吗?如果可以的话把报错的那条数据发一下。

请你问你是哪个分支,我可以用你的版本试试,或者你的repo,发我一下

我现在的repo试一下,哥们你clone完代码后有没有pip install -e . ?

@dtlzhuangz
Copy link
Contributor

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

Linux FunASR Version 最新main ModelScope Version 最新版 PyTorch Version 2.3.1 Python version: 3.11.5 GPU nvidia 3090 Built on Thu_Nov_18_09:45:30_PST_2021 Cuda compilation tools, release 11.5, V11.5.119 Build cuda_11.5.r11.5/compiler.30672275_0
funasr_error
还是有报错

@dtlzhuangz 请问可以看看这个问题么?最近反馈这个报错的有点多哈

Hello, 现在main分支训练代码调用的还是原始版本的cif,而不是cif_v1。 @FastSchnell ,可以把更具体的报错信息,启动命令还有cif_predictor.py的文件给我吗?如果可以的话把报错的那条数据发一下。

抱歉,不知道是怎么被我覆盖的,请问,方便再pr一下么

好的我提一下,但是讲道理原本的cif应该不会有这个报错的

@FastSchnell
Copy link

@dtlzhuangz
funasr/bin/train.py ++model=paraformer-zh++train_data_set_list=train.jsonl ++valid_data_set_list=val.jsonl ++output_dir="./outputs"

#!/usr/bin/env python3

-- encoding: utf-8 --

Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.

MIT License (https://opensource.org/licenses/MIT)

import torch
import logging
import numpy as np

from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast

@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
def init(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
):
super().init()

    self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
    self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
    self.cif_output = torch.nn.Linear(idim, 1)
    self.dropout = torch.nn.Dropout(p=dropout)
    self.threshold = threshold
    self.smooth_factor = smooth_factor
    self.noise_threshold = noise_threshold
    self.tail_threshold = tail_threshold

def forward(
    self,
    hidden,
    target_label=None,
    mask=None,
    ignore_id=-1,
    mask_chunk_predictor=None,
    target_label_length=None,
):

    with autocast(False):
        h = hidden
        context = h.transpose(1, 2)
        queries = self.pad(context)
        memory = self.cif_conv1d(queries)
        output = memory + context
        output = self.dropout(output)
        output = output.transpose(1, 2)
        output = torch.relu(output)
        output = self.cif_output(output)
        alphas = torch.sigmoid(output)
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        if mask is not None:
            mask = mask.transpose(-1, -2).float()
            alphas = alphas * mask
        if mask_chunk_predictor is not None:
            alphas = alphas * mask_chunk_predictor
        alphas = alphas.squeeze(-1)
        mask = mask.squeeze(-1)
        if target_label_length is not None:
            target_length = target_label_length
        elif target_label is not None:
            target_length = (target_label != ignore_id).float().sum(-1)
        else:
            target_length = None
        token_num = alphas.sum(-1)
        if target_length is not None:
            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
        elif self.tail_threshold > 0.0:
            hidden, alphas, token_num = self.tail_process_fn(
                hidden, alphas, token_num, mask=mask
            )

        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

        if target_length is None and self.tail_threshold > 0.0:
            token_num_int = torch.max(token_num).type(torch.int32).item()
            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]

    return acoustic_embeds, token_num, alphas, cif_peak

def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
    b, t, d = hidden.size()
    tail_threshold = self.tail_threshold
    if mask is not None:
        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
        ones_t = torch.ones_like(zeros_t)
        mask_1 = torch.cat([mask, zeros_t], dim=1)
        mask_2 = torch.cat([ones_t, mask], dim=1)
        mask = mask_2 - mask_1
        tail_threshold = mask * tail_threshold
        alphas = torch.cat([alphas, zeros_t], dim=1)
        alphas = torch.add(alphas, tail_threshold)
    else:
        tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
        tail_threshold = torch.reshape(tail_threshold, (1, 1))
        alphas = torch.cat([alphas, tail_threshold], dim=1)
    zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
    hidden = torch.cat([hidden, zeros], dim=1)
    token_num = alphas.sum(dim=-1)
    token_num_floor = torch.floor(token_num)

    return hidden, alphas, token_num_floor

def gen_frame_alignments(
    self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
    batch_size, maximum_length = alphas.size()
    int_type = torch.int32

    is_training = self.training
    if is_training:
        token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
    else:
        token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)

    max_token_num = torch.max(token_num).item()

    alphas_cumsum = torch.cumsum(alphas, dim=1)
    alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
    alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)

    index = torch.ones([batch_size, max_token_num], dtype=int_type)
    index = torch.cumsum(index, dim=1)
    index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)

    index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
    index_div_bool_zeros = index_div.eq(0)
    index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
    index_div_bool_zeros_count = torch.clamp(
        index_div_bool_zeros_count, 0, encoder_sequence_length.max()
    )
    token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
    index_div_bool_zeros_count *= token_num_mask

    index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
        1, 1, maximum_length
    )
    ones = torch.ones_like(index_div_bool_zeros_count_tile)
    zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
    ones = torch.cumsum(ones, dim=2)
    cond = index_div_bool_zeros_count_tile == ones
    index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)

    index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
    index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
    index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
    index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
    predictor_mask = (
        (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
        .type(int_type)
        .to(encoder_sequence_length.device)
    )
    index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask

    predictor_alignments = index_div_bool_zeros_count_tile_out
    predictor_alignments_length = predictor_alignments.sum(-1).type(
        encoder_sequence_length.dtype
    )
    return predictor_alignments.detach(), predictor_alignments_length.detach()

@tables.register("predictor_classes", "CifPredictorV2")
class CifPredictorV2(torch.nn.Module):
def init(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.0,
tf2torch_tensor_name_prefix_torch="predictor",
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
super().init()

    self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
    self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
    self.cif_output = torch.nn.Linear(idim, 1)
    self.dropout = torch.nn.Dropout(p=dropout)
    self.threshold = threshold
    self.smooth_factor = smooth_factor
    self.noise_threshold = noise_threshold
    self.tail_threshold = tail_threshold
    self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
    self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
    self.tail_mask = tail_mask

def forward(
    self,
    hidden,
    target_label=None,
    mask=None,
    ignore_id=-1,
    mask_chunk_predictor=None,
    target_label_length=None,
):

    with autocast(False):
        h = hidden
        context = h.transpose(1, 2)
        queries = self.pad(context)
        output = torch.relu(self.cif_conv1d(queries))
        output = output.transpose(1, 2)

        output = self.cif_output(output)
        alphas = torch.sigmoid(output)
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        if mask is not None:
            mask = mask.transpose(-1, -2).float()
            alphas = alphas * mask
        if mask_chunk_predictor is not None:
            alphas = alphas * mask_chunk_predictor
        alphas = alphas.squeeze(-1)
        mask = mask.squeeze(-1)
        if target_label_length is not None:
            target_length = target_label_length.squeeze(-1)
        elif target_label is not None:
            target_length = (target_label != ignore_id).float().sum(-1)
        else:
            target_length = None
        token_num = alphas.sum(-1)
        if target_length is not None:
            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
        elif self.tail_threshold > 0.0:
            if self.tail_mask:
                hidden, alphas, token_num = self.tail_process_fn(
                    hidden, alphas, token_num, mask=mask
                )
            else:
                hidden, alphas, token_num = self.tail_process_fn(
                    hidden, alphas, token_num, mask=None
                )

        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
        if target_length is None and self.tail_threshold > 0.0:
            token_num_int = torch.max(token_num).type(torch.int32).item()
            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]

    return acoustic_embeds, token_num, alphas, cif_peak

def forward_chunk(self, hidden, cache=None, **kwargs):
    is_final = kwargs.get("is_final", False)
    batch_size, len_time, hidden_size = hidden.shape
    h = hidden
    context = h.transpose(1, 2)
    queries = self.pad(context)
    output = torch.relu(self.cif_conv1d(queries))
    output = output.transpose(1, 2)
    output = self.cif_output(output)
    alphas = torch.sigmoid(output)
    alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)

    alphas = alphas.squeeze(-1)

    token_length = []
    list_fires = []
    list_frames = []
    cache_alphas = []
    cache_hiddens = []

    if cache is not None and "chunk_size" in cache:
        alphas[:, : cache["chunk_size"][0]] = 0.0
        if not is_final:
            alphas[:, sum(cache["chunk_size"][:2]) :] = 0.0
    if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
        cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
        cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
        hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
        alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
    if cache is not None and is_final:
        tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
        tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
        tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
        hidden = torch.cat((hidden, tail_hidden), dim=1)
        alphas = torch.cat((alphas, tail_alphas), dim=1)

    len_time = alphas.shape[1]
    for b in range(batch_size):
        integrate = 0.0
        frames = torch.zeros((hidden_size), device=hidden.device)
        list_frame = []
        list_fire = []
        for t in range(len_time):
            alpha = alphas[b][t]
            if alpha + integrate < self.threshold:
                integrate += alpha
                list_fire.append(integrate)
                frames += alpha * hidden[b][t]
            else:
                frames += (self.threshold - integrate) * hidden[b][t]
                list_frame.append(frames)
                integrate += alpha
                list_fire.append(integrate)
                integrate -= self.threshold
                frames = integrate * hidden[b][t]

        cache_alphas.append(integrate)
        if integrate > 0.0:
            cache_hiddens.append(frames / integrate)
        else:
            cache_hiddens.append(frames)

        token_length.append(torch.tensor(len(list_frame), device=alphas.device))
        list_fires.append(list_fire)
        list_frames.append(list_frame)

    cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
    cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
    cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
    cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)

    max_token_len = max(token_length)
    if max_token_len == 0:
        return hidden, torch.stack(token_length, 0), None, None
    list_ls = []
    for b in range(batch_size):
        pad_frames = torch.zeros(
            (max_token_len - token_length[b], hidden_size), device=alphas.device
        )
        if token_length[b] == 0:
            list_ls.append(pad_frames)
        else:
            list_frames[b] = torch.stack(list_frames[b])
            list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))

    cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
    cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
    cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
    cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
    return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None

def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
    b, t, d = hidden.size()
    tail_threshold = self.tail_threshold
    if mask is not None:
        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
        ones_t = torch.ones_like(zeros_t)
        mask_1 = torch.cat([mask, zeros_t], dim=1)
        mask_2 = torch.cat([ones_t, mask], dim=1)
        mask = mask_2 - mask_1
        tail_threshold = mask * tail_threshold
        alphas = torch.cat([alphas, zeros_t], dim=1)
        alphas = torch.add(alphas, tail_threshold)
    else:
        tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
        tail_threshold = torch.reshape(tail_threshold, (1, 1))
        if b > 1:
            alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
        else:
            alphas = torch.cat([alphas, tail_threshold], dim=1)
    zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
    hidden = torch.cat([hidden, zeros], dim=1)
    token_num = alphas.sum(dim=-1)
    token_num_floor = torch.floor(token_num)

    return hidden, alphas, token_num_floor

def gen_frame_alignments(
    self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
):
    batch_size, maximum_length = alphas.size()
    int_type = torch.int32

    is_training = self.training
    if is_training:
        token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
    else:
        token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)

    max_token_num = torch.max(token_num).item()

    alphas_cumsum = torch.cumsum(alphas, dim=1)
    alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
    alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)

    index = torch.ones([batch_size, max_token_num], dtype=int_type)
    index = torch.cumsum(index, dim=1)
    index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)

    index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
    index_div_bool_zeros = index_div.eq(0)
    index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
    index_div_bool_zeros_count = torch.clamp(
        index_div_bool_zeros_count, 0, encoder_sequence_length.max()
    )
    token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
    index_div_bool_zeros_count *= token_num_mask

    index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
        1, 1, maximum_length
    )
    ones = torch.ones_like(index_div_bool_zeros_count_tile)
    zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
    ones = torch.cumsum(ones, dim=2)
    cond = index_div_bool_zeros_count_tile == ones
    index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)

    index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
    index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
    index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
    index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
    predictor_mask = (
        (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max()))
        .type(int_type)
        .to(encoder_sequence_length.device)
    )
    index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask

    predictor_alignments = index_div_bool_zeros_count_tile_out
    predictor_alignments_length = predictor_alignments.sum(-1).type(
        encoder_sequence_length.dtype
    )
    return predictor_alignments.detach(), predictor_alignments_length.detach()

@tables.register("predictor_classes", "CifPredictorV2Export")
class CifPredictorV2Export(torch.nn.Module):
def init(self, model, **kwargs):
super().init()

    self.pad = model.pad
    self.cif_conv1d = model.cif_conv1d
    self.cif_output = model.cif_output
    self.threshold = model.threshold
    self.smooth_factor = model.smooth_factor
    self.noise_threshold = model.noise_threshold
    self.tail_threshold = model.tail_threshold

def forward(
    self,
    hidden: torch.Tensor,
    mask: torch.Tensor,
):
    alphas, token_num = self.forward_cnn(hidden, mask)
    mask = mask.transpose(-1, -2).float()
    mask = mask.squeeze(-1)
    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
    acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)

    return acoustic_embeds, token_num, alphas, cif_peak

def forward_cnn(
    self,
    hidden: torch.Tensor,
    mask: torch.Tensor,
):
    h = hidden
    context = h.transpose(1, 2)
    queries = self.pad(context)
    output = torch.relu(self.cif_conv1d(queries))
    output = output.transpose(1, 2)

    output = self.cif_output(output)
    alphas = torch.sigmoid(output)
    alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
    mask = mask.transpose(-1, -2).float()
    alphas = alphas * mask
    alphas = alphas.squeeze(-1)
    token_num = alphas.sum(-1)

    return alphas, token_num

def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
    b, t, d = hidden.size()
    tail_threshold = self.tail_threshold

    zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
    ones_t = torch.ones_like(zeros_t)

    mask_1 = torch.cat([mask, zeros_t], dim=1)
    mask_2 = torch.cat([ones_t, mask], dim=1)
    mask = mask_2 - mask_1
    tail_threshold = mask * tail_threshold
    alphas = torch.cat([alphas, zeros_t], dim=1)
    alphas = torch.add(alphas, tail_threshold)

    zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
    hidden = torch.cat([hidden, zeros], dim=1)
    token_num = alphas.sum(dim=-1)
    token_num_floor = torch.floor(token_num)

    return hidden, alphas, token_num_floor

@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)

frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)

# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
    torch.float32
)  # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)

dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor

fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor

# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)

batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0

remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]

shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0

frames = frames - shift_frames + shift_remain_frames - remain_frames

# max_label_len = batch_len.max()
max_label_len = alphas.sum(dim=-1)
max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)

# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires

@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)

# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []

for t in range(len_time):
    alpha = alphas[:, t]
    distribution_completion = (
        torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
    )

    integrate += alpha
    list_fires.append(integrate)

    fire_place = integrate >= threshold
    integrate = torch.where(
        fire_place,
        integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
        integrate,
    )
    cur = torch.where(fire_place, distribution_completion, alpha)
    remainds = alpha - cur

    frame += cur[:, None] * hidden[:, t, :]
    list_frames.append(frame)
    frame = torch.where(
        fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
    )

fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)

fire_idxs = fires >= threshold
frame_fires = torch.zeros_like(hidden)
max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
    frame_fire = frames[b, fire_idxs[b]]
    frame_len = frame_fire.size(0)
    frame_fires[b, :frame_len, :] = frame_fire

    if frame_len >= max_label_len:
        max_label_len = frame_len
frame_fires = frame_fires[:, :max_label_len, :]
return frame_fires, fires

class mae_loss(torch.nn.Module):

def __init__(self, normalize_length=False):
    super(mae_loss, self).__init__()
    self.normalize_length = normalize_length
    self.criterion = torch.nn.L1Loss(reduction="sum")

def forward(self, token_length, pre_token_length):
    loss_token_normalizer = token_length.size(0)
    if self.normalize_length:
        loss_token_normalizer = token_length.sum().type(torch.float32)
    loss = self.criterion(token_length, pre_token_length)
    loss = loss / loss_token_normalizer
    return loss

def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()

# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []

for t in range(len_time):
    alpha = alphas[:, t]
    distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate

    integrate += alpha
    list_fires.append(integrate)

    fire_place = integrate >= threshold
    integrate = torch.where(
        fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
    )
    cur = torch.where(fire_place, distribution_completion, alpha)
    remainds = alpha - cur

    frame += cur[:, None] * hidden[:, t, :]
    list_frames.append(frame)
    frame = torch.where(
        fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
    )

fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
    fire = fires[b, :]
    l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
    pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
    list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires

def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
batch_size, len_time = alphas.size()
device = alphas.device
dtype = alphas.dtype

threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)

fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)

# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
    torch.float32
)  # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)

dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor

fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
if return_fire_idxs:
    return fires, fire_idxs
return fires

def cif_v1(hidden, alphas, threshold):
fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)

device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)

frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)

batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0

remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]

shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0

frames = frames - shift_frames + shift_remain_frames - remain_frames

# max_label_len = batch_len.max()
max_label_len = (
    torch.round(alphas.sum(-1)).int().max()
)  # torch.round to calculate the max length

# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires

def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()

# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []

for t in range(len_time):
    alpha = alphas[:, t]

    integrate += alpha
    list_fires.append(integrate)

    fire_place = integrate >= threshold
    integrate = torch.where(
        fire_place,
        integrate - torch.ones([batch_size], device=alphas.device) * threshold,
        integrate,
    )

fires = torch.stack(list_fires, 1)
return fires

@dtlzhuangz
Copy link
Contributor

@FastSchnell 我的repo用了如果没问题说一下哈,没问题我再提pr改一下

@FastSchnell
Copy link

https://drive.google.com/file/d/1ssUYvoCvwQZhLUhhuOBzabVeirjE2Ze3/view?usp=sharing
也不行,是不是我这套数据集有问题
File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/funasr/models/sanm/attention.py", line 518, in forward
inputs = inputs * mask
~~~~~~~^~~~~~
RuntimeError: The size of tensor a (13) must match the size of tensor b (14) at non-singleton dimension 1

@FastSchnell 我的repo用了如果没问题说一下哈,没问题我再提pr改一下

@dtlzhuangz
Copy link
Contributor

https://drive.google.com/file/d/1ssUYvoCvwQZhLUhhuOBzabVeirjE2Ze3/view?usp=sharing 也不行,是不是我这套数据集有问题 File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/funasr/models/sanm/attention.py", line 518, in forward inputs = inputs * mask ~~~~~~~^~~~~~ RuntimeError: The size of tensor a (13) must match the size of tensor b (14) at non-singleton dimension 1

@FastSchnell 我的repo用了如果没问题说一下哈,没问题我再提pr改一下

你代码调用的不是clone下来的代码吧,clone下来后有没有pip install -e .呢

@FastSchnell
Copy link

https://drive.google.com/file/d/1ssUYvoCvwQZhLUhhuOBzabVeirjE2Ze3/view?usp=sharing 也不行,是不是我这套数据集有问题 File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/funasr/models/sanm/attention.py", line 518, in forward inputs = inputs * mask ~~~~~~~^~~~~~ RuntimeError: The size of tensor a (13) must match the size of tensor b (14) at non-singleton dimension 1

@FastSchnell 我的repo用了如果没问题说一下哈,没问题我再提pr改一下

你代码调用的不是clone下来的代码吧,clone下来后有没有pip install -e .呢

下载main分支的zip包到服务器跑的,pip install -e .打成pip包也试过 一样报错,我怀疑是我数据集问题,因为用这个项目的测试数据不报错

@dtlzhuangz
Copy link
Contributor

https://drive.google.com/file/d/1ssUYvoCvwQZhLUhhuOBzabVeirjE2Ze3/view?usp=sharing 也不行,是不是我这套数据集有问题 File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/myenv/lib/python3.11/site-packages/funasr/models/sanm/attention.py", line 518, in forward inputs = inputs * mask ~~~~~~~^~~~~~ RuntimeError: The size of tensor a (13) must match the size of tensor b (14) at non-singleton dimension 1

@FastSchnell 我的repo用了如果没问题说一下哈,没问题我再提pr改一下

你代码调用的不是clone下来的代码吧,clone下来后有没有pip install -e .呢

下载main分支的zip包到服务器跑的,pip install -e .打成pip包也试过 一样报错,我怀疑是我数据集问题,因为用这个项目的测试数据不报错

你把微信发我邮箱吧,我加你

@dtlzhuangz
Copy link
Contributor

@LauraGPT 这个问题解决了,最新的代码没安装上

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants