Torch Hook


< 목차 >


Extracting activations from a layer

ResNet Example

from PIL import Image
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import transforms as T
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

image = Image.open('cat.jpeg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)

model = resnet18(pretrained=True)
model = model.to(device)

out = model(X)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

...

model.layer3[0].downsample[1]
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
features = {}
def get_features(name):
    def hook(model, input, output):
        features[name] = output.detach()
    return hook
h1 = model.avgpool.register_forward_hook(get_features('avgpool'))
h2 = model.maxpool.register_forward_hook(get_features('maxpool'))
h3 = model.layer3[0].downsample[1].register_forward_hook(get_features('comp'))
out = model(X)
for k,v in features.items():
    print(k)
dict_keys(['maxpool', 'comp', 'avgpool'])
{'maxpool': tensor([[[[0.1353, 0.1667, 0.1667,  ..., 0.1608, 0.1212, 0.1595],
          [0.2054, 0.2170, 0.2137,  ..., 0.2532, 0.2609, 0.2438],
          [0.1948, 0.2130, 0.2137,  ..., 0.2781, 0.2662, 0.2738],
          ...,
          [0.4701, 0.4701, 0.3039,  ..., 0.8598, 0.8067, 0.8267],
          [0.3148, 0.3475, 0.3475,  ..., 0.7769, 0.7109, 0.7109],
          [0.3148, 0.3475, 0.3475,  ..., 0.7769, 0.7109, 0.7109]],

         [[0.6138, 0.4425, 0.4108,  ..., 0.3728, 0.3662, 0.2074],
          [0.6186, 0.4509, 0.4509,  ..., 0.4669, 0.4669, 0.3600],
          [0.6540, 0.4388, 0.4542,  ..., 0.5525, 0.5525, 0.4178],
          ...,
          [1.3034, 1.2000, 0.9978,  ..., 1.0912, 1.0912, 0.5319],
          [1.1350, 0.9732, 0.6896,  ..., 0.5049, 0.5121, 0.5468],
          [1.1350, 0.9732, 0.6051,  ..., 0.5049, 0.3075, 0.1783]],
...
h1.remove()
h2.remove()
h3.remove()

Stable Diffusion Example

# !pip install diffusers
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')
features = {}

def get_features(name):
    def hook(model, input, output):
        
        linear = nn.Linear(model.in_features, model.out_features, model.bias is not None)
        linear.load_state_dict(model.state_dict())
        linear.to("cuda")
        
        if features.get(name, None) is None:
            features[name] = [(input[0].detach().cpu(), output.detach().cpu(), linear(input[0].detach().float()).detach().cpu())]
        else:
            features[name].append((input[0].detach().cpu(), output.detach().cpu(), linear(input[0].detach().float()).detach().cpu())) 
            
    return hook
for name, module in pipe.unet.named_modules():
    if isinstance(module, nn.Linear):
        module.register_forward_hook(get_features(name))
out = pipe(prompt="a cute cat", num_inference_steps = 5)
out.images[0]
for k,v in features.items():
    print(k)
print(features['up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q'])

References