Source code for mpinterfaces.firetasks

# coding: utf-8
# Copyright (c) Henniggroup.
# Distributed under the terms of the MIT License.

from __future__ import division, print_function, unicode_literals, \
    absolute_import

"""
Defines various firetasks
"""

from pymatgen.apps.borg.queen import BorgQueen

from monty.json import MontyDecoder

from fireworks.core.firework import FireTaskBase, FWAction
from fireworks.utilities.fw_serializers import FWSerializable
from fireworks.utilities.fw_utilities import explicit_serialize

from mpinterfaces.database import MPINTVaspToDbTaskDrone
from mpinterfaces.default_logger import get_default_logger

logger = get_default_logger(__name__)


[docs]def load_class(mod, name): """ load class named name from module mod this function is adapted from materialsproject """ mod = __import__(mod, globals(), locals(), [name], 0) return getattr(mod, name)
[docs]def get_cal_obj(d): """ construct a calibration object from the input dictionary, d returns a calibration object """ cal = MontyDecoder().process_decoded(d) # default # if not d.get("qadapter"): # qadapter, job_cmd = get_run_cmmnd(**d.get('que_params', {})) # cal.qadapter = qadapter # cal.job_cmd = job_cmd return cal
@explicit_serialize
[docs]class MPINTCalibrateTask(FireTaskBase): """ Calibration Task """ optional_params = ["que_params"]
[docs] def run_task(self, fw_spec): """ launch jobs to the queue """ cal = get_cal_obj(self) cal.setup() cal.run() d = cal.as_dict() d.update({'que_params': self.get('que_params')}) return FWAction(mod_spec=[{'_push': {'cal_objs': d}}])
@explicit_serialize
[docs]class MPINTMeasurementTask(FireTaskBase, FWSerializable): """ Measurement Task """ required_params = ["measurement"] optional_params = ["que_params", "job_cmd", "other_params", "fw_id"]
[docs] def run_task(self, fw_spec): """ setup up a measurement task using the prior calibration jobs and run """ cal_objs = [] logger.info( 'The measurement task will be constructed from {} calibration objects' .format(len(fw_spec['cal_objs']))) for calparams in fw_spec['cal_objs']: calparams.update({'que_params': self.get('que_params')}) cal = get_cal_obj(calparams) cal_objs.append(cal) done = load_class("mpinterfaces.calibrate", "Calibrate").check_calcs( cal_objs) if not done: logger.info('Calibration not done yet. Try again later') logger.info('All subsequent fireworks will be defused') logger.info("""Try re-running this firework again later. Re-running this firework will activate all the subsequent foreworks too""") logger.info('This fireworks id = {}'.format(self.get("fw_id"))) return FWAction(defuse_children=True) # to enable dynamic workflow, uncomment the following # if self.get("fw_id"): # fw_id = int(self.get("fw_id")) + 1 # self["fw_id"] = fw_id # new_fw = Firework(MPINTMeasurementTask(self), # spec={'cal_objs':fw_spec['cal_objs']}, # name = 'new_fw', fw_id = -fw_id) # else: # new_fw = Firework(MPINTMeasurementTask(self), # spec={'cal_objs':fw_spec['cal_objs']}, # name = 'new_fw') # # return FWAction(detours=new_fw) else: measure = load_class("mpinterfaces.measurement", self['measurement'])(cal_objs, **self.get( "other_params", {})) job_cmd = None if self.get("job_cmd", None) is not None: job_cmd = self.get("job_cmd") measure.setup() measure.run(job_cmd=job_cmd) cal_list = [] for cal in measure.cal_objs: d = cal.as_dict() d.update({'que_params': self.get('que_params')}) cal_list.append(d) return FWAction(update_spec={'cal_objs': cal_list})
@explicit_serialize
[docs]class MPINTDatabaseTask(FireTaskBase, FWSerializable): """ submit data to the database firetask """ required_params = ["measure_dir"] optional_params = ["dbase_params"]
[docs] def run_task(self, fw_spec): """ go through the measurement job dirs and put the measurement jobs in the database """ drone = MPINTVaspToDbTaskDrone(**self.get("dbase_params", {})) queen = BorgQueen(drone) # , number_of_drones=ncpus) queen.serial_assimilate(self["measure_dir"]) return FWAction()