diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py
index 612da260..d914ebf9 100644
--- a/TTS/tts/layers/xtts/gpt.py
+++ b/TTS/tts/layers/xtts/gpt.py
@@ -426,15 +426,6 @@ class GPT(nn.Module):
         if max_mel_len > audio_codes.shape[-1]:
             audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
 
-        silence = True
-        for idx, l in enumerate(code_lengths):
-            length = l.item()
-            while silence:
-                if audio_codes[idx, length - 1] != 83:
-                    break
-                length -= 1
-            code_lengths[idx] = length
-
         # 💖 Lovely assertions
         assert (
             max_mel_len <= audio_codes.shape[-1]
@@ -450,7 +441,7 @@ class GPT(nn.Module):
         audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
 
         # Pad mel codes with stop_audio_token
-        audio_codes = self.set_mel_padding(audio_codes, code_lengths)
+        audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
 
         # Build input and target tensors
         # Prepend start token to inputs and append stop token to targets
diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py
index 211d0a93..7726d829 100644
--- a/TTS/tts/layers/xtts/tokenizer.py
+++ b/TTS/tts/layers/xtts/tokenizer.py
@@ -115,7 +115,7 @@ _abbreviations = {
             # There are not many common abbreviations in Arabic as in English.
         ]
     ],
-    "zh": [
+    "zh-cn": [
         (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
         for x in [
             # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
@@ -280,7 +280,7 @@ _symbols_multilingual = {
             ("°", " درجة "),
         ]
     ],
-    "zh": [
+    "zh-cn": [
         # Chinese
         (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
         for x in [
@@ -571,7 +571,7 @@ class VoiceBpeTokenizer:
             )
 
     def preprocess_text(self, txt, lang):
-        if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
+        if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
             txt = multilingual_cleaners(txt, lang)
             if lang in {"zh", "zh-cn"}:
                 txt = chinese_transliterate(txt)
@@ -682,8 +682,8 @@ def test_expand_numbers_multilingual():
         ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
         ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
         # Chinese (Simplified)
-        ("在12.5秒内", "在十二点五秒内", "zh"),
-        ("有50名士兵", "有五十名士兵", "zh"),
+        ("在12.5秒内", "在十二点五秒内", "zh-cn"),
+        ("有50名士兵", "有五十名士兵", "zh-cn"),
         # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
         # ("那将是20€先生", '那将是二十欧元先生', 'zh'),
         # Turkish
@@ -764,7 +764,7 @@ def test_symbols_multilingual():
         ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
         ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
         ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
-        ("我的电量为 14%", "我的电量为 14 百分之", "zh"),
+        ("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
         ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
         ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
         ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py
index 91985912..f37f0844 100644
--- a/TTS/tts/models/xtts.py
+++ b/TTS/tts/models/xtts.py
@@ -7,7 +7,6 @@ import torch.nn.functional as F
 import torchaudio
 from coqpit import Coqpit
 
-from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
 from TTS.tts.layers.xtts.gpt import GPT
 from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
 from TTS.tts.layers.xtts.stream_generator import init_stream_support
@@ -308,26 +307,6 @@ class Xtts(BaseTTS):
             cond_latent = self.gpt.get_style_emb(mel.to(self.device))
         return cond_latent.transpose(1, 2)
 
-    @torch.inference_mode()
-    def get_diffusion_cond_latents(self, audio, sr):
-        from math import ceil
-
-        diffusion_conds = []
-        CHUNK_SIZE = 102400
-        audio_24k = torchaudio.functional.resample(audio, sr, 24000)
-        for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
-            current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
-            current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
-            cond_mel = wav_to_univnet_mel(
-                current_sample.to(self.device),
-                do_normalization=False,
-                device=self.device,
-            )
-            diffusion_conds.append(cond_mel)
-        diffusion_conds = torch.stack(diffusion_conds, dim=1)
-        diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
-        return diffusion_latent
-
     @torch.inference_mode()
     def get_speaker_embedding(self, audio, sr):
         audio_16k = torchaudio.functional.resample(audio, sr, 16000)
@@ -575,16 +554,6 @@ class Xtts(BaseTTS):
                 return_attentions=False,
                 return_latent=True,
             )
-            silence_token = 83
-            ctokens = 0
-            for k in range(gpt_codes.shape[-1]):
-                if gpt_codes[0, k] == silence_token:
-                    ctokens += 1
-                else:
-                    ctokens = 0
-                if ctokens > 8:
-                    gpt_latents = gpt_latents[:, :k]
-                    break
 
             if length_scale != 1.0:
                 gpt_latents = F.interpolate(