Skip to content

Semantic

SemanticSplitter

Bases: Node

Semantic-similarity splitter.

Splits text where consecutive sentence-group embeddings diverge above a configurable threshold. Re-uses any dynamiq :class:TextEmbedder node to produce embeddings.

Source code in dynamiq/nodes/splitters/semantic.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class SemanticSplitter(Node):
    """Semantic-similarity splitter.

    Splits text where consecutive sentence-group embeddings diverge above a
    configurable threshold. Re-uses any dynamiq :class:`TextEmbedder` node to
    produce embeddings.
    """

    group: Literal[NodeGroup.SPLITTERS] = NodeGroup.SPLITTERS
    name: str = "SemanticSplitter"
    description: str = "Splits text on semantic-similarity breakpoints."

    embedder: TextEmbedder | None = Field(default=None, description="Text embedder node used to embed sentence groups.")
    breakpoint_threshold_type: BreakpointThresholdType = Field(
        default=BreakpointThresholdType.PERCENTILE,
        description="Statistical method to compute breakpoint threshold.",
    )
    breakpoint_threshold_amount: float | None = Field(
        default=None,
        description="Threshold amount; defaults depend on threshold type.",
    )
    number_of_chunks: int | None = Field(
        default=None,
        description="If set, picks top-N largest distances as breakpoints (overrides threshold).",
    )
    buffer_size: int = Field(
        default=1,
        ge=0,
        description="Number of neighbour sentences in each embedding group.",
    )
    sentence_split_regex: str = Field(
        default=r"(?<=[.?!])\s+",
        description="Regex used to split text into sentences before grouping.",
    )
    min_chunk_size: int = Field(
        default=0,
        ge=0,
        description="Merge tail chunks shorter than this back into the previous chunk.",
    )

    splitter: Any | None = None
    input_schema: ClassVar[type[BaseModel]] = SemanticSplitterInputSchema

    @property
    def to_dict_exclude_params(self) -> dict[str, Any]:
        return super().to_dict_exclude_params | {"splitter": True, "embedder": True}

    def to_dict(
        self,
        include_secure_params: bool = False,
        for_tracing: bool = False,
        **kwargs,
    ) -> dict[str, Any]:
        data = super().to_dict(
            include_secure_params=include_secure_params,
            for_tracing=for_tracing,
            **kwargs,
        )
        if self.embedder is not None:
            data["embedder"] = self.embedder.to_dict(
                include_secure_params=include_secure_params,
                for_tracing=for_tracing,
                **kwargs,
            )
        return data

    def init_components(self, connection_manager: ConnectionManager | None = None) -> None:
        connection_manager = connection_manager or ConnectionManager()
        super().init_components(connection_manager)
        if self.embedder is None:
            raise ValueError("SemanticSplitter requires an `embedder` (TextEmbedder) node.")
        self.embedder.init_components(connection_manager)
        if self.splitter is None:
            self.splitter = SemanticSplitterComponent(
                embed_fn=self._embed_batch,
                breakpoint_threshold_type=self.breakpoint_threshold_type,
                breakpoint_threshold_amount=self.breakpoint_threshold_amount,
                number_of_chunks=self.number_of_chunks,
                buffer_size=self.buffer_size,
                sentence_split_regex=self.sentence_split_regex,
                min_chunk_size=self.min_chunk_size,
            )

    def _embed_batch(self, texts: list[str]) -> list[list[float]]:
        if self.embedder is None or self.embedder.text_embedder is None:
            raise ValueError("SemanticSplitter requires an initialized `embedder`.")

        component = self.embedder.text_embedder
        texts_to_embed = [
            component._apply_text_truncation(f"{component.prefix}{text}{component.suffix}".replace("\n", " "))
            for text in texts
        ]
        embeddings, _ = component._embed_texts_batch(
            texts_to_embed=texts_to_embed,
            batch_size=component.batch_size,
        )
        try:
            for embedding in embeddings:
                BaseEmbedder.validate_embedding(embedding)
        except InvalidEmbeddingError as e:
            logger.error(f"Invalid embedding returned by model {component.model}: {str(e)}")
            raise ValueError(f"Invalid embedding returned by the model: {str(e)}")
        return embeddings

    def execute(
        self,
        input_data: SemanticSplitterInputSchema,
        config: RunnableConfig = None,
        **kwargs,
    ) -> dict[str, Any]:
        config = ensure_config(config)
        self.run_on_node_execute_run(config.callbacks, **kwargs)
        documents = input_data.documents
        logger.debug(f"SemanticSplitter: splitting {len(documents)} documents.")
        output = self.splitter.run(documents=documents)
        return {"documents": output["documents"]}