From 36fee9bf1b45d6c111e4faac28ace68876aab4f7 Mon Sep 17 00:00:00 2001 From: mammo0 Date: Thu, 11 Mar 2021 10:52:38 +0100 Subject: [PATCH] allow batch processing for language detection --- app/language.py | 42 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/app/language.py b/app/language.py index 2a1b71c..17d1822 100644 --- a/app/language.py +++ b/app/language.py @@ -9,12 +9,25 @@ __lang_codes = [l.code for l in languages] def detect_languages(text): - f = Detector(text).languages + # detect batch processing + if isinstance(text, list): + is_batch = True + else: + is_batch = False + text = [text] # get the candidates - candidate_langs = list(filter(lambda l: l.read_bytes != 0 and l.code in __lang_codes, f)) + candidates = [] + for t in text: + candidates.extend(Detector(t).languages) - # this happens if no language can be detected + # total read bytes of the provided text + read_bytes_total = sum(c.read_bytes for c in candidates) + + # only use candidates that are supported by argostranslate + candidate_langs = list(filter(lambda l: l.read_bytes != 0 and l.code in __lang_codes, candidates)) + + # this happens if no language could be detected if not candidate_langs: # use language "en" by default but with zero confidence return [ @@ -24,8 +37,29 @@ def detect_languages(text): } ] + # for multiple occurrences of the same language (can happen on batch detection) + # calculate the average confidence for each language + if is_batch: + temp_average_list = [] + for lang_code in __lang_codes: + # get all candidates for a specific language + lc = list(filter(lambda l: l.code == lang_code, candidate_langs)) + if len(lc) > 1: + # if more than one is present, calculate the average confidence + lang = lc[0] + lang.confidence = sum(l.confidence for l in lc) / len(lc) + lang.read_bytes = sum(l.read_bytes for l in lc) + temp_average_list.append(lang) + elif lc: + # otherwise just add it to the temporary list + temp_average_list.append(lc[0]) + + if temp_average_list: + # replace the list + candidate_langs = temp_average_list + # sort the candidates descending based on the detected confidence - candidate_langs.sort(key=lambda l: l.confidence, reverse=True) + candidate_langs.sort(key=lambda l: (l.confidence * l.read_bytes) / read_bytes_total, reverse=True) return [ {