Torch Hook
10 Mar 2023< 목차 >
Extracting activations from a layer
ResNet Example
from PIL import Image
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import transforms as T
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
image = Image.open('cat.jpeg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)
model = resnet18(pretrained=True)
model = model.to(device)
out = model(X)
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
...
model.layer3[0].downsample[1]
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
features = {}
def get_features(name):
def hook(model, input, output):
features[name] = output.detach()
return hook
h1 = model.avgpool.register_forward_hook(get_features('avgpool'))
h2 = model.maxpool.register_forward_hook(get_features('maxpool'))
h3 = model.layer3[0].downsample[1].register_forward_hook(get_features('comp'))
out = model(X)
for k,v in features.items():
print(k)
dict_keys(['maxpool', 'comp', 'avgpool'])
{'maxpool': tensor([[[[0.1353, 0.1667, 0.1667, ..., 0.1608, 0.1212, 0.1595],
[0.2054, 0.2170, 0.2137, ..., 0.2532, 0.2609, 0.2438],
[0.1948, 0.2130, 0.2137, ..., 0.2781, 0.2662, 0.2738],
...,
[0.4701, 0.4701, 0.3039, ..., 0.8598, 0.8067, 0.8267],
[0.3148, 0.3475, 0.3475, ..., 0.7769, 0.7109, 0.7109],
[0.3148, 0.3475, 0.3475, ..., 0.7769, 0.7109, 0.7109]],
[[0.6138, 0.4425, 0.4108, ..., 0.3728, 0.3662, 0.2074],
[0.6186, 0.4509, 0.4509, ..., 0.4669, 0.4669, 0.3600],
[0.6540, 0.4388, 0.4542, ..., 0.5525, 0.5525, 0.4178],
...,
[1.3034, 1.2000, 0.9978, ..., 1.0912, 1.0912, 0.5319],
[1.1350, 0.9732, 0.6896, ..., 0.5049, 0.5121, 0.5468],
[1.1350, 0.9732, 0.6051, ..., 0.5049, 0.3075, 0.1783]],
...
h1.remove()
h2.remove()
h3.remove()
Stable Diffusion Example
# !pip install diffusers
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to('cuda')
features = {}
def get_features(name):
def hook(model, input, output):
linear = nn.Linear(model.in_features, model.out_features, model.bias is not None)
linear.load_state_dict(model.state_dict())
linear.to("cuda")
if features.get(name, None) is None:
features[name] = [(input[0].detach().cpu(), output.detach().cpu(), linear(input[0].detach().float()).detach().cpu())]
else:
features[name].append((input[0].detach().cpu(), output.detach().cpu(), linear(input[0].detach().float()).detach().cpu()))
return hook
for name, module in pipe.unet.named_modules():
if isinstance(module, nn.Linear):
module.register_forward_hook(get_features(name))
out = pipe(prompt="a cute cat", num_inference_steps = 5)
out.images[0]
for k,v in features.items():
print(k)
print(features['up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q'])