-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
64 lines (45 loc) · 1.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import pathlib
import pandas as pd
from fastapi import FastAPI, Path, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from uvicorn import run
from search import SearchEngine
script_dir = pathlib.Path(__file__).resolve().parent
templates_path = script_dir / "templates"
static_path = script_dir / "static"
app = FastAPI()
engine = SearchEngine()
templates = Jinja2Templates(directory=str(templates_path))
app.mount("/static", StaticFiles(directory=str(static_path)), name="static")
def get_top_urls(scores_dict: dict, n: int):
sorted_urls = sorted(scores_dict.items(), key=lambda x: x[1], reverse=True)
top_n_urls = sorted_urls[:n]
top_n_dict = dict(top_n_urls)
return top_n_dict
@app.get("/", response_class=HTMLResponse)
async def search(request: Request):
posts = engine.posts
return templates.TemplateResponse("search.html", {"request": request, "posts": posts})
@app.get("/results/{query}", response_class=HTMLResponse)
async def search_results(request: Request, query: str = Path(...)):
results = engine.search(query)
results = get_top_urls(results, n=5)
return templates.TemplateResponse("results.html", {"request": request, "results": results, "query": query})
@app.get("/about")
def read_about(request: Request):
return templates.TemplateResponse("about.html", {"request": request})
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", required=True, type=str, help="Path to the parquet data file")
return parser.parse_args()
def main():
args = parse_args()
data = pd.read_parquet(args.data_path)
content = list(zip(data["URL"].values, data["content"].values))
engine.bulk_index(content)
run(app, host="0.0.0.0", port=80)
if __name__ == "__main__":
main() # pragma: no cover