diff --git a/datasets/nfts/nfts/dataset.py b/datasets/nfts/nfts/dataset.py index 19cc53b6..112def74 100644 --- a/datasets/nfts/nfts/dataset.py +++ b/datasets/nfts/nfts/dataset.py @@ -2,7 +2,7 @@ Functions to access various data in the NFTs dataset. """ import sqlite3 -from typing import Union +from typing import Dict import pandas as pd @@ -92,27 +92,44 @@ def explain() -> None: The Moonstream NFTs dataset =========================== +To load the NFTs dataset from a SQLite file, run: +>>> ds = nfts.dataset.FromSQLite() + This dataset consists of the following dataframes:""" print(preamble) for name, explanation in AVAILABLE_DATAFRAMES.items(): print(f"\nDataframe: {name}") print( - f"Load using:\n\t{name}_df = nfts.dataset.load_dataframe(, {name})" + f'Load using:\n\t{name}_df = ds.load_dataframe(, "{name}")' ) print("") print(explanation) print("- - -") -def load_dataframe(db: Union[str, sqlite3.Connection], name: str) -> pd.DataFrame: - """ - Loads one of the available dataframes. To learn more about the available dataframes, run: - >>> nfts.dataset.explain() - """ - if name not in AVAILABLE_DATAFRAMES: - raise ValueError( - f"Invalid dataframe: {name}. Please choose from one of the available dataframes: {','.join(AVAILABLE_DATAFRAMES)}." - ) - df = pd.read_sql_table(name, db) - return df +class FromSQLite: + def __init__(self, datafile: str) -> None: + """ + Initialize an NFTs dataset instance by connecting it to a SQLite database containing the data. + """ + self.conn = sqlite3.connect(datafile) + + def load_dataframe(self, name: str) -> pd.DataFrame: + """ + Loads one of the available dataframes. To learn more about the available dataframes, run: + >>> nfts.dataset.explain() + """ + if name not in AVAILABLE_DATAFRAMES: + raise ValueError( + f"Invalid dataframe: {name}. Please choose from one of the available dataframes: {','.join(AVAILABLE_DATAFRAMES)}." + ) + df = pd.read_sql_query(f"SELECT * FROM {name};", self.conn) + return df + + def load_all(self) -> Dict[str, pd.DataFrame]: + """ + Load all the datasets and return them in a dictionary with the keys being the dataframe names. + """ + dfs = {f"{name}_df": self.load_dataframe(name) for name in AVAILABLE_DATAFRAMES} + return dfs