449 def forward(self, x, gain, bias):
450 if self.training:
451 out, mean, var = manual_bn(
452 x, gain, bias, return_mean_var=True, eps=self.eps
453 )
454
455 if self.accumulate_standing:
456 self.stored_mean[:] = self.stored_mean + mean.data
457 self.stored_var[:] = self.stored_var + var.data
458 self.accumulation_counter += 1.0
459
460 else:
461 self.stored_mean[:] = (
462 self.stored_mean * (1 - self.momentum) + mean * self.momentum
463 )
464 self.stored_var[:] = (
465 self.stored_var * (1 - self.momentum) + var * self.momentum
466 )
467 return out
468
469 else:
470 mean = self.stored_mean.view(1, -1, 1, 1)
471 var = self.stored_var.view(1, -1, 1, 1)
472
473 if self.accumulate_standing:
474 mean = mean / self.accumulation_counter
475 var = var / self.accumulation_counter
476 return fused_bn(x, mean, var, gain, bias, self.eps)
477
478