Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Copyright 2019 Brian Thompson | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| https://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| import argparse | |
| import sys | |
| from collections import defaultdict | |
| import numpy as np | |
| from dp_utils import read_alignments | |
| """ | |
| Faster implementation of lax and strict precision and recall, based on | |
| https://www.aclweb.org/anthology/W11-4624/. | |
| """ | |
| def _precision(goldalign, testalign): | |
| """ | |
| Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments | |
| """ | |
| tpstrict = 0 # true positive strict counter | |
| tplax = 0 # true positive lax counter | |
| fpstrict = 0 # false positive strict counter | |
| fplax = 0 # false positive lax counter | |
| # convert to sets, remove alignments empty on both sides | |
| testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)]) | |
| goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)]) | |
| # mappings from source test sentence idxs to | |
| # target gold sentence idxs for which the source test sentence | |
| # was found in corresponding source gold alignment | |
| src_id_to_gold_tgt_ids = defaultdict(set) | |
| for gold_src, gold_tgt in goldalign: | |
| for gold_src_id in gold_src: | |
| for gold_tgt_id in gold_tgt: | |
| src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id) | |
| for (test_src, test_target) in testalign: | |
| if (test_src, test_target) == ((), ()): | |
| continue | |
| if (test_src, test_target) in goldalign: | |
| # strict match | |
| tpstrict += 1 | |
| tplax += 1 | |
| else: | |
| # For anything with partial gold/test overlap on the source, | |
| # see if there is also partial overlap on the gold/test target | |
| # If so, its a lax match | |
| target_ids = set() | |
| for src_test_id in test_src: | |
| for tgt_id in src_id_to_gold_tgt_ids[src_test_id]: | |
| target_ids.add(tgt_id) | |
| if set(test_target).intersection(target_ids): | |
| fpstrict += 1 | |
| tplax += 1 | |
| else: | |
| fpstrict += 1 | |
| fplax += 1 | |
| return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32) | |
| def score_multiple(gold_list, test_list, value_for_div_by_0=0.0): | |
| # accumulate counts for all gold/test files | |
| pcounts = np.array([0, 0, 0, 0], dtype=np.int32) | |
| rcounts = np.array([0, 0, 0, 0], dtype=np.int32) | |
| for goldalign, testalign in zip(gold_list, test_list): | |
| pcounts += _precision(goldalign=goldalign, testalign=testalign) | |
| # recall is precision with no insertion/deletion and swap args | |
| test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)] | |
| gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)] | |
| rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del) | |
| # Compute results | |
| # pcounts: tpstrict,fnstrict,tplax,fnlax | |
| # rcounts: tpstrict,fpstrict,tplax,fplax | |
| if pcounts[0] + pcounts[1] == 0: | |
| pstrict = value_for_div_by_0 | |
| else: | |
| pstrict = pcounts[0] / float(pcounts[0] + pcounts[1]) | |
| if pcounts[2] + pcounts[3] == 0: | |
| plax = value_for_div_by_0 | |
| else: | |
| plax = pcounts[2] / float(pcounts[2] + pcounts[3]) | |
| if rcounts[0] + rcounts[1] == 0: | |
| rstrict = value_for_div_by_0 | |
| else: | |
| rstrict = rcounts[0] / float(rcounts[0] + rcounts[1]) | |
| if rcounts[2] + rcounts[3] == 0: | |
| rlax = value_for_div_by_0 | |
| else: | |
| rlax = rcounts[2] / float(rcounts[2] + rcounts[3]) | |
| if (pstrict + rstrict) == 0: | |
| fstrict = value_for_div_by_0 | |
| else: | |
| fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict) | |
| if (plax + rlax) == 0: | |
| flax = value_for_div_by_0 | |
| else: | |
| flax = 2 * (plax * rlax) / (plax + rlax) | |
| result = dict(recall_strict=rstrict, | |
| recall_lax=rlax, | |
| precision_strict=pstrict, | |
| precision_lax=plax, | |
| f1_strict=fstrict, | |
| f1_lax=flax) | |
| return result | |
| def log_final_scores(res): | |
| print(' ---------------------------------', file=sys.stderr) | |
| print('| | Strict | Lax |', file=sys.stderr) | |
| print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr) | |
| print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr) | |
| print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr) | |
| print(' ---------------------------------', file=sys.stderr) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| 'Compute strict/lax precision and recall for one or more pairs of gold/test alignments', | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| parser.add_argument('-t', '--test', type=str, nargs='+', required=True, | |
| help='one or more test alignment files') | |
| parser.add_argument('-g', '--gold', type=str, nargs='+', required=True, | |
| help='one or more gold alignment files') | |
| args = parser.parse_args() | |
| if len(args.test) != len(args.gold): | |
| raise Exception('number of gold/test files must be the same') | |
| gold_list = [read_alignments(x) for x in args.gold] | |
| test_list = [read_alignments(x) for x in args.test] | |
| res = score_multiple(gold_list=gold_list, test_list=test_list) | |
| log_final_scores(res) | |
| if __name__ == '__main__': | |
| main() | |