Skip to content

Commit 3ad07cd

Browse files
committed
test: add script to generate demo data
1 parent ae5f631 commit 3ad07cd

File tree

4 files changed

+97
-6
lines changed

4 files changed

+97
-6
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,15 @@ jobs:
3131
- name: Install dependencies
3232
run: npm install
3333

34+
- uses: actions/cache@v4
35+
id: cache-demo-data
36+
with:
37+
path: packages/docs/demo-data
38+
key: ${{ runner.os }}-${{ hashFiles('packages/docs/generate_demo_data.py') }}
39+
40+
- name: Generate demo data
41+
if: steps.cache-demo-data.outputs.cache-hit != 'true'
42+
run: cd packages/docs && uv run generate_demo_data.py
43+
3444
- name: Run custom build script
3545
run: ./scripts/build.sh

packages/docs/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
.vitepress/dist
33
public/upload
44
public/demo
5+
demo-data

packages/docs/generate_assets.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
set -euxo pipefail
44

5+
# Create the upload page
56
rm -rf public/upload
67
cp -r ../viewer/dist public/upload
78
python -c "fn='public/upload/index.html';c=open(fn).read().replace('viewer','upload');open(fn,'w').write(c);"
89

9-
DEMO_DATA_FOLDER=../../../embedding-atlas-demo/data
10-
11-
rm -rf public/demo
12-
if [ -d "$DEMO_DATA_FOLDER" ]; then
13-
cp -r ../viewer/dist public/demo
14-
cp -r "$DEMO_DATA_FOLDER" public/demo/data
10+
# Create the demo page
11+
if [ -d "demo-data" ]; then
12+
rm -rf public/demo
13+
cp -r ../viewer/dist public/demo
14+
cp -r demo-data public/demo/data
1515
fi
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# /// script
2+
# requires-python = ">=3.11"
3+
# dependencies = ["click", "datasets", "pandas", "sentence-transformers", "umap-learn"]
4+
# ///
5+
6+
import json
7+
import os
8+
import shutil
9+
10+
import click
11+
import pandas as pd
12+
from datasets import load_dataset
13+
from sentence_transformers import SentenceTransformer
14+
from umap import UMAP
15+
from umap.umap_ import nearest_neighbors
16+
17+
18+
def add_embedding_projection(df: pd.DataFrame, text: str):
19+
texts = list(df[text])
20+
21+
transformer = SentenceTransformer("all-MiniLM-L6-v2")
22+
hidden_vectors = transformer.encode(texts)
23+
24+
knn = nearest_neighbors(
25+
hidden_vectors,
26+
n_neighbors=15,
27+
metric="cosine",
28+
metric_kwds=None,
29+
angular=False,
30+
random_state=None,
31+
)
32+
33+
proj = UMAP(metric="cosine", precomputed_knn=knn).fit_transform(hidden_vectors)
34+
35+
df["projection_x"] = proj[:, 0] # type: ignore
36+
df["projection_y"] = proj[:, 1] # type: ignore
37+
df["__neighbors"] = [{"distances": b, "ids": a} for a, b in zip(knn[0], knn[1])]
38+
39+
40+
@click.command()
41+
@click.option("--output", default="demo-data")
42+
def main(output: str):
43+
shutil.rmtree(output, ignore_errors=True)
44+
os.makedirs(output, exist_ok=True)
45+
46+
name = "spawn99/wine-reviews"
47+
columns = [
48+
"country",
49+
"province",
50+
"description",
51+
"points",
52+
"price",
53+
"variety",
54+
"designation",
55+
]
56+
57+
ds = load_dataset(name, split="train")
58+
df = ds.to_pandas().sample(100)[columns] # type: ignore
59+
60+
add_embedding_projection(df, text="description")
61+
62+
df.to_parquet(os.path.join(output, "dataset.parquet"), index=False)
63+
64+
metadata = {
65+
"columns": {
66+
"id": "_row_index",
67+
"text": "description",
68+
"embedding": {"x": "projection_x", "y": "projection_y"},
69+
"neighbors": "__neighbors",
70+
},
71+
"is_static": True,
72+
"database": {"type": "wasm", "load": True},
73+
}
74+
75+
with open(os.path.join(output, "metadata.json"), "wb") as f:
76+
f.write(json.dumps(metadata).encode("utf-8"))
77+
78+
79+
if __name__ == "__main__":
80+
main()

0 commit comments

Comments
 (0)