Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
lcs_bound_config: LCSBoundConfig | None = None,
disable_word_level_longest_common_subsequence: bool = False,
disable_char_level_longest_common_subsequence: bool = True,
remove_consecutive_whitespace: bool = False,
) -> None:
columns = generation_df.columns.tolist()
assert (
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(
disable_char_level_longest_common_subsequence
)

self.remove_consecutive_whitespace = remove_consecutive_whitespace

super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame())

@property
Expand Down
42 changes: 31 additions & 11 deletions privacy_guard/analysis/extraction/text_inclusion_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def _clean_text(text: str) -> str:
return cleaned_text


def _clean_text_remove_consecutive_whitespace(text: str) -> str:
"""Normalizes text.

- Lowercases
- Removes punctuation
- Turn newlines and tabs into spaces
- Strips leading and trailing whitespace
- Removes consecutive whitespace
"""
cleaned_text = _clean_text(text=text)
cleaned_text = " ".join(cleaned_text.split(" "))
return cleaned_text


def _word_level_longest_common_subsequence_helper(
s1: str, s2: str, autojunk: bool = True
) -> int:
Expand Down Expand Up @@ -299,6 +313,12 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None:
self.target_set_key
].apply(lambda x: len(x))

self.clean_text_method = (
_clean_text
if not analysis_input.remove_consecutive_whitespace
else _clean_text_remove_consecutive_whitespace
)

super().__init__(analysis_input=analysis_input)

def _compute_word_level_longest_common_subsequence_helper(
Expand All @@ -310,8 +330,8 @@ def _compute_word_level_longest_common_subsequence_helper(
Returns:
int: Number of shared words between the two strings.
"""
s1 = _clean_text(row[s1_column or self.target_key])
s2 = _clean_text(row[s2_column or self.generation_key])
s1 = self.clean_text_method(row[s1_column or self.target_key])
s2 = self.clean_text_method(row[s2_column or self.generation_key])
return _word_level_longest_common_subsequence_helper(s1, s2)

def _compute_char_level_longest_common_subsequence_helper(
Expand All @@ -323,8 +343,8 @@ def _compute_char_level_longest_common_subsequence_helper(
Returns:
int: Number of shared words between the two strings.
"""
s1 = _clean_text(row[s1_column or self.target_key])
s2 = _clean_text(row[s2_column or self.generation_key])
s1 = self.clean_text_method(row[s1_column or self.target_key])
s2 = self.clean_text_method(row[s2_column or self.generation_key])
return _char_level_longest_common_subsequence_helper(s1, s2)

def _compute_edit_similarity(
Expand All @@ -339,8 +359,8 @@ def _compute_edit_similarity(
Returns:
int: Edit similarity between the two strings.
"""
s1 = _clean_text(row[s1_column or self.target_key])
s2 = _clean_text(row[s2_column or self.generation_key])
s1 = self.clean_text_method(row[s1_column or self.target_key])
s2 = self.clean_text_method(row[s2_column or self.generation_key])
levenshtein = textdistance.levenshtein.similarity(s1, s2)
return levenshtein

Expand All @@ -366,8 +386,8 @@ def _compute_inclusion_score(self, row: pd.Series) -> bool:
Returns:
bool: True if the target is included in the output_text, False otherwise.
"""
s1 = _clean_text(row[self.target_key])
s2 = _clean_text(row[self.generation_key])
s1 = self.clean_text_method(row[self.target_key])
s2 = self.clean_text_method(row[self.generation_key])
return s1 in s2

def get_compute_longest_common_substring_map(
Expand Down Expand Up @@ -418,11 +438,11 @@ def _compute_longest_common_substring_map(

target_set = row[self.target_set_key]

comparison_text = _clean_text(row[comparison_key])
fp_text = _clean_text(row[false_positive_key])
comparison_text = self.clean_text_method(row[comparison_key])
fp_text = self.clean_text_method(row[false_positive_key])

for target in target_set:
clean_target = _clean_text(target)
clean_target = self.clean_text_method(target)

if lcs_bound_config is not None:
lcs = _char_level_longest_common_substring_helper_bound(
Expand Down
Loading