From 4ec9a752932c9c0db900e64ca78b2440229506bb Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 7 Dec 2021 22:52:13 -0800 Subject: [PATCH] add a secret GET argument svm_c that changes the C value in the SVM. no UI for this yet :) --- serve.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/serve.py b/serve.py index 13f5460..3bad8c4 100644 --- a/serve.py +++ b/serve.py @@ -114,7 +114,7 @@ def time_rank(): scores = [(tnow - v['_time'])/60/60/24 for k, v in ms] # time delta in days return pids, scores -def svm_rank(tags: str = '', pid: str = ''): +def svm_rank(tags: str = '', pid: str = '', svm_c: str = ''): # tag can be one tag or a few comma-separated tags or 'all' for all tags we have in db # pid can be a specific paper id to set as positive for a kind of nearest neighbor search @@ -145,7 +145,13 @@ def svm_rank(tags: str = '', pid: str = ''): return [], [] # there are no positives? # classify - clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1) + C = 0.1 + if svm_c: # if a desired C is provided attempt to use it as a float + try: + C = float(svm_c) + except ValueError: + C = 1.0 + clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=C) clf.fit(x, y) s = clf.decision_function(x) sortix = np.argsort(-s) @@ -195,6 +201,7 @@ def main(): opt_pid = request.args.get('pid', '') # pid to find nearest neighbors to opt_time_filter = request.args.get('time_filter', '') # number of days to filter by opt_skip_have = request.args.get('skip_have', 'no') # hide papers we already have? + opt_svm_c = request.args.get('svm_c', '') # svm C parameter # if a query is given, override rank to be of type "search" # this allows the user to simply hit ENTER in the search field and have the correct thing happen @@ -205,9 +212,9 @@ def main(): if opt_rank == 'search': pids, scores = search_rank(q=opt_q) elif opt_rank == 'tags': - pids, scores = svm_rank(tags=opt_tags) + pids, scores = svm_rank(tags=opt_tags, svm_c=opt_svm_c) elif opt_rank == 'pid': - pids, scores = svm_rank(pid=opt_pid) + pids, scores = svm_rank(pid=opt_pid, svm_c=opt_svm_c) elif opt_rank == 'time': pids, scores = time_rank() elif opt_rank == 'random':