429 def forward(self, x, gain, bias):
430
431 if self.training:
432 out, mean, var = manual_bn(
433 x, gain, bias, return_mean_var=True, eps=self.eps
434 )
435
436 if self.accumulate_standing:
437 self.stored_mean[:] = self.stored_mean + mean.data
438 self.stored_var[:] = self.stored_var + var.data
439 self.accumulation_counter += 1.0
440
441 else:
442 self.stored_mean[:] = (
443 self.stored_mean * (1 - self.momentum) + mean * self.momentum
444 )
445 self.stored_var[:] = (
446 self.stored_var * (1 - self.momentum) + var * self.momentum
447 )
448 return out
449
450 else:
451 mean = self.stored_mean.view(1, -1, 1, 1)
452 var = self.stored_var.view(1, -1, 1, 1)
453
454 if self.accumulate_standing:
455 mean = mean / self.accumulation_counter
456 var = var / self.accumulation_counter
457 return fused_bn(x, mean, var, gain, bias, self.eps)
458
459