Skip to content

Commit 6fbc509

Browse files
authored
fix issue with max_sample_size_embeddings argument (#6)
1 parent 54cea05 commit 6fbc509

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
exclude: '^(src/mostlyai/qa/report_assets)/'
2+
exclude: '^(src/mostlyai/qa/assets)/'
33
repos:
44
- repo: local
55
hooks:

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ report_path, metrics = qa.report(
4949
)
5050
```
5151

52+
Note, that due to the calculation of embeddings the function call might take a while. Embedding 10k samples on a Mac M2 take for example about 40secs. Limit the size of the passed DataFrames, or use the `max_sample_size_embeddings` parameter to speed up the report.
53+
5254
## Function signature
5355

5456
```python
@@ -126,21 +128,21 @@ Three sets of metrics are calculated to compare synthetic data with the original
126128
### Accuracy
127129

128130
The L1 distances between the discretized marginal distributions of the synthetic and the original training data are being calculated for all columns.
129-
The reported accuracy is expressed as 100% minus the total variational distance (TVD), which is half the L1 distance between the two distributions.
131+
The reported accuracy is expressed as 100% minus the total variational distance (TVD), which is half the L1 distance between the two distributions.
130132
These accuracies are then averaged to produce a single accuracy score, where higher scores indicate better synthetic data.
131133

132-
1. **Univariate Accuracy**: The accuracy of the univariate distributions for all target columns is measured.
133-
2. **Bivariate Accuracy**: The accuracy of all pair-wise distributions for target columns, as well as for target columns with respect to the context columns, is measured.
134-
3. **Coherence Accuracy**: The accuracy of the auto-correlation for all target columns is measured. This is applicable only for sequential data.
134+
1. **Univariate Accuracy**: The accuracy of the univariate distributions for all target columns is measured.
135+
2. **Bivariate Accuracy**: The accuracy of all pair-wise distributions for target columns, as well as for target columns with respect to the context columns, is measured.
136+
3. **Coherence Accuracy**: The accuracy of the auto-correlation for all target columns is measured. This is applicable only for sequential data.
135137

136138
An overall accuracy score is calculated as the average of these aggregate-level scores.
137139

138140
### Similarity
139141

140142
All records are embedded into an embedding space to calculate two metrics:
141143

142-
1. **Cosine Similarity**: The cosine similarity between the centroids of the synthetic and the original training data is calculated and compared to the cosine similarity between the centroids of the original training and holdout data. Higher scores indicate better synthetic data.
143-
2. **Discriminator AUC**: A binary classifier is trained to determine whether synthetic and original training data can be distinguished based on their embeddings. This score is compared to the same metric for the original training and holdout data. A score close to 50% indicates that synthetic samples are indistinguishable from original samples.
144+
1. **Cosine Similarity**: The cosine similarity between the centroids of the synthetic and the original training data is calculated and compared to the cosine similarity between the centroids of the original training and holdout data. Higher scores indicate better synthetic data.
145+
2. **Discriminator AUC**: A binary classifier is trained to determine whether synthetic and original training data can be distinguished based on their embeddings. This score is compared to the same metric for the original training and holdout data. A score close to 50% indicates that synthetic samples are indistinguishable from original samples.
144146

145147
### Distances
146148

src/mostlyai/qa/report.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,12 @@ def report(
224224
on_progress(current=30, total=100)
225225

226226
# ensure that embeddings are all of equal size for a fair 3-way comparison
227-
max_sample_size_embeddings = min(syn_sample_size, trn_sample_size)
228-
if hol_sample_size != 0:
229-
max_sample_size_embeddings = min(max_sample_size_embeddings, hol_sample_size)
230-
227+
max_sample_size_embeddings = min(
228+
max_sample_size_embeddings or float("inf"),
229+
syn_sample_size,
230+
trn_sample_size,
231+
hol_sample_size or float("inf"),
232+
)
231233
# calculate embeddings
232234
syn_embeds = calculate_embeddings(
233235
pull_data_for_embeddings(

0 commit comments

Comments
 (0)