Coverage for transformer_lens/utilities/gpu_utils.py: 80%

5 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""gpu_utils. 

2 

3This module contains varied utility functions related to GPUs. 

4""" 

5 

6from __future__ import annotations 

7 

8import numpy as np 

9import torch 

10 

11 

12def print_gpu_mem(step_name=""): 

13 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")