88# ' @param strata vector of strings representing stratifying variables
99# ' @param grpName string representing variable name for treatment or
1010# ' exposure group
11+ # ' @param ratio vector of values indicating relative proportion of group
12+ # ' assignment
1113# ' @return An integer (group) ranging from 1 to length of the
1214# ' probability vector
1315# ' @seealso \code{\link{trtObserve}}
3638# ' dt5 <- trtAssign(dt, nTrt = 5, balanced = TRUE, grpName = "Group")
3739# ' dt5[, .N, keyby = .(male, Group)]
3840# ' dt5[, .N, keyby = .(Group)]
41+ # '
42+ # ' dt6 <- trtAssign(dt, nTrt = 3, ratio = c(1, 2, 2), grpName = "Group")
43+ # ' dt6[, .N, keyby = .(Group)]
3944# '
4045# ' @export
4146
4247trtAssign <- function (dtName , nTrt = 2 , balanced = TRUE ,
43- strata = NULL , grpName = " trtGrp" ) {
48+ strata = NULL , grpName = " trtGrp" , ratio = NULL ) {
4449
4550 # 'declare' vars
4651
@@ -56,6 +61,11 @@ trtAssign <- function(dtName, nTrt = 2, balanced = TRUE,
5661 if (grpName %in% names(dtName )) {
5762 stop(" Group name has previously been defined in data table" , call. = FALSE )
5863 }
64+ if (! is.null(ratio )) {
65+ if (length(ratio ) != nTrt ) {
66+ stop(" Number of treatments does not match specified ratio" , call. = FALSE )
67+ }
68+ }
5969
6070 dt <- copy(dtName )
6171
@@ -68,21 +78,28 @@ trtAssign <- function(dtName, nTrt = 2, balanced = TRUE,
6878 }
6979
7080 dt [, .n : = .N , keyby = .stratum ]
71- dtrx <- dt [, list (grpExp = .stratSamp(.n [1 ], nTrt )), keyby = .stratum ]
81+ dtrx <- dt [, list (grpExp = .stratSamp(.n [1 ], nTrt , ratio )), keyby = .stratum ]
7282 dt [, grpExp : = dtrx $ grpExp ]
7383 dt [, `:=`(.stratum = NULL , .n = NULL )]
7484
75- if (nTrt == 2 ) dt [grpExp == 2 , grpExp : = 0 ]
85+ if (nTrt == 2 ) dt [, grpExp : = grpExp - 1 ]
7686 data.table :: setnames(dt , " grpExp" , grpName )
7787 data.table :: setkeyv(dt ,key(dtName ))
7888
7989 } else { # balanced is FALSE - strata are not relevant
8090
81- if (nTrt == 2 ) {
82- formula <- .5
83- } else {
84- formula <- rep(1 / nTrt , nTrt )
91+ if (is.null(ratio )) {
92+
93+ if (nTrt == 2 ) {
94+ formula <- .5
95+ } else {
96+ formula <- rep(1 / nTrt , nTrt )
97+ }
98+
99+ } else { # ratio not null
100+ formula <- ratio / sum(ratio )
85101 }
102+
86103
87104 dt <- trtObserve(dt , formulas = formula , logit.link = FALSE , grpName )
88105
0 commit comments