Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions wefe/metrics/ECT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from wefe.metrics.base_metric import BaseMetric
from wefe.preprocessing import get_embeddings_from_query
from wefe.query import Query
from wefe.word_embedding_model import WordEmbeddingModel
from wefe.models.base_model import BaseModel


class ECT(BaseMetric):
Expand Down Expand Up @@ -43,7 +43,8 @@ class ECT(BaseMetric):
def run_query(
self,
query: Query,
model: WordEmbeddingModel,

word_embedding: BaseModel,
lost_vocabulary_threshold: float = 0.2,
preprocessors: List[Dict[str, Union[str, bool, Callable]]] = [{}],
strategy: str = "first",
Expand All @@ -60,7 +61,7 @@ def run_query(
A Query object that contains the target and attribute word sets to be
tested.

model : WordEmbeddingModel
model : BaseModel
A word embedding model.

preprocessors : List[Dict[str, Union[str, bool, Callable]]]
Expand Down
6 changes: 3 additions & 3 deletions wefe/metrics/MAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from wefe.metrics.base_metric import BaseMetric
from wefe.preprocessing import get_embeddings_from_query
from wefe.query import Query
from wefe.word_embedding_model import WordEmbeddingModel
from wefe.models.base_model import BaseModel


class MAC(BaseMetric):
Expand Down Expand Up @@ -93,7 +93,7 @@ def _calc_mac(self, T, A):
def run_query(
self,
query: Query,
model: WordEmbeddingModel,
word_embedding: BaseModel,
lost_vocabulary_threshold: float = 0.2,
preprocessors: List[Dict[str, Union[str, bool, Callable]]] = [{}],
strategy: str = "first",
Expand All @@ -110,7 +110,7 @@ def run_query(
A Query object that contains the target and attribute word sets
for be tested.

model : WordEmbeddingModel
model : BaseModel
A word embedding model.

preprocessors : List[Dict[str, Union[str, bool, Callable]]]
Expand Down
124 changes: 116 additions & 8 deletions wefe/metrics/RND.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from wefe.metrics.base_metric import BaseMetric
from wefe.preprocessing import get_embeddings_from_query
from wefe.preprocessing import get_embeddings_from_query, get_related_embeddings_from_query
from wefe.query import Query
from wefe.word_embedding_model import WordEmbeddingModel
from wefe.models.base_model import BaseModel


class RND(BaseMetric):
Expand Down Expand Up @@ -66,7 +66,7 @@ def __calc_rnd(
) - self.__calc_distance(
attribute_embedding, target_2_avg_vector, distance_type=distance_type,
)

# add the distance of the neutral word to the accumulated
# distances.
sum_of_distances += current_distance
Expand All @@ -82,10 +82,57 @@ def __calc_rnd(
mean_distance = sum_of_distances / len(distance_by_words)
return mean_distance, sorted_distances_by_word

def __calc_rnd2(
self,
target_1: np.ndarray,
target_2: np.ndarray,
attribute_t1: np.ndarray,
attribute_t2: np.ndarray,
attribute_words: list,
distance_type: str,
) -> Tuple[float, Dict[str, float]]:

# calculates the average wv for the group words.
target_1_avg_vector = np.average(target_1, axis=0)
target_2_avg_vector = np.average(target_2, axis=0)

sum_of_distances = 0.0
distance_by_words = {}

if len(attribute_t1) != len(attribute_t2):
print('Different lengths ???')

for index in range(len(attribute_t1)):
attribute_embedding1 = attribute_t1[index]
attribute_embedding2 = attribute_t2[index]

current_distance = self.__calc_distance(
attribute_embedding1,
target_1_avg_vector,
distance_type=distance_type) - self.__calc_distance(
attribute_embedding2,
target_2_avg_vector,
distance_type=distance_type)

# add the distance of the neutral word to the accumulated
# distances.
sum_of_distances += current_distance
# add the distance of the neutral word to the list of distances
# by word
distance_by_words[attribute_words[index]] = current_distance

sorted_distances_by_word = {
k: v for k, v in sorted(distance_by_words.items(), key=lambda item: item[1])
}

# calculate the average of the distances and return
mean_distance = sum_of_distances / len(distance_by_words)
return mean_distance, sorted_distances_by_word

def run_query(
self,
query: Query,
model: WordEmbeddingModel,
word_embedding: BaseModel,
distance: str = "norm",
lost_vocabulary_threshold: float = 0.2,
preprocessors: List[Dict[str, Union[str, bool, Callable]]] = [{}],
Expand All @@ -102,8 +149,8 @@ def run_query(
query : Query
A Query object that contains the target and attribute sets to be tested.

model : WordEmbeddingModel
A word embedding model.
word_embedding_model : BaseModel
An object containing a word embedding model.

distance : str, optional
Specifies which type of distance will be calculated. It could be:
Expand Down Expand Up @@ -233,11 +280,21 @@ def run_query(
'wedding': 0.104610026}}
"""
# check the types of the provided arguments (only the defaults).
self._check_input(query, model, locals())
self._check_input(query, word_embedding, locals())

if word_embedding.context == True and query.sentence_template != None:
return self.run_contextual_query(query,
word_embedding,
distance,
lost_vocabulary_threshold,
preprocessors,
strategy,
normalize,
warn_not_found_words)

# transform query word sets into embeddings
embeddings = get_embeddings_from_query(
model=model,
model=word_embedding,
query=query,
lost_vocabulary_threshold=lost_vocabulary_threshold,
preprocessors=preprocessors,
Expand Down Expand Up @@ -284,3 +341,54 @@ def run_query(
"rnd": rnd,
"distances_by_word": distances_by_word,
}

def run_contextual_query(
self,
query: Query,
word_embedding: BaseModel,
distance: str = "norm",
lost_vocabulary_threshold: float = 0.2,
preprocessors: List[Dict[str, Union[str, bool, Callable]]] = [{}],
strategy: str = "first",
normalize: bool = False,
warn_not_found_words: bool = False
) -> Dict[str, Any]:

# transform query word sets into embeddings
embeddings = get_related_embeddings_from_query(
model=word_embedding,
query=query,
lost_vocabulary_threshold=lost_vocabulary_threshold,
preprocessors=preprocessors,
strategy=strategy,
normalize=normalize,
warn_not_found_words=warn_not_found_words,
)

# if there is any/some set has less words than the allowed limit,
# return the default value (nan)
if embeddings is None:
return {"query_name": query.query_name,
"result": np.nan,
"rnd": np.nan,
"distances_by_word": {}}

# get the targets sets transformed into embeddings.
target1 = embeddings.getTargetsMean(1, 1)
target2 = embeddings.getTargetsMean(2, 1)
# get the attribute sets transformed into embeddings.
attr1 = embeddings.getAttributesMean(1, 1)
attr2 = embeddings.getAttributesMean(2, 1)
attr_words = embeddings.getAttributeWords(1)
# get the metric
rnd, distances_by_word = self.__calc_rnd2(
target1,
target2,
attr1,
attr2,
attr_words,
distance)
return {"query_name": query.query_name,
"result": rnd,
"rnd": rnd,
"distances_by_word": distances_by_word}
Loading