refactor and clean up the code in preparation for some consolidation, coming up

This commit is contained in:
Andrej Karpathy 2021-11-26 09:10:35 -08:00
parent f565eba1c7
commit 470bd00563
2 changed files with 36 additions and 33 deletions

View File

@ -22,15 +22,17 @@ from aslite.db import get_papers_db, get_metas_db, get_tags_db
from aslite.db import load_features
# -----------------------------------------------------------------------------
# TODO: user accounts / password login are necessary...
# inits and globals
app = Flask(__name__)
RET_NUM = 100 # number of papers to return per page
# -----------------------------------------------------------------------------
# globals that manage the (lazy) loading of various state for a request
def get_tags():
if not hasattr(g, '_tags'):
user = 'root' # root for now, the only default user
print("reading tags for user %s" % (user, ))
with get_tags_db() as tags_db:
tags_dict = tags_db[user] if user in tags_db else {}
g._tags = tags_dict
@ -46,6 +48,17 @@ def get_metas():
g._mdb = get_metas_db()
return g._mdb
@app.teardown_request
def close_connection(error=None):
# close any opened database connections
if hasattr(g, '_pdb'):
g._pdb.close()
if hasattr(g, '_mdb'):
g._mdb.close()
# -----------------------------------------------------------------------------
# ranking utilities for completing the search/rank/filter requests
def render_pids(pids):
pdb = get_papers()
@ -123,32 +136,8 @@ def svm_rank(tags: str = '', pid: str = ''):
return pids, scores
def default_context(papers, **kwargs):
context = {}
# insert the papers
context['papers'] = papers
# fetch and insert the available tags
tags = get_tags()
context['tags'] = [{'name':t, 'n':len(pids)} for t, pids in tags.items()] + [{'name': 'all'}]
# various other globals
gvars = {}
gvars['search_query'] = ''
gvars['time_filter'] = ''
gvars['message'] = ''
context['gvars'] = gvars
return context
# -----------------------------------------------------------------------------
@app.teardown_request
def close_connection(error=None):
# close any opened database connections
if hasattr(g, '_pdb'):
g._pdb.close()
if hasattr(g, '_mdb'):
g._mdb.close()
# -----------------------------------------------------------------------------
# primary application endpoints
@app.route('/', methods=['GET'])
def main():
@ -193,15 +182,20 @@ def main():
for i, p in enumerate(papers):
p['weight'] = float(scores[i])
context = default_context(papers)
# build the page context information and render
tags = get_tags()
context = {}
context['papers'] = papers
context['tags'] = [{'name':t, 'n':len(pids)} for t, pids in tags.items()] + [{'name': 'all'}]
context['gvars'] = {}
context['gvars']['rank'] = opt_rank
context['gvars']['tags'] = opt_tags
context['gvars']['pid'] = opt_pid
context['gvars']['time_filter'] = opt_time_filter
context['gvars']['skip_have'] = opt_skip_have
context['gvars']['search_query'] = ''
return render_template('index.html', **context)
@app.route("/search", methods=['GET'])
def search():
q = request.args.get('q', '') # get the search request
@ -228,7 +222,16 @@ def search():
for i, p in enumerate(papers):
p['weight'] = pairs[i][0]
context = default_context(papers)
tags = get_tags()
context = {}
context['papers'] = papers
context['tags'] = [{'name':t, 'n':len(pids)} for t, pids in tags.items()] + [{'name': 'all'}]
context['gvars'] = {}
context['gvars']['rank'] = ''
context['gvars']['tags'] = ''
context['gvars']['pid'] = ''
context['gvars']['time_filter'] = ''
context['gvars']['skip_have'] = ''
context['gvars']['search_query'] = q
return render_template('index.html', **context)
@ -265,6 +268,9 @@ def inspect():
)
return render_template('inspect.html', **context)
# -----------------------------------------------------------------------------
# tag related endpoints: add, delete tags for any paper
@app.route('/add/<pid>/<tag>')
def add(pid=None, tag=None):
user = 'root'

View File

@ -66,9 +66,6 @@ var gvars = {{ gvars | tojson }};
<div>
</div>
<div id="message">
{{gvars.message}}
</div>
</div>
<div id="tagwrap">