| 
					
				 | 
			
			
				@@ -338,8 +338,13 @@ class UnitYGenerator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Convert to speech units. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         units = self.unit_decoder(unit_seqs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if ngram_filtering: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            units = torch.tensor(arr) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # ngram-filtering doesn't apply to NAR unit decoding. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if ngram_filtering and isinstance(self.model.t2u_model, UnitYT2UModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if units.size(0) > 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                raise NotImplementedError( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "unit ngram_filtering is not implemented for batch_size > 1." 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            arr = remove_consecutive_repeated_ngrams(units[0].tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            units = torch.tensor(arr).to(units).unsqueeze(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return texts, units 
			 |