首页>>人工智能->WDNet的Paddle复现

WDNet的Paddle复现

时间:2023-11-29 本站 点击:0

实验室算力太垃圾,我想白嫖V100算力。现初次尝试使用Paddle框架进行torch框架代码的重构

如果你想白嫖V100, 这个link我的邀请码,我们可以一起获得更多额外的算力

写在前面

关于Paddle的资料不如Torch和Tensorflow那么多,在Goole上边也搜不到什么。百度上面的资料也参差不齐,因此最高效的查资料的方式是去Paddle的论坛中,找对应的API或者去论坛中,以及一些Paddle论文复现的群

Paddle 论坛

Paddle API文档

1. WDNet架构简述

WDNet主要是用于可见水印去除的,该网络由一个生成器和一个鉴别器构成。

生成器是由一个U-net作为Backbone,U-net进行编解码之后,进行了对水印掩膜的分割任务、对水印像素的重构映射、对透明度的重构映射。通过这3个结果,初步得到去除水印的结果,接着进入第二个阶段进行精细化的去除,该部分是由Resblocks构成的,得到最终的去除结果。

鉴别器是采用的CGAN的鉴别器,以【水印图像,无水印图像(label)】作为输入,或者以【预测的水印图像,无水印图像(label)】作为输入进行鉴别训练。

2. 复现

2.1 常见的API的对照表

详情参照:PyTorch-PaddlePaddle API映射表

在对论文进行重构时,以下列出我在复现时最常用映射的API,有些是在API文档直接看不到的

Torch Paddle torch.nn.Module paddle.nn.Layer torch.nn.functional paddle.nn.functional torch.cat paddle.concat torch.optim paddle.optimizer torch.Tensor paddle.to_tensor torch.size() paddle.shape torch.load paddle.load .zero_grad() .clear_gradients() torch.no_grad() paddle.no_grad() torchvision.models paddle.vision.models .requires_grad = True .stop_gradient = False

还有对各个组件层参数进行的初始化weight_init

Paddle代码

from paddleseg.cvlibs import param_initdef weight_init(Layer):    for n, m in Layer.named_children():        if isinstance(m, nn.Conv2D):            param_init.normal_init(m.weight, mean=0.0, std=0.02)        elif isinstance(m, nn.BatchNorm2D):            param_init.normal_init(m.weight, mean=1.0, std=0.02)            param_init.constant_init(m.bias, value=0)

Torch代码

def weight_init(m):    if isinstance(m, nn.Conv2d):        nn.init.normal_(m.weight.data, 0.0, 0.02)    elif isinstance(m, nn.BatchNorm2d):        nn.init.normal_(m.weight.data, 1.0, 0.02)        nn.init.constant_(m.bias.data, 0)

2.2  一些评价指标计算

衡量指标主要是PSNR和SSIM

PSNR计算 一定要考虑输入图像的像素区间是【0-255】还是【0-1】还是【-1,1】,要做对应的转换

from paddle.nn import functional as F

psnr = 10 * paddle.log10(pixelmax / F.mse_loss(pred, GT, reduction="mean")).item()

其中pixelmax是输入图像像素上限值,是255或者1F.mse_loss算的的是一个均方误差- **SSIM计算**因为手动构建代码较为复杂,因此使用相应的库```pythonfrom paddle_msssim import SSIMself.ssim = SSIM(data_range=1.0)ssim = self.ssim(pred, GT).item()```#  3. DEMO重构的代码正在训练;现成的Torch代码,可以参照[论文作者的代码](https://github.com/MRUIL/WDNet)
原文:https://juejin.cn/post/7097094763668570119


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:/AI/1175.html