Skip to content

tq21/atmle

Repository files navigation

atmle: Adaptive TMLE for RCT + RWD Data Fusion

License: MIT

This package implements adaptive targeted minimum loss-based estimation of average treatment effects using combined randomized trial and real-world data to improve efficiency.

Authors: Sky Qiu, Jens Tarp, Lars van der Laan, Mark van der Laan,


Issues

If you encounter any bugs or have any specific feature requests, please file an issue.


Example

library(atmle)
library(sl3)
library(ggplot2)

set.seed(2026)
data <- readRDS("data/example_data.RDS")
true_ate <- -0.04429033
W_nodes <- names(data)[!(names(data) %in% c("Y", "A", "S"))]

# unadjusted trial-only difference-in-means
rct <- data[data$S == 1, ]
diff_est <- with(rct, mean(Y[A == 1])-mean(Y[A == 0]))
diff_se <- sqrt(with(rct,
                     var(Y[A == 1])/sum(A == 1)+var(Y[A == 0])/sum(A == 0)))
diff_ci <- diff_est + c(-1, 1) * qnorm(0.975) * diff_se

# A-TMLE
lrnr <- list(
  learners = list(
    sl3::Lrnr_glm$new(),
    sl3::Lrnr_earth$new(),
    sl3::Lrnr_gam$new()
  ),
  metalearner = sl3::Lrnr_cv_selector$new(sl3::loss_loglik_binomial)
)
fit <- atmle_ate_fusion$new(
  data = data,
  S_node = "S",
  W_nodes = W_nodes,
  A_node = "A",
  Y_node = "Y",
  family = "binomial",
  n_folds = 3
)

fit$run(
  g_bar_method = lrnr,
  theta_method = lrnr,
  Pi_method = lrnr,
  Q_bar_method = lrnr,
  A_cate_args = list(max_degree = 3L, smoothness_orders = 1L, num_knots = 5L),
  S_cate_args = list(max_degree = 3L, smoothness_orders = 1L, num_knots = 5L),
  target_method = "tmle",
  target_gwt = FALSE,
  max_iter = 10,
  verbose = FALSE
)
atmle_res <- fit$results[fit$results$param == "Avg. over S=1", ]
df_plot <- data.frame(Estimator = c("Unadjusted Difference-in-Means",
                                    "A-TMLE"),
                      Estimate = c(diff_est, atmle_res$psi),
                      Lower = c(diff_ci[1], atmle_res$lower),
                      Upper = c(diff_ci[2], atmle_res$upper))
df_plot$Estimator <- factor(df_plot$Estimator, 
                            levels = c("Unadjusted Difference-in-Means",
                                       "A-TMLE"))
ggplot(df_plot, aes(x = Estimator, y = Estimate, fill = Estimator)) +
  geom_point() +
  geom_errorbar(aes(ymin = Lower, ymax = Upper), width = 0.2, linewidth = 0.8) +
  geom_hline(yintercept = true_ate, linetype = "dashed", color = "red") +
  theme_minimal(base_size = 13) +
  labs(title = "Point Estimates with 95% Confidence Intervals",
       y = "Estimate",
       x = "") +
  theme(legend.position = "none",
        plot.title = element_text(hjust = 0.5)) +
  coord_flip()

License

© 2026 Sky Qiu, Jens Tarp, Lars van der Laan, Mark van der Laan,

The contents of this repository are distributed under the MIT license.

About

No description, website, or topics provided.

Resources

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages