[Time-Series] fix past_observed_mask type (#22076)

added > 0.5 to `past_observed_mask`
This commit is contained in:
Eli Simhayev 2023-04-03 20:07:21 +07:00 committed by GitHub
parent 559a45d1dc
commit 9eae4aa576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -117,7 +117,7 @@ class InformerModelTester:
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5
# decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])

View File

@ -114,7 +114,7 @@ class TimeSeriesTransformerModelTester:
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5
# decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])