GitXplorerGitXplorer
s

sparse_transformers

public
1 stars
0 forks
0 issues

Commits

List of commits on branch main.
Unverified
acec8bfbfd5485001b9350036730b8017bd68afb

update main attack

ssirluk committed 2 years ago
Unverified
3099b17290f97b568f0c2c28a803604482733294

update nepoch attack

ssirluk committed 2 years ago
Unverified
0e9ebee0f340c5a5ba7c6fdf64af6103e2853b79

add possibility to run with same seed

ssirluk committed 2 years ago
Unverified
9ad9ec05ce0c856e03e1f87fb8348fcaf757ad3b

change results dict creation

ssirluk committed 2 years ago
Unverified
14a4220993a238e0516e19e6d82f0f4192f44006

small improvements

ssirluk committed 2 years ago
Unverified
a97c886c3c76ccada56797245c163fbabc1699be

integrate task head loading for adv baseline

ssirluk committed 2 years ago

README

The README file for this repository.

Diff-Pruning with Adverserial Training

This code implements various models related to adverserial training and diff-pruning (Guo et. al., 2020)

Installation

To run the code make sure conda is installed and then run

conda env create -f environment.yml

Then activate the environment by running

conda activate diff_pruning

Architecture

The project structure looks as follows

📦sparse_transformers
┣ 📂analysis (jupyter notebooks for analysis of models and results)
┣ 📂scripts (additional scripts for various experiments)
┣ 📂src
┃ ┣ 📂models (directory which contains all model classes)
┃ ┃ ┣ 📜model_adv.py (baseline model for adverserial training)
┃ ┃ ┣ 📜model_base.py (contains base classes with methods that are used by all models)
┃ ┃ ┣ 📜model_diff_adv.py (model with 2 subnetworks for adverserial training)
┃ ┃ ┣ 📜model_diff_modular.py (model with 2 subnetworks for task and adv training)
┃ ┃ ┣ 📜model_diff_task.py (model with subnetwork for task training)
┃ ┃ ┣ 📜model_functions.py (generic functions for all model classes)
┃ ┃ ┣ 📜model_heads.py (classifier and adverserial head classes)
┃ ┃ ┣ 📜model_modular.py (baseline model for modular training)
┃ ┃ ┣ 📜model_task.py (baseline model for task training)
┃ ┃ ┗ 📜weight_parametrizations.py (contains weight parametrizations for subnetwork training*)
┃ ┣ 📜adv_attack.py (contains function to run adverserial attack)
┃ ┣ 📜data_handler.py
┃ ┣ 📜metrics.py
┃ ┣ 📜training_logger.py
┃ ┗ 📜utils.py
┣ 📜cfg.yml (hyperparameters)
┣ 📜environment.yml (conda environment config)
┣ 📜main.py (main file to run experiments with)
┣ 📜main_attack.py (used to run an adverserial attack only using a model checkpoint)
┣ 📜main_wrapper.py (used to run multiple experiments sequentially)
┗ 📜readme.md

* Weight parametrizations are implemented as modules and use pytorch parametrizations functionality LINK

cfg.yml

contains hyperparameter configuration

  • data_config
    filepaths to data files
  • model_config
    name of pretrained model and batch_size to use
  • train_config_diff_pruning
    hyperparameters for diff-pruning-models (model_diff_adv.py and model_diff_task.py)
  • train_config_baseline
    hyperparameters for baseline models (model_adv.py and model_task.py)
  • adv_attack hyperparameters for adverserial attack

Usage

python3 main.py

Optional arguments with example inputs

  • --gpu_id 0 1 2 3
    Which gpus to run experiment on (can be multiple)
  • --adv
    Set if you want to run adverserial training
  • --baseline
    Set if you want to run a baseline model instead of diff-pruning
  • --modular
    Run modular architecture (overwrites adv argument)
  • --seed=0
    random seed
  • --ds="bios"
    which dataset to run ("bios", "pan16", "hatespeech")
  • --cpu
    Run on cpu (even if gpu is available)
  • --no_adv_attack
    Set if you do not want to run adverserial attack after training
  • --cp_path="path_to_model"
    Overwrite pre-trained encoder weights
  • --cp_is_sd
    Set if checkpoint is a state dict
  • --cp_model_type="TaskModel"
    Model type from which to load encoder weights as string (not required if loading state dict directly)
  • --cp_modular_biased
    If loading checkpoint from modular model set debiased state as false
  • --prot_key_idx=0
    If protected key is type list: index of key to use
  • --debug
    To verify code can run through, limits number of batches which are used to 10