269 lines
11 KiB
Python
269 lines
11 KiB
Python
import json
|
|
import multiprocessing as mp
|
|
import re
|
|
from collections import defaultdict
|
|
from functools import partial
|
|
from typing import Dict, List, Optional, Set, Tuple, Type
|
|
|
|
from datasets import Dataset
|
|
from datasketch import MinHash, MinHashLSH
|
|
from dpu_utils.utils.iterators import ThreadedIterator
|
|
from tqdm import tqdm
|
|
|
|
|
|
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
|
# parameters used in DuplicationIndex
|
|
MIN_NUM_TOKENS = 10
|
|
NUM_PERM = 256
|
|
|
|
|
|
def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
|
|
"""Compute the MinHash of a code snippet."""
|
|
if len(tokens) < MIN_NUM_TOKENS:
|
|
return None
|
|
min_hash = MinHash(num_perm=NUM_PERM)
|
|
for token in set(tokens):
|
|
min_hash.update(token.encode())
|
|
return min_hash
|
|
|
|
|
|
def get_tokens(code: str) -> Set[str]:
|
|
"""Tokenize a code snippet."""
|
|
return {t for t in NON_ALPHA.split(code) if len(t.strip()) > 0}
|
|
|
|
|
|
class DuplicationIndex:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
duplication_jaccard_threshold: float = 0.85,
|
|
):
|
|
self._duplication_jaccard_threshold = duplication_jaccard_threshold
|
|
self._num_perm = NUM_PERM
|
|
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)
|
|
|
|
self._duplicate_clusters = defaultdict(set)
|
|
|
|
def add(self, code_key: Tuple, min_hash: MinHash) -> None:
|
|
"""Add a key to _index (MinHashLSH)
|
|
the min_hash is used to query closest matches based on the jaccard_threshold.
|
|
The new key is either added to a existing cluster of one close match,
|
|
or a new cluster is created. The clusters created in this way, depend on the order of add.
|
|
|
|
Args:
|
|
code_key (Tuple of (index, repo_name, path)):
|
|
Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
|
|
min_hash: MinHash of the code_key.
|
|
"""
|
|
close_duplicates = self._index.query(min_hash)
|
|
if code_key in self._index.keys:
|
|
print(f"Duplicate key {code_key}")
|
|
return
|
|
|
|
self._index.insert(code_key, min_hash)
|
|
if len(close_duplicates) > 0:
|
|
for base_duplicate in close_duplicates:
|
|
if base_duplicate in self._duplicate_clusters:
|
|
self._duplicate_clusters[base_duplicate].add(code_key)
|
|
break
|
|
else:
|
|
self._duplicate_clusters[close_duplicates[0]].add(code_key)
|
|
|
|
def get_duplicate_clusters(self) -> List[List[Dict]]:
|
|
"""Export the duplicate clusters.
|
|
For each cluster, the first element is the base element of the cluster.
|
|
The base element has an estimation jaccard similarity higher than the threshold with all the other elements.
|
|
|
|
Returns:
|
|
duplicate_clusters (List[List[Dict]]):
|
|
List of duplicate clusters.
|
|
"""
|
|
duplicate_clusters = []
|
|
for base, duplicates in self._duplicate_clusters.items():
|
|
cluster = [base] + list(duplicates)
|
|
# reformat the cluster to be a list of dict
|
|
cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
|
|
duplicate_clusters.append(cluster)
|
|
return duplicate_clusters
|
|
|
|
def save(self, filepath) -> None:
|
|
duplicate_clusters = self.get_duplicate_clusters()
|
|
with open(filepath, "w") as f:
|
|
json.dump(duplicate_clusters, f)
|
|
|
|
|
|
def _compute_min_hash(element):
|
|
index, data = element
|
|
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0])
|
|
if min_hash is not None:
|
|
return (index, data["repo_name"], data["path"]), min_hash
|
|
|
|
|
|
def minhash_iter(dataset_iterator: Type[Dataset]):
|
|
with mp.Pool() as pool:
|
|
for data in pool.imap_unordered(
|
|
_compute_min_hash,
|
|
ThreadedIterator(dataset_iterator, max_queue_size=10000),
|
|
chunksize=100,
|
|
):
|
|
if data is not None:
|
|
yield data
|
|
|
|
|
|
def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
|
|
"""Find duplicate clusters in the dataset in two steps:
|
|
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
|
|
This step is computed using an asynchronous multiprocessing pool, minhash_iter
|
|
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
|
|
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
|
|
"""
|
|
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)
|
|
|
|
for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
|
|
di.add(filename, min_hash)
|
|
|
|
# Returns a List[Cluster] where Cluster is List[str] with the filenames.
|
|
return di.get_duplicate_clusters()
|
|
|
|
|
|
def jaccard_similarity(code1: str, code2: str) -> float:
|
|
"""Compute the Jaccard similarity of two code snippets."""
|
|
tokens1 = get_tokens(code1)
|
|
tokens2 = get_tokens(code2)
|
|
return len(tokens1 & tokens2) / len(tokens1 | tokens2)
|
|
|
|
|
|
_shared_dataset = None
|
|
|
|
|
|
def _find_cluster_extremes_shared(cluster, jaccard_threshold):
|
|
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
|
|
Two codes are similar if their Jaccard similarity is above the threshold.
|
|
|
|
Args:
|
|
cluster (List[dict]):
|
|
cluster is a list of dict, each dict contains the following keys:
|
|
- base_index
|
|
- repo_name
|
|
- path
|
|
This is a typical output of DuplicationIndex.get_duplicate_clusters()
|
|
jaccard_threshold (float):
|
|
threshold for Jaccard similarity.
|
|
Two codes are similar if their Jaccard similarity is above the threshold.
|
|
|
|
Returns:
|
|
extremes (List[dict]):
|
|
A reduced representation of the cluster. The field copies is added to each dict.
|
|
The copies field indicates the number of similar codes in the cluster for a extreme.
|
|
"""
|
|
extremes = []
|
|
for element1 in cluster:
|
|
code1 = _shared_dataset[element1["base_index"]]["content"]
|
|
for element2 in extremes:
|
|
code2 = _shared_dataset[element2["base_index"]]["content"]
|
|
if jaccard_similarity(code1, code2) >= jaccard_threshold:
|
|
element2["copies"] += 1
|
|
break
|
|
else:
|
|
element1["copies"] = 1
|
|
extremes.append(element1)
|
|
return extremes
|
|
|
|
|
|
def find_extremes(cluster_list, dataset, jaccard_threshold):
|
|
"""Call the _find_cluster_extremes_shared function in a parallel fashion.
|
|
|
|
Args:
|
|
cluster_list (List[List[Dict]]):
|
|
each cluster is a list of dicts with the key base_index,
|
|
referring to the index of the base code in the dataset.
|
|
dataset (Type[Dataset]):
|
|
dataset is used to access the content of the code snippets,
|
|
using the base_index from the cluster_list.
|
|
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
|
|
otherwise the multi processing is not speeded up.
|
|
jaccard_threshold (float):
|
|
the threshold for the jaccard similarity. The default value is 0.85
|
|
|
|
Returns:
|
|
extremes_list (List[Dict]):
|
|
Each cluster is reduced to extremes.
|
|
See _find_cluster_extremes_shared for the definition of extremes.
|
|
"""
|
|
global _shared_dataset
|
|
_shared_dataset = dataset
|
|
extremes_list = []
|
|
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
|
|
with mp.Pool() as pool:
|
|
for extremes in tqdm(
|
|
pool.imap_unordered(
|
|
f,
|
|
cluster_list,
|
|
),
|
|
total=len(cluster_list),
|
|
):
|
|
extremes_list.append(extremes)
|
|
return extremes_list
|
|
|
|
|
|
def deduplicate_dataset(
|
|
dataset: Type[Dataset], jaccard_threshold: float = 0.85
|
|
) -> Tuple[Type[Dataset], List[List[Dict]]]:
|
|
"""Deduplicate the dataset using minhash and jaccard similarity.
|
|
This function first generate duplicate clusters, then each cluster
|
|
is reduced to the extremes that are similar to the other elements in the cluster.
|
|
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).
|
|
|
|
Args:
|
|
dataset (Type[Dataset]):
|
|
The dataset to deduplicate.
|
|
jaccard_threshold (float, default=0.85):
|
|
jaccard threshold to determine if two codes are similar
|
|
|
|
Returns:
|
|
ds_dedup (Type[Dataset]):
|
|
The deduplicated dataset.
|
|
duplicate_clusters (List[List[Dict]]):
|
|
The list of duplicate clusters.
|
|
Each cluster is a list of dicts with the following keys:
|
|
- base_index : int
|
|
The index of the code in the original dataset.
|
|
- repo_name : str
|
|
- path : str
|
|
- copies : int
|
|
The number of copies of the code in the cluster. (find_cluster_extremes)
|
|
- is_extreme : bool
|
|
Whether the code is an extreme in the cluster.
|
|
All the codes in the cluster are removed from the dataset except the extremes.
|
|
|
|
Example:
|
|
>>> from datasets import load_dataset
|
|
>>> from minhash_deduplication import deduplicate_dataset
|
|
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
|
|
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
|
|
"""
|
|
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
|
|
duplicate_indices = {x["base_index"] for cluster in duplicate_clusters for x in cluster}
|
|
extreme_dict = {}
|
|
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
|
|
for extremes in extremes_clusters:
|
|
for element in extremes:
|
|
extreme_dict[element["base_index"]] = element
|
|
remove_indices = duplicate_indices - set(extreme_dict.keys())
|
|
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)
|
|
|
|
# update duplicate_clusters
|
|
for cluster in duplicate_clusters:
|
|
for element in cluster:
|
|
element["is_extreme"] = element["base_index"] in extreme_dict
|
|
if element["is_extreme"]:
|
|
element["copies"] = extreme_dict[element["base_index"]]["copies"]
|
|
|
|
print(f"Original dataset size: {len(dataset)}")
|
|
print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
|
|
print(f"Files in duplicate cluster: {len(duplicate_indices)}")
|
|
print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
|
|
print(f"Filtered dataset size: {len(ds_filter)}")
|
|
|
|
return ds_filter, duplicate_clusters
|