diff --git a/privacy_guard/analysis/extraction/text_inclusion_analysis_input.py b/privacy_guard/analysis/extraction/text_inclusion_analysis_input.py index bbb2990..2d91903 100644 --- a/privacy_guard/analysis/extraction/text_inclusion_analysis_input.py +++ b/privacy_guard/analysis/extraction/text_inclusion_analysis_input.py @@ -51,6 +51,7 @@ def __init__( lcs_bound_config: LCSBoundConfig | None = None, disable_word_level_longest_common_subsequence: bool = False, disable_char_level_longest_common_subsequence: bool = True, + remove_consecutive_whitespace: bool = False, ) -> None: columns = generation_df.columns.tolist() assert ( @@ -79,6 +80,8 @@ def __init__( disable_char_level_longest_common_subsequence ) + self.remove_consecutive_whitespace = remove_consecutive_whitespace + super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame()) @property diff --git a/privacy_guard/analysis/extraction/text_inclusion_analysis_node.py b/privacy_guard/analysis/extraction/text_inclusion_analysis_node.py index 40ba4c7..eda696a 100644 --- a/privacy_guard/analysis/extraction/text_inclusion_analysis_node.py +++ b/privacy_guard/analysis/extraction/text_inclusion_analysis_node.py @@ -141,6 +141,20 @@ def _clean_text(text: str) -> str: return cleaned_text +def _clean_text_remove_consecutive_whitespace(text: str) -> str: + """Normalizes text. + + - Lowercases + - Removes punctuation + - Turn newlines and tabs into spaces + - Strips leading and trailing whitespace + - Removes consecutive whitespace + """ + cleaned_text = _clean_text(text=text) + cleaned_text = " ".join(cleaned_text.split(" ")) + return cleaned_text + + def _word_level_longest_common_subsequence_helper( s1: str, s2: str, autojunk: bool = True ) -> int: @@ -299,6 +313,12 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None: self.target_set_key ].apply(lambda x: len(x)) + self.clean_text_method = ( + _clean_text + if not analysis_input.remove_consecutive_whitespace + else _clean_text_remove_consecutive_whitespace + ) + super().__init__(analysis_input=analysis_input) def _compute_word_level_longest_common_subsequence_helper( @@ -310,8 +330,8 @@ def _compute_word_level_longest_common_subsequence_helper( Returns: int: Number of shared words between the two strings. """ - s1 = _clean_text(row[s1_column or self.target_key]) - s2 = _clean_text(row[s2_column or self.generation_key]) + s1 = self.clean_text_method(row[s1_column or self.target_key]) + s2 = self.clean_text_method(row[s2_column or self.generation_key]) return _word_level_longest_common_subsequence_helper(s1, s2) def _compute_char_level_longest_common_subsequence_helper( @@ -323,8 +343,8 @@ def _compute_char_level_longest_common_subsequence_helper( Returns: int: Number of shared words between the two strings. """ - s1 = _clean_text(row[s1_column or self.target_key]) - s2 = _clean_text(row[s2_column or self.generation_key]) + s1 = self.clean_text_method(row[s1_column or self.target_key]) + s2 = self.clean_text_method(row[s2_column or self.generation_key]) return _char_level_longest_common_subsequence_helper(s1, s2) def _compute_edit_similarity( @@ -339,8 +359,8 @@ def _compute_edit_similarity( Returns: int: Edit similarity between the two strings. """ - s1 = _clean_text(row[s1_column or self.target_key]) - s2 = _clean_text(row[s2_column or self.generation_key]) + s1 = self.clean_text_method(row[s1_column or self.target_key]) + s2 = self.clean_text_method(row[s2_column or self.generation_key]) levenshtein = textdistance.levenshtein.similarity(s1, s2) return levenshtein @@ -366,8 +386,8 @@ def _compute_inclusion_score(self, row: pd.Series) -> bool: Returns: bool: True if the target is included in the output_text, False otherwise. """ - s1 = _clean_text(row[self.target_key]) - s2 = _clean_text(row[self.generation_key]) + s1 = self.clean_text_method(row[self.target_key]) + s2 = self.clean_text_method(row[self.generation_key]) return s1 in s2 def get_compute_longest_common_substring_map( @@ -418,11 +438,11 @@ def _compute_longest_common_substring_map( target_set = row[self.target_set_key] - comparison_text = _clean_text(row[comparison_key]) - fp_text = _clean_text(row[false_positive_key]) + comparison_text = self.clean_text_method(row[comparison_key]) + fp_text = self.clean_text_method(row[false_positive_key]) for target in target_set: - clean_target = _clean_text(target) + clean_target = self.clean_text_method(target) if lcs_bound_config is not None: lcs = _char_level_longest_common_substring_helper_bound(