Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/rail/estimation/algos/tpz_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run(self):
if self._parallel == MPI_PARALLEL:
self._comm.Barrier()
if self._rank == 0:
print(
self.log.info(
f"self._parallel is {self._parallel}, number of processors we will use is {self._size}"
)

Expand Down Expand Up @@ -196,9 +196,9 @@ def run(self):
f"value of {self.config.tree_strategy} not valid! Valid values for tree_strategy are 'native' or 'sklearn'"
)
if self.config.tree_strategy == "sklearn" and self._rank == 0:
print("using sklearn decision trees")
self.log.info("using sklearn decision trees")
if self.config.tree_strategy == "native" and self._rank == 0:
print("using native TPZ decision trees")
self.log.info("using native TPZ decision trees")

# TPZ expects a param called `keyatt` that is just the redshift column, copy redshift_col
self.config.keyatt = self.config.redshift_col
Expand All @@ -213,8 +213,8 @@ def run(self):
# npdata = np.array(list(training_data.values()))
trainkeys = self.config.bands + self.config.err_bands
trainkeys.append(self.config.redshift_col)
print(trainkeys)
print("STOP")
self.log.info(trainkeys)
self.log.info("STOP")
ncols = len(trainkeys)
nvals = len(training_data[self.config.redshift_col])
npdata = np.zeros([ncols, nvals])
Expand Down Expand Up @@ -244,7 +244,7 @@ def run(self):
# not how I would have done things, but we're keeping it to try to duplicate MLZ's code exactly.
if self.config.n_random > 1:
if self._rank == 0:
print(f"creating {self.config.n_random} random realizations...")
self.log.info(f"creating {self.config.n_random} random realizations...")
traindata.make_random(ntimes=int(self.config.n_random))
temprandos = traindata.BigRan
else: # pragma: no cover
Expand All @@ -264,7 +264,7 @@ def run(self):

ntot = int(self.config.n_random * self.config.n_trees)
if self._rank == 0:
print(
self.log.info(
f"making a total of {ntot} trees for {self.config.n_random} random realizations * {self.config.n_trees} bootstraps"
)

Expand All @@ -278,16 +278,16 @@ def run(self):
for i in range(Nproc):
Xs_0, Xs_1 = utils_mlz.get_limits(ntot, Nproc, i)
if Xs_0 == Xs_1: # pragma: no cover
print(f"idle... -------------> to core {i}")
self.log.info(f"idle... -------------> to core {i}")
else:
print(f"{Xs_0} - {Xs_1} -------------> to core {i}")
self.log.info(f"{Xs_0} - {Xs_1} -------------> to core {i}")

treedict = {}
if self._parallel == MPI_PARALLEL:
self._comm.Barrier()
# copy some stuff from the runMLZ script:
for kss in range(s0, s1):
print(f"making {kss + 1} of {ntot}...")
self.log.info(f"making {kss + 1} of {ntot}...")
if self.config.n_random > 1:
ir = kss // int(self.config.n_trees)
if ir != 0:
Expand Down Expand Up @@ -320,7 +320,7 @@ def run(self):
if self._parallel == MPI_PARALLEL:
if self._rank == 0:
for i in range(1, self._size, 1):
print(f"receiving data from rank {i}")
self.log.info(f"receiving data from rank {i}")
xdata = self._comm.recv(source=i, tag=11)
for key in xdata:
treedict[key] = xdata[key]
Expand Down Expand Up @@ -442,7 +442,7 @@ def _process_chunk(self, start, end, inputdata, first):

# Load trees
alltreedict = self.model["treedict"]
print(f"loading {ntot} total trees from model")
self.log.info(f"loading {ntot} total trees from model")
for k in range(ntot):

S = alltreedict[f"tree_{k}"]
Expand Down
Loading