Refactor feature extraction
This commit is contained in:
parent
87d8b40cad
commit
770b4b945b
2 changed files with 19 additions and 24 deletions
|
@ -4,7 +4,7 @@ Learning to rank predictor
|
|||
from pandas import DataFrame, Series
|
||||
from sklearn.base import BaseEstimator, RegressorMixin, TransformerMixin
|
||||
|
||||
from mwmbl.tinysearchengine.rank import get_match_features, get_domain_score
|
||||
from mwmbl.tinysearchengine.rank import get_match_features, get_domain_score, score_match
|
||||
|
||||
|
||||
class ThresholdPredictor(BaseEstimator, RegressorMixin):
|
||||
|
@ -26,27 +26,19 @@ class ThresholdPredictor(BaseEstimator, RegressorMixin):
|
|||
|
||||
def get_match_features_as_series(item: Series):
|
||||
terms = item['query'].lower().split()
|
||||
last_match_char_title, match_length_title, total_possible_match_length_title = get_match_features(
|
||||
terms, item['title'], True, False)
|
||||
last_match_char_extract, match_length_extract, total_possible_match_length_extract = get_match_features(
|
||||
terms, item['extract'], True, False)
|
||||
last_match_char_url, match_length_url, total_possible_match_length_url = get_match_features(
|
||||
terms, item['url'], True, False)
|
||||
domain_score = get_domain_score(item['url'])
|
||||
return Series({
|
||||
'last_match_char_title': last_match_char_title,
|
||||
'match_length_title': match_length_title,
|
||||
'total_possible_match_length_title': total_possible_match_length_title,
|
||||
'last_match_char_extract': last_match_char_extract,
|
||||
'match_length_extract': match_length_extract,
|
||||
'total_possible_match_length_extract': total_possible_match_length_extract,
|
||||
'last_match_char_url': last_match_char_url,
|
||||
'match_length_url': match_length_url,
|
||||
'total_possible_match_length_url': total_possible_match_length_url,
|
||||
'num_terms': len(terms),
|
||||
'domain_score': domain_score,
|
||||
'item_score': item['score'],
|
||||
})
|
||||
features = {}
|
||||
for part in ['title', 'extract', 'url']:
|
||||
last_match_char, match_length, total_possible_match_length = get_match_features(terms, item[part], True, False)
|
||||
features[f'last_match_char_{part}'] = last_match_char
|
||||
features[f'match_length_{part}'] = match_length
|
||||
features[f'total_possible_match_length_{part}'] = total_possible_match_length
|
||||
# features[f'score_{part}'] = score_match(last_match_char, match_length, total_possible_match_length)
|
||||
|
||||
features['num_terms'] = len(terms)
|
||||
features['num_chars'] = len(' '.join(terms))
|
||||
features['domain_score'] = get_domain_score(item['url'])
|
||||
features['item_score'] = item['score']
|
||||
return Series(features)
|
||||
|
||||
|
||||
class FeatureExtractor(BaseEstimator, TransformerMixin):
|
||||
|
|
|
@ -39,13 +39,17 @@ def _score_result(terms, result: Document, is_complete: bool, max_score: float):
|
|||
last_match_char, match_length, total_possible_match_length = get_match_features(
|
||||
terms, result_string, is_complete, False)
|
||||
|
||||
match_score = (match_length + 1. / last_match_char) / (total_possible_match_length + 1)
|
||||
match_score = score_match(last_match_char, match_length, total_possible_match_length)
|
||||
score = 0.01 * domain_score + 0.99 * match_score
|
||||
# score = (0.1 + 0.9*match_score) * (0.1 + 0.9*(result.score / max_score))
|
||||
# score = 0.01 * match_score + 0.99 * (result.score / max_score)
|
||||
return score
|
||||
|
||||
|
||||
def score_match(last_match_char, match_length, total_possible_match_length):
|
||||
return (match_length + 1. / last_match_char) / (total_possible_match_length + 1)
|
||||
|
||||
|
||||
def get_domain_score(url):
|
||||
domain = urlparse(url).netloc
|
||||
domain_score = DOMAINS.get(domain, 0.0)
|
||||
|
@ -54,7 +58,6 @@ def get_domain_score(url):
|
|||
|
||||
def get_match_features(terms, result_string, is_complete, is_url):
|
||||
query_regex = _get_query_regex(terms, is_complete, is_url)
|
||||
print("Query regex", query_regex)
|
||||
matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
|
||||
match_strings = {x.group(0).lower() for x in matches}
|
||||
match_length = sum(len(x) for x in match_strings)
|
||||
|
|
Loading…
Add table
Reference in a new issue