pytorch显示内存 或者跟踪内存

    科技2025-09-07  27

    有时候用pytorch跑模型很容易出现内存溢出的情况,因此要具体查看内存情况: 简单版:

    #直接调用返回当前可用内存/显存 import torch import pynvml import psutil def get_avaliable_memory(device): if device==torch.device('cuda:0'): pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 0表示第一块显卡 meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) ava_mem=round(meminfo.free/1024**2) print('current available video memory is' +' : '+ str(round(meminfo.free/1024**2)) +' MIB') elif device==torch.device('cpu'): mem = psutil.virtual_memory() print('current available memory is' +' : '+ str(round(mem.used/1024**2)) +' MIB') ava_mem=round(mem.used/1024**2) return ava_mem

    复杂版,有一个github开源项目: https://github.com/Oldpan/Pytorch-Memory-Utils 主要代码如下:

    import gc import datetime import pynvml import torch import numpy as np class MemTracker(object): """ Class used to track pytorch memory usage Arguments: frame: a frame to detect current py-file runtime detail(bool, default True): whether the function shows the detail gpu memory usage path(str): where to save log file verbose(bool, default False): whether show the trivial exception device(int): GPU number, default is 0 """ def __init__(self, frame, detail=True, path='', verbose=False, device=0): self.frame = frame self.print_detail = detail self.last_tensor_sizes = set() self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt' self.verbose = verbose self.begin = True self.device = device self.func_name = frame.f_code.co_name self.filename = frame.f_globals["__file__"] if (self.filename.endswith(".pyc") or self.filename.endswith(".pyo")): self.filename = self.filename[:-1] self.module_name = self.frame.f_globals["__name__"] self.curr_line = self.frame.f_lineno def get_tensors(self): for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): tensor = obj else: continue if tensor.is_cuda: yield tensor except Exception as e: if self.verbose: print('A trivial exception occured: {}'.format(e)) def track(self): """ Track the GPU memory usage """ pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(self.device) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) self.curr_line = self.frame.f_lineno where_str = self.module_name + ' ' + self.func_name + ':' + ' line ' + str(self.curr_line) with open(self.gpu_profile_fn, 'a+') as f: if self.begin: f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |" f" Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n") self.begin = False if self.print_detail is True: ts_list = [tensor.size() for tensor in self.get_tensors()] new_tensor_sizes = {(type(x), tuple(x.size()), ts_list.count(x.size()), np.prod(np.array(x.size()))*4/1000**2) for x in self.get_tensors()} for t, s, n, m in new_tensor_sizes - self.last_tensor_sizes: f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20}\n') for t, s, n, m in self.last_tensor_sizes - new_tensor_sizes: f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} \n') self.last_tensor_sizes = new_tensor_sizes f.write(f"\nAt {where_str:<50}" f"Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n") pynvml.nvmlShutdown()

    使用方法: 定义好之后,有序的放到各个模块之间,就可以生成一个txt文件,显示显存的使用情况

    mport torch import inspect from torchvision import models from gpu_mem_track import MemTracker device = torch.device('cuda:0') frame = inspect.currentframe() # define a frame to track gpu_tracker = MemTracker(frame) # define a GPU tracker gpu_tracker.track() # run function between the code line where uses GPU cnn = models.vgg19(pretrained=True).features.to(device).eval() gpu_tracker.track() # run function between the code line where uses GPU dummy_tensor_1 = torch.randn(30, 3, 512, 512).float().to(device) # 30*3*512*512*4/1000/1000 = 94.37M dummy_tensor_2 = torch.randn(40, 3, 512, 512).float().to(device) # 40*3*512*512*4/1000/1000 = 125.82M dummy_tensor_3 = torch.randn(60, 3, 512, 512).float().to(device) # 60*3*512*512*4/1000/1000 = 188.74M gpu_tracker.track()
    Processed: 0.009, SQL: 8