Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mostlyai/engine/_tabular/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def generate(

if not enable_flexible_generation:
check_column_order(gen_column_order, trn_column_order)

_LOG.info(f"{rare_category_replacement_method=}")
rare_token_fixed_probs = fix_rare_token_probs(tgt_stats, rare_category_replacement_method)
imputation_fixed_probs = _fix_imputation_probs(tgt_stats, imputation)
Expand Down
7 changes: 3 additions & 4 deletions mostlyai/engine/_tabular/probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,14 @@ def predict_proba(
)
)

# Get seed column names (needed for column order check and _generate_marginal_probs)
seed_columns = list(seed_data.columns)

# Check column order when flexible generation is disabled
if not enable_flexible_generation:
seed_columns_argn = get_argn_column_names(tgt_stats["columns"], seed_columns)
target_columns_argn = get_argn_column_names(tgt_stats["columns"], target_columns)
gen_column_order = seed_columns_argn + target_columns_argn
check_column_order(gen_column_order, all_columns)
columns_to_check = seed_columns_argn + target_columns_argn
expected_order = [col for col in all_columns if col in columns_to_check]
check_column_order(columns_to_check, expected_order)

# Encode seed data (features to condition on) - common for both single and multi-target
# seed_data should NOT include any target columns
Expand Down
42 changes: 33 additions & 9 deletions tests/end_to_end/test_tabular_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,15 @@ def test_predict_proba_multi_target(
# Numeric binned values (may be bin labels or ranges)
assert len(col_values) >= 3 # At least some bins present

def test_predict_proba_wrong_column_order_raises(self, classification_data, tmp_path_factory):
"""Test predict_proba raises error with different column order when flexible generation is disabled."""
data = classification_data
X = data[["feature1", "feature2"]]
y = data["target"]
def test_wrong_column_order_raises(self, tmp_path_factory):
"""Test that wrong column order raises error when flexible generation is disabled."""
data = pd.DataFrame(
{
"col_a": ["x", "y", "z"] * 20,
"col_b": ["p", "q", "r"] * 20,
"col_c": ["1", "2", "3"] * 20,
}
)

argn = TabularARGN(
model="MOSTLY_AI/Small",
Expand All @@ -360,13 +364,33 @@ def test_predict_proba_wrong_column_order_raises(self, classification_data, tmp_
enable_flexible_generation=False,
workspace_dir=tmp_path_factory.mktemp("workspace"),
)
argn.fit(X=X, y=y)
argn.fit(X=data)

# Reorder columns in test data
test_X = X.head(10)[["feature2", "feature1"]]
# Wrong seed order for sample
X_wrong_seed = data.head(5)[["col_b", "col_a"]] # wrong: should be col_a, col_b
with pytest.raises(ValueError, match="(?i)column order.*does not match"):
argn.sample(n_samples=5, seed_data=X_wrong_seed)

# Wrong seed order for predict_proba
with pytest.raises(ValueError, match="(?i)column order.*does not match"):
argn.predict_proba(X_wrong_seed, target="col_c")

# Wrong seed order for predict
with pytest.raises(ValueError, match="(?i)column order.*does not match"):
argn.predict_proba(test_X, target="target")
argn.predict(X_wrong_seed, target="col_c")

# Wrong target order for predict_proba (computes joint probabilities in order)
X_seed = data.head(5)[["col_a"]]
with pytest.raises(ValueError, match="(?i)column order.*does not match"):
argn.predict_proba(X_seed, target=["col_c", "col_b"]) # wrong: should be col_b, col_c

# predict() doesn't require target order - it generates all columns and extracts targets
result = argn.predict(X_seed, target=["col_c", "col_b"])
assert list(result.columns) == ["col_c", "col_b"]

# predict() works even with targets completely out of original order
result = argn.predict(X_seed, target=["col_c", "col_b", "col_a"])
assert list(result.columns) == ["col_c", "col_b", "col_a"]


class TestTabularARGNRegression:
Expand Down