[Time-Series] fix past_observed_mask type (#22076)
added > 0.5 to `past_observed_mask`
This commit is contained in:
parent
559a45d1dc
commit
9eae4aa576
|
@ -117,7 +117,7 @@ class InformerModelTester:
|
|||
|
||||
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
|
||||
past_values = floats_tensor([self.batch_size, _past_length])
|
||||
past_observed_mask = floats_tensor([self.batch_size, _past_length])
|
||||
past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5
|
||||
|
||||
# decoder inputs
|
||||
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
|
||||
|
|
|
@ -114,7 +114,7 @@ class TimeSeriesTransformerModelTester:
|
|||
|
||||
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
|
||||
past_values = floats_tensor([self.batch_size, _past_length])
|
||||
past_observed_mask = floats_tensor([self.batch_size, _past_length])
|
||||
past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5
|
||||
|
||||
# decoder inputs
|
||||
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
|
||||
|
|
Loading…
Reference in New Issue