我是小六子前端,准备链接Matrix,矩阵,互联网,高纬度宇宙。。。。

@6doai 这段代码算的是什么?
def compute_KL_divergence(D, real_samples, fake_samples):
alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty