-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
The class Adapter expects Z in constructor:
class Adapter(transformers.PreTrainedModel):
config_class = transformers.PretrainedConfig
def __init__(self, config, classifiers=None, Z=None, labels_list=[]):
super().__init__(config)
self.Z= torch.nn.Embedding(len(config.classifiers_size),config.hidden_size, max_norm=1.0).weight if Z==None else Z
self.classifiers=torch.nn.ModuleList(
[torch.nn.Linear(config.hidden_size,size) for size in config.classifiers_size]
) if classifiers==None else classifiers
self.config=self.config.from_dict(
{**self.config.to_dict(),
'labels_list':labels_list}
)
def adapt_model_to_task(self, model, task_name):
task_index=self.config.tasks.index(task_name)
#setattr(model,search_module(model,'linear',mode='class')[-1], self.classifiers[task_index])
model.classifier=self.classifiers[task_index]
return model
def _init_weights(*args):
pass
but doesn't use it at all when adapting model to task?
Metadata
Metadata
Assignees
Labels
No labels