|
1 | 1 | import os |
2 | 2 | import pandas as pd |
| 3 | +from nltk.tokenize import sent_tokenize |
3 | 4 | from transformers import AutoTokenizer, AutoModel |
4 | 5 | import torch |
5 | 6 | import torch.nn.functional as F |
6 | 7 |
|
7 | | -INPUT_FILENAME = "./data/city_wikipedia_summaries.csv" |
8 | | -EXPORT_FILENAME = "./data/city_wikipedia_summaries_with_embeddings.parquet" |
| 8 | +BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), "feature_repo")) |
| 9 | +DATA_DIR = os.path.join(BASE_DIR, "data") |
| 10 | +INPUT_FILENAME = os.path.join(DATA_DIR, "city_wikipedia_summaries.csv") |
| 11 | +CHUNKED_FILENAME = os.path.join(DATA_DIR, "city_wikipedia_summaries_chunked.csv") |
| 12 | +EXPORT_FILENAME = os.path.join( |
| 13 | + DATA_DIR, "city_wikipedia_summaries_with_embeddings.parquet" |
| 14 | +) |
9 | 15 | TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2" |
10 | 16 | MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
11 | 17 |
|
@@ -36,23 +42,33 @@ def run_model(sentences, tokenizer, model): |
36 | 42 |
|
37 | 43 |
|
38 | 44 | def score_data() -> None: |
39 | | - if EXPORT_FILENAME not in os.listdir(): |
40 | | - print("scored data not found...generating embeddings...") |
41 | | - df = pd.read_csv(INPUT_FILENAME) |
| 45 | + os.makedirs(DATA_DIR, exist_ok=True) |
| 46 | + |
| 47 | + if not os.path.exists(EXPORT_FILENAME): |
| 48 | + print("Scored data not found... generating embeddings...") |
| 49 | + |
| 50 | + if not os.path.exists(CHUNKED_FILENAME): |
| 51 | + print("Chunked data not found... generating chunked data...") |
| 52 | + df = pd.read_csv(INPUT_FILENAME) |
| 53 | + df["Sentence Chunks"] = df["Wiki Summary"].apply(lambda x: sent_tokenize(x)) |
| 54 | + chunked_df = df.explode("Sentence Chunks") |
| 55 | + chunked_df.to_csv(CHUNKED_FILENAME, index=False) |
| 56 | + df = chunked_df |
| 57 | + else: |
| 58 | + df = pd.read_csv(CHUNKED_FILENAME) |
| 59 | + |
42 | 60 | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
43 | 61 | model = AutoModel.from_pretrained(MODEL) |
44 | 62 | embeddings = run_model(df["Wiki Summary"].tolist(), tokenizer, model) |
45 | | - print(embeddings) |
46 | | - print("shape = ", df.shape) |
47 | | - df["Embeddings"] = list(embeddings.detach().cpu().numpy()) |
48 | 63 | print("embeddings generated...") |
| 64 | + df["Embeddings"] = list(embeddings.detach().cpu().numpy()) |
49 | 65 | df["event_timestamp"] = pd.to_datetime("today") |
50 | 66 | df["item_id"] = df.index |
51 | | - print(df.head()) |
| 67 | + |
52 | 68 | df.to_parquet(EXPORT_FILENAME, index=False) |
53 | | - print("...data exported. job complete") |
| 69 | + print("...data exported. Job complete") |
54 | 70 | else: |
55 | | - print("scored data found...skipping generating embeddings.") |
| 71 | + print("Scored data found... skipping generating embeddings.") |
56 | 72 |
|
57 | 73 |
|
58 | 74 | if __name__ == "__main__": |
|
0 commit comments