|
@@ -208,7 +208,9 @@ class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline): # type: ig
|
|
assert len(self.pipeline) > 0
|
|
assert len(self.pipeline) > 0
|
|
module_dict = {}
|
|
module_dict = {}
|
|
for module_class, children in self.pipeline.items():
|
|
for module_class, children in self.pipeline.items():
|
|
- module_dict[module_class.from_args(args, *models_and_configs)] = children
|
|
|
|
|
|
+ module_dict[module_class.from_args(args, **models_and_configs)] = children
|
|
|
|
+
|
|
|
|
+ super().__init__(module_dict, args)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def from_args(cls, args: Any) -> UnitYAgentPipeline:
|
|
def from_args(cls, args: Any) -> UnitYAgentPipeline:
|
|
@@ -225,16 +227,22 @@ class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline): # type: ig
|
|
assert len(states) == len(self.module_dict)
|
|
assert len(states) == len(self.module_dict)
|
|
first_states = states[self.source_module]
|
|
first_states = states[self.source_module]
|
|
|
|
|
|
- if not first_states.source_finished and any(
|
|
|
|
- segment.finished for segment in output_segment
|
|
|
|
- ):
|
|
|
|
|
|
+ if isinstance(output_segment, list):
|
|
|
|
+ finished = any(segment.finished for segment in output_segment)
|
|
|
|
+ else:
|
|
|
|
+ # case when output_index is used
|
|
|
|
+ finished = output_segment.finished
|
|
|
|
+ if not first_states.source_finished and finished:
|
|
# An early stop.
|
|
# An early stop.
|
|
# The temporary solution is to start over
|
|
# The temporary solution is to start over
|
|
if states is not None:
|
|
if states is not None:
|
|
maybe_reset_states(states)
|
|
maybe_reset_states(states)
|
|
else:
|
|
else:
|
|
self.reset()
|
|
self.reset()
|
|
- for segment in output_segment:
|
|
|
|
- segment.finished = False
|
|
|
|
|
|
+ if isinstance(output_segment, list):
|
|
|
|
+ for segment in output_segment:
|
|
|
|
+ segment.finished = False
|
|
|
|
+ else:
|
|
|
|
+ output_segment.finished = False
|
|
|
|
|
|
return output_segment # type: ignore[no-any-return]
|
|
return output_segment # type: ignore[no-any-return]
|