Source code for papers_without_code.search

#!/usr/bin/env python

import itertools
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path

import backoff
import requests
from bs4 import BeautifulSoup
from dataclasses_json import DataClassJsonMixin
from dotenv import load_dotenv
from fastcore.net import HTTP4xxClientError
from ghapi.all import GhApi
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import HumanMessage
from pydantic import BaseModel, Field
from requests.exceptions import HTTPError
from sentence_transformers import SentenceTransformer, util

from .custom_types import MinimalPaperDetails

###############################################################################

log = logging.getLogger(__name__)

###############################################################################

DEFAULT_TRANSFORMER_MODEL = "thenlper/gte-small"
DEFAULT_LOCAL_CACHE_MODEL = f"./sentence-transformers_{DEFAULT_TRANSFORMER_MODEL}"

###############################################################################


[docs] def get_paper(query: str) -> MinimalPaperDetails: """ Get a papers details from the Semantic Scholar API. Provide a DOI, SemanticScholarID, CorpusID, ArXivID, ACL, or URL from semanticscholar.org, arxiv.org, aclweb.org, acm.org, or biorxiv.org. DOIs can be provided as is. All other IDs should be given with their type, for example: `doi:doi:10.18653/v1/2020.acl-main.447` or `CorpusID:202558505` or `url:https://arxiv.org/abs/2004.07180`. Parameters ---------- query: str The structured paper to query for. Returns ------- Paper The paper details. Raises ------ ValueErorr No paper was found. """ log.info(f"Getting SemanticScholar paper details with query: '{query}'") response = requests.get( f"https://api.semanticscholar.org/graph/v1/paper/{query.strip()}" "?fields=paperId,title,authors,abstract" ) response.raise_for_status() response_data = response.json() log.info(f"Found SemanticScholar paper with query: '{query}'") log.info(f"Response data: {response_data}") # Handle no paper found if len(response_data) == 0: raise ValueError(f"No paper found with DOI: '{query}'") return MinimalPaperDetails( url=f"https://www.semanticscholar.org/paper/{response_data['paperId']}", title=response_data.get("title"), authors=response_data.get("authors", None), abstract=response_data.get("abstract", None), keywords=None, other={"full_semantic_scholar_data": response_data}, )
[docs] class LLMKeywordResults(BaseModel): keywords: list[str] = Field( description=("Extracted keyword sequences found in the text.") )
LLM_KEYWORD_RESULTS_PARSER = PydanticOutputParser(pydantic_object=LLMKeywordResults) LLM_KEYWORD_PROMPT_STRING = ( "Task: Create a list of five keywords from the following text. " "Keywords can range from one to four words in length. " "Only extracted text should be included in the list of keywords. " "Keywords can include acronyms and abbreviations.\n\n" "{{ format_instructions }}" "\n\n---\n\n" "Example Input Text:\n\n" "SciBERT: A Pretrained Language Model for Scientific Text " "Obtaining large-scale annotated data for NLP tasks in the " "scientific domain is challenging and expensive. We release SciBERT, " "a pretrained language model based on BERT (Devlin et. al., 2018) " "to address the lack of high-quality, large-scale labeled scientific data. " "SciBERT leverages unsupervised pretraining on a large multi-domain corpus of " "scientific publications to improve performance on downstream scientific " "NLP tasks. We evaluate on a suite of tasks including sequence tagging, " "sentence classification and dependency parsing, with datasets from a " "variety of scientific domains. We demonstrate statistically significant " "improvements over BERT and achieve new state-of-the-art results on several " "of these tasks. The code and pretrained models are available at " "https://github.com/allenai/scibert/." "\n\n---\n\n" "Example Output Text:\n\n" '{"keywords": [' '"SciBERT", ' '"Language Model for Scientific Text", ' '"large-scale labeled scientific data", ' '"Scientific Text", ' '"SciBERT leverages unsupervised pretraining"' "]}" "\n\n---\n\n" "Input Text:\n\n{{ text }}" "\n\n---\n\n" ) LLM_KEYWORD_PROMPT_TEMPLATE = PromptTemplate.from_template( LLM_KEYWORD_PROMPT_STRING, template_format="jinja2", ) backoff.on_exception(backoff.expo, exception=json.JSONDecodeError, max_time=10) def _run_keyword_get_from_llm(text: str, llm: ChatOpenAI) -> LLMKeywordResults: # Fill prompt to get input input_ = LLM_KEYWORD_PROMPT_TEMPLATE.format_prompt( text=text, format_instructions=LLM_KEYWORD_RESULTS_PARSER.get_format_instructions(), ) # Generate keywords output = llm([HumanMessage(content=input_.text)]).content.strip() # Parse output parsed_output = LLM_KEYWORD_RESULTS_PARSER.parse(output) return parsed_output def _get_keywords(text: str) -> list[str]: # Create connection to LLM llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, max_tokens=1000) # Get keywords parsed_output = _run_keyword_get_from_llm(text, llm) # Return keywords return parsed_output.keywords
[docs] @dataclass class SearchQueryDataTracker: query_str: str strict: bool = False
[docs] @dataclass class SearchQueryResponse: query_str: str repo_name: str stars: int forks: int watchers: int description: str
@backoff.on_exception(backoff.expo, HTTP4xxClientError) def _search_repos( query: SearchQueryDataTracker, api: GhApi ) -> list[SearchQueryResponse]: # Make request if query.strict: response = api( "/search/repositories", "GET", query={ "q": f'"{query.query_str}"', "per_page": 10, }, ) else: response = api( "/search/repositories", "GET", query={ "q": f"{query.query_str}", "per_page": 10, }, ) # Dedupe and process dedupe_repos_strs = set() results = [] # Unpack items for item in response["items"]: if not item["fork"]: if item["full_name"] not in dedupe_repos_strs: dedupe_repos_strs.add(item["full_name"]) results.append( SearchQueryResponse( query_str=query.query_str, repo_name=item["full_name"], stars=item["stargazers_count"], forks=item["forks"], watchers=item["watchers_count"], description=item["description"], ) ) return results
[docs] @dataclass class RepoReadmeResponse: repo_name: str search_query: str readme_text: str stars: int forks: int watchers: int description: str
@backoff.on_exception(backoff.expo, HTTPError, max_time=60) def _get_repo_readme_content( repo_data: SearchQueryResponse, ) -> RepoReadmeResponse | None: # Request repo page response = requests.get(f"https://github.com/{repo_data.repo_name}") response.raise_for_status() # Read README content soup = BeautifulSoup(response.content, "html.parser") readme_container = soup.find(id="readme") # Will be filtered out after this if not readme_container: return None return RepoReadmeResponse( repo_name=repo_data.repo_name, search_query=repo_data.query_str, readme_text=readme_container.text, stars=repo_data.stars, forks=repo_data.forks, watchers=repo_data.watchers, description=repo_data.description, )
[docs] @dataclass class RepoDetails(DataClassJsonMixin): name: str link: str search_query: str similarity: float stars: int forks: int watchers: int description: str
def _semantic_sim_repos( all_repos_details: list[RepoReadmeResponse], paper: MinimalPaperDetails, model: SentenceTransformer | None = None, ) -> list[RepoDetails]: # Load model if not model: potential_cache_dir = Path(DEFAULT_LOCAL_CACHE_MODEL).resolve() if potential_cache_dir.exists(): model = SentenceTransformer(str(potential_cache_dir)) else: model = SentenceTransformer(DEFAULT_TRANSFORMER_MODEL) # Encode abstract once if paper.abstract: sem_vec_paper = model.encode(paper.abstract, convert_to_tensor=True) else: sem_vec_paper = model.encode(paper.title, convert_to_tensor=True) # Collapse all readmes complete_repo_details = [] for repo_details in all_repos_details: sem_vec_readme = model.encode(repo_details.readme_text, convert_to_tensor=True) # Compute cosine-similarities score = util.cos_sim(sem_vec_readme, sem_vec_paper).item() complete_repo_details.append( RepoDetails( name=repo_details.repo_name, link=f"https://github.com/{repo_details.repo_name}", search_query=repo_details.search_query, similarity=score, stars=repo_details.stars, forks=repo_details.forks, watchers=repo_details.watchers, description=repo_details.description, ) ) return complete_repo_details
[docs] def get_repos( paper: MinimalPaperDetails, loaded_sent_transformer: SentenceTransformer | None = None, ) -> list[RepoDetails]: """ Try to find GitHub repositories matching a provided paper. Parameters ---------- paper: MinimalPaperDetails The paper to try and find similar repositories to. loaded_sent_transformer: Optional[SentenceTransformer] An optional preloaded SentenceTransformer model to use instead of loading a new one. Default: None Returns ------- list[RepoDetails] A list of repositories that are similar to the paper, sorted by each repositories README's semantic similarity to the abstract (or title if no abstract was attached to the paper details). """ # Try loading dotenv load_dotenv() # Connect to API api = GhApi() # No keywords were provided, generate from abstract and title if not paper.keywords: # Get all the queries we want to run if paper.title and paper.abstract: paper_content = f"{paper.title}\n\n{paper.abstract}" elif paper.title: paper_content = paper.title elif paper.abstract: paper_content = paper.abstract # Get keywords log.info("Right before keyword search...") keywords = _get_keywords( paper_content, ) log.info("Right after keywords search...") # Create the queries set_queries = [ SearchQueryDataTracker( query_str=keyword, strict=True, ) for keyword in keywords ] # Paper was provided with keywords, use those else: set_queries = [ SearchQueryDataTracker( query_str=keyword, strict=True, ) for keyword in paper.keywords ] # Progress info log.info( f"Searching GitHub for Paper: '{paper.title}'. Using queries: {set_queries}" ) # Create partial search func with API access already attached search_func = partial(_search_repos, api=api) # Do a bunch of threading during the search with ThreadPoolExecutor() as exe: # Find repos from GH Search found_repos = itertools.chain(*list(exe.map(search_func, set_queries))) # Combine all responses repos_to_parse = [] set_repo_strs = set() for found_repo in found_repos: if found_repo.repo_name not in set_repo_strs: repos_to_parse.append(found_repo) set_repo_strs.add(found_repo.repo_name) # Get the README for each repo in the set repos_and_readmes = list( exe.map( _get_repo_readme_content, repos_to_parse, ) ) # Filter nones from readmes repos_and_readmes = [ r_and_r for r_and_r in repos_and_readmes if r_and_r is not None ] repos = _semantic_sim_repos(repos_and_readmes, paper, model=loaded_sent_transformer) return sorted(repos, key=lambda x: x.similarity, reverse=True)